diff --git a/Modules/Classification/CLCore/include/mitkAbstractClassifier.h b/Modules/Classification/CLCore/include/mitkAbstractClassifier.h index 614cdea47c..ab7d8351c6 100644 --- a/Modules/Classification/CLCore/include/mitkAbstractClassifier.h +++ b/Modules/Classification/CLCore/include/mitkAbstractClassifier.h @@ -1,192 +1,199 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkAbstractClassifier_h #define mitkAbstractClassifier_h #include #include // Eigen #include // STD Includes // MITK includes #include namespace mitk { class MITKCLCORE_EXPORT AbstractClassifier : public BaseData { public: mitkClassMacro(AbstractClassifier,BaseData) /// /// @brief Build a forest of trees from the training set (X, y). /// @param X, The training input samples. Matrix of shape = [n_samples, n_features] /// @param Y, The target values (class labels in classification, real numbers in regression). Matrix of shape = [n_samples, 1] /// virtual void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) = 0; /// /// @brief Predict class for X. /// @param X, The input samples. /// @return The predicted classes. Y matrix of shape = [n_samples, 1] /// virtual Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) = 0; /// /// @brief GetPointWiseWeightCopy /// @return return label matrix of shape = [n_samples , 1] /// Eigen::MatrixXi & GetLabels() { return m_OutLabel; } protected: Eigen::MatrixXi m_OutLabel; public: // * --------------- * // PointWiseWeight // * --------------- * /// /// @brief SupportsPointWiseWeight /// @return True if the classifier supports pointwise weighting else false /// virtual bool SupportsPointWiseWeight() = 0; /// /// @brief GetPointWiseWeightCopy /// @return Create and return a copy of W /// virtual Eigen::MatrixXd & GetPointWiseWeight() { return m_PointWiseWeight; } /// /// @brief SetPointWiseWeight /// @param W, The pointwise weights. W matrix of shape = [n_samples, 1] /// virtual void SetPointWiseWeight(const Eigen::MatrixXd& W) { this->m_PointWiseWeight = W; } /// /// @brief UsePointWiseWeight /// @param toggle weighting on/off /// virtual void UsePointWiseWeight(bool value) { this->m_IsUsingPointWiseWeight = value; } /// /// @brief IsUsingPointWiseWeight /// @return true if pointewise weighting is enabled. /// virtual bool IsUsingPointWiseWeight() { return this->m_IsUsingPointWiseWeight; } protected: Eigen::MatrixXd m_PointWiseWeight; bool m_IsUsingPointWiseWeight; // * --------------- * // PointWiseProbabilities // * --------------- * public: /// /// @brief SupportsPointWiseProbability /// @return True if the classifier supports pointwise class probability calculation else false /// virtual bool SupportsPointWiseProbability() = 0; /// /// @brief GetPointWiseWeightCopy /// @return return probability matrix /// virtual Eigen::MatrixXd & GetPointWiseProbabilities() { return m_OutProbability; } /// /// \brief UsePointWiseProbabilities /// \param value /// virtual void UsePointWiseProbability(bool value) { m_IsUsingPointWiseProbability = value; } /// /// \brief IsUsingPointWiseProbabilities /// \return /// virtual bool IsUsingPointWiseProbability() { return m_IsUsingPointWiseProbability; } protected: Eigen::MatrixXd m_OutProbability; bool m_IsUsingPointWiseProbability; - private: - void MethodForBuild(); +private: + void MethodForBuild(); public: + + void SetNthItems(const char *val, unsigned int idx); + std::string GetNthItems(unsigned int idx) const; + + void SetItemList(std::vector); + std::vector GetItemList() const; + #ifndef DOXYGEN_SKIP virtual void SetRequestedRegionToLargestPossibleRegion(){} virtual bool RequestedRegionIsOutsideOfTheBufferedRegion(){return true;} virtual bool VerifyRequestedRegion(){return false;} virtual void SetRequestedRegion(const itk::DataObject* /*data*/){} // Override virtual bool IsEmpty() const override { if(IsInitialized() == false) return true; const TimeGeometry* timeGeometry = const_cast(this)->GetUpdatedTimeGeometry(); if(timeGeometry == NULL) return true; return false; } #endif // Skip Doxygen }; } #endif //mitkAbstractClassifier_h diff --git a/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp b/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp index 4ea6f900db..94b4b0bce0 100644 --- a/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp +++ b/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp @@ -1,37 +1,58 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include -void mitk::AbstractClassifier::MethodForBuild() +void mitk::AbstractClassifier::SetNthItems(const char * val, unsigned int idx) { - //class A - //void F1 (print "1") - //virtual void F2 (print "1") + std::stringstream ss; + ss << "itemlist." << idx; + this->GetPropertyList()->SetStringProperty(ss.str().c_str(),val); +} - //class B : A - //void F1 (print "2") - //void F2 (print "2") - - - //A* var = new B; - //B* var2 = new B; - //A->F1() --> 1 - //B->F1() --> 2 +std::string mitk::AbstractClassifier::GetNthItems(unsigned int idx) const +{ + std::stringstream ss; + ss << "itemlist." << idx; + std::string val; + this->GetPropertyList()->GetStringProperty(ss.str().c_str(),val); + return val; +} + +void mitk::AbstractClassifier::SetItemList(std::vector list) +{ + for(unsigned int i = 0 ; i < list.size(); ++i) + this->SetNthItems(list[i].c_str(),i); +} - // A->F2() --> 2 // schau in dem Objekt welcher Typ vorhanden ist. -} \ No newline at end of file +std::vector mitk::AbstractClassifier::GetItemList() const +{ + std::vector result; + for(unsigned int idx = 0 ;; idx++) + { + std::stringstream ss; + ss << "itemlist." << idx; + if(this->GetPropertyList()->GetProperty(ss.str().c_str())) + { + std::string s; + this->GetPropertyList()->GetStringProperty(ss.str().c_str(),s); + result.push_back(s); + }else + break; + } + return result; +} diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h index ac98418594..b30989a17a 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h @@ -1,106 +1,89 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkVigraRandomForestClassifier_h #define mitkVigraRandomForestClassifier_h #include #include //#include #include #include namespace mitk { class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier { public: mitkClassMacro(VigraRandomForestClassifier,AbstractClassifier) itkFactorylessNewMacro(Self) itkCloneMacro(Self) - VigraRandomForestClassifier(); - VigraRandomForestClassifier(const VigraRandomForestClassifier & other) - { - this->m_RandomForest = other.m_RandomForest; - } + VigraRandomForestClassifier(); ~VigraRandomForestClassifier(); void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); Eigen::MatrixXi Predict(const Eigen::MatrixXd &X); Eigen::MatrixXi WeightedPredict(const Eigen::MatrixXd &X); bool SupportsPointWiseWeight(); bool SupportsPointWiseProbability(); void ConvertParameter(); - void SetRandomForest(const vigra::RandomForest & rf) - { - m_RandomForest = rf; - } - - const vigra::RandomForest & GetRandomForest() const - { - return m_RandomForest; - } + void SetRandomForest(const vigra::RandomForest & rf); + const vigra::RandomForest & GetRandomForest() const; void UsePointWiseWeight(bool); void SetMaximumTreeDepth(int); void SetMinimumSplitNodeSize(int); void SetPrecision(double); void SetSamplesPerTree(double); void UseSampleWithReplacement(bool); void SetTreeCount(int); void SetWeightLambda(double); void SetTreeWeights(Eigen::MatrixXd weights); void SetTreeWeight(int treeId, double weight); Eigen::MatrixXd GetTreeWeights() const; - void SetNthItems(const char *val, unsigned int idx); - std::string GetNthItem(unsigned int idx); - - void SetItemList(std::vector); - std::vector GetItemList(); - void PrintParameter(std::ostream &str = std::cout); private: // *------------------- // * THREADING // *------------------- static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); static ITK_THREAD_RETURN_TYPE PredictCallback(void *); struct TrainingData; struct PredictionData; struct EigenToVigraTransform; struct Parameter; Eigen::MatrixXd m_TreeWeights; Parameter * m_Parameter; vigra::RandomForest m_RandomForest; }; } #endif //mitkVigraRandomForestClassifier_h diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index abd38d3c1c..2fd4c00f8d 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,513 +1,503 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include +#include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { - this->ConvertParameter(); + itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); + command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); + this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::auto_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); std::auto_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } Eigen::MatrixXi mitk::VigraRandomForestClassifier::WeightedPredict(const Eigen::MatrixXd &X_in) { m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); int isSampleWeighted = m_RandomForest.options_.predict_weighted_; #pragma omp parallel for for(int row=0; row < vigra::rowCount(X); ++row) { vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); vigra::ArrayVector::const_iterator weights; //totalWeight == totalVoteCount! double totalWeight = 0.0; //Let each tree classify... for(int k=0; k::cast(totalWeight); } int erg; int maxCol = 0; for (int col=0;col m_OutProbability(row, maxCol)) maxCol = col; } m_RandomForest.ext_param_.to_classlabel(maxCol, erg); Y(row,0) = erg; } return m_OutLabel; } void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) { m_TreeWeights = weights; } Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const { return m_TreeWeights; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return NULL; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the 0th thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) numberOfRowsToCalculate += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); return NULL; } void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults - if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet - // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) + + MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; + if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; + if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; + if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; + if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; + if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; + if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; + if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; + if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; + if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet + // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults - if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) - str << "classifier.vigra-rf.usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; + if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) + str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else - str << "classifier.vigra-rf.usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; + str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) - str << "classifier.vigra-rf.userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; + if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) + str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else - str << "classifier.vigra-rf.userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; + str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) - str << "classifier.vigra-rf.treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; + if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) + str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else - str << "classifier.vigra-rf.treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; + str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) - str << "classifier.vigra-rf.minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; + if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) + str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else - str << "classifier.vigra-rf.minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; + str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) - str << "classifier.vigra-rf.precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; + if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) + str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else - str << "classifier.vigra-rf.precision\t\t" << this->m_Parameter->Precision << "\n"; + str << "precision\t\t" << this->m_Parameter->Precision << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) - str << "classifier.vigra-rf.samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; + if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) + str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else - str << "classifier.vigra-rf.samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; + str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) - str << "classifier.vigra-rf.samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; + if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) + str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else - str << "classifier.vigra-rf.samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; + str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) - str << "classifier.vigra-rf.treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; + if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) + str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else - str << "classifier.vigra-rf.treecount\t\t" << this->m_Parameter->TreeCount << "\n"; + str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) - str << "classifier.vigra-rf.lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; + if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) + str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else - str << "classifier.vigra-rf.lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; + str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; - // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) + // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); - this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.usepointbasedweight",val); + this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treedepth",val); + this->GetPropertyList()->SetIntProperty("treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.minimalsplitnodesize",val); + this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.precision",val); + this->GetPropertyList()->SetDoubleProperty("precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.samplespertree",val); + this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { - this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.samplewithreplacement",val); + this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treecount",val); + this->GetPropertyList()->SetIntProperty("treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.lambda",val); + this->GetPropertyList()->SetDoubleProperty("lambda",val); } -void mitk::VigraRandomForestClassifier::SetNthItems(const char * val, unsigned int idx) +void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) { - std::stringstream ss; - ss << "classifier.vigra-rf.item." << idx; - this->GetPropertyList()->SetStringProperty(ss.str().c_str(),val); + m_TreeWeights(treeId,0) = weight; } -void mitk::VigraRandomForestClassifier::SetItemList(std::vector list) +void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) { - for(unsigned int i = 0 ; i < list.size(); ++i) - this->SetNthItems(list[i].c_str(),i); + this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); + this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); + this->SetTreeCount(rf.options().tree_count_); + this->SetSamplesPerTree(rf.options().training_set_proportion_); + this->UseSampleWithReplacement(rf.options().sample_with_replacement_); + this->m_RandomForest = rf; } -std::vector mitk::VigraRandomForestClassifier::GetItemList() +const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const { - std::vector result; - for(unsigned int idx = 0 ; idx < 100; idx++) - { - std::stringstream ss; - ss << "classifier.vigra-rf.item." << idx; - if(this->GetPropertyList()->GetProperty(ss.str().c_str())) - { - std::string s; - this->GetPropertyList()->GetStringProperty(ss.str().c_str(),s); - result.push_back(s); - } - } - return result; -} - -void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) -{ - m_TreeWeights(treeId,0) = weight; + return this->m_RandomForest; }