diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp index 6e3cfec1b9..1fd8899327 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -1,624 +1,630 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_txx #define __itkMLBSTrackingFilter_txx #include #include #include +#include #include "itkMLBSTrackingFilter.h" #include #include #include #include #include "itkPointShell.h" #define _USE_MATH_DEFINES #include namespace itk { template< int ShOrder, int NumImageFeatures > MLBSTrackingFilter< ShOrder, NumImageFeatures > ::MLBSTrackingFilter() : m_PauseTracking(false) , m_AbortTracking(false) , m_FiberPolyData(NULL) , m_Points(NULL) , m_Cells(NULL) , m_AngularThreshold(0.7) , m_StepSize(0) , m_MaxLength(10000) , m_MinTractLength(20.0) , m_MaxTractLength(400.0) , m_SeedsPerVoxel(1) , m_RandomSampling(false) , m_SamplingDistance(-1) , m_DeflectionMod(1.0) , m_OnlyForwardSamples(true) , m_UseStopVotes(true) , m_NumberOfSamples(30) , m_NumPreviousDirections(1) , m_StoppingRegions(NULL) , m_SeedImage(NULL) , m_MaskImage(NULL) , m_AposterioriCurvCheck(false) , m_AvoidStop(true) , m_DemoMode(false) { this->SetNumberOfRequiredInputs(1); } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::BeforeThreadedGenerateData() { m_InputImage = const_cast(this->GetInput(0)); m_ForestHandler.InitForTracking(); m_FiberPolyData = PolyDataType::New(); m_Points = vtkSmartPointer< vtkPoints >::New(); m_Cells = vtkSmartPointer< vtkCellArray >::New(); std::vector< double > m_ImageSpacing; m_ImageSpacing.resize(3); m_ImageSpacing[0] = m_InputImage->GetSpacing()[0]; m_ImageSpacing[1] = m_InputImage->GetSpacing()[1]; m_ImageSpacing[2] = m_InputImage->GetSpacing()[2]; double minSpacing; if(m_ImageSpacing[0]GetNumberOfThreads(); i++) { PolyDataType poly = PolyDataType::New(); m_PolyDataContainer.push_back(poly); } if (m_StoppingRegions.IsNull()) { m_StoppingRegions = ItkUcharImgType::New(); m_StoppingRegions->SetSpacing( m_InputImage->GetSpacing() ); m_StoppingRegions->SetOrigin( m_InputImage->GetOrigin() ); m_StoppingRegions->SetDirection( m_InputImage->GetDirection() ); m_StoppingRegions->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_StoppingRegions->Allocate(); m_StoppingRegions->FillBuffer(0); } if (m_SeedImage.IsNull()) { m_SeedImage = ItkUcharImgType::New(); m_SeedImage->SetSpacing( m_InputImage->GetSpacing() ); m_SeedImage->SetOrigin( m_InputImage->GetOrigin() ); m_SeedImage->SetDirection( m_InputImage->GetDirection() ); m_SeedImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_SeedImage->Allocate(); m_SeedImage->FillBuffer(1); } if (m_MaskImage.IsNull()) { // initialize mask image m_MaskImage = ItkUcharImgType::New(); m_MaskImage->SetSpacing( m_InputImage->GetSpacing() ); m_MaskImage->SetOrigin( m_InputImage->GetOrigin() ); m_MaskImage->SetDirection( m_InputImage->GetDirection() ); m_MaskImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1); } else std::cout << "MLBSTrackingFilter: using mask image" << std::endl; if (m_AngularThreshold<0.0) m_AngularThreshold = 0.5*minSpacing; m_BuildFibersReady = 0; m_BuildFibersFinished = false; - m_Threads = 0; m_Tractogram.clear(); m_SamplingPointset = mitk::PointSet::New(); m_AlternativePointset = mitk::PointSet::New(); m_StartTime = std::chrono::system_clock::now(); std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Deflection modifier: " << m_DeflectionMod << std::endl; std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Use stop votes: " << m_UseStopVotes << std::endl; std::cout << "MLBSTrackingFilter: Only frontal samples: " << m_OnlyForwardSamples << std::endl; std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; + this->SetNumberOfThreads(1); } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) { // DON'T CHANGE DIR!! // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); // dir = rot*dir; pos[0] += dir[0]*m_StepSize; pos[1] += dir[1]*m_StepSize; pos[2] += dir[2]*m_StepSize; } template< int ShOrder, int NumImageFeatures > bool MLBSTrackingFilter< ShOrder, NumImageFeatures > ::IsValidPosition(itk::Point &pos) { typename FeatureImageType::IndexType idx; m_InputImage->TransformPhysicalPointToIndex(pos, idx); if (!m_InputImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) return false; return true; } template< int ShOrder, int NumImageFeatures > double MLBSTrackingFilter< ShOrder, NumImageFeatures >::GetRandDouble(double min, double max) { return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; } template< int ShOrder, int NumImageFeatures > std::vector< vnl_vector_fixed > MLBSTrackingFilter< ShOrder, NumImageFeatures >::CreateDirections(int NPoints) { std::vector< double > theta; theta.resize(NPoints); std::vector< double > phi; phi.resize(NPoints); double C = sqrt(4*M_PI); phi[0] = 0.0; phi[NPoints-1] = 0.0; for(int i=0; i0 && i > pointshell; for(int i=0; i d; d[0] = cos(theta[i]) * cos(phi[i]); d[1] = cos(theta[i]) * sin(phi[i]); d[2] = sin(theta[i]); pointshell.push_back(d); } return pointshell; } template< int ShOrder, int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< ShOrder, NumImageFeatures >::GetNewDirection(itk::Point &pos, std::deque >& olddirs) { if (m_DemoMode) { m_SamplingPointset->Clear(); m_AlternativePointset->Clear(); } vnl_vector_fixed direction; direction.fill(0); ItkUcharImgType::IndexType idx; m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); if (m_StoppingRegions->GetLargestPossibleRegion().IsInside(idx) && m_StoppingRegions->GetPixel(idx)>0) return direction; if (m_MaskImage.IsNotNull() && ((m_MaskImage->GetLargestPossibleRegion().IsInside(idx) && m_MaskImage->GetPixel(idx)<=0) || !m_MaskImage->GetLargestPossibleRegion().IsInside(idx)) ) return direction; int candidates = 0; // number of directions with probability > 0 double w = 0; // weight of the direction predicted at each sampling point if (IsValidPosition(pos)) { direction = m_ForestHandler.Classify(pos, candidates, olddirs, m_AngularThreshold, w, m_MaskImage); // get direction proposal at current streamline position direction *= w; // HERE WE ARE WEIGHTING AGAIN EVEN THOUGH THE OUTPUT DIRECTIONS ARE ALREADY WEIGHTED!!! THE EFFECT OF THIS HAS YET TO BE EVALUATED. } vnl_vector_fixed olddir = olddirs.back(); std::vector< vnl_vector_fixed > probeVecs = CreateDirections(m_NumberOfSamples); itk::Point sample_pos; int alternatives = 1; int stop_votes = 0; int possible_stop_votes = 0; for (int i=0; i d; bool is_stop_voter = false; if (m_RandomSampling) { d[0] = GetRandDouble(); d[1] = GetRandDouble(); d[2] = GetRandDouble(); d.normalize(); d *= GetRandDouble(0,m_SamplingDistance); } else { d = probeVecs.at(i); double dot = dot_product(d, olddir); if (m_UseStopVotes && dot>0.7) { is_stop_voter = true; possible_stop_votes++; } else if (m_OnlyForwardSamples && dot<0) continue; d *= m_SamplingDistance; } sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_SamplingPointset->InsertPoint(i, sample_pos); candidates = 0; vnl_vector_fixed tempDir; tempDir.fill(0.0); if (IsValidPosition(sample_pos)) tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddirs, m_AngularThreshold, w, m_MaskImage); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) { direction += tempDir*w; } else if (m_AvoidStop && candidates==0 && olddir.magnitude()>0) // out of white matter { if (is_stop_voter) stop_votes++; double dot = dot_product(d, olddir); if (dot >= 0.0) // in front of plane defined by pos and olddir d = -d + 2*dot*olddir; // reflect else d = -d; // invert // look a bit further into the other direction sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_AlternativePointset->InsertPoint(alternatives, sample_pos); alternatives++; candidates = 0; vnl_vector_fixed tempDir; tempDir.fill(0.0); if (IsValidPosition(sample_pos)) tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddirs, m_AngularThreshold, w, m_MaskImage); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? { direction += d * m_DeflectionMod; // go into the direction of the white matter direction += tempDir*w; // go into the direction of the white matter direction at this location } } else if (is_stop_voter) stop_votes++; } if (direction.magnitude()>0.001 && (possible_stop_votes==0 || (float)stop_votes/possible_stop_votes<0.5) ) direction.normalize(); else direction.fill(0); return direction; } template< int ShOrder, int NumImageFeatures > double MLBSTrackingFilter< ShOrder, NumImageFeatures >::FollowStreamline(itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) { vnl_vector_fixed zero_dir; zero_dir.fill(0.0); std::deque< vnl_vector_fixed > last_dirs; for (int i=0; ipush_front(pos); else fib->push_back(pos); if (m_AposterioriCurvCheck) { int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? if (curv>0) { tractLength -= m_StepSize*curv; while (curv>0) { if (front) fib->pop_front(); else fib->pop_back(); curv--; } return tractLength; } } if (tractLength>m_MaxTractLength) return tractLength; } + +#pragma omp critical if (m_DemoMode) // CHECK: warum sind die samplingpunkte der streamline in der visualisierung immer einen schritt voras? { - m_Mutex.Lock(); m_BuildFibersReady++; m_Tractogram.push_back(*fib); BuildFibers(true); m_Stop = true; - m_Mutex.Unlock(); while (m_Stop){ } } last_dirs.push_back(dir); if (last_dirs.size()>m_NumPreviousDirections) last_dirs.pop_front(); dir = GetNewDirection(pos, last_dirs); while (m_PauseTracking){} if (dir.magnitude()<0.0001) return tractLength; } return tractLength; } template< int ShOrder, int NumImageFeatures > int MLBSTrackingFilter< ShOrder, NumImageFeatures >::CheckCurvature(FiberType* fib, bool front) { double m_Distance = 5; if (fib->size()<3) return 0; double dist = 0; std::vector< vnl_vector_fixed< float, 3 > > vectors; vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); double dev = 0; if (front) { int c=0; while(distsize()-1) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c+1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==0) meanV += v; c++; } } else { int c=fib->size()-1; while(dist0) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c-1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==fib->size()-1) meanV += v; c--; } } meanV.normalize(); for (int c=0; c1.0) angle = 1.0; if (angle<-1.0) angle = -1.0; dev += acos(angle)*180/M_PI; } if (vectors.size()>0) dev /= vectors.size(); if (dev<30) return 0; else return vectors.size(); } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) { - m_Mutex.Lock(); - m_Threads++; - m_Mutex.Unlock(); typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; MaskIteratorType sit(m_SeedImage, regionForThread ); MaskIteratorType mit(m_MaskImage, regionForThread ); + std::vector< itk::Point > seedpoints; sit.GoToBegin(); mit.GoToBegin(); - itk::Point worldPos; while( !sit.IsAtEnd() ) { if (sit.Value()==0 || mit.Value()==0) { ++sit; ++mit; continue; } for (int s=0; s start; - unsigned int counter = 0; if (m_SeedsPerVoxel>1) { start[0] = index[0]+GetRandDouble(-0.5, 0.5); start[1] = index[1]+GetRandDouble(-0.5, 0.5); start[2] = index[2]+GetRandDouble(-0.5, 0.5); } else { start[0] = index[0]; start[1] = index[1]; start[2] = index[2]; } - // get staring position + itk::Point worldPos; m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); + seedpoints.push_back(worldPos); + } + ++sit; + ++mit; + } - // get starting direction - int candidates = 0; - double prob = 0; - vnl_vector_fixed dir; dir.fill(0.0); - std::deque< vnl_vector_fixed > olddirs; - for (int i=0; i worldPos = seedpoints.at(i); + FiberType fib; + double tractLength = 0; + unsigned int counter = 0; + + // get starting direction + int candidates = 0; + double prob = 0; + vnl_vector_fixed dir; dir.fill(0.0); + std::deque< vnl_vector_fixed > olddirs; + for (int i=0; i void MLBSTrackingFilter< ShOrder, NumImageFeatures >::BuildFibers(bool check) { - if (m_BuildFibersReady::New(); vtkSmartPointer vNewLines = vtkSmartPointer::New(); vtkSmartPointer vNewPoints = vtkSmartPointer::New(); for (int i=0; i container = vtkSmartPointer::New(); FiberType fib = m_Tractogram.at(i); for (FiberType::iterator it = fib.begin(); it!=fib.end(); ++it) { vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); container->GetPointIds()->InsertNextId(id); } vNewLines->InsertNextCell(container); } if (check) for (int i=0; iSetPoints(vNewPoints); m_FiberPolyData->SetLines(vNewLines); m_BuildFibersFinished = true; } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::AfterThreadedGenerateData() { MITK_INFO << "Generating polydata "; BuildFibers(false); MITK_INFO << "done"; m_EndTime = std::chrono::system_clock::now(); std::chrono::hours hh = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::minutes mm = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::seconds ss = std::chrono::duration_cast(m_EndTime - m_StartTime); mm %= 60; ss %= 60; MITK_INFO << "Tracking took " << hh.count() << "h, " << mm.count() << "m and " << ss.count() << "s"; } } #endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h index c7cb965000..d553666b89 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h @@ -1,184 +1,182 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ /*=================================================================== This file is based heavily on a corresponding ITK filter. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_h_ #define __itkMLBSTrackingFilter_h_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // classification includes #include #include #include namespace itk{ /** * \brief Performes deterministic streamline tracking on the input tensor image. */ template< int ShOrder, int NumImageFeatures > class MLBSTrackingFilter : public ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > { public: typedef MLBSTrackingFilter Self; typedef SmartPointer Pointer; typedef SmartPointer ConstPointer; typedef ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > Superclass; typedef typename Superclass::InputImageType InputImageType; typedef typename Superclass::InputImageRegionType InputImageRegionType; typedef Image< Vector< float, NumImageFeatures > , 3 > FeatureImageType; /** Method for creation through the object factory. */ itkFactorylessNewMacro(Self) itkCloneMacro(Self) /** Runtime information support. */ itkTypeMacro(MLBSTrackingFilter, ImageToImageFilter) typedef itk::Image ItkUcharImgType; typedef itk::Image ItkDoubleImgType; typedef itk::Image ItkFloatImgType; typedef vtkSmartPointer< vtkPolyData > PolyDataType; typedef std::deque< itk::Point > FiberType; typedef std::vector< FiberType > BundleType; volatile bool m_PauseTracking; bool m_AbortTracking; bool m_BuildFibersFinished; int m_BuildFibersReady; volatile bool m_Stop; mitk::PointSet::Pointer m_SamplingPointset; mitk::PointSet::Pointer m_AlternativePointset; itkGetMacro( FiberPolyData, PolyDataType ) ///< Output fibers itkSetMacro( SeedImage, ItkUcharImgType::Pointer) ///< Seeds are only placed inside of this mask. itkSetMacro( MaskImage, ItkUcharImgType::Pointer) ///< Tracking is only performed inside of this mask image. itkSetMacro( SeedsPerVoxel, int) ///< One seed placed in the center of each voxel or multiple seeds randomly placed inside each voxel. itkSetMacro( StepSize, double) ///< Integration step size in mm itkSetMacro( MinTractLength, double ) ///< Shorter tracts are discarded. itkSetMacro( MaxTractLength, double ) ///< Streamline progression stops if tract is longer than specified. itkSetMacro( AngularThreshold, double ) ///< Probabilities for directions with larger angular deviation from previous direction is set to 0 itkSetMacro( SamplingDistance, double ) ///< Maximum distance of sampling points in mm itkSetMacro( UseStopVotes, bool ) ///< Frontal sampling points can vote for stopping the streamline even if the remaining sampling points keep pushing itkSetMacro( OnlyForwardSamples, bool ) ///< Don't use sampling points behind the current position in progression direction itkSetMacro( DeflectionMod, double ) ///< Deflection distance modifier itkSetMacro( StoppingRegions, ItkUcharImgType::Pointer) ///< Streamlines entering a stopping region will stop immediately itkSetMacro( DemoMode, bool ) itkSetMacro( NumberOfSamples, int ) ///< Number of neighborhood sampling points itkSetMacro( AposterioriCurvCheck, bool ) ///< Checks fiber curvature (angular deviation across 5mm) is larger than 30°. If yes, the streamline progression is stopped. itkSetMacro( AvoidStop, bool ) ///< Use additional sampling points to avoid premature streamline termination itkSetMacro( RandomSampling, bool ) ///< If true, the sampling points are distributed randomly around the current position, not sphericall in the specified sampling distance. itkSetMacro( NumPreviousDirections, int ) ///< How many "old" steps do we want to consider in our decision where to go next? void SetForestHandler( mitk::TrackingForestHandler fh ) ///< Stores random forest classifier and performs actual classification { m_ForestHandler = fh; } protected: MLBSTrackingFilter(); ~MLBSTrackingFilter() {} void CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir); ///< Calculate next integration step. double FollowStreamline(itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front); ///< Start streamline in one direction. bool IsValidPosition(itk::Point& pos); ///< Are we outside of the mask image? vnl_vector_fixed GetNewDirection(itk::Point& pos, std::deque< vnl_vector_fixed >& olddirs); ///< Determine new direction by sample voting at the current position taking the last progression direction into account. double GetRandDouble(double min=-1, double max=1); std::vector< vnl_vector_fixed > CreateDirections(int NPoints); void BeforeThreadedGenerateData() override; void ThreadedGenerateData( const InputImageRegionType &outputRegionForThread, ThreadIdType threadId) override; void AfterThreadedGenerateData() override; PolyDataType m_FiberPolyData; vtkSmartPointer m_Points; vtkSmartPointer m_Cells; BundleType m_Tractogram; double m_AngularThreshold; double m_StepSize; int m_MaxLength; double m_MinTractLength; double m_MaxTractLength; int m_SeedsPerVoxel; bool m_RandomSampling; double m_SamplingDistance; double m_DeflectionMod; bool m_OnlyForwardSamples; bool m_UseStopVotes; int m_NumberOfSamples; int m_NumPreviousDirections; - SimpleFastMutexLock m_Mutex; ItkUcharImgType::Pointer m_StoppingRegions; ItkUcharImgType::Pointer m_SeedImage; ItkUcharImgType::Pointer m_MaskImage; bool m_AposterioriCurvCheck; bool m_AvoidStop; - int m_Threads; bool m_DemoMode; void BuildFibers(bool check); int CheckCurvature(FiberType* fib, bool front); // decision forest mitk::TrackingForestHandler m_ForestHandler; typename InputImageType::Pointer m_InputImage; std::vector< PolyDataType > m_PolyDataContainer; std::chrono::time_point m_StartTime; std::chrono::time_point m_EndTime; private: }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkMLBSTrackingFilter.cpp" #endif #endif //__itkMLBSTrackingFilter_h_ diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp index 820dfb3081..4c3243ca0b 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp @@ -1,435 +1,434 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // Blueberry #include #include // Qmitk #include "QmitkMLBTView.h" #include "QmitkStdMultiWidget.h" // Qt #include #include #include #include #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include const std::string QmitkMLBTView::VIEW_ID = "org.mitk.views.mlbtview"; using namespace berry; QmitkMLBTView::QmitkMLBTView() : QmitkFunctionality() , m_Controls( 0 ) , m_MultiWidget( NULL ) { m_TrackingTimer = std::make_shared(this); m_LastLoadedForestName = "(none)"; } // Destructor QmitkMLBTView::~QmitkMLBTView() { } void QmitkMLBTView::CreateQtPartControl( QWidget *parent ) { // build up qt view, unless already done if ( !m_Controls ) { // create GUI widgets from the Qt Designer's .ui file m_Controls = new Ui::QmitkMLBTViewControls; m_Controls->setupUi( parent ); connect( m_Controls->m_StartTrainingButton, SIGNAL ( clicked() ), this, SLOT( StartTrainingThread() ) ); connect( &m_TrainingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrainingThreadStop() ) ); connect( m_Controls->m_StartTrackingButton, SIGNAL ( clicked() ), this, SLOT( StartTrackingThread() ) ); connect( &m_TrackingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrackingThreadStop() ) ); connect( m_Controls->m_SaveForestButton, SIGNAL ( clicked() ), this, SLOT( SaveForest() ) ); connect( m_Controls->m_LoadForestButton, SIGNAL ( clicked() ), this, SLOT( LoadForest() ) ); connect( m_TrackingTimer.get(), SIGNAL(timeout()), this, SLOT(BuildFibers()) ); connect( m_Controls->m_TimerIntervalBox, SIGNAL(valueChanged(int)), this, SLOT( ChangeTimerInterval(int) )); connect( m_Controls->m_DemoModeBox, SIGNAL(stateChanged(int)), this, SLOT( ToggleDemoMode(int) )); connect( m_Controls->m_PauseTrackingButton, SIGNAL ( clicked() ), this, SLOT( PauseTracking() ) ); connect( m_Controls->m_AbortTrackingButton, SIGNAL ( clicked() ), this, SLOT( AbortTracking() ) ); connect( m_Controls->m_AddTwButton, SIGNAL ( clicked() ), this, SLOT( AddTrainingWidget() ) ); connect( m_Controls->m_RemoveTwButton, SIGNAL ( clicked() ), this, SLOT( RemoveTrainingWidget() ) ); m_Controls->m_TrackingMaskImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingSeedImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingStopImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingRawImageBox->SetDataStorage(this->GetDataStorage()); mitk::NodePredicateIsDWI::Pointer isDiffusionImage = mitk::NodePredicateIsDWI::New(); mitk::TNodePredicateDataType::Pointer isMitkImage = mitk::TNodePredicateDataType::New(); mitk::NodePredicateNot::Pointer noDiffusionImage = mitk::NodePredicateNot::New(isDiffusionImage); mitk::NodePredicateAnd::Pointer finalPredicate = mitk::NodePredicateAnd::New(isMitkImage, noDiffusionImage); mitk::NodePredicateProperty::Pointer isBinaryPredicate = mitk::NodePredicateProperty::New("binary", mitk::BoolProperty::New(true)); finalPredicate = mitk::NodePredicateAnd::New(finalPredicate, isBinaryPredicate); m_Controls->m_TrackingMaskImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingSeedImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingStopImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingRawImageBox->SetPredicate(isDiffusionImage); m_Controls->m_TrackingMaskImageBox->SetZeroEntryText("--"); m_Controls->m_TrackingSeedImageBox->SetZeroEntryText("--"); m_Controls->m_TrackingStopImageBox->SetZeroEntryText("--"); AddTrainingWidget(); UpdateGui(); } } void QmitkMLBTView::AddTrainingWidget() { std::shared_ptr tw = std::make_shared(); tw->SetDataStorage(this->GetDataStorage()); m_Controls->m_TwFrame->layout()->addWidget(tw.get()); m_TrainingWidgets.push_back(tw); } void QmitkMLBTView::RemoveTrainingWidget() { if(m_TrainingWidgets.size()>1) { m_TrainingWidgets.back().reset(); m_TrainingWidgets.pop_back(); } } void QmitkMLBTView::UpdateGui() { if (m_ForestHandler.IsForestValid()) { std::string label_text="Random forest available: "+m_LastLoadedForestName; m_Controls->statusLabel->setText( QString(label_text.c_str()) ); m_Controls->m_SaveForestButton->setEnabled(true); m_Controls->m_StartTrackingButton->setEnabled(true); } else { m_Controls->statusLabel->setText("Please load or train random forest!"); m_Controls->m_SaveForestButton->setEnabled(false); m_Controls->m_StartTrackingButton->setEnabled(false); } } void QmitkMLBTView::AbortTracking() { if (tracker.IsNotNull()) { tracker->m_AbortTracking = true; } } void QmitkMLBTView::PauseTracking() { if (tracker.IsNotNull()) { tracker->m_PauseTracking = !tracker->m_PauseTracking; } } void QmitkMLBTView::ChangeTimerInterval(int value) { m_TrackingTimer->setInterval(value); } void QmitkMLBTView::ToggleDemoMode(int state) { if (tracker.IsNotNull()) { tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); tracker->m_Stop = false; } } void QmitkMLBTView::BuildFibers() { if (m_Controls->m_DemoModeBox->isChecked() && tracker.IsNotNull() && tracker->m_BuildFibersFinished) { vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); outFib->SetFiberColors(255,255,255); m_TractogramNode->SetData(outFib); m_SamplingPointsNode->SetData(tracker->m_SamplingPointset); m_AlternativePointsNode->SetData(tracker->m_AlternativePointset); mitk::RenderingManager::GetInstance()->RequestUpdateAll(); tracker->m_BuildFibersFinished = false; tracker->m_BuildFibersReady = 0; tracker->m_Stop = false; } } void QmitkMLBTView::LoadForest() { QString filename = QFileDialog::getOpenFileName(0, tr("Load Forest"), QDir::currentPath(), tr("HDF5 random forest file (*.rf)") ); if(filename.isEmpty() || filename.isNull()) return; m_ForestHandler.LoadForest( filename.toStdString() ); QFileInfo fi( filename ); m_LastLoadedForestName = QString( fi.baseName() + "." + fi.completeSuffix() ).toStdString(); UpdateGui(); } void QmitkMLBTView::StartTrackingThread() { m_TractogramNode = mitk::DataNode::New(); m_TractogramNode->SetName("MLBT Result"); m_TractogramNode->SetProperty("Fiber2DSliceThickness", mitk::FloatProperty::New(20)); m_TractogramNode->SetProperty("Fiber2DfadeEFX", mitk::BoolProperty::New(false)); m_TractogramNode->SetProperty("LineWidth", mitk::IntProperty::New(2)); m_TractogramNode->SetProperty("color",mitk::ColorProperty::New(0, 1, 1)); this->GetDataStorage()->Add(m_TractogramNode); m_SamplingPointsNode = mitk::DataNode::New(); m_SamplingPointsNode->SetName("SamplingPoints"); m_SamplingPointsNode->SetProperty("pointsize", mitk::FloatProperty::New(0.2)); m_SamplingPointsNode->SetProperty("color", mitk::ColorProperty::New(1,1,1)); mitk::PointSetShapeProperty::Pointer bla = mitk::PointSetShapeProperty::New(); bla->SetValue(8); m_SamplingPointsNode->SetProperty("Pointset.2D.shape", bla); m_SamplingPointsNode->SetProperty("Pointset.2D.distance to plane", mitk::FloatProperty::New(1.5)); m_SamplingPointsNode->SetProperty("point 2D size", mitk::FloatProperty::New(0.1)); m_SamplingPointsNode->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_SamplingPointsNode); m_AlternativePointsNode = mitk::DataNode::New(); m_AlternativePointsNode->SetName("AlternativePoints"); m_AlternativePointsNode->SetProperty("pointsize", mitk::FloatProperty::New(0.2)); m_AlternativePointsNode->SetProperty("color", mitk::ColorProperty::New(1,0,0)); m_AlternativePointsNode->SetProperty("Pointset.2D.shape", bla); m_AlternativePointsNode->SetProperty("Pointset.2D.distance to plane", mitk::FloatProperty::New(1.5)); m_AlternativePointsNode->SetProperty("point 2D size", mitk::FloatProperty::New(0.1)); m_AlternativePointsNode->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_AlternativePointsNode); QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTracking ); m_TrackingWatcher.setFuture(future); m_TrackingThreadIsRunning = true; m_Controls->m_StartTrackingButton->setEnabled(false); m_TrackingTimer->start(m_Controls->m_TimerIntervalBox->value()); } void QmitkMLBTView::OnTrackingThreadStop() { m_TrackingThreadIsRunning = false; m_Controls->m_StartTrackingButton->setEnabled(true); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); m_TractogramNode->SetData(outFib); m_TractogramNode->SetProperty("Fiber2DSliceThickness", mitk::FloatProperty::New(1)); if (m_Controls->m_DemoModeBox->isChecked()) { m_SamplingPointsNode->SetData(tracker->m_SamplingPointset); m_AlternativePointsNode->SetData(tracker->m_AlternativePointset); } tracker = NULL; m_TrackingTimer->stop(); mitk::RenderingManager::GetInstance()->RequestUpdateAll(); } void QmitkMLBTView::StartTracking() { if ( m_Controls->m_TrackingRawImageBox->GetSelectedNode().IsNull() || !m_ForestHandler.IsForestValid()) return; mitk::Image::Pointer dwi = dynamic_cast(m_Controls->m_TrackingRawImageBox->GetSelectedNode()->GetData()); m_ForestHandler.AddDwi(dwi); // int numThread = itk::MultiThreader::GetGlobalDefaultNumberOfThreads(); tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); if (m_Controls->m_DemoModeBox->isChecked()) tracker->SetNumberOfThreads(1); if (m_Controls->m_TrackingMaskImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer mask = dynamic_cast(m_Controls->m_TrackingMaskImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(mask, itkMask); tracker->SetMaskImage(itkMask); } if (m_Controls->m_TrackingSeedImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingSeedImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkImg); tracker->SetSeedImage(itkImg); } if (m_Controls->m_TrackingStopImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingStopImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkImg); tracker->SetStoppingRegions(itkImg); } tracker->SetSeedsPerVoxel(m_Controls->m_NumberOfSeedsBox->value()); tracker->SetStepSize(m_Controls->m_TrackingStepSizeBox->value()); tracker->SetMinTractLength(m_Controls->m_MinLengthBox->value()); tracker->SetMaxTractLength(m_Controls->m_MaxLengthBox->value()); tracker->SetAposterioriCurvCheck(m_Controls->m_Curvcheck2->isChecked()); tracker->SetNumberOfSamples(m_Controls->m_NumSamplesBox->value()); tracker->SetAvoidStop(m_Controls->m_AvoidStop->isChecked()); tracker->SetForestHandler(m_ForestHandler); tracker->SetSamplingDistance(m_Controls->m_SamplingDistanceBox->value()); tracker->SetDeflectionMod(m_Controls->m_DeflectionModBox->value()); tracker->SetRandomSampling(m_Controls->m_RandomSampling->isChecked()); tracker->SetUseStopVotes(m_Controls->m_UseStopVotes->isChecked()); tracker->SetOnlyForwardSamples(m_Controls->m_OnlyForwardSamples->isChecked()); tracker->SetNumPreviousDirections(m_Controls->m_NumPrevDirs->value()); - tracker->SetNumberOfThreads(1); tracker->Update(); } void QmitkMLBTView::SaveForest() { if (!m_ForestHandler.IsForestValid()) { UpdateGui(); return; } QString filename = QFileDialog::getSaveFileName(0, tr("Save Forest"), QDir::currentPath()+"/forest.rf", tr("HDF5 random forest file (*.rf)") ); if(filename.isEmpty() || filename.isNull()) return; if(!filename.endsWith(".rf")) filename += ".rf"; m_ForestHandler.SaveForest( filename.toStdString() ); } void QmitkMLBTView::StartTrainingThread() { if (!this->IsTrainingInputValid()) { QMessageBox::warning(nullptr, "Training aborted", "Training could not be started. Not all necessary datasets were selected."); return; } QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTraining ); m_TrainingWatcher.setFuture(future); m_Controls->m_StartTrainingButton->setEnabled(false); m_Controls->m_SaveForestButton->setEnabled(false); m_Controls->m_LoadForestButton->setEnabled(false); } void QmitkMLBTView::OnTrainingThreadStop() { m_Controls->m_StartTrainingButton->setEnabled(true); m_Controls->m_SaveForestButton->setEnabled(true); m_Controls->m_LoadForestButton->setEnabled(true); UpdateGui(); } void QmitkMLBTView::StartTraining() { std::vector< mitk::Image::Pointer > m_SelectedDiffImages; std::vector< mitk::FiberBundle::Pointer > m_SelectedFB; std::vector< ItkUcharImgType::Pointer > m_MaskImages; std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; for (auto w : m_TrainingWidgets) { m_SelectedDiffImages.push_back(dynamic_cast(w->GetImage()->GetData())); m_SelectedFB.push_back(dynamic_cast(w->GetFibers()->GetData())); if (w->GetMask().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(w->GetMask()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkMask); m_MaskImages.push_back(itkMask); } else m_MaskImages.push_back(nullptr); if (w->GetWhiteMatter().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(w->GetWhiteMatter()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkMask); m_WhiteMatterImages.push_back(itkMask); } else m_WhiteMatterImages.push_back(nullptr); } m_ForestHandler.SetDwis(m_SelectedDiffImages); m_ForestHandler.SetTractograms(m_SelectedFB); m_ForestHandler.SetMaskImages(m_MaskImages); m_ForestHandler.SetWhiteMatterImages(m_WhiteMatterImages); m_ForestHandler.SetNumTrees(m_Controls->m_NumTreesBox->value()); m_ForestHandler.SetMaxTreeDepth(m_Controls->m_MaxDepthBox->value()); m_ForestHandler.SetGrayMatterSamplesPerVoxel(m_Controls->m_GmSamplingBox->value()); m_ForestHandler.SetSampleFraction(m_Controls->m_SampleFractionBox->value()); m_ForestHandler.SetStepSize(m_Controls->m_TrainingStepSizeBox->value()); m_ForestHandler.SetNumPreviousDirections(m_Controls->m_NumPrevDirs->value()); m_ForestHandler.StartTraining(); } void QmitkMLBTView::StdMultiWidgetAvailable (QmitkStdMultiWidget &stdMultiWidget) { m_MultiWidget = &stdMultiWidget; } void QmitkMLBTView::StdMultiWidgetNotAvailable() { m_MultiWidget = NULL; } void QmitkMLBTView::Activated() { } bool QmitkMLBTView::IsTrainingInputValid(void) const { for (auto widget : m_TrainingWidgets) { if (widget->GetImage().IsNull() || widget->GetFibers().IsNull()) { return false; } } return true; }