diff --git a/Modules/Classification/CLCore/include/mitkAbstractClassifier.h b/Modules/Classification/CLCore/include/mitkAbstractClassifier.h index 614cdea47c..ab7d8351c6 100644 --- a/Modules/Classification/CLCore/include/mitkAbstractClassifier.h +++ b/Modules/Classification/CLCore/include/mitkAbstractClassifier.h @@ -1,192 +1,199 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkAbstractClassifier_h #define mitkAbstractClassifier_h #include #include // Eigen #include // STD Includes // MITK includes #include namespace mitk { class MITKCLCORE_EXPORT AbstractClassifier : public BaseData { public: mitkClassMacro(AbstractClassifier,BaseData) /// /// @brief Build a forest of trees from the training set (X, y). /// @param X, The training input samples. Matrix of shape = [n_samples, n_features] /// @param Y, The target values (class labels in classification, real numbers in regression). Matrix of shape = [n_samples, 1] /// virtual void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y) = 0; /// /// @brief Predict class for X. /// @param X, The input samples. /// @return The predicted classes. Y matrix of shape = [n_samples, 1] /// virtual Eigen::MatrixXi Predict(const Eigen::MatrixXd &X) = 0; /// /// @brief GetPointWiseWeightCopy /// @return return label matrix of shape = [n_samples , 1] /// Eigen::MatrixXi & GetLabels() { return m_OutLabel; } protected: Eigen::MatrixXi m_OutLabel; public: // * --------------- * // PointWiseWeight // * --------------- * /// /// @brief SupportsPointWiseWeight /// @return True if the classifier supports pointwise weighting else false /// virtual bool SupportsPointWiseWeight() = 0; /// /// @brief GetPointWiseWeightCopy /// @return Create and return a copy of W /// virtual Eigen::MatrixXd & GetPointWiseWeight() { return m_PointWiseWeight; } /// /// @brief SetPointWiseWeight /// @param W, The pointwise weights. W matrix of shape = [n_samples, 1] /// virtual void SetPointWiseWeight(const Eigen::MatrixXd& W) { this->m_PointWiseWeight = W; } /// /// @brief UsePointWiseWeight /// @param toggle weighting on/off /// virtual void UsePointWiseWeight(bool value) { this->m_IsUsingPointWiseWeight = value; } /// /// @brief IsUsingPointWiseWeight /// @return true if pointewise weighting is enabled. /// virtual bool IsUsingPointWiseWeight() { return this->m_IsUsingPointWiseWeight; } protected: Eigen::MatrixXd m_PointWiseWeight; bool m_IsUsingPointWiseWeight; // * --------------- * // PointWiseProbabilities // * --------------- * public: /// /// @brief SupportsPointWiseProbability /// @return True if the classifier supports pointwise class probability calculation else false /// virtual bool SupportsPointWiseProbability() = 0; /// /// @brief GetPointWiseWeightCopy /// @return return probability matrix /// virtual Eigen::MatrixXd & GetPointWiseProbabilities() { return m_OutProbability; } /// /// \brief UsePointWiseProbabilities /// \param value /// virtual void UsePointWiseProbability(bool value) { m_IsUsingPointWiseProbability = value; } /// /// \brief IsUsingPointWiseProbabilities /// \return /// virtual bool IsUsingPointWiseProbability() { return m_IsUsingPointWiseProbability; } protected: Eigen::MatrixXd m_OutProbability; bool m_IsUsingPointWiseProbability; - private: - void MethodForBuild(); +private: + void MethodForBuild(); public: + + void SetNthItems(const char *val, unsigned int idx); + std::string GetNthItems(unsigned int idx) const; + + void SetItemList(std::vector); + std::vector GetItemList() const; + #ifndef DOXYGEN_SKIP virtual void SetRequestedRegionToLargestPossibleRegion(){} virtual bool RequestedRegionIsOutsideOfTheBufferedRegion(){return true;} virtual bool VerifyRequestedRegion(){return false;} virtual void SetRequestedRegion(const itk::DataObject* /*data*/){} // Override virtual bool IsEmpty() const override { if(IsInitialized() == false) return true; const TimeGeometry* timeGeometry = const_cast(this)->GetUpdatedTimeGeometry(); if(timeGeometry == NULL) return true; return false; } #endif // Skip Doxygen }; } #endif //mitkAbstractClassifier_h diff --git a/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp b/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp index 4ea6f900db..94b4b0bce0 100644 --- a/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp +++ b/Modules/Classification/CLCore/src/mitkAbstractClassifier.cpp @@ -1,37 +1,58 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include -void mitk::AbstractClassifier::MethodForBuild() +void mitk::AbstractClassifier::SetNthItems(const char * val, unsigned int idx) { - //class A - //void F1 (print "1") - //virtual void F2 (print "1") + std::stringstream ss; + ss << "itemlist." << idx; + this->GetPropertyList()->SetStringProperty(ss.str().c_str(),val); +} - //class B : A - //void F1 (print "2") - //void F2 (print "2") - - - //A* var = new B; - //B* var2 = new B; - //A->F1() --> 1 - //B->F1() --> 2 +std::string mitk::AbstractClassifier::GetNthItems(unsigned int idx) const +{ + std::stringstream ss; + ss << "itemlist." << idx; + std::string val; + this->GetPropertyList()->GetStringProperty(ss.str().c_str(),val); + return val; +} + +void mitk::AbstractClassifier::SetItemList(std::vector list) +{ + for(unsigned int i = 0 ; i < list.size(); ++i) + this->SetNthItems(list[i].c_str(),i); +} - // A->F2() --> 2 // schau in dem Objekt welcher Typ vorhanden ist. -} \ No newline at end of file +std::vector mitk::AbstractClassifier::GetItemList() const +{ + std::vector result; + for(unsigned int idx = 0 ;; idx++) + { + std::stringstream ss; + ss << "itemlist." << idx; + if(this->GetPropertyList()->GetProperty(ss.str().c_str())) + { + std::string s; + this->GetPropertyList()->GetStringProperty(ss.str().c_str(),s); + result.push_back(s); + }else + break; + } + return result; +} diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h index ac98418594..b30989a17a 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h @@ -1,106 +1,89 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkVigraRandomForestClassifier_h #define mitkVigraRandomForestClassifier_h #include #include //#include #include #include namespace mitk { class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier { public: mitkClassMacro(VigraRandomForestClassifier,AbstractClassifier) itkFactorylessNewMacro(Self) itkCloneMacro(Self) - VigraRandomForestClassifier(); - VigraRandomForestClassifier(const VigraRandomForestClassifier & other) - { - this->m_RandomForest = other.m_RandomForest; - } + VigraRandomForestClassifier(); ~VigraRandomForestClassifier(); void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); Eigen::MatrixXi Predict(const Eigen::MatrixXd &X); Eigen::MatrixXi WeightedPredict(const Eigen::MatrixXd &X); bool SupportsPointWiseWeight(); bool SupportsPointWiseProbability(); void ConvertParameter(); - void SetRandomForest(const vigra::RandomForest & rf) - { - m_RandomForest = rf; - } - - const vigra::RandomForest & GetRandomForest() const - { - return m_RandomForest; - } + void SetRandomForest(const vigra::RandomForest & rf); + const vigra::RandomForest & GetRandomForest() const; void UsePointWiseWeight(bool); void SetMaximumTreeDepth(int); void SetMinimumSplitNodeSize(int); void SetPrecision(double); void SetSamplesPerTree(double); void UseSampleWithReplacement(bool); void SetTreeCount(int); void SetWeightLambda(double); void SetTreeWeights(Eigen::MatrixXd weights); void SetTreeWeight(int treeId, double weight); Eigen::MatrixXd GetTreeWeights() const; - void SetNthItems(const char *val, unsigned int idx); - std::string GetNthItem(unsigned int idx); - - void SetItemList(std::vector); - std::vector GetItemList(); - void PrintParameter(std::ostream &str = std::cout); private: // *------------------- // * THREADING // *------------------- static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); static ITK_THREAD_RETURN_TYPE PredictCallback(void *); struct TrainingData; struct PredictionData; struct EigenToVigraTransform; struct Parameter; Eigen::MatrixXd m_TreeWeights; Parameter * m_Parameter; vigra::RandomForest m_RandomForest; }; } #endif //mitkVigraRandomForestClassifier_h diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index abd38d3c1c..2fd4c00f8d 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,513 +1,503 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include +#include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { - this->ConvertParameter(); + itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); + command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); + this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::auto_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); std::auto_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } Eigen::MatrixXi mitk::VigraRandomForestClassifier::WeightedPredict(const Eigen::MatrixXd &X_in) { m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); int isSampleWeighted = m_RandomForest.options_.predict_weighted_; #pragma omp parallel for for(int row=0; row < vigra::rowCount(X); ++row) { vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); vigra::ArrayVector::const_iterator weights; //totalWeight == totalVoteCount! double totalWeight = 0.0; //Let each tree classify... for(int k=0; k::cast(totalWeight); } int erg; int maxCol = 0; for (int col=0;col m_OutProbability(row, maxCol)) maxCol = col; } m_RandomForest.ext_param_.to_classlabel(maxCol, erg); Y(row,0) = erg; } return m_OutLabel; } void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) { m_TreeWeights = weights; } Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const { return m_TreeWeights; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return NULL; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the 0th thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) numberOfRowsToCalculate += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); return NULL; } void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults - if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet - // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) + + MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; + if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; + if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; + if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; + if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; + if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; + if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; + if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; + if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; + if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet + // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults - if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) - str << "classifier.vigra-rf.usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; + if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) + str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else - str << "classifier.vigra-rf.usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; + str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) - str << "classifier.vigra-rf.userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; + if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) + str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else - str << "classifier.vigra-rf.userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; + str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) - str << "classifier.vigra-rf.treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; + if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) + str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else - str << "classifier.vigra-rf.treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; + str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) - str << "classifier.vigra-rf.minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; + if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) + str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else - str << "classifier.vigra-rf.minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; + str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) - str << "classifier.vigra-rf.precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; + if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) + str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else - str << "classifier.vigra-rf.precision\t\t" << this->m_Parameter->Precision << "\n"; + str << "precision\t\t" << this->m_Parameter->Precision << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) - str << "classifier.vigra-rf.samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; + if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) + str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else - str << "classifier.vigra-rf.samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; + str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) - str << "classifier.vigra-rf.samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; + if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) + str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else - str << "classifier.vigra-rf.samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; + str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) - str << "classifier.vigra-rf.treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; + if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) + str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else - str << "classifier.vigra-rf.treecount\t\t" << this->m_Parameter->TreeCount << "\n"; + str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; - if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) - str << "classifier.vigra-rf.lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; + if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) + str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else - str << "classifier.vigra-rf.lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; + str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; - // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) + // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); - this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.usepointbasedweight",val); + this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treedepth",val); + this->GetPropertyList()->SetIntProperty("treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.minimalsplitnodesize",val); + this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.precision",val); + this->GetPropertyList()->SetDoubleProperty("precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.samplespertree",val); + this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { - this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.samplewithreplacement",val); + this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { - this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treecount",val); + this->GetPropertyList()->SetIntProperty("treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { - this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.lambda",val); + this->GetPropertyList()->SetDoubleProperty("lambda",val); } -void mitk::VigraRandomForestClassifier::SetNthItems(const char * val, unsigned int idx) +void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) { - std::stringstream ss; - ss << "classifier.vigra-rf.item." << idx; - this->GetPropertyList()->SetStringProperty(ss.str().c_str(),val); + m_TreeWeights(treeId,0) = weight; } -void mitk::VigraRandomForestClassifier::SetItemList(std::vector list) +void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) { - for(unsigned int i = 0 ; i < list.size(); ++i) - this->SetNthItems(list[i].c_str(),i); + this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); + this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); + this->SetTreeCount(rf.options().tree_count_); + this->SetSamplesPerTree(rf.options().training_set_proportion_); + this->UseSampleWithReplacement(rf.options().sample_with_replacement_); + this->m_RandomForest = rf; } -std::vector mitk::VigraRandomForestClassifier::GetItemList() +const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const { - std::vector result; - for(unsigned int idx = 0 ; idx < 100; idx++) - { - std::stringstream ss; - ss << "classifier.vigra-rf.item." << idx; - if(this->GetPropertyList()->GetProperty(ss.str().c_str())) - { - std::string s; - this->GetPropertyList()->GetStringProperty(ss.str().c_str(),s); - result.push_back(s); - } - } - return result; -} - -void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) -{ - m_TreeWeights(treeId,0) = weight; + return this->m_RandomForest; } diff --git a/Modules/Classification/CLVigraRandomForest/src/IO/mitkRandomForestIO.cpp b/Modules/Classification/CLVigraRandomForest/src/IO/mitkRandomForestIO.cpp index 3993bea7f2..020ae69da0 100644 --- a/Modules/Classification/CLVigraRandomForest/src/IO/mitkRandomForestIO.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/IO/mitkRandomForestIO.cpp @@ -1,212 +1,222 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __mitkDecisionForestIO__cpp #define __mitkDecisionForestIO__cpp #include "mitkRandomForestIO.h" #include "itksys/SystemTools.hxx" //#include "mitkHDF5IOMimeTypes.h" #include "vigra/random_forest_hdf5_impex.hxx" #include #include #include "mitkVigraRandomForestClassifier.h" #include "mitkIOMimeTypes.h" #define GetAttribute(name,type)\ type name;\ hdf5_file.readAttribute(".",name,name); mitk::RandomForestFileIO::ConfidenceLevel mitk::RandomForestFileIO::GetReaderConfidenceLevel() const { std::string ext = itksys::SystemTools::GetFilenameLastExtension(this->GetLocalFileName().c_str()); bool is_loaded = vigra::rf_import_HDF5(m_rf, this->GetInputLocation()); return ext == ".forest" && is_loaded == true? IFileReader::Supported : IFileReader::Unsupported; } mitk::RandomForestFileIO::ConfidenceLevel mitk::RandomForestFileIO::GetWriterConfidenceLevel() const { mitk::VigraRandomForestClassifier::ConstPointer input = dynamic_cast(this->GetInput()); if (input.IsNull()) { return IFileWriter::Unsupported; }else{ return IFileWriter::Supported; } } mitk::RandomForestFileIO::RandomForestFileIO() : AbstractFileIO(mitk::VigraRandomForestClassifier::GetStaticNameOfClass()) { CustomMimeType customReaderMimeType(mitk::IOMimeTypes::DEFAULT_BASE_NAME() + ".forest"); std::string category = "Vigra Random Forest File"; customReaderMimeType.SetComment("Vigra Random Forest"); customReaderMimeType.SetCategory(category); customReaderMimeType.AddExtension("forest"); // this->AbstractFileIOWriter::SetRanking(100); this->AbstractFileWriter::SetMimeTypePrefix(mitk::IOMimeTypes::DEFAULT_BASE_NAME() + ".forest"); this->AbstractFileWriter::SetMimeType(customReaderMimeType); this->SetWriterDescription("Vigra Random Forest"); this->AbstractFileReader::SetMimeTypePrefix(mitk::IOMimeTypes::DEFAULT_BASE_NAME() + ".forest"); this->AbstractFileReader::SetMimeType(customReaderMimeType); this->SetReaderDescription("Vigra Random Forest"); // this->SetReaderDescription(mitk::DecisionForestIOMimeTypes::DECISIONFOREST_MIMETYPE_DESCRIPTION()); // this->SetWriterDescription(mitk::DecisionForestIOMimeTypes::DECISIONFOREST_MIMETYPE_DESCRIPTION()); this->RegisterService(); } mitk::RandomForestFileIO::RandomForestFileIO(const mitk::RandomForestFileIO& other) : AbstractFileIO(other) { } mitk::RandomForestFileIO::~RandomForestFileIO() {} std::vector > - mitk::RandomForestFileIO:: - Read() +mitk::RandomForestFileIO:: +Read() { mitk::VigraRandomForestClassifier::Pointer output = mitk::VigraRandomForestClassifier::New(); std::vector > result; if ( this->GetInputLocation().empty()) { MITK_ERROR << "Sorry, filename has not been set!"; return result; } else { const std::string& locale = "C"; const std::string& currLocale = setlocale( LC_ALL, NULL ); if ( locale.compare(currLocale)!=0 ) { try { setlocale(LC_ALL, locale.c_str()); } catch(...) { MITK_INFO << "Could not set locale " << locale; } } - // vigra::HDF5File hdf5_file; - // vigra::rf_import_HDF5(rf,hdf5_file,this->GetInputLocation()); output->SetRandomForest(m_rf); result.push_back(output.GetPointer()); - - auto treeWeight = output->GetTreeWeights(); - treeWeight.resize(m_rf.tree_count(),1); - vigra::MultiArrayView<2, double> W(vigra::Shape2(treeWeight.rows(),treeWeight.cols()),treeWeight.data()); - vigra::HDF5File hdf5_file(this->GetInputLocation() , vigra::HDF5File::Open); + hdf5_file.cd_mk("/_mitkOptions"); - hdf5_file.read("treeWeights",W); - output->SetTreeWeights(treeWeight); - hdf5_file.close(); - // if(!hdf5_file.existsAttribute(".","mitk")){ - // return result; - // }else{ - // GetAttribute(mitk_isMitkDecisionTree,std::string); - // if(mitk_isMitkDecisionTree.empty()) return result; + // --------------------------------------------------------- + // Read tree weights + if(hdf5_file.existsDataset("treeWeights")) + { + auto treeWeight = output->GetTreeWeights(); + treeWeight.resize(m_rf.tree_count(),1); + vigra::MultiArrayView<2, double> W(vigra::Shape2(treeWeight.rows(),treeWeight.cols()),treeWeight.data()); + hdf5_file.read("treeWeights",W); + output->SetTreeWeights(treeWeight); + } + // --------------------------------------------------------- - // GetAttribute(mitk_Modalities,std::string); - // std::vector strs; - // boost::split(strs, mitk_Modalities, boost::is_any_of("\t ,")); - // MITK_INFO << "Import Modalities: " << mitk_Modalities; - // output->SetModalities(strs); + // --------------------------------------------------------- + // Read itemList + if(hdf5_file.existsDataset("itemList")){ + std::string items_string; + hdf5_file.read("itemList",items_string); + auto itemlist = output->GetItemList(); + + std::string current_item = ""; + for(auto character : items_string) + { + if(character == ';'){ + // skip seperator and push back item + itemlist.push_back(current_item); + current_item.clear(); + }else{ + current_item = current_item + character; + } + } + output->SetItemList(itemlist); + } + // --------------------------------------------------------- - // } + hdf5_file.close(); return result; } } void mitk::RandomForestFileIO::Write() { mitk::BaseData::ConstPointer input = this->GetInput(); if (input.IsNull()) { MITK_ERROR <<"Sorry, input to NrrdDiffusionImageWriter is NULL!"; return; } if ( this->GetOutputLocation().empty() ) { MITK_ERROR << "Sorry, filename has not been set!"; return ; }else{ const std::string& locale = "C"; const std::string& currLocale = setlocale( LC_ALL, NULL ); if ( locale.compare(currLocale)!=0 ) { try { setlocale(LC_ALL, locale.c_str()); } catch(...) { MITK_INFO << "Could not set locale " << locale; } } mitk::VigraRandomForestClassifier::ConstPointer mitkDC = dynamic_cast(input.GetPointer()); //mitkDC->GetRandomForest() vigra::rf_export_HDF5(mitkDC->GetRandomForest(), this->GetOutputLocation()); vigra::HDF5File hdf5_file(this->GetOutputLocation() , vigra::HDF5File::Open); - auto treeWeight = mitkDC->GetTreeWeights(); hdf5_file.cd_mk("/_mitkOptions"); + + // Write tree weights + // --------------------------------------------------------- + auto treeWeight = mitkDC->GetTreeWeights(); vigra::MultiArrayView<2, double> W(vigra::Shape2(treeWeight.rows(),treeWeight.cols()),treeWeight.data()); hdf5_file.write("treeWeights",W); - hdf5_file.close(); + // --------------------------------------------------------- - // vigra::rf_import_HDF5(rf,hdf5_file,this->GetInputLocation()); - // output->SetRandomForest(rf); - // result.push_back(output.GetPointer()); + // Write itemList + // --------------------------------------------------------- + auto items = mitkDC->GetItemList(); + std::string item_stringlist; + for(auto entry : items) + item_stringlist = item_stringlist + entry + ";"; - // if(!hdf5_file.existsAttribute(".","mitk")){ - // return result; - // }else{ - // GetAttribute(mitk_isMitkDecisionTree,std::string); - // if(mitk_isMitkDecisionTree.empty()) return result; + hdf5_file.write("itemList",item_stringlist); + // --------------------------------------------------------- - // GetAttribute(mitk_Modalities,std::string); - // std::vector strs; - // boost::split(strs, mitk_Modalities, boost::is_any_of("\t ,")); - // MITK_INFO << "Import Modalities: " << mitk_Modalities; - // output->SetModalities(strs); - - // } + hdf5_file.close(); } } mitk::AbstractFileIO* mitk::RandomForestFileIO::IOClone() const { return new RandomForestFileIO(*this); } -#endif \ No newline at end of file +#endif