diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index 1b1ea5856f..150dd5417f 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,592 +1,592 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include #include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb, vigra::MultiArrayView<2, double> refTreeWeights) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb), m_TreeWeights(refTreeWeights) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; vigra::MultiArrayView<2, double> m_TreeWeights; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::unique_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); // If no weights provided if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); m_TreeWeights.fill(1); } vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::unique_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); // If no weights provided if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); m_TreeWeights.fill(1); } vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::unique_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictWeightedCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) { m_TreeWeights = weights; } Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const { return m_TreeWeights; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } - return nullptr; + return 0; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the last thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) { end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); - return nullptr; + return 0; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the last thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) { end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } VigraPredictWeighted(data, split_features,split_labels,split_probability); - return nullptr; + return 0; } void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P) { int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_; //#pragma omp parallel for for(int row=0; row < vigra::rowCount(X); ++row) { vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); vigra::ArrayVector::const_iterator weights; //totalWeight == totalVoteCount! double totalWeight = 0.0; //Let each tree classify... for(int k=0; km_RandomForest.options_.tree_count_; ++k) { //get weights predicted by single tree weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow); double numberOfLeafObservations = (*(weights-1)); //update votecount. for(int l=0; lm_RandomForest.ext_param_.class_count_; ++l) { // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node. double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted)); cur_w = cur_w * data->m_TreeWeights(k,0); P(row, l) += (int)cur_w; //every weight in totalWeight. totalWeight += cur_w; } } //Normalise votes in each row by total VoteCount (totalWeight for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l) { P(row, l) /= vigra::detail::RequiresExplicitCast::cast(totalWeight); } int erg; int maxCol = 0; for (int col=0;colm_RandomForest.class_count();++col) { if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol)) maxCol = col; } data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg); Y(row,0) = erg; } } void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else str << "precision\t\t" << this->m_Parameter->Precision << "\n"; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { this->GetPropertyList()->SetIntProperty("treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { this->GetPropertyList()->SetDoubleProperty("precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { this->GetPropertyList()->SetIntProperty("treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { this->GetPropertyList()->SetDoubleProperty("lambda",val); } void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) { m_TreeWeights(treeId,0) = weight; } void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) { this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); this->SetTreeCount(rf.options().tree_count_); this->SetSamplesPerTree(rf.options().training_set_proportion_); this->UseSampleWithReplacement(rf.options().sample_with_replacement_); this->m_RandomForest = rf; } const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const { return this->m_RandomForest; } diff --git a/Modules/IGT/TrackingDevices/mitkOptitrackTrackingDevice.cpp b/Modules/IGT/TrackingDevices/mitkOptitrackTrackingDevice.cpp index 6a8e80307e..5901534b3b 100644 --- a/Modules/IGT/TrackingDevices/mitkOptitrackTrackingDevice.cpp +++ b/Modules/IGT/TrackingDevices/mitkOptitrackTrackingDevice.cpp @@ -1,770 +1,770 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkOptitrackTrackingDevice.h" #include #ifdef MITK_USE_OPTITRACK_TRACKER /** * \brief API library header for Optitrack Systems */ #include //======================================================= // Static method: IsDeviceInstalled //======================================================= bool mitk::OptitrackTrackingDevice::IsDeviceInstalled() { return true; } //======================================================= // Constructor //======================================================= mitk::OptitrackTrackingDevice::OptitrackTrackingDevice() : mitk::TrackingDevice(), m_initialized(false) { // Set the MultiThread and Mutex this->m_MultiThreader = itk::MultiThreader::New(); this->m_ToolsMutex = itk::FastMutexLock::New(); //Set the mitk device information SetData(mitk::DeviceDataNPOptitrack); //Clear List of tools this->m_AllTools.clear(); } //======================================================= // Destructor //======================================================= mitk::OptitrackTrackingDevice::~OptitrackTrackingDevice() { MITK_DEBUG << "Deleting OptitrackTrackingDevice"; int result; // If device is in Tracking mode, stop the Tracking firts if (this->GetState() == mitk::TrackingDevice::Tracking) { MITK_DEBUG << "OptitrackTrackingDevice in Tracking State -> Stopping Tracking"; result = this->StopTracking(); if(result == NPRESULT_SUCCESS){ MITK_INFO << "OptitrackTrackingDevice Stopped"; } else { MITK_INFO << "Error during Stopping"; mitkThrowException(mitk::IGTException) << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(result); } } // If device is Ready, Close the connection to device and release the memory if (this->GetState() == mitk::TrackingDevice::Ready) { MITK_DEBUG << "OptitrackTrackingDevice in Ready State -> Closing the Connection"; result = this->CloseConnection(); if(result) { MITK_INFO << "OptitrackTrackingDevice Connection closed"; } else { MITK_DEBUG << "Error during Closing Connection"; mitkThrowException(mitk::IGTException) << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(result); } } // Set the device off m_initialized = false; // Change State to Setup this->SetState(mitk::TrackingDevice::Setup); MITK_DEBUG <<"OptitrackTrackingDevice deleted successfully"; } //======================================================= // OpenConnection //======================================================= bool mitk::OptitrackTrackingDevice::OpenConnection() { // Not initialize the system twice. if(!m_initialized) { MITK_DEBUG << "Initialize Optitrack Tracking System"; if( this->InitializeCameras() ) { m_initialized = true; // Set the initialized variable to true this->SetState(mitk::TrackingDevice::Ready); if(this->m_calibrationPath.empty()){ MITK_INFO << "Numer of connected cameras = " << TT_CameraCount(); MITK_WARN << "Attention: No calibration File defined !!"; return m_initialized; } else { this->LoadCalibration(); } } else { m_initialized = false; // Set the initialized variable to false this->SetState(mitk::TrackingDevice::Setup); // Set the State to Setup MITK_INFO << "Device initialization failed. Device is still in setup state"; mitkThrowException(mitk::IGTException) << "Device initialization failed. Device is still in setup state"; } } //this->LoadCalibration(); return m_initialized; } //======================================================= // InitializeCameras //======================================================= bool mitk::OptitrackTrackingDevice::InitializeCameras() { MITK_DEBUG << "Initialize Optitrack"; int result; result = TT_Initialize(); // Initialize the cameras if(result == NPRESULT_SUCCESS) { MITK_DEBUG << "Optitrack Initialization Succeed"; return true; } else { MITK_DEBUG << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(result); // If not succeed after OPTITRACK_ATTEMPTS times launch exception MITK_INFO << "Optitrack Tracking System cannot be initialized \n" << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(result); mitkThrowException(mitk::IGTException) << "Optitrack Tracking System cannot be initialized \n" << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(result); return false; } } //======================================================= // LoadCalibration //======================================================= bool mitk::OptitrackTrackingDevice::LoadCalibration() { MITK_DEBUG << "Loading System Calibration"; int resultLoadCalibration; // Check the file path if(this->m_calibrationPath.empty()){ MITK_INFO << "Calibration Path is empty"; mitkThrowException(mitk::IGTException) << "Calibration Path is empty"; return false; } // Once the system is ready and Initialized , a calibration file is loaded. if(this->m_initialized) { for( int i=OPTITRACK_ATTEMPTS; i>0; i--) { resultLoadCalibration = TT_LoadCalibration(this->m_calibrationPath.c_str()); if(resultLoadCalibration != NPRESULT_SUCCESS) { MITK_DEBUG << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(resultLoadCalibration); MITK_DEBUG << "Trying again..."; } else { MITK_DEBUG << "Calibration file has been loaded successfully"; return true; } } MITK_INFO << "System cannot load a calibration file"; mitkThrowException(mitk::IGTException) << "System cannot load a calibration file"; } else { MITK_INFO << "System is not ready for load a calibration file because it has not been initialized yet"; mitkThrowException(mitk::IGTException) << "System is not ready for load a calibration file because it has not been initialized yet"; return false; } // Never reach this point return false; } //======================================================= // SetCalibrationPath //======================================================= void mitk::OptitrackTrackingDevice::SetCalibrationPath(std::string calibrationPath){ MITK_DEBUG << "SetcalibrationPath"; MITK_DEBUG << calibrationPath; // Check the file path if(calibrationPath.empty()) { MITK_INFO << "Calibration Path is empty"; //mitkThrowException(mitk::IGTException) << "Calibration Path is empty"; return; } this->m_calibrationPath = calibrationPath; MITK_INFO << "Calibration Path has been updated to: " << this->m_calibrationPath; return; } //======================================================= // CloseConnection //======================================================= bool mitk::OptitrackTrackingDevice::CloseConnection() { MITK_DEBUG << "CloseConnection"; int resultStop, resultShutdown; if(m_initialized) // Close connection if the System was initialized first { if(this->GetState() == mitk::TrackingDevice::Tracking) { MITK_DEBUG << "Device state: Tracking -> Stoping the Tracking"; resultStop = this->StopTracking(); //Stop tracking on close } this->SetState(mitk::OptitrackTrackingDevice::Setup); for( int i=OPTITRACK_ATTEMPTS; i>0; i--) { TT_ClearTrackableList(); resultShutdown = TT_Shutdown(); if(resultShutdown == NPRESULT_SUCCESS) { MITK_DEBUG << "System has been Shutdown Correctly"; Sleep(2000); return true; } else { MITK_DEBUG << "System cannot ShutDown now. Trying again..."; } } MITK_INFO << "System cannot ShutDown now"; mitkThrowException(mitk::IGTException) << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(resultShutdown); return false; } else { MITK_INFO << "System has not been initialized. Close connection cannot be done"; mitkThrowException(mitk::IGTException) << "System has not been initialized. Close connection cannot be done"; return false; } return false; } //======================================================= // StartTracking //======================================================= bool mitk::OptitrackTrackingDevice::StartTracking() { MITK_DEBUG << "StartTracking"; bool resultIsTrackableTracked; if (this->GetState() != mitk::TrackingDevice::Ready) { MITK_INFO << "System is not in State Ready -> Cannot StartTracking"; mitkThrowException(mitk::IGTException) << "System is not in State Ready -> Cannot StartTracking"; return false; } this->SetState(mitk::TrackingDevice::Tracking); // Change the m_StopTracking Variable to false this->m_StopTrackingMutex->Lock(); this->m_StopTracking = false; this->m_StopTrackingMutex->Unlock(); /****************************************************************************** ############################################################################### TODO: check the timestamp from the Optitrack API ############################################################################### ******************************************************************************/ mitk::IGTTimeStamp::GetInstance()->Start(this); // Launch multiThreader using the Function ThreadStartTracking that executes the TrackTools() method m_ThreadID = m_MultiThreader->SpawnThread(this->ThreadStartTracking, this); // start a new thread that executes the TrackTools() method // Information for the user if(GetToolCount() == 0) MITK_INFO << "No tools are defined"; for ( int i = 0; i < GetToolCount(); ++i) // use mutexed methods to access tool container { resultIsTrackableTracked = TT_IsTrackableTracked(i); if(resultIsTrackableTracked) { MITK_DEBUG << "Trackable " << i << " is inside the Tracking Volume and it is Tracked"; } else { MITK_DEBUG << "Trackable " << i << " is not been tracked. Check if it is inside the Tracking volume"; } } return true; } //======================================================= // StopTracking //======================================================= bool mitk::OptitrackTrackingDevice::StopTracking() { MITK_DEBUG << "StopTracking"; if (this->GetState() == mitk::TrackingDevice::Tracking) // Only if the object is in the correct state { //Change the StopTracking value m_StopTrackingMutex->Lock(); // m_StopTracking is used by two threads, so we have to ensure correct thread handling m_StopTrackingMutex->Unlock(); this->SetState(mitk::TrackingDevice::Ready); } else { m_TrackingFinishedMutex->Unlock(); MITK_INFO << "System is not in State Tracking -> Cannot StopTracking"; mitkThrowException(mitk::IGTException) << "System is not in State Tracking -> Cannot StopTracking"; return false; } /****************************************************************************** ############################################################################### TODO: check the timestamp from the Optitrack API ############################################################################### ******************************************************************************/ mitk::IGTTimeStamp::GetInstance()->Stop(this); m_TrackingFinishedMutex->Unlock(); return true; } //======================================================= // ThreadStartTracking //======================================================= ITK_THREAD_RETURN_TYPE mitk::OptitrackTrackingDevice::ThreadStartTracking(void* pInfoStruct) { MITK_DEBUG << "ThreadStartTracking"; /* extract this pointer from Thread Info structure */ struct itk::MultiThreader::ThreadInfoStruct * pInfo = (struct itk::MultiThreader::ThreadInfoStruct*)pInfoStruct; if (pInfo == nullptr) { return ITK_THREAD_RETURN_VALUE; } if (pInfo->UserData == nullptr) { return ITK_THREAD_RETURN_VALUE; } OptitrackTrackingDevice *trackingDevice = static_cast(pInfo->UserData); if (trackingDevice != nullptr) { // Call the TrackTools function in this thread trackingDevice->TrackTools(); } else { mitkThrowException(mitk::IGTException) << "In ThreadStartTracking(): trackingDevice is nullptr"; } trackingDevice->m_ThreadID = -1; // reset thread ID because we end the thread here return ITK_THREAD_RETURN_VALUE; } //======================================================= // GetOptitrackTool //======================================================= mitk::OptitrackTrackingTool* mitk::OptitrackTrackingDevice::GetOptitrackTool( unsigned int toolNumber) const { MITK_DEBUG << "ThreadStartTracking"; OptitrackTrackingTool* t = nullptr; MutexLockHolder toolsMutexLockHolder(*m_ToolsMutex); // lock and unlock the mutex if(toolNumber < m_AllTools.size()) { t = m_AllTools.at(toolNumber); } else { MITK_INFO << "The tool numbered " << toolNumber << " does not exist"; mitkThrowException(mitk::IGTException) << "The tool numbered " << toolNumber << " does not exist"; } return t; } //======================================================= // GetToolCount //======================================================= unsigned int mitk::OptitrackTrackingDevice::GetToolCount() const { MITK_DEBUG << "GetToolCount"; MutexLockHolder lock(*m_ToolsMutex); // lock and unlock the mutex return ( int)(this->m_AllTools.size()); } //======================================================= // TrackTools //======================================================= void mitk::OptitrackTrackingDevice::TrackTools() { MITK_DEBUG << "TrackTools"; Point3D position; ScalarType t = 0.0; try { bool localStopTracking; // Because m_StopTracking is used by two threads, access has to be guarded by a mutex. To minimize thread locking, a local copy is used here this->m_StopTrackingMutex->Lock(); // update the local copy of m_StopTracking localStopTracking = this->m_StopTracking; /* lock the TrackingFinishedMutex to signal that the execution rights are now transfered to the tracking thread */ if (!localStopTracking) { m_TrackingFinishedMutex->Lock(); } this->m_StopTrackingMutex->Unlock(); while ((this->GetState() == mitk::TrackingDevice::Tracking) && (localStopTracking == false)) { // For each Tracked Tool update the position and orientation for ( int i = 0; i < GetToolCount(); ++i) // use mutexed methods to access tool container { OptitrackTrackingTool* currentTool = this->GetOptitrackTool(i); if(currentTool != nullptr) { currentTool->updateTool(); MITK_DEBUG << "Tool number " << i << " updated position"; } else { MITK_DEBUG << "Get data from tool number " << i << " failed"; mitkThrowException(mitk::IGTException) << "Get data from tool number " << i << " failed"; } } /* Update the local copy of m_StopTracking */ this->m_StopTrackingMutex->Lock(); localStopTracking = m_StopTracking; this->m_StopTrackingMutex->Unlock(); Sleep(OPTITRACK_FRAME_RATE); } // tracking ends if we pass this line m_TrackingFinishedMutex->Unlock(); // transfer control back to main thread } catch(...) { m_TrackingFinishedMutex->Unlock(); this->StopTracking(); mitkThrowException(mitk::IGTException) << "Error while trying to track tools. Thread stopped."; } } //======================================================= // SetCameraParams //======================================================= bool mitk::OptitrackTrackingDevice::SetCameraParams(int exposure, int threshold , int intensity, int videoType ) { MITK_DEBUG << "SetCameraParams"; if(this->m_initialized) { int num_cams = 0; int resultUpdate; bool resultSetCameraSettings; for( int i=OPTITRACK_ATTEMPTS; i>0; i--) { resultUpdate = TT_Update(); // Get Update for the Optitrack API if(resultUpdate == NPRESULT_SUCCESS) { MITK_DEBUG << "Update Succeed"; num_cams = TT_CameraCount(); i = 0; } else { MITK_DEBUG << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(resultUpdate); MITK_DEBUG << "Trying again..."; Sleep(30); } } // If no cameras are connected if(num_cams == 0) { MITK_DEBUG << "No cameras are connected to the device"; return false; mitkThrowException(mitk::IGTException) << "No cameras are connected to the device"; } for(int cam = 0; cam < num_cams; cam++) // for all connected cameras { for( int i=OPTITRACK_ATTEMPTS; i>0; i--) { resultUpdate = TT_Update(); // Get Update for the Optitrack API if(resultUpdate == NPRESULT_SUCCESS) { MITK_DEBUG << "Update Succeed for camera number " << cam; resultSetCameraSettings = TT_SetCameraSettings(cam,videoType,exposure,threshold,intensity); if(resultSetCameraSettings) { MITK_INFO << "Camera # "< System is not ready to perform the Camera Parameters Setting"; mitkThrowException(mitk::IGTException) << "System is not Initialized -> System is not ready to perform the Camera Parameters Setting"; return false; } return true; } //======================================================= // GetTool //======================================================= mitk::TrackingTool* mitk::OptitrackTrackingDevice::GetTool(unsigned int toolNumber) const { return static_cast(GetOptitrackTool(toolNumber)); } //======================================================= // AddToolByFileName //======================================================= bool mitk::OptitrackTrackingDevice::AddToolByDefinitionFile(std::string fileName) { bool resultSetToolByFileName; if(m_initialized) { OptitrackTrackingTool::Pointer t = OptitrackTrackingTool::New(); resultSetToolByFileName= t->SetToolByFileName(fileName); if(resultSetToolByFileName) { this->m_AllTools.push_back(t); MITK_INFO << "Added tool "<GetToolName()<< ". Tool vector size is now: "< Cannot Add tools"; mitkThrowException(mitk::IGTException) << "System is not Initialized -> Cannot Add tools"; return false; } } //======================================================= // IF Optitrack is not installed set functions to warnings //======================================================= #else //======================================================= // Static method: IsDeviceInstalled //======================================================= bool mitk::OptitrackTrackingDevice::IsDeviceInstalled() { return false; } //======================================================= // Constructor //======================================================= mitk::OptitrackTrackingDevice::OptitrackTrackingDevice() : mitk::TrackingDevice(), m_initialized(false) { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); } //======================================================= // Destructor //======================================================= mitk::OptitrackTrackingDevice::~OptitrackTrackingDevice() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); } //======================================================= // OpenConnection //======================================================= bool mitk::OptitrackTrackingDevice::OpenConnection() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // InitializeCameras //======================================================= bool mitk::OptitrackTrackingDevice::InitializeCameras() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // LoadCalibration //======================================================= bool mitk::OptitrackTrackingDevice::LoadCalibration() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // SetcalibrationPath //======================================================= void mitk::OptitrackTrackingDevice::SetCalibrationPath(std::string /*calibrationPath*/) { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); } //======================================================= // CloseConnection //======================================================= bool mitk::OptitrackTrackingDevice::CloseConnection() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // StartTracking //======================================================= bool mitk::OptitrackTrackingDevice::StartTracking() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // StopTracking //======================================================= bool mitk::OptitrackTrackingDevice::StopTracking() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // ThreadStartTracking //======================================================= ITK_THREAD_RETURN_TYPE mitk::OptitrackTrackingDevice::ThreadStartTracking(void*) { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); - return nullptr; + return 0; } //======================================================= // GetOptitrackTool //======================================================= mitk::OptitrackTrackingTool* mitk::OptitrackTrackingDevice::GetOptitrackTool(unsigned int) const { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return nullptr; } //======================================================= // GetToolCount //======================================================= unsigned int mitk::OptitrackTrackingDevice::GetToolCount() const { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return 0; } //======================================================= // TrackTools //======================================================= void mitk::OptitrackTrackingDevice::TrackTools() { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); } //======================================================= // SetCameraParams //======================================================= bool mitk::OptitrackTrackingDevice::SetCameraParams(int, int, int, int) { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } //======================================================= // GetTool //======================================================= mitk::TrackingTool* mitk::OptitrackTrackingDevice::GetTool(unsigned int) const { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return nullptr; } //======================================================= // AddToolByFileName //======================================================= bool mitk::OptitrackTrackingDevice::AddToolByDefinitionFile(std::string) { MITK_WARN("IGT") << "Error: " << mitk::OptitrackErrorMessages::GetOptitrackErrorMessage(100); return false; } #endif