diff --git a/Modules/Classification/ClassifierVigraRandomForest/include/mitkThresholdSplit.h b/Modules/Classification/ClassifierVigraRandomForest/include/mitkThresholdSplit.h index d6c3085d34..643e26e1ec 100644 --- a/Modules/Classification/ClassifierVigraRandomForest/include/mitkThresholdSplit.h +++ b/Modules/Classification/ClassifierVigraRandomForest/include/mitkThresholdSplit.h @@ -1,82 +1,81 @@ #ifndef mitkThresholdSplit_h #define mitkThresholdSplit_h #include #include namespace mitk { template class ThresholdSplit: public vigra::SplitBase { public: ThresholdSplit(); // ThresholdSplit(const ThresholdSplit & other); void SetFeatureCalculator(TFeatureCalculator processor); TFeatureCalculator GetFeatureCalculator() const; void SetCalculatingFeature(bool calculate); bool IsCalculatingFeature() const; void UsePointBasedWeights(bool weightsOn); bool IsUsingPointBasedWeights() const; void UseRandomSplit(bool split) {m_UseRandomSplit = split;} bool IsUsingRandomSplit() const { return m_UseRandomSplit; } void SetPrecision(double value); double GetPrecision() const; void SetMaximumTreeDepth(int value); - virtual int GetMaximumTreeDepth(); - int GetConstMaximumTreeDepth() const; + virtual int GetMaximumTreeDepth() const; void SetWeights(vigra::MultiArrayView<2, double> weights); vigra::MultiArrayView<2, double> GetWeights() const; // From vigra::ThresholdSplit double minGini() const; int bestSplitColumn() const; double bestSplitThreshold() const; template void set_external_parameters(vigra::ProblemSpec const & in); template int findBestSplit(vigra::MultiArrayView<2, T, C> features, vigra::MultiArrayView<2, T2, C2> labels, Region & region, vigra::ArrayVector& childRegions, Random & randint); double region_gini_; private: // From vigra::ThresholdSplit typedef vigra::SplitBase SB; // splitter parameters (used by copy constructor) bool m_CalculatingFeature; bool m_UseWeights; bool m_UseRandomSplit; double m_Precision; int m_MaximumTreeDepth; TFeatureCalculator m_FeatureCalculator; vigra::MultiArrayView<2, double> m_Weights; // variabels to work with vigra::ArrayVector splitColumns; TColumnDecisionFunctor bgfunc; vigra::ArrayVector min_gini_; vigra::ArrayVector min_indices_; vigra::ArrayVector min_thresholds_; int bestSplitIndex; }; } #include <../src/Splitter/mitkThresholdSplit.cpp> #endif //mitkThresholdSplit_h diff --git a/Modules/Classification/ClassifierVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/ClassifierVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index f3e7fd75d1..00b8e046c9 100644 --- a/Modules/Classification/ClassifierVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/ClassifierVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,433 +1,433 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { this->ConvertParameter(); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::auto_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); std::auto_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); - splitter.SetMaximumTreeDepth(data->m_Splitter.GetConstMaximumTreeDepth()); + splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return ITK_THREAD_RETURN_VALUE; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the 0th thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) numberOfRowsToCalculate += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); //ITK_THREAD_RETURN_TYPE value = NULL; return ITK_THREAD_RETURN_VALUE; } void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("classifier.vigra-rf.usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) str << "classifier.vigra-rf.usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else str << "classifier.vigra-rf.usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.userandomsplit",this->m_Parameter->UseRandomSplit)) str << "classifier.vigra-rf.userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else str << "classifier.vigra-rf.userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.treedepth",this->m_Parameter->TreeDepth)) str << "classifier.vigra-rf.treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else str << "classifier.vigra-rf.treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) str << "classifier.vigra-rf.minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else str << "classifier.vigra-rf.minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.precision",this->m_Parameter->Precision)) str << "classifier.vigra-rf.precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else str << "classifier.vigra-rf.precision\t\t" << this->m_Parameter->Precision << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplespertree",this->m_Parameter->SamplesPerTree)) str << "classifier.vigra-rf.samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else str << "classifier.vigra-rf.samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->SampleWithReplacement)) str << "classifier.vigra-rf.samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else str << "classifier.vigra-rf.samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.treecount",this->m_Parameter->TreeCount)) str << "classifier.vigra-rf.treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else str << "classifier.vigra-rf.treecount\t\t" << this->m_Parameter->TreeCount << "\n"; if(!this->GetPropertyList()->Get("classifier.vigra-rf.lambda",this->m_Parameter->WeightLambda)) str << "classifier.vigra-rf.lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else str << "classifier.vigra-rf.lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; // if(!this->GetPropertyList()->Get("classifier.vigra-rf.samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { this->GetPropertyList()->SetBoolProperty("classifier.vigra-rf.samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { this->GetPropertyList()->SetIntProperty("classifier.vigra-rf.treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { this->GetPropertyList()->SetDoubleProperty("classifier.vigra-rf.lambda",val); } void mitk::VigraRandomForestClassifier::SetNthItems(const char * val, unsigned int idx) { std::stringstream ss; ss << "classifier.vigra-rf.item." << idx; this->GetPropertyList()->SetStringProperty(ss.str().c_str(),val); } void mitk::VigraRandomForestClassifier::SetItemList(std::vector list) { for(unsigned int i = 0 ; i < list.size(); ++i) this->SetNthItems(list[i].c_str(),i); } std::vector mitk::VigraRandomForestClassifier::GetItemList() { std::vector result; for(unsigned int idx = 0 ; idx < 100; idx++) { std::stringstream ss; ss << "classifier.vigra-rf.item." << idx; if(this->GetPropertyList()->GetProperty(ss.str().c_str())) { std::string s; this->GetPropertyList()->GetStringProperty(ss.str().c_str(),s); result.push_back(s); } } return result; } diff --git a/Modules/Classification/ClassifierVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp b/Modules/Classification/ClassifierVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp index 2e8562928a..86a2f635a8 100644 --- a/Modules/Classification/ClassifierVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp +++ b/Modules/Classification/ClassifierVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp @@ -1,306 +1,298 @@ #ifndef mitkThresholdSplit_cpp #define mitkThresholdSplit_cpp #include template mitk::ThresholdSplit::ThresholdSplit() : m_CalculatingFeature(false), m_UseWeights(false), m_UseRandomSplit(false), m_Precision(0.0), m_MaximumTreeDepth(1000) { } //template //mitk::ThresholdSplit::ThresholdSplit(const ThresholdSplit & /*other*/)/*: // m_CalculatingFeature(other.IsCalculatingFeature()), // m_UseWeights(other.IsUsingPointBasedWeights()), // m_UseRandomSplit(other.IsUsingRandomSplit()), // m_Precision(other.GetPrecision()), // m_MaximumTreeDepth(other.GetMaximumTreeDepth()), // m_FeatureCalculator(other.GetFeatureCalculator()), // m_Weights(other.GetWeights())*/ //{ //} template void mitk::ThresholdSplit::SetFeatureCalculator(TFeatureCalculator processor) { m_FeatureCalculator = processor; } template TFeatureCalculator mitk::ThresholdSplit::GetFeatureCalculator() const { return m_FeatureCalculator; } template void mitk::ThresholdSplit::SetCalculatingFeature(bool calculate) { m_CalculatingFeature = calculate; } template bool mitk::ThresholdSplit::IsCalculatingFeature() const { return m_CalculatingFeature; } template void mitk::ThresholdSplit::UsePointBasedWeights(bool weightsOn) { m_UseWeights = weightsOn; bgfunc.UsePointWeights(weightsOn); } template bool mitk::ThresholdSplit::IsUsingPointBasedWeights() const { return m_UseWeights; } template void mitk::ThresholdSplit::SetPrecision(double value) { m_Precision = value; } template double mitk::ThresholdSplit::GetPrecision() const { return m_Precision; } template void mitk::ThresholdSplit::SetMaximumTreeDepth(int value) { m_MaximumTreeDepth = value; } template int -mitk::ThresholdSplit::GetMaximumTreeDepth() +mitk::ThresholdSplit::GetMaximumTreeDepth() const { return m_MaximumTreeDepth; } -template -int -mitk::ThresholdSplit::GetConstMaximumTreeDepth() const -{ - return m_MaximumTreeDepth; -} - - template void mitk::ThresholdSplit::SetWeights(vigra::MultiArrayView<2, double> weights) { m_Weights = weights; bgfunc.UsePointWeights(m_UseWeights); bgfunc.SetPointWeights(weights); } template vigra::MultiArrayView<2, double> mitk::ThresholdSplit::GetWeights() const { return m_Weights; } template double mitk::ThresholdSplit::minGini() const { return min_gini_[bestSplitIndex]; } template int mitk::ThresholdSplit::bestSplitColumn() const { return splitColumns[bestSplitIndex]; } template double mitk::ThresholdSplit::bestSplitThreshold() const { return min_thresholds_[bestSplitIndex]; } template template void mitk::ThresholdSplit::set_external_parameters(vigra::ProblemSpec const & in) { SB::set_external_parameters(in); bgfunc.set_external_parameters( SB::ext_param_); int featureCount_ = SB::ext_param_.column_count_; splitColumns.resize(featureCount_); for(int k=0; k template int mitk::ThresholdSplit::findBestSplit(vigra::MultiArrayView<2, T, C> features, vigra::MultiArrayView<2, T2, C2> labels, Region & region, vigra::ArrayVector& childRegions, Random & randint) { typedef typename Region::IndexIterator IndexIteratorType; if (m_CalculatingFeature) { // Do some very fance stuff here!! // This is not so simple as it might look! We need to // remember which feature has been used to be able to // use it for testing again!! // There, no Splitting class is used!! } bgfunc.UsePointWeights(m_UseWeights); bgfunc.UseRandomSplit(m_UseRandomSplit); vigra::detail::Correction::exec(region, labels); // Create initial class count. for(std::size_t i = 0; i < region.classCounts_.size(); ++i) { region.classCounts_[i] = 0; } double regionSum = 0; for (typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter) { double probability = 1.0; if (m_UseWeights) { probability = m_Weights(*iter, 0); } region.classCounts_[labels(*iter,0)] += probability; regionSum += probability; } region.classCountsIsValid = true; vigra::ArrayVector vec; // Is pure region? region_gini_ = bgfunc.LossOfRegion(labels, region.begin(), region.end(), region.classCounts()); if (region_gini_ <= m_Precision * regionSum) // Necessary to fix wrong calculation of Gini-Index { return this->makeTerminalNode(features, labels, region, randint); } // Randomize the order of columns for (int i = 0; i < SB::ext_param_.actual_mtry_; ++i) { std::swap(splitColumns[i], splitColumns[i+ randint(features.shape(1) - i)]); } // find the split with the best evaluation value bestSplitIndex = 0; double currentMiniGini = region_gini_; int numberOfTrials = features.shape(1); for (int k = 0; k < numberOfTrials; ++k) { bgfunc(columnVector(features, splitColumns[k]), labels, region.begin(), region.end(), region.classCounts()); min_gini_[k] = bgfunc.GetMinimumLoss(); min_indices_[k] = bgfunc.GetMinimumIndex(); min_thresholds_[k] = bgfunc.GetMinimumThreshold(); // removed classifier test section, because not necessary if (bgfunc.GetMinimumLoss() < currentMiniGini) { currentMiniGini = bgfunc.GetMinimumLoss(); childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0]; childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1]; childRegions[0].classCountsIsValid = true; childRegions[1].classCountsIsValid = true; bestSplitIndex = k; numberOfTrials = SB::ext_param_.actual_mtry_; } } //If only a small improvement, make terminal node... if(vigra::closeAtTolerance(currentMiniGini, region_gini_)) { return this->makeTerminalNode(features, labels, region, randint); } vigra::Node node(SB::t_data, SB::p_data); SB::node_ = node; node.threshold() = min_thresholds_[bestSplitIndex]; node.column() = splitColumns[bestSplitIndex]; // partition the range according to the best dimension vigra::SortSamplesByDimensions > sorter(features, node.column(), node.threshold()); IndexIteratorType bestSplit = std::partition(region.begin(), region.end(), sorter); // Save the ranges of the child stack entries. childRegions[0].setRange( region.begin() , bestSplit ); childRegions[0].rule = region.rule; childRegions[0].rule.push_back(std::make_pair(1, 1.0)); childRegions[1].setRange( bestSplit , region.end() ); childRegions[1].rule = region.rule; childRegions[1].rule.push_back(std::make_pair(1, 1.0)); return vigra::i_ThresholdNode; return 0; } //template //static void UpdateRegionCounts(TRegion & region, TRegionIterator begin, TRegionIterator end, TLabelHolder labels, TWeightsHolder weights) //{ // if(std::accumulate(region.classCounts().begin(), // region.classCounts().end(), 0.0) != region.size()) // { // RandomForestClassCounter< LabelT, // ArrayVector > // counter(labels, region.classCounts()); // std::for_each( region.begin(), region.end(), counter); // region.classCountsIsValid = true; // } //} // //template //static void exec(Region & region, LabelT & labels) //{ // if(std::accumulate(region.classCounts().begin(), // region.classCounts().end(), 0.0) != region.size()) // { // RandomForestClassCounter< LabelT, // ArrayVector > // counter(labels, region.classCounts()); // std::for_each( region.begin(), region.end(), counter); // region.classCountsIsValid = true; // } //} #endif //mitkThresholdSplit_cpp