diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h index b30989a17a..9eb3a1f270 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h @@ -1,89 +1,93 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkVigraRandomForestClassifier_h #define mitkVigraRandomForestClassifier_h #include #include //#include #include #include namespace mitk { class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier { public: mitkClassMacro(VigraRandomForestClassifier,AbstractClassifier) itkFactorylessNewMacro(Self) itkCloneMacro(Self) VigraRandomForestClassifier(); ~VigraRandomForestClassifier(); void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); Eigen::MatrixXi Predict(const Eigen::MatrixXd &X); - Eigen::MatrixXi WeightedPredict(const Eigen::MatrixXd &X); + Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X); + bool SupportsPointWiseWeight(); bool SupportsPointWiseProbability(); void ConvertParameter(); void SetRandomForest(const vigra::RandomForest & rf); const vigra::RandomForest & GetRandomForest() const; void UsePointWiseWeight(bool); void SetMaximumTreeDepth(int); void SetMinimumSplitNodeSize(int); void SetPrecision(double); void SetSamplesPerTree(double); void UseSampleWithReplacement(bool); void SetTreeCount(int); void SetWeightLambda(double); void SetTreeWeights(Eigen::MatrixXd weights); void SetTreeWeight(int treeId, double weight); Eigen::MatrixXd GetTreeWeights() const; void PrintParameter(std::ostream &str = std::cout); private: // *------------------- // * THREADING // *------------------- - static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); - static ITK_THREAD_RETURN_TYPE PredictCallback(void *); struct TrainingData; struct PredictionData; struct EigenToVigraTransform; struct Parameter; Eigen::MatrixXd m_TreeWeights; Parameter * m_Parameter; vigra::RandomForest m_RandomForest; + + static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); + static ITK_THREAD_RETURN_TYPE PredictCallback(void *); + static ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *); + static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P); }; } #endif //mitkVigraRandomForestClassifier_h diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index 2fd4c00f8d..255189c2ab 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,503 +1,592 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include #include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, - vigra::MultiArrayView<2, double> refProb) + vigra::MultiArrayView<2, double> refProb, + vigra::MultiArrayView<2, double> refTreeWeights) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), - m_Probabilities(refProb) + m_Probabilities(refProb), + m_TreeWeights(refTreeWeights) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; + vigra::MultiArrayView<2, double> m_TreeWeights; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::auto_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); + // If no weights provided + if(m_TreeWeights.rows() != m_RandomForest.tree_count()) + { + m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); + m_TreeWeights.fill(1); + } + + vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); + vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::auto_ptr data; - data.reset( new PredictionData(m_RandomForest,X,Y,P)); + data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } -Eigen::MatrixXi mitk::VigraRandomForestClassifier::WeightedPredict(const Eigen::MatrixXd &X_in) +Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in) { + // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); - vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); - vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); - vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); - - int isSampleWeighted = m_RandomForest.options_.predict_weighted_; -#pragma omp parallel for - for(int row=0; row < vigra::rowCount(X); ++row) + // If no weights provided + if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { - vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); + m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); + m_TreeWeights.fill(1); + } - vigra::ArrayVector::const_iterator weights; - //totalWeight == totalVoteCount! - double totalWeight = 0.0; + vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); + vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); + vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); + vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); - //Let each tree classify... - for(int k=0; k data; + data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); - //update votecount. - for(int l=0; lSetSingleMethod(this->PredictWeightedCallback,data.get()); + threader->SingleMethodExecute(); - //Normalise votes in each row by total VoteCount (totalWeight - for(int l=0; l< m_RandomForest.ext_param_.class_count_; ++l) - { - P(row, l) /= vigra::detail::RequiresExplicitCast::cast(totalWeight); - } - int erg; - int maxCol = 0; - for (int col=0;col m_OutProbability(row, maxCol)) - maxCol = col; - } - m_RandomForest.ext_param_.to_classlabel(maxCol, erg); - Y(row,0) = erg; - } return m_OutLabel; } + + void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) { m_TreeWeights = weights; } Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const { return m_TreeWeights; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return NULL; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); - // the 0th thread takes the residuals - if(threadId == infoStruct->NumberOfThreads-1) numberOfRowsToCalculate += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; + // the last thread takes the residuals + if(threadId == infoStruct->NumberOfThreads-1) { + end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; + } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); return NULL; } +ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg) +{ + // Get the ThreadInfoStruct + typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; + ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); + // assigne the thread id + const unsigned int threadId = infoStruct->ThreadID; + + // Get the user defined parameters containing all + // neccesary informations + PredictionData * data = (PredictionData *)(infoStruct->UserData); + unsigned int numberOfRowsToCalculate = 0; + + // Get number of rows to calculate + numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; + + unsigned int start_index = numberOfRowsToCalculate * threadId; + unsigned int end_index = numberOfRowsToCalculate * (threadId+1); + + // the last thread takes the residuals + if(threadId == infoStruct->NumberOfThreads-1) { + end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; + } + + vigra::MultiArrayView<2, double> split_features; + vigra::MultiArrayView<2, int> split_labels; + vigra::MultiArrayView<2, double> split_probability; + { + vigra::TinyVector lowerBound(start_index,0); + vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); + split_features = data->m_Feature.subarray(lowerBound,upperBound); + } + + { + vigra::TinyVector lowerBound(start_index,0); + vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); + split_labels = data->m_Label.subarray(lowerBound,upperBound); + } + + { + vigra::TinyVector lowerBound(start_index,0); + vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); + split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); + } + + VigraPredictWeighted(data, split_features,split_labels,split_probability); + + return NULL; +} + + +void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P) +{ + + int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_; +//#pragma omp parallel for + for(int row=0; row < vigra::rowCount(X); ++row) + { + vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); + + vigra::ArrayVector::const_iterator weights; + + //totalWeight == totalVoteCount! + double totalWeight = 0.0; + + //Let each tree classify... + for(int k=0; km_RandomForest.options_.tree_count_; ++k) + { + //get weights predicted by single tree + weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow); + double numberOfLeafObservations = (*(weights-1)); + + //update votecount. + for(int l=0; lm_RandomForest.ext_param_.class_count_; ++l) + { + // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node. + double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted)); + cur_w = cur_w * data->m_TreeWeights(k,0); + P(row, l) += (int)cur_w; + //every weight in totalWeight. + totalWeight += cur_w; + } + } + + //Normalise votes in each row by total VoteCount (totalWeight + for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l) + { + P(row, l) /= vigra::detail::RequiresExplicitCast::cast(totalWeight); + } + int erg; + int maxCol = 0; + for (int col=0;colm_RandomForest.class_count();++col) + { + if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol)) + maxCol = col; + } + data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg); + Y(row,0) = erg; + } +} + void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else str << "precision\t\t" << this->m_Parameter->Precision << "\n"; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { this->GetPropertyList()->SetIntProperty("treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { this->GetPropertyList()->SetDoubleProperty("precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { this->GetPropertyList()->SetIntProperty("treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { this->GetPropertyList()->SetDoubleProperty("lambda",val); } void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) { m_TreeWeights(treeId,0) = weight; } void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) { this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); this->SetTreeCount(rf.options().tree_count_); this->SetSamplesPerTree(rf.options().training_set_proportion_); this->UseSampleWithReplacement(rf.options().sample_with_replacement_); this->m_RandomForest = rf; } const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const { return this->m_RandomForest; } diff --git a/Modules/Classification/CLVigraRandomForest/test/mitkVigraRandomForestTest.cpp b/Modules/Classification/CLVigraRandomForest/test/mitkVigraRandomForestTest.cpp index 7a96c9104c..0f2d794f51 100644 --- a/Modules/Classification/CLVigraRandomForest/test/mitkVigraRandomForestTest.cpp +++ b/Modules/Classification/CLVigraRandomForest/test/mitkVigraRandomForestTest.cpp @@ -1,285 +1,329 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include "mitkIOUtil.h" #include "itkArray2D.h" #include #include #include #include #include #include #include class mitkVigraRandomForestTestSuite : public mitk::TestFixture { CPPUNIT_TEST_SUITE(mitkVigraRandomForestTestSuite ); + // MITK_TEST(Load_RandomForestBaseDataUsingIOUtil_shouldReturnTrue); + // MITK_TEST(Save_RandomForestBaseDataUsingIOUtil_shouldReturnTrue); + + // MITK_TEST(LoadWithMitkOptions_RandomForestBaseDataUsingIOUtil_shouldReturnTrue); + // MITK_TEST(SaveWithMitkOptions_RandomForestBaseDataUsingIOUtil_shouldReturnTrue); + MITK_TEST(TrainThreadedDecisionForest_MatlabDataSet_shouldReturnTrue); + MITK_TEST(PredictWeightedDecisionForest_SetWeightsToZero_shouldReturnTrue); MITK_TEST(TrainThreadedDecisionForest_BreastCancerDataSet_shouldReturnTrue); CPPUNIT_TEST_SUITE_END(); private: typedef Eigen::Matrix MatrixDoubleType; typedef Eigen::Matrix MatrixIntType; - Eigen::MatrixXd m_TrainingMatrixX; - Eigen::MatrixXi m_TrainingLabelMatrixY; - Eigen::MatrixXd m_TestXPredict; - Eigen::MatrixXi m_TestYPredict; + std::pair FeatureData_Cancer; + std::pair LabelData_Cancer; + + std::pair FeatureData_Matlab; + std::pair LabelData_Matlab; mitk::VigraRandomForestClassifier::Pointer classifier; public: + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + + + void setUp() + { + FeatureData_Cancer = convertCSVToMatrix(GetTestDataFilePath("Classification/FeaturematrixBreastcancer.csv"),';',0.5,true); + LabelData_Cancer = convertCSVToMatrix(GetTestDataFilePath("Classification/LabelmatrixBreastcancer.csv"),';',0.5,false); + FeatureData_Matlab = convertCSVToMatrix(GetTestDataFilePath("Classification/FeaturematrixMatlab.csv"),';',0.5,true); + LabelData_Matlab = convertCSVToMatrix(GetTestDataFilePath("Classification/LabelmatrixMatlab.csv"),';',0.5,false); + classifier = mitk::VigraRandomForestClassifier::New(); + } + + void tearDown() + { + classifier = nullptr; + } + + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + /* + Train the classifier with an exampledataset of mattlab. + Note: The included data are gaußan normaldistributed. + */ + void TrainThreadedDecisionForest_MatlabDataSet_shouldReturnTrue() + { + + auto & Features_Training = FeatureData_Matlab.first; + auto & Labels_Training = LabelData_Matlab.first; + + auto & Features_Testing = FeatureData_Matlab.second; + auto & Labels_Testing = LabelData_Matlab.second; + + /* Train the classifier, by giving trainingdataset for the labels and features. + The result in an colunmvector of the labels.*/ + classifier->Train(Features_Training,Labels_Training); + Eigen::MatrixXi classes = classifier->Predict(Features_Testing); + + /* Testing the matching between the calculated colunmvector and the result of the RandomForest */ + unsigned int testmatrix_rows = classes.rows(); + + unsigned int correctly_classified_rows = 0; + for(unsigned int i= 0; i < testmatrix_rows; i++){ + if(classes(i,0) == Labels_Testing(i,0)){ + correctly_classified_rows++; + } + } + + MITK_TEST_CONDITION(correctly_classified_rows == testmatrix_rows, "Matlab Data correctly classified"); + } + + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + /* + Train the classifier with the dataset of breastcancer patients from the + LibSVM Libary + */ + void TrainThreadedDecisionForest_BreastCancerDataSet_shouldReturnTrue() + { + + auto & Features_Training = FeatureData_Cancer.first; + auto & Features_Testing = FeatureData_Cancer.second; + auto & Labels_Training = LabelData_Cancer.first; + auto & Labels_Testing = LabelData_Cancer.second; + + + /* Train the classifier, by giving trainingdataset for the labels and features. + The result in an colunmvector of the labels.*/ + classifier->Train(Features_Training,Labels_Training); + Eigen::MatrixXi classes = classifier->Predict(Features_Testing); + + /* Testing the matching between the calculated colunmvector and the result of the RandomForest */ + unsigned int maxrows = classes.rows(); + + bool isYPredictVector = false; + int count = 0; + + for(unsigned int i= 0; i < maxrows; i++){ + if(classes(i,0) == Labels_Testing(i,0)){ + isYPredictVector = true; + count++; + } + } + MITK_TEST_CONDITION(isIntervall(Labels_Testing,classes,98,99),"Testvalue of cancer data set is in range."); + } + + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ + + void PredictWeightedDecisionForest_SetWeightsToZero_shouldReturnTrue() + { + + auto & Features_Training = FeatureData_Matlab.first; + auto & Features_Testing = FeatureData_Matlab.second; + auto & Labels_Training = LabelData_Matlab.first; +// auto & Labels_Testing = LabelData_Matlab.second; + + classifier->Train(Features_Training,Labels_Training); + + // get weights type resize it and set all weights to zero + auto weights = classifier->GetTreeWeights(); + weights.resize(classifier->GetRandomForest().tree_count(),1); + weights.fill(0); + + classifier->SetTreeWeights(weights); + + // if all wieghts zero the missclassification rate mus be high + Eigen::MatrixXi classes = classifier->PredictWeighted(Features_Testing); + + /* Testing the matching between the calculated colunmvector and the result of the RandomForest */ + unsigned int maxrows = classes.rows(); + unsigned int count = 0; + + // check if all predictions are of class 1 + for(unsigned int i= 0; i < maxrows; i++) + if(classes(i,0) == 1) + count++; + + MITK_TEST_CONDITION( (count == maxrows) ,"Weighted prediction - weights applied (all weights = 0)."); + } + + + // ------------------------------------------------------------------------------------------------------ + // ------------------------------------------------------------------------------------------------------ /*Reading an file, which includes the trainingdataset and the testdataset, and convert the content of the file into an 2dim matrixpair. There are an delimiter, which separates the matrix into an trainingmatrix and testmatrix */ template std::pair,Eigen::Matrix >convertCSVToMatrix(const std::string &path, char delimiter,double range, bool isXMatrix) { typename itk::CSVArray2DFileReader::Pointer fr = itk::CSVArray2DFileReader::New(); fr->SetFileName(path); fr->SetFieldDelimiterCharacter(delimiter); fr->HasColumnHeadersOff(); fr->HasRowHeadersOff(); fr->Parse(); try{ fr->Update(); }catch(itk::ExceptionObject& ex){ cout << "Exception caught!" << std::endl; cout << ex << std::endl; } typename itk::CSVArray2DDataObject::Pointer p = fr->GetOutput(); unsigned int maxrowrange = p->GetMatrix().rows(); unsigned int c = p->GetMatrix().cols(); unsigned int percentRange = (unsigned int)(maxrowrange*range); if(isXMatrix == true){ Eigen::Matrix trainMatrixX(percentRange,c); Eigen::Matrix testMatrixXPredict(maxrowrange-percentRange,c); - for(int row = 0; row < percentRange; row++){ - for(int col = 0; col < c; col++){ + for(unsigned int row = 0; row < percentRange; row++){ + for(unsigned int col = 0; col < c; col++){ trainMatrixX(row,col) = p->GetData(row,col); } } - for(int row = percentRange; row < maxrowrange; row++){ - for(int col = 0; col < c; col++){ + for(unsigned int row = percentRange; row < maxrowrange; row++){ + for(unsigned int col = 0; col < c; col++){ testMatrixXPredict(row-percentRange,col) = p->GetData(row,col); } } return std::make_pair(trainMatrixX,testMatrixXPredict); } - else if(isXMatrix == false){ + else{ Eigen::Matrix trainLabelMatrixY(percentRange,c); Eigen::Matrix testMatrixYPredict(maxrowrange-percentRange,c); - for(int row = 0; row < percentRange; row++){ - for(int col = 0; col < c; col++){ + for(unsigned int row = 0; row < percentRange; row++){ + for(unsigned int col = 0; col < c; col++){ trainLabelMatrixY(row,col) = p->GetData(row,col); } } - for(int row = percentRange; row < maxrowrange; row++){ - for(int col = 0; col < c; col++){ + for(unsigned int row = percentRange; row < maxrowrange; row++){ + for(unsigned int col = 0; col < c; col++){ testMatrixYPredict(row-percentRange,col) = p->GetData(row,col); } } - return std::make_pair(trainLabelMatrixY,testMatrixYPredict); } + } /* Reading an csv-data and transfer the included datas into an matrix. */ template Eigen::Matrix readCsvData(const std::string &path, char delimiter) { typename itk::CSVArray2DFileReader::Pointer fr = itk::CSVArray2DFileReader::New(); fr->SetFileName(path); fr->SetFieldDelimiterCharacter(delimiter); fr->HasColumnHeadersOff(); fr->HasRowHeadersOff(); fr->Parse(); try{ fr->Update(); }catch(itk::ExceptionObject& ex){ cout << "Exception caught!" << std::endl; cout << ex << std::endl; } typename itk::CSVArray2DDataObject::Pointer p = fr->GetOutput(); unsigned int maxrowrange = p->GetMatrix().rows(); unsigned int maxcols = p->GetMatrix().cols(); Eigen::Matrix matrix(maxrowrange,maxcols); - for(int rows = 0; rows < maxrowrange; rows++){ - for(int cols = 0; cols < maxcols; cols++ ){ + for(unsigned int rows = 0; rows < maxrowrange; rows++){ + for(unsigned int cols = 0; cols < maxcols; cols++ ){ matrix(rows,cols) = p->GetData(rows,cols); } } return matrix; } /* Write the content of the array into an own csv-data in the following sequence: root.csv: 1 2 3 0 0 4 writen.csv: 1 1:2 2:3 3:0 4:0 5:4 */ template void writeMatrixToCsv(Eigen::Matrix paramMatrix,const std::string &path) { std::ofstream outputstream (path,std::ofstream::out); // 682 if(outputstream.is_open()){ for(int i = 0; i < paramMatrix.rows(); i++){ outputstream << paramMatrix(i,0); for(int j = 1; j < 11; j++){ outputstream << " " << j << ":" << paramMatrix(i,j); } outputstream << endl; } outputstream.close(); } else{ cout << "Unable to write into CSV" << endl; } } - /* - Train the classifier with an exampledataset of mattlab. - Note: The included data are gaußan normaldistributed. - */ - void TrainThreadedDecisionForest_MatlabDataSet_shouldReturnTrue() - { - /* Declarating an featurematrixdataset, the first matrix - of the matrixpair is the trainingmatrix and the second one is the testmatrix.*/ - std::pair matrixDouble; - matrixDouble = convertCSVToMatrix(GetTestDataFilePath("Classification/FeaturematrixMatlab.csv"),';',0.5,true); - m_TrainingMatrixX = matrixDouble.first; - m_TestXPredict = matrixDouble.second; - - /* The declaration of the labelmatrixdataset is equivalent to the declaration - of the featurematrixdataset.*/ - - std::pair matrixInt; - matrixInt = convertCSVToMatrix(GetTestDataFilePath("Classification/LabelmatrixMatlab.csv"),';',0.5,false); - m_TrainingLabelMatrixY = matrixInt.first; - m_TestYPredict = matrixInt.second; - classifier = mitk::VigraRandomForestClassifier::New(); - - /* Train the classifier, by giving trainingdataset for the labels and features. - The result in an colunmvector of the labels.*/ - classifier->Train(m_TrainingMatrixX,m_TrainingLabelMatrixY); - Eigen::MatrixXi classes = classifier->Predict(m_TestXPredict); - - /* Testing the matching between the calculated colunmvector and the result of the RandomForest */ - unsigned int maxrows = classes.rows(); - - bool isYPredictVector = false; - int count = 0; - - for(int i= 0; i < maxrows; i++){ - if(classes(i,0) == m_TestYPredict(i,0)){ - isYPredictVector = true; - count++; - } - } - MITK_INFO << 100*count/(double)(maxrows) << "%"; - MITK_TEST_CONDITION(isIntervall(m_TestYPredict,classes,97,99),"Testvalue is in range."); - } - // Method for intervalltesting template bool isIntervall(Eigen::Matrix expected, Eigen::Matrix actual, double lowrange, double toprange) { bool isInIntervall = false; int count = 0; unsigned int rowRange = expected.rows(); unsigned int colRange = expected.cols(); - for(int i = 0; i < rowRange; i++){ - for(int j = 0; j < colRange; j++){ + for(unsigned int i = 0; i < rowRange; i++){ + for(unsigned int j = 0; j < colRange; j++){ if(expected(i,j) == actual(i,j)){ count++; } } double valueOfMatch = 100*count/(double)(rowRange); if((lowrange <= valueOfMatch) && (toprange >= valueOfMatch)){ isInIntervall = true; } } return isInIntervall; } - /* - Train the classifier with the dataset of breastcancer patients from the - LibSVM Libary - */ - void TrainThreadedDecisionForest_BreastCancerDataSet_shouldReturnTrue() - { - /* Declarating an featurematrixdataset, the first matrix - of the matrixpair is the trainingmatrix and the second one is the testmatrix.*/ - std::pair matrixDouble; - matrixDouble = convertCSVToMatrix(GetTestDataFilePath("Classification/FeaturematrixBreastcancer.csv"),';',0.5,true); - m_TrainingMatrixX = matrixDouble.first; - m_TestXPredict = matrixDouble.second; - - /* The declaration of the labelmatrixdataset is equivalent to the declaration - of the featurematrixdataset.*/ - std::pair matrixInt; - matrixInt = convertCSVToMatrix(GetTestDataFilePath("Classification/LabelmatrixBreastcancer.csv"),';',0.5,false); - m_TrainingLabelMatrixY = matrixInt.first; - m_TestYPredict = matrixInt.second; - - classifier = mitk::VigraRandomForestClassifier::New(); - - /* Train the classifier, by giving trainingdataset for the labels and features. - The result in an colunmvector of the labels.*/ - classifier->Train(m_TrainingMatrixX,m_TrainingLabelMatrixY); - Eigen::MatrixXi classes = classifier->Predict(m_TestXPredict); - - /* Testing the matching between the calculated colunmvector and the result of the RandomForest */ - unsigned int maxrows = classes.rows(); - - bool isYPredictVector = false; - int count = 0; - for(int i= 0; i < maxrows; i++){ - if(classes(i,0) == m_TestYPredict(i,0)){ - isYPredictVector = true; - count++; - } - } - MITK_INFO << 100*count/(double)(maxrows) << "%"; - MITK_TEST_CONDITION(isIntervall(m_TestYPredict,classes,97,99),"Testvalue is in range."); - } - - void TestThreadedDecisionForest() - { - } }; -MITK_TEST_SUITE_REGISTRATION(mitkVigraRandomForest) \ No newline at end of file +MITK_TEST_SUITE_REGISTRATION(mitkVigraRandomForest)