diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp index 112d08a3de..8c34946e69 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -1,594 +1,605 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_txx #define __itkMLBSTrackingFilter_txx #include #include #include #include "itkMLBSTrackingFilter.h" #include #include #include #include #define _USE_MATH_DEFINES #include namespace itk { template< int ShOrder, int NumImageFeatures > MLBSTrackingFilter< ShOrder, NumImageFeatures > ::MLBSTrackingFilter() : m_PauseTracking(false) , m_AbortTracking(false) , m_FiberPolyData(NULL) , m_Points(NULL) , m_Cells(NULL) , m_AngularThreshold(0.7) , m_StepSize(0) , m_MaxLength(10000) , m_MinTractLength(20.0) , m_MaxTractLength(400.0) , m_SeedsPerVoxel(1) , m_RandomSampling(false) , m_SamplingDistance(-1) - , m_NumberOfSamples(50) + , m_DeflectionMod(1.0) + , m_OnlyForwardSamples(true) + , m_UseStopVotes(true) , m_StoppingRegions(NULL) , m_SeedImage(NULL) , m_MaskImage(NULL) , m_AposterioriCurvCheck(false) , m_RemoveWmEndFibers(false) , m_AvoidStop(true) , m_DemoMode(false) { this->SetNumberOfRequiredInputs(1); } -template< int ShOrder, int NumImageFeatures > -double MLBSTrackingFilter< ShOrder, NumImageFeatures > -::RoundToNearest(double num) { - return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); -} - template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::BeforeThreadedGenerateData() { m_InputImage = const_cast(this->GetInput(0)); m_ForestHandler.InitForTracking(); m_FiberPolyData = PolyDataType::New(); m_Points = vtkSmartPointer< vtkPoints >::New(); m_Cells = vtkSmartPointer< vtkCellArray >::New(); std::vector< double > m_ImageSpacing; m_ImageSpacing.resize(3); m_ImageSpacing[0] = m_InputImage->GetSpacing()[0]; m_ImageSpacing[1] = m_InputImage->GetSpacing()[1]; m_ImageSpacing[2] = m_InputImage->GetSpacing()[2]; double minSpacing; if(m_ImageSpacing[0]GetNumberOfThreads(); i++) { PolyDataType poly = PolyDataType::New(); m_PolyDataContainer.push_back(poly); } if (m_StoppingRegions.IsNull()) { m_StoppingRegions = ItkUcharImgType::New(); m_StoppingRegions->SetSpacing( m_InputImage->GetSpacing() ); m_StoppingRegions->SetOrigin( m_InputImage->GetOrigin() ); m_StoppingRegions->SetDirection( m_InputImage->GetDirection() ); m_StoppingRegions->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_StoppingRegions->Allocate(); m_StoppingRegions->FillBuffer(0); } if (m_SeedImage.IsNull()) { m_SeedImage = ItkUcharImgType::New(); m_SeedImage->SetSpacing( m_InputImage->GetSpacing() ); m_SeedImage->SetOrigin( m_InputImage->GetOrigin() ); m_SeedImage->SetDirection( m_InputImage->GetDirection() ); m_SeedImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_SeedImage->Allocate(); m_SeedImage->FillBuffer(1); } if (m_MaskImage.IsNull()) { // initialize mask image m_MaskImage = ItkUcharImgType::New(); m_MaskImage->SetSpacing( m_InputImage->GetSpacing() ); m_MaskImage->SetOrigin( m_InputImage->GetOrigin() ); m_MaskImage->SetDirection( m_InputImage->GetDirection() ); m_MaskImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1); } else std::cout << "MLBSTrackingFilter: using mask image" << std::endl; if (m_AngularThreshold<0.0) m_AngularThreshold = 0.5*minSpacing; m_BuildFibersReady = 0; m_BuildFibersFinished = false; m_Threads = 0; m_Tractogram.clear(); m_SamplingPointset = mitk::PointSet::New(); m_AlternativePointset = mitk::PointSet::New(); m_StartTime = std::chrono::system_clock::now(); std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; - std::cout << "MLBSTrackingFilter: Number of samples: " << m_NumberOfSamples << std::endl; + std::cout << "MLBSTrackingFilter: Deflection modifier: " << m_DeflectionMod << std::endl; std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; + std::cout << "MLBSTrackingFilter: Use stop votes: " << m_UseStopVotes << std::endl; + std::cout << "MLBSTrackingFilter: Only frontal samples: " << m_OnlyForwardSamples << std::endl; std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) { // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); // dir = rot*dir; dir *= m_StepSize; pos[0] += dir[0]; pos[1] += dir[1]; pos[2] += dir[2]; } template< int ShOrder, int NumImageFeatures > bool MLBSTrackingFilter< ShOrder, NumImageFeatures > ::IsValidPosition(itk::Point &pos) { typename FeatureImageType::IndexType idx; m_InputImage->TransformPhysicalPointToIndex(pos, idx); if (!m_InputImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) return false; return true; } template< int ShOrder, int NumImageFeatures > double MLBSTrackingFilter< ShOrder, NumImageFeatures >::GetRandDouble(double min, double max) { return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; } template< int ShOrder, int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< ShOrder, NumImageFeatures >::GetNewDirection(itk::Point &pos, vnl_vector_fixed& olddir) { if (m_DemoMode) { m_SamplingPointset->Clear(); m_AlternativePointset->Clear(); } vnl_vector_fixed direction; direction.fill(0); ItkUcharImgType::IndexType idx; m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); if (m_StoppingRegions->GetLargestPossibleRegion().IsInside(idx) && m_StoppingRegions->GetPixel(idx)>0) return direction; if (m_MaskImage.IsNotNull() && ((m_MaskImage->GetLargestPossibleRegion().IsInside(idx) && m_MaskImage->GetPixel(idx)<=0) || !m_MaskImage->GetLargestPossibleRegion().IsInside(idx)) ) return direction; if (olddir.magnitude()>0) olddir.normalize(); int candidates = 0; // number of directions with probability > 0 double w = 0; // weight of the direction predicted at each sampling point if (IsValidPosition(pos)) { direction = m_ForestHandler.Classify(pos, candidates, olddir, m_AngularThreshold, w, m_MaskImage); // get direction proposal at current streamline position direction *= w; // HERE WE ARE WEIGHTING AGAIN EVEN THOUGH THE OUTPUT DIRECTIONS ARE ALREADY WEIGHTED!!! THE EFFECT OF THIS HAS YET TO BE EVALUATED. } - itk::OrientationDistributionFunction< double, 50 > probeVecs; + itk::OrientationDistributionFunction< double, 30 > probeVecs; itk::Point sample_pos; int alternatives = 1; - for (int i=0; i d; - + bool is_stop_voter = false; if (m_RandomSampling) { d[0] = GetRandDouble(); d[1] = GetRandDouble(); d[2] = GetRandDouble(); d.normalize(); d *= GetRandDouble(0,m_SamplingDistance); } else { + double dot = dot_product(probeVecs.GetDirection(i), olddir); + if (m_UseStopVotes && dot>0.7) + { + is_stop_voter = true; + possible_stop_votes++; + } + else if (m_OnlyForwardSamples && dot<0) + continue; d = probeVecs.GetDirection(i)*m_SamplingDistance; } sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_SamplingPointset->InsertPoint(i, sample_pos); candidates = 0; vnl_vector_fixed tempDir; tempDir.fill(0.0); if (IsValidPosition(sample_pos)) tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, w, m_MaskImage); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) { direction += tempDir*w; } else if (m_AvoidStop && candidates==0 && olddir.magnitude()>0) // out of white matter { - double dot = dot_product(d, olddir); - if (dot >= 0.0) // in front of plane defined by pos and olddir - d = -d + 2*dot*olddir; // reflect - else - d = -d; // invert - - // look a bit further into the other direction - sample_pos[0] = pos[0] + d[0]; - sample_pos[1] = pos[1] + d[1]; - sample_pos[2] = pos[2] + d[2]; - if(m_DemoMode) - m_AlternativePointset->InsertPoint(alternatives, sample_pos); - alternatives++; - candidates = 0; - vnl_vector_fixed tempDir; tempDir.fill(0.0); - if (IsValidPosition(sample_pos)) - tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, w, m_MaskImage); // sample neighborhood - - if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? - { - direction += d; // go into the direction of the white matter - direction += tempDir*w; // go into the direction of the white matter direction at this location - } + if (is_stop_voter) + stop_votes++; + + double dot = dot_product(d, olddir); + if (dot >= 0.0) // in front of plane defined by pos and olddir + d = -d + 2*dot*olddir; // reflect + else + d = -d; // invert + + // look a bit further into the other direction + sample_pos[0] = pos[0] + d[0]; + sample_pos[1] = pos[1] + d[1]; + sample_pos[2] = pos[2] + d[2]; + if(m_DemoMode) + m_AlternativePointset->InsertPoint(alternatives, sample_pos); + alternatives++; + candidates = 0; + vnl_vector_fixed tempDir; tempDir.fill(0.0); + if (IsValidPosition(sample_pos)) + tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, w, m_MaskImage); // sample neighborhood + + if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? + { + direction += d * m_DeflectionMod; // go into the direction of the white matter + direction += tempDir*w; // go into the direction of the white matter direction at this location + } } + else if (is_stop_voter) + stop_votes++; } - if (direction.magnitude()>0.001) + if (direction.magnitude()>0.001 && (possible_stop_votes==0 || (float)stop_votes/possible_stop_votes<0.5) ) { direction.normalize(); olddir[0] = direction[0]; olddir[1] = direction[1]; olddir[2] = direction[2]; } else direction.fill(0); return direction; } template< int ShOrder, int NumImageFeatures > double MLBSTrackingFilter< ShOrder, NumImageFeatures >::FollowStreamline(itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) { vnl_vector_fixed dirOld = dir; - dirOld = dir; for (int step=0; step< m_MaxLength/2; step++) { - // get new position CalculateNewPosition(pos, dir); // is new position inside of image and mask if (m_AbortTracking) // if not end streamline { return tractLength; } else // if yes, add new point to streamline { tractLength += m_StepSize; if (front) fib->push_front(pos); else fib->push_back(pos); if (m_AposterioriCurvCheck) { int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? if (curv>0) { tractLength -= m_StepSize*curv; while (curv>0) { if (front) fib->pop_front(); else fib->pop_back(); curv--; } return tractLength; } } if (tractLength>m_MaxTractLength) return tractLength; } if (m_DemoMode) // CHECK: warum sind die samplingpunkte der streamline in der visualisierung immer einen schritt voras? { m_Mutex.Lock(); m_BuildFibersReady++; m_Tractogram.push_back(*fib); BuildFibers(true); m_Stop = true; m_Mutex.Unlock(); while (m_Stop){ } } dir = GetNewDirection(pos, dirOld); while (m_PauseTracking){} if (dir.magnitude()<0.0001) return tractLength; } return tractLength; } template< int ShOrder, int NumImageFeatures > int MLBSTrackingFilter< ShOrder, NumImageFeatures >::CheckCurvature(FiberType* fib, bool front) { double m_Distance = 5; if (fib->size()<3) return 0; double dist = 0; std::vector< vnl_vector_fixed< float, 3 > > vectors; vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); double dev = 0; if (front) { int c=0; while(distsize()-1) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c+1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==0) meanV += v; c++; } } else { int c=fib->size()-1; while(dist0) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c-1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==fib->size()-1) meanV += v; c--; } } meanV.normalize(); for (int c=0; c1.0) angle = 1.0; if (angle<-1.0) angle = -1.0; dev += acos(angle)*180/M_PI; } if (vectors.size()>0) dev /= vectors.size(); if (dev<30) return 0; else return vectors.size(); } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) { m_Mutex.Lock(); m_Threads++; m_Mutex.Unlock(); typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; MaskIteratorType sit(m_SeedImage, regionForThread ); MaskIteratorType mit(m_MaskImage, regionForThread ); sit.GoToBegin(); mit.GoToBegin(); itk::Point worldPos; while( !sit.IsAtEnd() ) { if (sit.Value()==0 || mit.Value()==0) { ++sit; ++mit; continue; } for (int s=0; s start; unsigned int counter = 0; if (m_SeedsPerVoxel>1) { start[0] = index[0]+GetRandDouble(-0.5, 0.5); start[1] = index[1]+GetRandDouble(-0.5, 0.5); start[2] = index[2]+GetRandDouble(-0.5, 0.5); } else { start[0] = index[0]; start[1] = index[1]; start[2] = index[2]; } // get staring position m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); // get starting direction int candidates = 0; double prob = 0; vnl_vector_fixed dirOld; dirOld.fill(0.0); vnl_vector_fixed dir; dir.fill(0.0); if (IsValidPosition(worldPos)) dir = m_ForestHandler.Classify(worldPos, candidates, dirOld, 0, prob, m_MaskImage); if (dir.magnitude()<0.0001) continue; // forward tracking tractLength = FollowStreamline(worldPos, dir, &fib, 0, false); fib.push_front(worldPos); if (m_RemoveWmEndFibers) { itk::Point check = fib.back(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } // backward tracking tractLength = FollowStreamline(worldPos, -dir, &fib, tractLength, true); counter = fib.size(); if (m_RemoveWmEndFibers) { itk::Point check = fib.front(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } if (tractLength void MLBSTrackingFilter< ShOrder, NumImageFeatures >::BuildFibers(bool check) { if (m_BuildFibersReady::New(); vtkSmartPointer vNewLines = vtkSmartPointer::New(); vtkSmartPointer vNewPoints = vtkSmartPointer::New(); for (int i=0; i container = vtkSmartPointer::New(); FiberType fib = m_Tractogram.at(i); for (FiberType::iterator it = fib.begin(); it!=fib.end(); ++it) { vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); container->GetPointIds()->InsertNextId(id); } vNewLines->InsertNextCell(container); } if (check) for (int i=0; iSetPoints(vNewPoints); m_FiberPolyData->SetLines(vNewLines); m_BuildFibersFinished = true; } template< int ShOrder, int NumImageFeatures > void MLBSTrackingFilter< ShOrder, NumImageFeatures >::AfterThreadedGenerateData() { MITK_INFO << "Generating polydata "; BuildFibers(false); MITK_INFO << "done"; m_EndTime = std::chrono::system_clock::now(); std::chrono::hours hh = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::minutes mm = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::seconds ss = std::chrono::duration_cast(m_EndTime - m_StartTime); mm %= 60; ss %= 60; MITK_INFO << "Tracking took " << hh.count() << "h, " << mm.count() << "m and " << ss.count() << "s"; } } #endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h index b7986bd18b..85ca9864d1 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h @@ -1,178 +1,181 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ /*=================================================================== This file is based heavily on a corresponding ITK filter. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_h_ #define __itkMLBSTrackingFilter_h_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // classification includes #include #include #include namespace itk{ /** * \brief Performes deterministic streamline tracking on the input tensor image. */ -template< int ShOrder=6, int NumImageFeatures=100 > +template< int ShOrder=6, int NumImageFeatures=28 > class MLBSTrackingFilter : public ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > { public: typedef MLBSTrackingFilter Self; typedef SmartPointer Pointer; typedef SmartPointer ConstPointer; typedef ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > Superclass; typedef typename Superclass::InputImageType InputImageType; typedef typename Superclass::InputImageRegionType InputImageRegionType; typedef Image< Vector< float, NumImageFeatures > , 3 > FeatureImageType; /** Method for creation through the object factory. */ itkFactorylessNewMacro(Self) itkCloneMacro(Self) /** Runtime information support. */ itkTypeMacro(MLBSTrackingFilter, ImageToImageFilter) typedef itk::Image ItkUcharImgType; typedef itk::Image ItkDoubleImgType; typedef itk::Image ItkFloatImgType; typedef vtkSmartPointer< vtkPolyData > PolyDataType; typedef std::deque< itk::Point > FiberType; typedef std::vector< FiberType > BundleType; volatile bool m_PauseTracking; bool m_AbortTracking; bool m_BuildFibersFinished; int m_BuildFibersReady; volatile bool m_Stop; mitk::PointSet::Pointer m_SamplingPointset; mitk::PointSet::Pointer m_AlternativePointset; itkGetMacro( FiberPolyData, PolyDataType ) ///< Output fibers itkSetMacro( SeedImage, ItkUcharImgType::Pointer) ///< Seeds are only placed inside of this mask. itkSetMacro( MaskImage, ItkUcharImgType::Pointer) ///< Tracking is only performed inside of this mask image. itkSetMacro( SeedsPerVoxel, int) ///< One seed placed in the center of each voxel or multiple seeds randomly placed inside each voxel. itkSetMacro( StepSize, double) ///< Integration step size in mm itkSetMacro( MinTractLength, double ) ///< Shorter tracts are discarded. itkSetMacro( MaxTractLength, double ) ///< Streamline progression stops if tract is longer than specified. itkSetMacro( AngularThreshold, double ) ///< Probabilities for directions with larger angular deviation from previous direction is set to 0 itkSetMacro( SamplingDistance, double ) ///< Maximum distance of sampling points in mm - itkSetMacro( NumberOfSamples, int ) ///< Number of sampling points + itkSetMacro( UseStopVotes, bool ) ///< Frontal sampling points can vote for stopping the streamline even if the remaining sampling points keep pushing + itkSetMacro( OnlyForwardSamples, bool ) ///< Don't use sampling points behind the current position in progression direction + itkSetMacro( DeflectionMod, double ) ///< Deflection distance modifier itkSetMacro( StoppingRegions, ItkUcharImgType::Pointer) ///< Streamlines entering a stopping region will stop immediately itkSetMacro( DemoMode, bool ) itkSetMacro( RemoveWmEndFibers, bool ) ///< Checks if fiber ending is located in the white matter. If this is the case, the streamline is discarded. itkSetMacro( AposterioriCurvCheck, bool ) ///< Checks fiber curvature (angular deviation across 5mm) is larger than 30°. If yes, the streamline progression is stopped. itkSetMacro( AvoidStop, bool ) ///< Use additional sampling points to avoid premature streamline termination itkSetMacro( RandomSampling, bool ) ///< If true, the sampling points are distributed randomly around the current position, not sphericall in the specified sampling distance. - void SetForestHandler( mitk::TrackingForestHandler fh ) ///< Stores random forest classifier and performs actual classification + void SetForestHandler( mitk::TrackingForestHandler fh ) ///< Stores random forest classifier and performs actual classification { m_ForestHandler = fh; } protected: MLBSTrackingFilter(); ~MLBSTrackingFilter() {} void CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir); ///< Calculate next integration step. double FollowStreamline(itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front); ///< Start streamline in one direction. bool IsValidPosition(itk::Point& pos); ///< Are we outside of the mask image? vnl_vector_fixed GetNewDirection(itk::Point& pos, vnl_vector_fixed& olddir); ///< Determine new direction by sample voting at the current position taking the last progression direction into account. double GetRandDouble(double min=-1, double max=1); - double RoundToNearest(double num); void BeforeThreadedGenerateData() override; void ThreadedGenerateData( const InputImageRegionType &outputRegionForThread, ThreadIdType threadId) override; void AfterThreadedGenerateData() override; PolyDataType m_FiberPolyData; vtkSmartPointer m_Points; vtkSmartPointer m_Cells; BundleType m_Tractogram; double m_AngularThreshold; double m_StepSize; int m_MaxLength; double m_MinTractLength; double m_MaxTractLength; int m_SeedsPerVoxel; bool m_RandomSampling; double m_SamplingDistance; - int m_NumberOfSamples; + double m_DeflectionMod; + bool m_OnlyForwardSamples; + bool m_UseStopVotes; SimpleFastMutexLock m_Mutex; ItkUcharImgType::Pointer m_StoppingRegions; ItkUcharImgType::Pointer m_SeedImage; ItkUcharImgType::Pointer m_MaskImage; bool m_AposterioriCurvCheck; bool m_RemoveWmEndFibers; bool m_AvoidStop; int m_Threads; bool m_DemoMode; void BuildFibers(bool check); int CheckCurvature(FiberType* fib, bool front); // decision forest - mitk::TrackingForestHandler m_ForestHandler; + mitk::TrackingForestHandler m_ForestHandler; typename InputImageType::Pointer m_InputImage; std::vector< PolyDataType > m_PolyDataContainer; std::chrono::time_point m_StartTime; std::chrono::time_point m_EndTime; private: }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkMLBSTrackingFilter.cpp" #endif #endif //__itkMLBSTrackingFilter_h_ diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp index e871bb0e70..5524b207a6 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp @@ -1,898 +1,868 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _TrackingForestHandler_cpp #define _TrackingForestHandler_cpp #include "mitkTrackingForestHandler.h" #include #include namespace mitk { template< int ShOrder, int NumberOfSignalFeatures > TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::TrackingForestHandler() : m_WmSampleDistance(-1) , m_NumTrees(30) , m_MaxTreeDepth(50) , m_SampleFraction(1.0) , m_NumberOfSamples(0) , m_GmSamplesPerVoxel(50) { + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + + itk::OrientationDistributionFunction< double, 200 > odf; + m_DirectionContainer.clear(); + for (unsigned int i = 0; i0) // only used directions on one hemisphere + m_DirectionContainer.push_back(odf.GetDirection(i)); // store indices for later mapping the classifier output to the actual direction + } } template< int ShOrder, int NumberOfSignalFeatures > TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::~TrackingForestHandler() { } template< int ShOrder, int NumberOfSignalFeatures > - typename TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InterpolatedRawImageType::PixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image) + typename TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::DwiFeatureImageType::PixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetDwiFeaturesAtPosition(itk::Point itkP, typename DwiFeatureImageType::Pointer image) { // transform physical point to index coordinates itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; image->TransformPhysicalPointToIndex(itkP, idx); image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); - typename InterpolatedRawImageType::PixelType pix; pix.Fill(0.0); + typename DwiFeatureImageType::PixelType pix; pix.Fill(0.0); if ( image->GetLargestPossibleRegion().IsInside(idx) ) pix = image->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) { // trilinear interpolation vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = image->GetPixel(idx) * interpWeights[0]; - typename InterpolatedRawImageType::IndexType tmpIdx = idx; tmpIdx[0]++; + typename DwiFeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[7]; } return pix; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InputDataValidForTracking() { - if (m_RawData.empty()) + if (m_InputDwis.empty()) mitkThrow() << "No diffusion-weighted images set!"; if (!IsForestValid()) mitkThrow() << "No or invalid random forest detected!"; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InitForTracking() { InputDataValidForTracking(); MITK_INFO << "Spherically interpolating raw data and creating feature image ..."; typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(0)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(0)) ); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(0))); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_InputDwis.at(0)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_InputDwis.at(0)) ); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_InputDwis.at(0))); filter->SetLambda(0.006); filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); filter->Update(); - vnl_vector_fixed ref; ref.fill(0); ref[0]=1; - itk::OrientationDistributionFunction< double, NumberOfSignalFeatures*2 > odf; - m_DirectionIndices.clear(); - for (unsigned int f=0; f0) // only used directions on one hemisphere - m_DirectionIndices.push_back(f); // store indices for later mapping the classifier output to the actual direction - } + m_DwiFeatureImages.clear(); + + //m_DwiFeatureImages.push_back(filter->GetCoefficientImage()); - m_FeatureImage = FeatureImageType::New(); - m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); - m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); - m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection()); - m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); - m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); - m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); - m_FeatureImage->Allocate(); - - // get signal values and store them in the feature image - itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); - while(!it.IsAtEnd()) { - typename FeatureImageType::PixelType pix; - for (unsigned int f=0; fSetPixel(it.GetIndex(), pix); - ++it; - } + typename DwiFeatureImageType::Pointer dwiFeatureImage = DwiFeatureImageType::New(); + dwiFeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); + dwiFeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); + dwiFeatureImage->SetDirection(filter->GetOutput()->GetDirection()); + dwiFeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->Allocate(); + + // get signal values and store them in the feature image + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + itk::OrientationDistributionFunction< double, 2*NumberOfSignalFeatures > odf; + itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); + while(!it.IsAtEnd()) + { + typename DwiFeatureImageType::PixelType pix; + int f = 0; + for (unsigned int i = 0; i0) // only used directions on one hemisphere + { + pix[f] = it.Get()[i]; + f++; + } + } + dwiFeatureImage->SetPixel(it.GetIndex(), pix); + ++it; + } - //m_Forest->multithreadPrediction = false; + m_DwiFeatureImages.push_back(dwiFeatureImage); + } } template< int ShOrder, int NumberOfSignalFeatures > template< class TPixelType > TPixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetImageValue(itk::Point itkP, itk::Image* image, bool interpolate) { // transform physical point to index coordinates itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; image->TransformPhysicalPointToIndex(itkP, idx); image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); TPixelType pix = 0.0; if ( image->GetLargestPossibleRegion().IsInside(idx) ) { pix = image->GetPixel(idx); if (!interpolate) return pix; } else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) { // trilinear interpolation vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = image->GetPixel(idx) * interpWeights[0]; typename itk::Image::IndexType tmpIdx = idx; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[7]; } if (pix!=pix) mitkThrow() << "nan values in volume modification image!"; return pix; } - template< int ShOrder, int NumberOfSignalFeatures > - typename TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::FeatureImageType::PixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetFeatureValues(itk::Point itkP) - { - // transform physical point to index coordinates - itk::Index<3> idx; - itk::ContinuousIndex< double, 3> cIdx; - m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx); - m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx); - - typename FeatureImageType::PixelType pix; pix.Fill(0.0); - if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) ) - pix = m_FeatureImage->GetPixel(idx); - else - return pix; - - double frac_x = cIdx[0] - idx[0]; - double frac_y = cIdx[1] - idx[1]; - double frac_z = cIdx[2] - idx[2]; - if (frac_x<0) - { - idx[0] -= 1; - frac_x += 1; - } - if (frac_y<0) - { - idx[1] -= 1; - frac_y += 1; - } - if (frac_z<0) - { - idx[2] -= 1; - frac_z += 1; - } - frac_x = 1-frac_x; - frac_y = 1-frac_y; - frac_z = 1-frac_z; - - // int coordinates inside image? - if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 && - idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 && - idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1) - { - // trilinear interpolation - vnl_vector_fixed interpWeights; - interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); - interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); - interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); - interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); - interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); - interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); - interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); - interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); - - pix = m_FeatureImage->GetPixel(idx) * interpWeights[0]; - typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1]; - tmpIdx = idx; tmpIdx[1]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2]; - tmpIdx = idx; tmpIdx[2]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3]; - tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4]; - tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5]; - tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6]; - tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; - pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7]; - } - - return pix; - } - template< int ShOrder, int NumberOfSignalFeatures > vnl_vector_fixed TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& w, ItkUcharImgType::Pointer mask) { + vnl_vector_fixed direction; direction.fill(0); itk::Index<3> idx; - m_FeatureImage->TransformPhysicalPointToIndex(pos, idx); + m_DwiFeatureImages.at(0)->TransformPhysicalPointToIndex(pos, idx); if (mask.IsNotNull() && ((mask->GetLargestPossibleRegion().IsInside(idx) && mask->GetPixel(idx)<=0) || !mask->GetLargestPossibleRegion().IsInside(idx)) ) return direction; + //std::chrono::time_point startTime = std::chrono::system_clock::now(); + //std::chrono::milliseconds ms = std::chrono::duration_cast(std::chrono::system_clock::now() - startTime); + //MITK_INFO << "time 1: " << ms.count() << "ms"; + // store feature pixel values in a vigra data type vigra::MultiArray<2, double> featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,m_Forest->feature_count()) ); - typename FeatureImageType::PixelType featurePixel = GetFeatureValues(pos); + typename DwiFeatureImageType::PixelType dwiFeaturePixel = GetDwiFeaturesAtPosition(pos, m_DwiFeatureImages.at(0)); for (unsigned int f=0; f ref; ref.fill(0); ref[0]=1; for (unsigned int f=NumberOfSignalFeatures; f0) { int add_feat_c = 0; for (auto img : m_AdditionalFeatureImages.at(0)) { float v = GetImageValue(pos, img, false); add_feat_c++; featureData(0,NumberOfSignalFeatures+2+add_feat_c) = v; } } // perform classification vigra::MultiArray<2, double> probs(vigra::Shape2(1, m_Forest->class_count())); m_Forest->predictProbabilities(featureData, probs); double pNonFib = 0; // probability that we left the white matter w = 0; // weight of the predicted direction candidates = 0; // directions with probability > 0 for (int i=0; iclass_count(); i++) // for each class (number of possible directions + out-of-wm class) { if (probs(0,i)>0) // if probability of respective class is 0, do nothing { // get label of class (does not correspond to the loop variable i) int classLabel = 0; m_Forest->ext_param_.to_classlabel(i, classLabel); - if (classLabel 0 (DO WE NEED THIS???) - vnl_vector_fixed d = m_DirContainer.GetDirection(m_DirectionIndices.at(classLabel)); // get direction vector assiciated with the respective direction index + vnl_vector_fixed d = m_DirectionContainer.at(classLabel); // get direction vector assiciated with the respective direction index if (olddir.magnitude()>0) // do we have a previous streamline direction or did we just start? { // TODO: check if hard curvature threshold is necessary. // alternatively try square of dot pruduct as weight. // TODO: check if additional weighting with dot product as directional prior is necessary. are there alternatives on the classification level? double dot = dot_product(d, olddir); // claculate angle between the candidate direction vector and the previous streamline direction if (fabs(dot)>angularThreshold) // is angle between the directions smaller than our hard threshold? { if (dot<0) // make sure we don't walk backwards d *= -1; double w_i = probs(0,i)*fabs(dot); direction += w_i*d; // weight contribution to output direction with its probability and the angular deviation from the previous direction w += w_i; // increase output weight of the final direction } } else { direction += probs(0,i)*d; w += probs(0,i); } } else pNonFib += probs(0,i); // probability that we are not in the white matter anymore } } // if we did not find a suitable direction, make sure that we return (0,0,0) if (pNonFib>w && w>0) { candidates = 0; w = 0; direction.fill(0.0); } return direction; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::StartTraining() { m_StartTime = std::chrono::system_clock::now(); InputDataValidForTraining(); - PreprocessInputDataForTraining(); - CalculateFeaturesForTraining(); - TrainForest(); + InitForTraining(); + CalculateTrainingSamples(); + + MITK_INFO << "Maximum tree depths: " << m_MaxTreeDepth; + MITK_INFO << "Sample fraction per tree: " << m_SampleFraction; + MITK_INFO << "Number of trees: " << m_NumTrees; + + DefaultSplitType splitter; + splitter.UsePointBasedWeights(true); + splitter.SetWeights(m_Weights); + splitter.UseRandomSplit(false); + splitter.SetPrecision(mitk::eps); + splitter.SetMaximumTreeDepth(m_MaxTreeDepth); + + std::vector< std::shared_ptr< vigra::RandomForest > > trees; + int count = 0; +#pragma omp parallel for + for (int i = 0; i < m_NumTrees; ++i) + { + std::shared_ptr< vigra::RandomForest > lrf = std::make_shared< vigra::RandomForest >(); + lrf->set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal + lrf->set_options().sample_with_replacement(true); // if sampled with replacement or not + lrf->set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree + lrf->set_options().tree_count(1); // Number of trees that are calculated; + lrf->set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + lrf->ext_param_.max_tree_depth = m_MaxTreeDepth; + + lrf->learn(m_FeatureData, m_LabelData,vigra::rf::visitors::VisitorBase(),splitter); +#pragma omp critical + { + count++; + MITK_INFO << "Tree " << count << " finished training."; + trees.push_back(lrf); + } + } + + for (int i = 1; i < m_NumTrees; ++i) + trees.at(0)->trees_.push_back(trees.at(i)->trees_[0]); + + m_Forest = trees.at(0); + m_Forest->options_.tree_count_ = m_NumTrees; + MITK_INFO << "Training finsihed"; + m_EndTime = std::chrono::system_clock::now(); std::chrono::hours hh = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::minutes mm = std::chrono::duration_cast(m_EndTime - m_StartTime); mm %= 60; MITK_INFO << "Training took " << hh.count() << "h and " << mm.count() << "m"; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InputDataValidForTraining() { - if (m_RawData.empty()) + if (m_InputDwis.empty()) mitkThrow() << "No diffusion-weighted images set!"; if (m_Tractograms.empty()) mitkThrow() << "No tractograms set!"; - if (m_RawData.size()!=m_Tractograms.size()) + if (m_InputDwis.size()!=m_Tractograms.size()) mitkThrow() << "Unequal number of diffusion-weighted images and tractograms detected!"; } template< int ShOrder, int NumberOfSignalFeatures > bool TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::IsForestValid() { int additional_features = 0; if (m_AdditionalFeatureImages.size()>0) additional_features = m_AdditionalFeatureImages.at(0).size(); if(m_Forest && m_Forest->tree_count()>0 && m_Forest->feature_count()==(NumberOfSignalFeatures+3+additional_features)) return true; return false; } template< int ShOrder, int NumberOfSignalFeatures > - void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::PreprocessInputDataForTraining() + void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InitForTraining() { typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; MITK_INFO << "Spherical signal interpolation and sampling ..."; - for (unsigned int i=0; iSetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(i)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(i)) ); - qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(i))); - qballfilter->SetLambda(0.006); - qballfilter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); - qballfilter->Update(); - // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); - m_InterpolatedRawImages.push_back(qballfilter->GetOutput()); + typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_InputDwis.at(i)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_InputDwis.at(i)) ); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_InputDwis.at(i))); + filter->SetLambda(0.006); + filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); + filter->Update(); + + //m_DwiFeatureImages.push_back(filter->GetCoefficientImage()); + + { + typename DwiFeatureImageType::Pointer dwiFeatureImage = DwiFeatureImageType::New(); + dwiFeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); + dwiFeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); + dwiFeatureImage->SetDirection(filter->GetOutput()->GetDirection()); + dwiFeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + dwiFeatureImage->Allocate(); + + // get signal values and store them in the feature image + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + itk::OrientationDistributionFunction< double, 2*NumberOfSignalFeatures > odf; + itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); + while(!it.IsAtEnd()) + { + typename DwiFeatureImageType::PixelType pix; + int f = 0; + for (unsigned int i = 0; i0) // only used directions on one hemisphere + { + pix[f] = it.Get()[i]; + f++; + } + } + dwiFeatureImage->SetPixel(it.GetIndex(), pix); + ++it; + } + + m_DwiFeatureImages.push_back(dwiFeatureImage); + } if (i>=m_AdditionalFeatureImages.size()) { m_AdditionalFeatureImages.push_back(std::vector< ItkFloatImgType::Pointer >()); } if (i>=m_FiberVolumeModImages.size()) { ItkFloatImgType::Pointer img = ItkFloatImgType::New(); - img->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); - img->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); - img->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); - img->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - img->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - img->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + img->SetSpacing( m_DwiFeatureImages.at(i)->GetSpacing() ); + img->SetOrigin( m_DwiFeatureImages.at(i)->GetOrigin() ); + img->SetDirection( m_DwiFeatureImages.at(i)->GetDirection() ); + img->SetLargestPossibleRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + img->SetBufferedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + img->SetRequestedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); img->Allocate(); img->FillBuffer(1); m_FiberVolumeModImages.push_back(img); } if (m_FiberVolumeModImages.at(i)==nullptr) { m_FiberVolumeModImages.at(i) = ItkFloatImgType::New(); - m_FiberVolumeModImages.at(i)->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); - m_FiberVolumeModImages.at(i)->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); - m_FiberVolumeModImages.at(i)->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); - m_FiberVolumeModImages.at(i)->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - m_FiberVolumeModImages.at(i)->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - m_FiberVolumeModImages.at(i)->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + m_FiberVolumeModImages.at(i)->SetSpacing( m_DwiFeatureImages.at(i)->GetSpacing() ); + m_FiberVolumeModImages.at(i)->SetOrigin( m_DwiFeatureImages.at(i)->GetOrigin() ); + m_FiberVolumeModImages.at(i)->SetDirection( m_DwiFeatureImages.at(i)->GetDirection() ); + m_FiberVolumeModImages.at(i)->SetLargestPossibleRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + m_FiberVolumeModImages.at(i)->SetBufferedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + m_FiberVolumeModImages.at(i)->SetRequestedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); m_FiberVolumeModImages.at(i)->Allocate(); m_FiberVolumeModImages.at(i)->FillBuffer(1); } if (i>=m_MaskImages.size()) { ItkUcharImgType::Pointer newMask = ItkUcharImgType::New(); - newMask->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); - newMask->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); - newMask->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); - newMask->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - newMask->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - newMask->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + newMask->SetSpacing( m_DwiFeatureImages.at(i)->GetSpacing() ); + newMask->SetOrigin( m_DwiFeatureImages.at(i)->GetOrigin() ); + newMask->SetDirection( m_DwiFeatureImages.at(i)->GetDirection() ); + newMask->SetLargestPossibleRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + newMask->SetBufferedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + newMask->SetRequestedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); newMask->Allocate(); newMask->FillBuffer(1); m_MaskImages.push_back(newMask); } if (m_MaskImages.at(i)==nullptr) { m_MaskImages.at(i) = ItkUcharImgType::New(); - m_MaskImages.at(i)->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); - m_MaskImages.at(i)->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); - m_MaskImages.at(i)->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); - m_MaskImages.at(i)->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - m_MaskImages.at(i)->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); - m_MaskImages.at(i)->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + m_MaskImages.at(i)->SetSpacing( m_DwiFeatureImages.at(i)->GetSpacing() ); + m_MaskImages.at(i)->SetOrigin( m_DwiFeatureImages.at(i)->GetOrigin() ); + m_MaskImages.at(i)->SetDirection( m_DwiFeatureImages.at(i)->GetDirection() ); + m_MaskImages.at(i)->SetLargestPossibleRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + m_MaskImages.at(i)->SetBufferedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); + m_MaskImages.at(i)->SetRequestedRegion( m_DwiFeatureImages.at(i)->GetLargestPossibleRegion() ); m_MaskImages.at(i)->Allocate(); m_MaskImages.at(i)->FillBuffer(1); } } MITK_INFO << "Resampling fibers and calculating number of samples ..."; m_NumberOfSamples = 0; for (unsigned int t=0; t::Pointer env = itk::TractDensityImageFilter< ItkUcharImgType >::New(); env->SetFiberBundle(m_Tractograms.at(t)); env->SetInputImage(mask); env->SetBinaryOutput(true); env->SetUseImageGeometry(true); env->Update(); wmmask = env->GetOutput(); if (t>=m_WhiteMatterImages.size()) m_WhiteMatterImages.push_back(wmmask); else m_WhiteMatterImages.at(t) = wmmask; } // Calculate white-matter samples if (m_WmSampleDistance<0) { - typename InterpolatedRawImageType::Pointer image = m_InterpolatedRawImages.at(t); + typename DwiFeatureImageType::Pointer image = m_DwiFeatureImages.at(t); float minSpacing = 1; if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) minSpacing = image->GetSpacing()[0]; else if (image->GetSpacing()[1] < image->GetSpacing()[2]) minSpacing = image->GetSpacing()[1]; else minSpacing = image->GetSpacing()[2]; m_WmSampleDistance = minSpacing*0.5; } m_Tractograms.at(t)->ResampleSpline(m_WmSampleDistance); unsigned int wmSamples = m_Tractograms.at(t)->GetNumberOfPoints()-2*m_Tractograms.at(t)->GetNumFibers(); MITK_INFO << "Unweighted samples inside of WM: " << wmSamples; m_NumberOfSamples += wmSamples; // calculate gray-matter samples itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); int OUTOFWM = 0; while(!it.IsAtEnd()) { if (it.Get()==0 && mask->GetPixel(it.GetIndex())>0) OUTOFWM++; ++it; } MITK_INFO << "Non-white matter voxels: " << OUTOFWM; if (m_GmSamplesPerVoxel>0) { m_GmSamples.push_back(m_GmSamplesPerVoxel); m_NumberOfSamples += m_GmSamplesPerVoxel*OUTOFWM; } else if (OUTOFWM>0) { MITK_INFO << "Non-white matter samples: " << wmSamples; m_GmSamples.push_back(0.5+(double)wmSamples/(double)OUTOFWM); m_NumberOfSamples += m_GmSamples.back()*OUTOFWM; MITK_INFO << "Non-white matter samples per voxel: " << m_GmSamples.back(); } else { m_GmSamples.push_back(0); } MITK_INFO << "Samples outside of WM: " << m_GmSamples.back()*OUTOFWM; } MITK_INFO << "Number of samples: " << m_NumberOfSamples; } template< int ShOrder, int NumberOfSignalFeatures > - void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::CalculateFeaturesForTraining() + void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::CalculateTrainingSamples() { vnl_vector_fixed ref; ref.fill(0); ref[0]=1; - itk::OrientationDistributionFunction< double, 2*NumberOfSignalFeatures > directions; - std::vector< int > directionIndices; - for (unsigned int f=0; f<2*NumberOfSignalFeatures; f++) - { - if (dot_product(ref, directions.GetDirection(f))>0) - directionIndices.push_back(f); - } int numDirectionFeatures = 3; m_FeatureData.reshape( vigra::Shape2(m_NumberOfSamples, NumberOfSignalFeatures+numDirectionFeatures+m_AdditionalFeatureImages.at(0).size()) ); m_LabelData.reshape( vigra::Shape2(m_NumberOfSamples,1) ); m_Weights.reshape( vigra::Shape2(m_NumberOfSamples,1) ); MITK_INFO << "Number of features: " << m_FeatureData.shape(1); itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); m_RandGen->SetSeed(); MITK_INFO << "Creating training data ..."; int sampleCounter = 0; for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) { - typename InterpolatedRawImageType::PixelType pix = image->GetPixel(it.GetIndex()); + typename DwiFeatureImageType::PixelType pix = image->GetPixel(it.GetIndex()); // null direction for (unsigned int f=0; f itkP; image->TransformIndexToPhysicalPoint(it.GetIndex(), itkP); float v = GetImageValue(itkP, img, false); add_feat_c++; m_FeatureData(sampleCounter,NumberOfSignalFeatures+2+add_feat_c) = v; } - m_LabelData(sampleCounter,0) = directionIndices.size(); + m_LabelData(sampleCounter,0) = m_DirectionContainer.size(); m_Weights(sampleCounter,0) = 1.0; sampleCounter++; // random directions for (int i=1; i probe; probe[0] = m_RandGen->GetVariate()*2-1; probe[1] = m_RandGen->GetVariate()*2-1; probe[2] = m_RandGen->GetVariate()*2-1; probe.normalize(); if (dot_product(ref, probe)<0) probe *= -1; for (unsigned int f=NumberOfSignalFeatures; f itkP; image->TransformIndexToPhysicalPoint(it.GetIndex(), itkP); float v = GetImageValue(itkP, img, false); add_feat_c++; m_FeatureData(sampleCounter,NumberOfSignalFeatures+2+add_feat_c) = v; } - m_LabelData(sampleCounter,0) = directionIndices.size(); + m_LabelData(sampleCounter,0) = m_DirectionContainer.size(); m_Weights(sampleCounter,0) = 1.0; sampleCounter++; } } ++it; } // white matter samples mitk::FiberBundle::Pointer fib = m_Tractograms.at(t); vtkSmartPointer< vtkPolyData > polyData = fib->GetFiberPolyData(); for (int i=0; iGetNumFibers(); i++) { vtkCell* cell = polyData->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); double fiber_weight = fib->GetFiberWeight(i); vnl_vector_fixed dirOld; dirOld.fill(0.0); for (int j=0; jGetPoint(j); itk::Point itkP1; itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; float volume_mod = GetImageValue(itkP1, fiber_folume, false); vnl_vector_fixed dir; dir.fill(0.0); itk::Point itkP2; double* p2 = points->GetPoint(j+1); itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; dir[0]=itkP2[0]-itkP1[0]; dir[1]=itkP2[1]-itkP1[1]; dir[2]=itkP2[2]-itkP1[2]; if (dir.magnitude()<0.0001) { MITK_INFO << "streamline error!"; continue; } dir.normalize(); if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) { MITK_INFO << "ERROR: NaN direction!"; continue; } if (j==0) { dirOld = dir; continue; } // get voxel values - typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image); + typename DwiFeatureImageType::PixelType pix = GetDwiFeaturesAtPosition(itkP1, image); for (unsigned int f=0; f(itkP1, img, false); add_feat_c++; m_FeatureData(sampleCounter,NumberOfSignalFeatures+2+add_feat_c) = v; } // set label values double angle = 0; double m = dir.magnitude(); if (m>0.0001) { - for (unsigned int f=0; fangle) { - m_LabelData(sampleCounter,0) = f; + m_LabelData(sampleCounter,0) = l; m_Weights(sampleCounter,0) = fiber_weight*volume_mod; //MITK_INFO << "m_Weights(sampleCounter,0): " << m_Weights(sampleCounter,0) << ' ' << fiber_weight << ' ' << volume_mod; angle = a; } + l++; } } dirOld = dir; sampleCounter++; } } } } - template< int ShOrder, int NumberOfSignalFeatures > - void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::TrainForest() - { - MITK_INFO << "Maximum tree depths: " << m_MaxTreeDepth; - MITK_INFO << "Sample fraction per tree: " << m_SampleFraction; - MITK_INFO << "Number of trees: " << m_NumTrees; - - - DefaultSplitType splitter; - splitter.UsePointBasedWeights(true); - splitter.SetWeights(m_Weights); - splitter.UseRandomSplit(false); - splitter.SetPrecision(mitk::eps); - splitter.SetMaximumTreeDepth(m_MaxTreeDepth); - - std::vector< std::shared_ptr< vigra::RandomForest > > trees; - int count = 0; -#pragma omp parallel for - for (int i = 0; i < m_NumTrees; ++i) - { - std::shared_ptr< vigra::RandomForest > lrf = std::make_shared< vigra::RandomForest >(); - lrf->set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal - lrf->set_options().sample_with_replacement(true); // if sampled with replacement or not - lrf->set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree - lrf->set_options().tree_count(1); // Number of trees that are calculated; - lrf->set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node - lrf->ext_param_.max_tree_depth = m_MaxTreeDepth; - - lrf->learn(m_FeatureData, m_LabelData,vigra::rf::visitors::VisitorBase(),splitter); -#pragma omp critical - { - count++; - MITK_INFO << "Tree " << count << " finished training."; - trees.push_back(lrf); - } - } - - for (int i = 1; i < m_NumTrees; ++i) - trees.at(0)->trees_.push_back(trees.at(i)->trees_[0]); - - m_Forest = trees.at(0); - m_Forest->options_.tree_count_ = m_NumTrees; - MITK_INFO << "Training finsihed"; - } - template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::SaveForest(std::string forestFile) { MITK_INFO << "Saving forest to " << forestFile; if (IsForestValid()) { vigra::rf_export_HDF5( *m_Forest, forestFile, "" ); MITK_INFO << "Forest saved successfully."; } else MITK_INFO << "Forest invalid! Could not be saved."; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::LoadForest(std::string forestFile) { MITK_INFO << "Loading forest from " << forestFile; m_Forest = std::make_shared< vigra::RandomForest >(); vigra::rf_import_HDF5( *m_Forest, forestFile); } } #endif diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h index cbaf4c99e1..80b07e72ef 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h @@ -1,149 +1,147 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _TrackingForestHandler #define _TrackingForestHandler #include "mitkBaseData.h" #include #include #include #include #include #include #undef DIFFERENCE #define VIGRA_STATIC_LIB #include #include #include //#include #include #include #include #define _USE_MATH_DEFINES #include namespace mitk { /** * \brief Manages random forests for fiber tractography. The preparation of the features from the inputa data and the training process are handled here. The data preprocessing and actual prediction for the tracking process is also performed here. The tracking itself is performed in MLBSTrackingFilter. */ -template< int ShOrder=6, int NumberOfSignalFeatures=100 > +template< int ShOrder=6, int NumberOfSignalFeatures=28 > class TrackingForestHandler { public: TrackingForestHandler(); ~TrackingForestHandler(); typedef itk::Image ItkShortImgType; typedef itk::Image ItkFloatImgType; typedef itk::Image ItkUcharImgType; - typedef itk::Image< itk::Vector< float, NumberOfSignalFeatures*2 > , 3 > InterpolatedRawImageType; - typedef itk::Image< Vector< float, NumberOfSignalFeatures > , 3 > FeatureImageType; + //typedef itk::Image< itk::Vector< float, NumberOfSignalFeatures*2 > , 3 > DwiFeatureImageType; + typedef itk::Image< itk::Vector< float, NumberOfSignalFeatures > , 3 > DwiFeatureImageType; + typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; - void SetRawData( std::vector< Image::Pointer > images ){ m_RawData = images; } - void AddRawData( Image::Pointer img ){ m_RawData.push_back(img); } + void SetDwis( std::vector< Image::Pointer > images ){ m_InputDwis = images; } + void AddDwi( Image::Pointer img ){ m_InputDwis.push_back(img); } void SetTractograms( std::vector< FiberBundle::Pointer > tractograms ) { m_Tractograms.clear(); for (auto fib : tractograms) { m_Tractograms.push_back(fib->GetDeepCopy()); } } void SetMaskImages( std::vector< ItkUcharImgType::Pointer > images ){ m_MaskImages = images; } void SetWhiteMatterImages( std::vector< ItkUcharImgType::Pointer > images ){ m_WhiteMatterImages = images; } void SetFiberVolumeModImages( std::vector< ItkFloatImgType::Pointer > images ){ m_FiberVolumeModImages = images; } void SetAdditionalFeatureImages( std::vector< std::vector< ItkFloatImgType::Pointer > > images ){ m_AdditionalFeatureImages = images; } void StartTraining(); void SaveForest(std::string forestFile); void LoadForest(std::string forestFile); // training parameters - void SetNumTrees(int num){ m_NumTrees = num; } void SetMaxTreeDepth(int depth){ m_MaxTreeDepth = depth; } void SetStepSize(double step){ m_WmSampleDistance = step; } void SetGrayMatterSamplesPerVoxel(int samples){ m_GmSamplesPerVoxel = samples; } void SetSampleFraction(double fraction){ m_SampleFraction = fraction; } std::shared_ptr< vigra::RandomForest > GetForest(){ return m_Forest; } void InitForTracking(); ///< calls InputDataValidForTracking() and creates feature images from the war input DWI vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& w, ItkUcharImgType::Pointer mask=nullptr); ///< predicts next progression direction at the given position bool IsForestValid(); ///< true is forest is not null, has more than 0 trees and the correct number of features (NumberOfSignalFeatures + 3) protected: // tracking void InputDataValidForTracking(); ///< check if raw data is set and tracking forest is valid - typename FeatureImageType::PixelType GetFeatureValues(itk::Point itkP); ///< get trilinearly interpolated feature values at given world position template< class TPixelType > TPixelType GetImageValue(itk::Point itkP, itk::Image* image, bool interpolate); - // training void InputDataValidForTraining(); ///< Check if everything is tehere for training (raw datasets, fiber tracts) - void PreprocessInputDataForTraining(); ///< Generate masks if necessary, resample fibers, spherically interpolate raw DWIs - void CalculateFeaturesForTraining(); ///< Calculate GM and WM features using the interpolated raw data, the WM masks and the fibers - void TrainForest(); ///< start training process - typename InterpolatedRawImageType::PixelType GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image); ///< get trilinearly interpolated raw image values at given world position + void InitForTraining(); ///< Generate masks if necessary, resample fibers, spherically interpolate raw DWIs + void CalculateTrainingSamples(); ///< Calculate GM and WM features using the interpolated raw data, the WM masks and the fibers + typename DwiFeatureImageType::PixelType GetDwiFeaturesAtPosition(itk::Point itkP, typename DwiFeatureImageType::Pointer image); ///< get trilinearly interpolated raw image values at given world position + + std::vector< Image::Pointer > m_InputDwis; ///< original input DWI data + std::shared_ptr< vigra::RandomForest > m_Forest; ///< random forest classifier + std::chrono::time_point m_StartTime; + std::chrono::time_point m_EndTime; - std::vector< Image::Pointer > m_RawData; ///< original input DWI data - std::shared_ptr< vigra::RandomForest > m_Forest; ///< random forest classifier - std::chrono::time_point m_StartTime; - std::chrono::time_point m_EndTime; - // only for training + std::vector< typename DwiFeatureImageType::Pointer > m_DwiFeatureImages; std::vector< std::vector< ItkFloatImgType::Pointer > > m_AdditionalFeatureImages; + std::vector< ItkFloatImgType::Pointer > m_FiberVolumeModImages; ///< used to correct the fiber density std::vector< FiberBundle::Pointer > m_Tractograms; ///< training tractograms std::vector< ItkUcharImgType::Pointer > m_MaskImages; ///< binary mask images to constrain training to a certain area (e.g. brain mask) std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; ///< defines white matter voxels. if not set, theses mask images are automatically generated from the input tractograms - std::vector< typename InterpolatedRawImageType::Pointer > m_InterpolatedRawImages; ///< spherically interpolated and resampled raw datasets + double m_WmSampleDistance; ///< deterines the number of white matter samples (distance of sampling points on each fiber). int m_NumTrees; ///< number of trees in random forest int m_MaxTreeDepth; ///< limits the tree depth double m_SampleFraction; ///< fraction of samples used to train each tree unsigned int m_NumberOfSamples; ///< stores overall number of samples used for training std::vector< unsigned int > m_GmSamples; ///< number of gray matter samples int m_GmSamplesPerVoxel; ///< number of gray matter samplees per voxel. if -1, then the number is automatically chosen to gain an overall number of GM samples close to the number of WM samples. vigra::MultiArray<2, double> m_FeatureData; ///< vigra container for training features // only for tracking - typename FeatureImageType::Pointer m_FeatureImage; ///< feature image used for tracking vigra::MultiArray<2, double> m_LabelData; ///< vigra container for training labels - vigra::MultiArray<2, double> m_Weights; ///< vigra container for training labels - std::vector< int > m_DirectionIndices; ///< maps each of the NumberOfSignalFeatures possible output directions to one of the 2*NumberOfSignalFeatures ODF directions. - itk::OrientationDistributionFunction< double, NumberOfSignalFeatures*2 > m_DirContainer; ///< direction container + vigra::MultiArray<2, double> m_Weights; ///< vigra container for training sample weights + + std::vector< vnl_vector_fixed > m_DirectionContainer; }; } #include "mitkTrackingForestHandler.cpp" #endif diff --git a/Modules/DiffusionImaging/FiberTracking/Testing/mitkMachineLearningTrackingTest.cpp b/Modules/DiffusionImaging/FiberTracking/Testing/mitkMachineLearningTrackingTest.cpp index 85d39d2c4e..7c9064cbbc 100644 --- a/Modules/DiffusionImaging/FiberTracking/Testing/mitkMachineLearningTrackingTest.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Testing/mitkMachineLearningTrackingTest.cpp @@ -1,102 +1,101 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkTestingMacros.h" #include #include #include #include #include #include #include #include #include #include "mitkTestFixture.h" class mitkMachineLearningTrackingTestSuite : public mitk::TestFixture { CPPUNIT_TEST_SUITE(mitkMachineLearningTrackingTestSuite); MITK_TEST(Track1); CPPUNIT_TEST_SUITE_END(); typedef itk::Image ItkUcharImgType; private: /** Members used inside the different (sub-)tests. All members are initialized via setUp().*/ mitk::FiberBundle::Pointer ref; mitk::TrackingForestHandler<> tfh; mitk::Image::Pointer dwi; ItkUcharImgType::Pointer seed; public: void setUp() override { ref = NULL; std::vector fibInfile = mitk::IOUtil::Load(GetTestDataFilePath("DiffusionImaging/MachineLearningTracking/ReferenceTracts.fib")); mitk::BaseData::Pointer baseData = fibInfile.at(0); ref = dynamic_cast(baseData.GetPointer()); dwi = dynamic_cast(mitk::IOUtil::LoadImage(GetTestDataFilePath("DiffusionImaging/MachineLearningTracking/DiffusionImage.dwi")).GetPointer()); mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(GetTestDataFilePath("DiffusionImaging/MachineLearningTracking/seed.nrrd")).GetPointer()); seed = ItkUcharImgType::New(); mitk::CastToItkImage(img, seed); tfh.LoadForest(GetTestDataFilePath("DiffusionImaging/MachineLearningTracking/forest.rf")); - tfh.AddRawData(dwi); + tfh.AddDwi(dwi); } void tearDown() override { ref = NULL; } void Track1() { typedef itk::MLBSTrackingFilter<> TrackerType; TrackerType::Pointer tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi)); tracker->SetDemoMode(false); tracker->SetSeedImage(seed); tracker->SetSeedsPerVoxel(1); tracker->SetStepSize(-1); tracker->SetMinTractLength(20); tracker->SetMaxTractLength(400); tracker->SetForestHandler(tfh); - tracker->SetNumberOfSamples(30); tracker->SetAposterioriCurvCheck(false); tracker->SetRemoveWmEndFibers(false); tracker->SetAvoidStop(true); tracker->SetSamplingDistance(0.5); tracker->SetRandomSampling(false); tracker->Update(); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); // mitk::IOUtil::Save(outFib, mitk::IOUtil::GetTempPath()+"RefFib.fib"); CPPUNIT_ASSERT_MESSAGE("Should be equal", ref->Equals(outFib)); } }; MITK_TEST_SUITE_REGISTRATION(mitkMachineLearningTracking) diff --git a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp index 5c9d692613..4023389efe 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp @@ -1,180 +1,189 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include using namespace std; const int numOdfSamples = 200; typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; /*! \brief Perform machine learning based streamline tractography */ int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription("Perform machine learning based streamline tractography"); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("image", "i", mitkCommandLineParser::String, "DWI:", "input diffusion-weighted image", us::Any(), false); parser.addArgument("addfeatures", "a", mitkCommandLineParser::StringList, "Additional feature images:", "specify a list of float images that hold additional features (float)", us::Any()); parser.addArgument("forest", "f", mitkCommandLineParser::String, "Forest:", "input random forest (HDF5 file)", us::Any(), false); parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output fiberbundle", us::Any(), false); parser.addArgument("stop", "st", mitkCommandLineParser::String, "Stop image:", "streamlines entering the binary mask will stop immediately", us::Any()); parser.addArgument("mask", "m", mitkCommandLineParser::String, "Mask image:", "restrict tractography with a binary mask image", us::Any()); parser.addArgument("seed", "s", mitkCommandLineParser::String, "Seed image:", "binary mask image defining seed voxels", us::Any()); parser.addArgument("stepsize", "se", mitkCommandLineParser::Float, "Stepsize:", "stepsize (in voxels)", us::Any()); - parser.addArgument("samples", "ns", mitkCommandLineParser::Int, "Samples:", "number of neighborhood samples", us::Any()); parser.addArgument("samplingdist", "sd", mitkCommandLineParser::Float, "Sampling distance:", "distance of neighborhood sampling points (in voxels)", us::Any()); parser.addArgument("seeds", "nse", mitkCommandLineParser::Int, "Seeds per voxel:", "number of seed points per voxel", us::Any()); + parser.addArgument("stopvotes", "sv", mitkCommandLineParser::Int, "Use stop votes:", "use stop votes", us::Any()); + parser.addArgument("forward", "fs", mitkCommandLineParser::Int, "Use only forward samples:", "use only forward samples", us::Any()); + + map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; string imageFile = us::any_cast(parsedArgs["image"]); string forestFile = us::any_cast(parsedArgs["forest"]); string outFile = us::any_cast(parsedArgs["out"]); + bool stopvotes = true; + if (parsedArgs.count("stopvotes")) + stopvotes = us::any_cast(parsedArgs["stopvotes"]); + + bool forward = true; + if (parsedArgs.count("forward")) + forward = us::any_cast(parsedArgs["forward"]); + string maskFile = ""; if (parsedArgs.count("mask")) maskFile = us::any_cast(parsedArgs["mask"]); string seedFile = ""; if (parsedArgs.count("seed")) seedFile = us::any_cast(parsedArgs["seed"]); string stopFile = ""; if (parsedArgs.count("stop")) stopFile = us::any_cast(parsedArgs["stop"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); float samplingdist = 0.25; if (parsedArgs.count("samplingdist")) samplingdist = us::any_cast(parsedArgs["samplingdist"]); - int samples = 30; - if (parsedArgs.count("samples")) - samples = us::any_cast(parsedArgs["samples"]); - int seeds = 1; if (parsedArgs.count("seeds")) seeds = us::any_cast(parsedArgs["seeds"]); mitkCommandLineParser::StringContainerType addFeatFiles; if (parsedArgs.count("addfeatures")) addFeatFiles = us::any_cast(parsedArgs["addfeatures"]); typedef itk::Image ItkUcharImgType; MITK_INFO << "loading diffusion-weighted image"; mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::LoadImage(imageFile).GetPointer()); ItkUcharImgType::Pointer mask; if (!maskFile.empty()) { MITK_INFO << "loading mask image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(maskFile).GetPointer()); mask = ItkUcharImgType::New(); mitk::CastToItkImage(img, mask); } ItkUcharImgType::Pointer seed; if (!seedFile.empty()) { MITK_INFO << "loading seed image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(seedFile).GetPointer()); seed = ItkUcharImgType::New(); mitk::CastToItkImage(img, seed); } ItkUcharImgType::Pointer stop; if (!stopFile.empty()) { MITK_INFO << "loading stop image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(stopFile).GetPointer()); stop = ItkUcharImgType::New(); mitk::CastToItkImage(img, stop); } MITK_INFO << "loading additional feature images"; typedef itk::Image ItkFloatImgType; std::vector< std::vector< ItkFloatImgType::Pointer > > addFeatImages; addFeatImages.push_back(std::vector< ItkFloatImgType::Pointer >()); for (auto file : addFeatFiles) { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(file).GetPointer()); ItkFloatImgType::Pointer itkimg = ItkFloatImgType::New(); mitk::CastToItkImage(img, itkimg); addFeatImages.at(0).push_back(itkimg); } - mitk::TrackingForestHandler<> tfh; + mitk::TrackingForestHandler<6,100> tfh; tfh.LoadForest(forestFile); - tfh.AddRawData(dwi); + tfh.AddDwi(dwi); tfh.SetAdditionalFeatureImages(addFeatImages); - typedef itk::MLBSTrackingFilter<> TrackerType; + typedef itk::MLBSTrackingFilter<6,100> TrackerType; TrackerType::Pointer tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi)); tracker->SetMaskImage(mask); tracker->SetSeedImage(seed); tracker->SetStoppingRegions(stop); tracker->SetSeedsPerVoxel(seeds); tracker->SetStepSize(stepsize); tracker->SetForestHandler(tfh); tracker->SetSamplingDistance(samplingdist); - tracker->SetNumberOfSamples(samples); + tracker->SetUseStopVotes(stopvotes); + tracker->SetOnlyForwardSamples(forward); + //tracker->SetDeflectionMod(deflection); //tracker->SetAvoidStop(false); tracker->SetAposterioriCurvCheck(false); tracker->SetRemoveWmEndFibers(false); tracker->Update(); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); mitk::IOUtil::Save(outFib, outFile); return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp index 5701702874..085c02d333 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp @@ -1,194 +1,194 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include using namespace std; /*! \brief Train random forest classifier for machine learning based streamline tractography */ int main(int argc, char* argv[]) { MITK_INFO << "DFTraining"; mitkCommandLineParser parser; parser.setTitle("Training for Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription("Train random forest classifier for machine learning based streamline tractography"); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("images", "i", mitkCommandLineParser::StringList, "DWIs:", "input diffusion-weighted images", us::Any(), false); parser.addArgument("tractograms", "t", mitkCommandLineParser::StringList, "Tractograms:", "input training tractograms (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("forest", "f", mitkCommandLineParser::OutputFile, "Forest:", "output random forest (HDF5)", us::Any(), false); parser.addArgument("masks", "m", mitkCommandLineParser::StringList, "Masks:", "restrict trining using a binary mask image", us::Any()); parser.addArgument("wmmasks", "w", mitkCommandLineParser::StringList, "WM-Masks:", "if no binary white matter mask is specified, the envelope of the input tractogram is used", us::Any()); parser.addArgument("volmod", "v", mitkCommandLineParser::StringList, "Volume modification images:", "specify a list of float images that modify the fiber density", us::Any()); parser.addArgument("addfeatures", "a", mitkCommandLineParser::StringList, "Additional feature images:", "specify a list of float images that hold additional features (float)", us::Any()); parser.addArgument("stepsize", "s", mitkCommandLineParser::Float, "Stepsize:", "resampling parameter for the input tractogram in mm (determines number of white-matter samples)", us::Any()); parser.addArgument("gmsamples", "g", mitkCommandLineParser::Int, "Number of gray matter samples per voxel:", "Number of gray matter samples per voxel", us::Any()); parser.addArgument("numtrees", "n", mitkCommandLineParser::Int, "Number of trees:", "number of trees", us::Any()); parser.addArgument("max_tree_depth", "d", mitkCommandLineParser::Int, "Max. tree depth:", "maximum tree depth", us::Any()); parser.addArgument("sample_fraction", "sf", mitkCommandLineParser::Float, "Sample fraction:", "fraction of samples used per tree", us::Any()); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; mitkCommandLineParser::StringContainerType imageFiles = us::any_cast(parsedArgs["images"]); mitkCommandLineParser::StringContainerType wmMaskFiles; if (parsedArgs.count("wmmasks")) wmMaskFiles = us::any_cast(parsedArgs["wmmasks"]); mitkCommandLineParser::StringContainerType volModFiles; if (parsedArgs.count("volmod")) volModFiles = us::any_cast(parsedArgs["volmod"]); mitkCommandLineParser::StringContainerType addFeatFiles; if (parsedArgs.count("addfeatures")) addFeatFiles = us::any_cast(parsedArgs["addfeatures"]); mitkCommandLineParser::StringContainerType maskFiles; if (parsedArgs.count("masks")) maskFiles = us::any_cast(parsedArgs["masks"]); string forestFile = us::any_cast(parsedArgs["forest"]); mitkCommandLineParser::StringContainerType tractogramFiles; if (parsedArgs.count("tractograms")) tractogramFiles = us::any_cast(parsedArgs["tractograms"]); int numTrees = 50; if (parsedArgs.count("numtrees")) numTrees = us::any_cast(parsedArgs["numtrees"]); int gmsamples = -1; if (parsedArgs.count("gmsamples")) gmsamples = us::any_cast(parsedArgs["gmsamples"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); int max_tree_depth = 25; if (parsedArgs.count("max_tree_depth")) max_tree_depth = us::any_cast(parsedArgs["max_tree_depth"]); double sample_fraction = 0.6; if (parsedArgs.count("sample_fraction")) sample_fraction = us::any_cast(parsedArgs["sample_fraction"]); MITK_INFO << "loading diffusion-weighted images"; std::vector< mitk::Image::Pointer > rawData; for (auto imgFile : imageFiles) { mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::LoadImage(imgFile).GetPointer()); rawData.push_back(dwi); } typedef itk::Image ItkFloatImgType; typedef itk::Image ItkUcharImgType; MITK_INFO << "loading mask images"; std::vector< ItkUcharImgType::Pointer > maskImageVector; for (auto maskFile : maskFiles) { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(maskFile).GetPointer()); ItkUcharImgType::Pointer mask = ItkUcharImgType::New(); mitk::CastToItkImage(img, mask); maskImageVector.push_back(mask); } MITK_INFO << "loading white matter mask images"; std::vector< ItkUcharImgType::Pointer > wmMaskImageVector; for (auto wmFile : wmMaskFiles) { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(wmFile).GetPointer()); ItkUcharImgType::Pointer wmmask = ItkUcharImgType::New(); mitk::CastToItkImage(img, wmmask); wmMaskImageVector.push_back(wmmask); } MITK_INFO << "loading tractograms"; std::vector< mitk::FiberBundle::Pointer > tractograms; for (auto tractFile : tractogramFiles) { mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(tractFile).at(0).GetPointer()); tractograms.push_back(fib); } MITK_INFO << "loading white volume modification images"; std::vector< ItkFloatImgType::Pointer > volumeModImages; for (auto file : volModFiles) { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(file).GetPointer()); ItkFloatImgType::Pointer itkimg = ItkFloatImgType::New(); mitk::CastToItkImage(img, itkimg); volumeModImages.push_back(itkimg); } MITK_INFO << "loading additional feature images"; std::vector< std::vector< ItkFloatImgType::Pointer > > addFeatImages; for (int i=0; i()); int c = 0; for (auto file : addFeatFiles) { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(file).GetPointer()); ItkFloatImgType::Pointer itkimg = ItkFloatImgType::New(); mitk::CastToItkImage(img, itkimg); addFeatImages.at(c%addFeatImages.size()).push_back(itkimg); c++; } - mitk::TrackingForestHandler<> forestHandler; - forestHandler.SetRawData(rawData); + mitk::TrackingForestHandler<6,100> forestHandler; + forestHandler.SetDwis(rawData); forestHandler.SetMaskImages(maskImageVector); forestHandler.SetWhiteMatterImages(wmMaskImageVector); forestHandler.SetFiberVolumeModImages(volumeModImages); forestHandler.SetAdditionalFeatureImages(addFeatImages); forestHandler.SetTractograms(tractograms); forestHandler.SetNumTrees(numTrees); forestHandler.SetMaxTreeDepth(max_tree_depth); forestHandler.SetGrayMatterSamplesPerVoxel(gmsamples); forestHandler.SetSampleFraction(sample_fraction); forestHandler.SetStepSize(stepsize); forestHandler.StartTraining(); forestHandler.SaveForest(forestFile); return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/MiniApps/FiberProcessing.cpp b/Modules/DiffusionImaging/MiniApps/FiberProcessing.cpp index 734b2aca9e..270ddab2e8 100644 --- a/Modules/DiffusionImaging/MiniApps/FiberProcessing.cpp +++ b/Modules/DiffusionImaging/MiniApps/FiberProcessing.cpp @@ -1,215 +1,220 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include mitk::FiberBundle::Pointer LoadFib(std::string filename) { std::vector fibInfile = mitk::IOUtil::Load(filename); if( fibInfile.empty() ) std::cout << "File " << filename << " could not be read!"; mitk::BaseData::Pointer baseData = fibInfile.at(0); return dynamic_cast(baseData.GetPointer()); } /*! \brief Modify input tractogram: fiber resampling, compression, pruning and transformation. */ int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Fiber Processing"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription("Modify input tractogram: fiber resampling, compression, pruning and transformation."); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("input", "i", mitkCommandLineParser::InputFile, "Input:", "input fiber bundle (.fib)", us::Any(), false); parser.addArgument("outFile", "o", mitkCommandLineParser::OutputFile, "Output:", "output fiber bundle (.fib)", us::Any(), false); parser.addArgument("smooth", "s", mitkCommandLineParser::Float, "Spline resampling:", "Resample fiber using splines with the given point distance (in mm)"); parser.addArgument("compress", "c", mitkCommandLineParser::Float, "Compress:", "Compress fiber using the given error threshold (in mm)"); parser.addArgument("minLength", "l", mitkCommandLineParser::Float, "Minimum length:", "Minimum fiber length (in mm)"); parser.addArgument("maxLength", "m", mitkCommandLineParser::Float, "Maximum length:", "Maximum fiber length (in mm)"); parser.addArgument("maxAngle", "a", mitkCommandLineParser::Float, "Maximum angle:", "Maximum angular STDEV over 1cm (in degree)"); + parser.addArgument("remove", "rm", mitkCommandLineParser::Int, "Remove fibers exceeding curvature threshold:", "if 0, only the high curvature parts are removed"); parser.addArgument("mirror", "p", mitkCommandLineParser::Int, "Invert coordinates:", "Invert fiber coordinates XYZ (e.g. 010 to invert y-coordinate of each fiber point)"); parser.addArgument("rotate-x", "rx", mitkCommandLineParser::Float, "Rotate x-axis:", "Rotate around x-axis (if copy is given the copy is rotated, in deg)"); parser.addArgument("rotate-y", "ry", mitkCommandLineParser::Float, "Rotate y-axis:", "Rotate around y-axis (if copy is given the copy is rotated, in deg)"); parser.addArgument("rotate-z", "rz", mitkCommandLineParser::Float, "Rotate z-axis:", "Rotate around z-axis (if copy is given the copy is rotated, in deg)"); parser.addArgument("scale-x", "sx", mitkCommandLineParser::Float, "Scale x-axis:", "Scale in direction of x-axis (if copy is given the copy is scaled)"); parser.addArgument("scale-y", "sy", mitkCommandLineParser::Float, "Scale y-axis:", "Scale in direction of y-axis (if copy is given the copy is scaled)"); parser.addArgument("scale-z", "sz", mitkCommandLineParser::Float, "Scale z-axis", "Scale in direction of z-axis (if copy is given the copy is scaled)"); parser.addArgument("translate-x", "tx", mitkCommandLineParser::Float, "Translate x-axis:", "Translate in direction of x-axis (if copy is given the copy is translated, in mm)"); parser.addArgument("translate-y", "ty", mitkCommandLineParser::Float, "Translate y-axis:", "Translate in direction of y-axis (if copy is given the copy is translated, in mm)"); parser.addArgument("translate-z", "tz", mitkCommandLineParser::Float, "Translate z-axis:", "Translate in direction of z-axis (if copy is given the copy is translated, in mm)"); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; + bool remove = true; + if (parsedArgs.count("remove")) + remove = us::any_cast(parsedArgs["remove"]); + float smoothDist = -1; if (parsedArgs.count("smooth")) smoothDist = us::any_cast(parsedArgs["smooth"]); float compress = -1; if (parsedArgs.count("compress")) compress = us::any_cast(parsedArgs["compress"]); float minFiberLength = -1; if (parsedArgs.count("minLength")) minFiberLength = us::any_cast(parsedArgs["minLength"]); float maxFiberLength = -1; if (parsedArgs.count("maxLength")) maxFiberLength = us::any_cast(parsedArgs["maxLength"]); float maxAngularDev = -1; if (parsedArgs.count("maxAngle")) maxAngularDev = us::any_cast(parsedArgs["maxAngle"]); int axis = 0; if (parsedArgs.count("mirror")) axis = us::any_cast(parsedArgs["mirror"]); float rotateX = 0; if (parsedArgs.count("rotate-x")) rotateX = us::any_cast(parsedArgs["rotate-x"]); float rotateY = 0; if (parsedArgs.count("rotate-y")) rotateY = us::any_cast(parsedArgs["rotate-y"]); float rotateZ = 0; if (parsedArgs.count("rotate-z")) rotateZ = us::any_cast(parsedArgs["rotate-z"]); float scaleX = 0; if (parsedArgs.count("scale-x")) scaleX = us::any_cast(parsedArgs["scale-x"]); float scaleY = 0; if (parsedArgs.count("scale-y")) scaleY = us::any_cast(parsedArgs["scale-y"]); float scaleZ = 0; if (parsedArgs.count("scale-z")) scaleZ = us::any_cast(parsedArgs["scale-z"]); float translateX = 0; if (parsedArgs.count("translate-x")) translateX = us::any_cast(parsedArgs["translate-x"]); float translateY = 0; if (parsedArgs.count("translate-y")) translateY = us::any_cast(parsedArgs["translate-y"]); float translateZ = 0; if (parsedArgs.count("translate-z")) translateZ = us::any_cast(parsedArgs["translate-z"]); string inFileName = us::any_cast(parsedArgs["input"]); string outFileName = us::any_cast(parsedArgs["outFile"]); try { mitk::FiberBundle::Pointer fib = LoadFib(inFileName); if (maxAngularDev>0) { auto filter = itk::FiberCurvatureFilter::New(); filter->SetInputFiberBundle(fib); filter->SetAngularDeviation(maxAngularDev); filter->SetDistance(10); - filter->SetRemoveFibers(true); + filter->SetRemoveFibers(remove); filter->Update(); fib = filter->GetOutputFiberBundle(); } if (minFiberLength>0) fib->RemoveShortFibers(minFiberLength); if (maxFiberLength>0) fib->RemoveLongFibers(maxFiberLength); if (smoothDist>0) fib->ResampleSpline(smoothDist); if (compress>0) fib->Compress(compress); if (axis/100==1) fib->MirrorFibers(0); if ((axis%100)/10==1) fib->MirrorFibers(1); if (axis%10==1) fib->MirrorFibers(2); if (rotateX > 0 || rotateY > 0 || rotateZ > 0){ std::cout << "Rotate " << rotateX << " " << rotateY << " " << rotateZ; fib->RotateAroundAxis(rotateX, rotateY, rotateZ); } if (translateX > 0 || translateY > 0 || translateZ > 0){ fib->TranslateFibers(translateX, translateY, translateZ); } if (scaleX > 0 || scaleY > 0 || scaleZ > 0) fib->ScaleFibers(scaleX, scaleY, scaleZ); mitk::IOUtil::SaveBaseData(fib.GetPointer(), outFileName ); } catch (itk::ExceptionObject e) { std::cout << e; return EXIT_FAILURE; } catch (std::exception e) { std::cout << e.what(); return EXIT_FAILURE; } catch (...) { std::cout << "ERROR!?!"; return EXIT_FAILURE; } return EXIT_SUCCESS; } diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp index 5c80118200..8a53615a5b 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp @@ -1,430 +1,432 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // Blueberry #include #include // Qmitk #include "QmitkMLBTView.h" #include "QmitkStdMultiWidget.h" // Qt #include #include #include #include #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include const std::string QmitkMLBTView::VIEW_ID = "org.mitk.views.mlbtview"; using namespace berry; QmitkMLBTView::QmitkMLBTView() : QmitkFunctionality() , m_Controls( 0 ) , m_MultiWidget( NULL ) { m_TrackingTimer = std::make_shared(this); m_LastLoadedForestName = "(none)"; } // Destructor QmitkMLBTView::~QmitkMLBTView() { } void QmitkMLBTView::CreateQtPartControl( QWidget *parent ) { // build up qt view, unless already done if ( !m_Controls ) { // create GUI widgets from the Qt Designer's .ui file m_Controls = new Ui::QmitkMLBTViewControls; m_Controls->setupUi( parent ); connect( m_Controls->m_StartTrainingButton, SIGNAL ( clicked() ), this, SLOT( StartTrainingThread() ) ); connect( &m_TrainingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrainingThreadStop() ) ); connect( m_Controls->m_StartTrackingButton, SIGNAL ( clicked() ), this, SLOT( StartTrackingThread() ) ); connect( &m_TrackingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrackingThreadStop() ) ); connect( m_Controls->m_SaveForestButton, SIGNAL ( clicked() ), this, SLOT( SaveForest() ) ); connect( m_Controls->m_LoadForestButton, SIGNAL ( clicked() ), this, SLOT( LoadForest() ) ); connect( m_TrackingTimer.get(), SIGNAL(timeout()), this, SLOT(BuildFibers()) ); connect( m_Controls->m_TimerIntervalBox, SIGNAL(valueChanged(int)), this, SLOT( ChangeTimerInterval(int) )); connect( m_Controls->m_DemoModeBox, SIGNAL(stateChanged(int)), this, SLOT( ToggleDemoMode(int) )); connect( m_Controls->m_PauseTrackingButton, SIGNAL ( clicked() ), this, SLOT( PauseTracking() ) ); connect( m_Controls->m_AbortTrackingButton, SIGNAL ( clicked() ), this, SLOT( AbortTracking() ) ); connect( m_Controls->m_AddTwButton, SIGNAL ( clicked() ), this, SLOT( AddTrainingWidget() ) ); connect( m_Controls->m_RemoveTwButton, SIGNAL ( clicked() ), this, SLOT( RemoveTrainingWidget() ) ); m_Controls->m_TrackingMaskImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingSeedImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingStopImageBox->SetDataStorage(this->GetDataStorage()); m_Controls->m_TrackingRawImageBox->SetDataStorage(this->GetDataStorage()); mitk::NodePredicateIsDWI::Pointer isDiffusionImage = mitk::NodePredicateIsDWI::New(); mitk::TNodePredicateDataType::Pointer isMitkImage = mitk::TNodePredicateDataType::New(); mitk::NodePredicateNot::Pointer noDiffusionImage = mitk::NodePredicateNot::New(isDiffusionImage); mitk::NodePredicateAnd::Pointer finalPredicate = mitk::NodePredicateAnd::New(isMitkImage, noDiffusionImage); mitk::NodePredicateProperty::Pointer isBinaryPredicate = mitk::NodePredicateProperty::New("binary", mitk::BoolProperty::New(true)); finalPredicate = mitk::NodePredicateAnd::New(finalPredicate, isBinaryPredicate); m_Controls->m_TrackingMaskImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingSeedImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingStopImageBox->SetPredicate(finalPredicate); m_Controls->m_TrackingRawImageBox->SetPredicate(isDiffusionImage); m_Controls->m_TrackingMaskImageBox->SetZeroEntryText("--"); m_Controls->m_TrackingSeedImageBox->SetZeroEntryText("--"); m_Controls->m_TrackingStopImageBox->SetZeroEntryText("--"); AddTrainingWidget(); UpdateGui(); } } void QmitkMLBTView::AddTrainingWidget() { std::shared_ptr tw = std::make_shared(); tw->SetDataStorage(this->GetDataStorage()); m_Controls->m_TwFrame->layout()->addWidget(tw.get()); m_TrainingWidgets.push_back(tw); } void QmitkMLBTView::RemoveTrainingWidget() { if(m_TrainingWidgets.size()>1) { m_TrainingWidgets.back().reset(); m_TrainingWidgets.pop_back(); } } void QmitkMLBTView::UpdateGui() { if (m_ForestHandler.IsForestValid()) { std::string label_text="Random forest available: "+m_LastLoadedForestName; m_Controls->statusLabel->setText( QString(label_text.c_str()) ); m_Controls->m_SaveForestButton->setEnabled(true); m_Controls->m_StartTrackingButton->setEnabled(true); } else { m_Controls->statusLabel->setText("Please load or train random forest!"); m_Controls->m_SaveForestButton->setEnabled(false); m_Controls->m_StartTrackingButton->setEnabled(false); } } void QmitkMLBTView::AbortTracking() { if (tracker.IsNotNull()) { tracker->m_AbortTracking = true; } } void QmitkMLBTView::PauseTracking() { if (tracker.IsNotNull()) { tracker->m_PauseTracking = !tracker->m_PauseTracking; } } void QmitkMLBTView::ChangeTimerInterval(int value) { m_TrackingTimer->setInterval(value); } void QmitkMLBTView::ToggleDemoMode(int state) { if (tracker.IsNotNull()) { tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); tracker->m_Stop = false; } } void QmitkMLBTView::BuildFibers() { if (m_Controls->m_DemoModeBox->isChecked() && tracker.IsNotNull() && tracker->m_BuildFibersFinished) { vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); outFib->SetFiberColors(255,255,255); m_TractogramNode->SetData(outFib); m_SamplingPointsNode->SetData(tracker->m_SamplingPointset); m_AlternativePointsNode->SetData(tracker->m_AlternativePointset); mitk::RenderingManager::GetInstance()->RequestUpdateAll(); tracker->m_BuildFibersFinished = false; tracker->m_BuildFibersReady = 0; tracker->m_Stop = false; } } void QmitkMLBTView::LoadForest() { QString filename = QFileDialog::getOpenFileName(0, tr("Load Forest"), QDir::currentPath(), tr("HDF5 random forest file (*.rf)") ); if(filename.isEmpty() || filename.isNull()) return; m_ForestHandler.LoadForest( filename.toStdString() ); QFileInfo fi( filename ); m_LastLoadedForestName = QString( fi.baseName() + "." + fi.completeSuffix() ).toStdString(); UpdateGui(); } void QmitkMLBTView::StartTrackingThread() { m_TractogramNode = mitk::DataNode::New(); m_TractogramNode->SetName("MLBT Result"); m_TractogramNode->SetProperty("Fiber2DSliceThickness", mitk::FloatProperty::New(20)); m_TractogramNode->SetProperty("Fiber2DfadeEFX", mitk::BoolProperty::New(false)); m_TractogramNode->SetProperty("LineWidth", mitk::IntProperty::New(2)); m_TractogramNode->SetProperty("color",mitk::ColorProperty::New(0, 1, 1)); this->GetDataStorage()->Add(m_TractogramNode); m_SamplingPointsNode = mitk::DataNode::New(); m_SamplingPointsNode->SetName("SamplingPoints"); m_SamplingPointsNode->SetProperty("pointsize", mitk::FloatProperty::New(0.2)); m_SamplingPointsNode->SetProperty("color", mitk::ColorProperty::New(1,1,1)); mitk::PointSetShapeProperty::Pointer bla = mitk::PointSetShapeProperty::New(); bla->SetValue(8); m_SamplingPointsNode->SetProperty("Pointset.2D.shape", bla); m_SamplingPointsNode->SetProperty("Pointset.2D.distance to plane", mitk::FloatProperty::New(1.5)); m_SamplingPointsNode->SetProperty("point 2D size", mitk::FloatProperty::New(0.1)); m_SamplingPointsNode->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_SamplingPointsNode); m_AlternativePointsNode = mitk::DataNode::New(); m_AlternativePointsNode->SetName("AlternativePoints"); m_AlternativePointsNode->SetProperty("pointsize", mitk::FloatProperty::New(0.2)); m_AlternativePointsNode->SetProperty("color", mitk::ColorProperty::New(1,0,0)); m_AlternativePointsNode->SetProperty("Pointset.2D.shape", bla); m_AlternativePointsNode->SetProperty("Pointset.2D.distance to plane", mitk::FloatProperty::New(1.5)); m_AlternativePointsNode->SetProperty("point 2D size", mitk::FloatProperty::New(0.1)); m_AlternativePointsNode->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_AlternativePointsNode); QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTracking ); m_TrackingWatcher.setFuture(future); m_TrackingThreadIsRunning = true; m_Controls->m_StartTrackingButton->setEnabled(false); m_TrackingTimer->start(m_Controls->m_TimerIntervalBox->value()); } void QmitkMLBTView::OnTrackingThreadStop() { m_TrackingThreadIsRunning = false; m_Controls->m_StartTrackingButton->setEnabled(true); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); m_TractogramNode->SetData(outFib); m_TractogramNode->SetProperty("Fiber2DSliceThickness", mitk::FloatProperty::New(1)); if (m_Controls->m_DemoModeBox->isChecked()) { m_SamplingPointsNode->SetData(tracker->m_SamplingPointset); m_AlternativePointsNode->SetData(tracker->m_AlternativePointset); } tracker = NULL; m_TrackingTimer->stop(); mitk::RenderingManager::GetInstance()->RequestUpdateAll(); } void QmitkMLBTView::StartTracking() { if ( m_Controls->m_TrackingRawImageBox->GetSelectedNode().IsNull() || !m_ForestHandler.IsForestValid()) return; mitk::Image::Pointer dwi = dynamic_cast(m_Controls->m_TrackingRawImageBox->GetSelectedNode()->GetData()); - m_ForestHandler.AddRawData(dwi); + m_ForestHandler.AddDwi(dwi); // int numThread = itk::MultiThreader::GetGlobalDefaultNumberOfThreads(); tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); if (m_Controls->m_DemoModeBox->isChecked()) tracker->SetNumberOfThreads(1); if (m_Controls->m_TrackingMaskImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer mask = dynamic_cast(m_Controls->m_TrackingMaskImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(mask, itkMask); tracker->SetMaskImage(itkMask); } if (m_Controls->m_TrackingSeedImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingSeedImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkImg); tracker->SetSeedImage(itkImg); } if (m_Controls->m_TrackingStopImageBox->GetSelectedNode().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingStopImageBox->GetSelectedNode()->GetData()); ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkImg); tracker->SetStoppingRegions(itkImg); } tracker->SetSeedsPerVoxel(m_Controls->m_NumberOfSeedsBox->value()); tracker->SetStepSize(m_Controls->m_TrackingStepSizeBox->value()); tracker->SetMinTractLength(m_Controls->m_MinLengthBox->value()); tracker->SetMaxTractLength(m_Controls->m_MaxLengthBox->value()); tracker->SetAposterioriCurvCheck(m_Controls->m_Curvcheck2->isChecked()); tracker->SetRemoveWmEndFibers(false); tracker->SetAvoidStop(m_Controls->m_AvoidStop->isChecked()); tracker->SetForestHandler(m_ForestHandler); tracker->SetSamplingDistance(m_Controls->m_SamplingDistanceBox->value()); - tracker->SetNumberOfSamples(m_Controls->m_NumSamplesBox->value()); + tracker->SetDeflectionMod(m_Controls->m_DeflectionModBox->value()); tracker->SetRandomSampling(m_Controls->m_RandomSampling->isChecked()); + tracker->SetUseStopVotes(m_Controls->m_UseStopVotes->isChecked()); + tracker->SetOnlyForwardSamples(m_Controls->m_OnlyForwardSamples->isChecked()); tracker->Update(); } void QmitkMLBTView::SaveForest() { if (!m_ForestHandler.IsForestValid()) { UpdateGui(); return; } QString filename = QFileDialog::getSaveFileName(0, tr("Save Forest"), QDir::currentPath()+"/forest.rf", tr("HDF5 random forest file (*.rf)") ); if(filename.isEmpty() || filename.isNull()) return; if(!filename.endsWith(".rf")) filename += ".rf"; m_ForestHandler.SaveForest( filename.toStdString() ); } void QmitkMLBTView::StartTrainingThread() { if (!this->IsTrainingInputValid()) { QMessageBox::warning(nullptr, "Training aborted", "Training could not be started. Not all necessary datasets were selected."); return; } QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTraining ); m_TrainingWatcher.setFuture(future); m_Controls->m_StartTrainingButton->setEnabled(false); m_Controls->m_SaveForestButton->setEnabled(false); m_Controls->m_LoadForestButton->setEnabled(false); } void QmitkMLBTView::OnTrainingThreadStop() { m_Controls->m_StartTrainingButton->setEnabled(true); m_Controls->m_SaveForestButton->setEnabled(true); m_Controls->m_LoadForestButton->setEnabled(true); UpdateGui(); } void QmitkMLBTView::StartTraining() { std::vector< mitk::Image::Pointer > m_SelectedDiffImages; std::vector< mitk::FiberBundle::Pointer > m_SelectedFB; std::vector< ItkUcharImgType::Pointer > m_MaskImages; std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; for (auto w : m_TrainingWidgets) { m_SelectedDiffImages.push_back(dynamic_cast(w->GetImage()->GetData())); m_SelectedFB.push_back(dynamic_cast(w->GetFibers()->GetData())); if (w->GetMask().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(w->GetMask()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkMask); m_MaskImages.push_back(itkMask); } else m_MaskImages.push_back(nullptr); if (w->GetWhiteMatter().IsNotNull()) { mitk::Image::Pointer img = dynamic_cast(w->GetWhiteMatter()->GetData()); ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); mitk::CastToItkImage(img, itkMask); m_WhiteMatterImages.push_back(itkMask); } else m_WhiteMatterImages.push_back(nullptr); } - m_ForestHandler.SetRawData(m_SelectedDiffImages); + m_ForestHandler.SetDwis(m_SelectedDiffImages); m_ForestHandler.SetTractograms(m_SelectedFB); m_ForestHandler.SetMaskImages(m_MaskImages); m_ForestHandler.SetWhiteMatterImages(m_WhiteMatterImages); m_ForestHandler.SetNumTrees(m_Controls->m_NumTreesBox->value()); m_ForestHandler.SetMaxTreeDepth(m_Controls->m_MaxDepthBox->value()); m_ForestHandler.SetGrayMatterSamplesPerVoxel(m_Controls->m_GmSamplingBox->value()); m_ForestHandler.SetSampleFraction(m_Controls->m_SampleFractionBox->value()); m_ForestHandler.SetStepSize(m_Controls->m_TrainingStepSizeBox->value()); m_ForestHandler.StartTraining(); } void QmitkMLBTView::StdMultiWidgetAvailable (QmitkStdMultiWidget &stdMultiWidget) { m_MultiWidget = &stdMultiWidget; } void QmitkMLBTView::StdMultiWidgetNotAvailable() { m_MultiWidget = NULL; } void QmitkMLBTView::Activated() { } bool QmitkMLBTView::IsTrainingInputValid(void) const { for (auto widget : m_TrainingWidgets) { if (widget->GetImage().IsNull() || widget->GetFibers().IsNull()) { return false; } } return true; } diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h index 4b34f9b68b..cf3d7690c0 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h @@ -1,112 +1,112 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include "ui_QmitkMLBTViewControls.h" #ifndef Q_MOC_RUN #include "mitkDataStorage.h" #include "mitkDataStorageSelection.h" #include #include #endif #include #include #include #include #include #include /*! \brief View to perform machine learning based fiber tractography. Includes training of the random forst classifier as well as the actual tractography. */ // Forward Qt class declarations class QmitkMLBTView : public QmitkFunctionality { // this is needed for all Qt objects that should have a Qt meta-object // (everything that derives from QObject and wants to have signal/slots) Q_OBJECT public: static const std::string VIEW_ID; typedef itk::Image ItkUcharImgType; - typedef itk::MLBSTrackingFilter<> TrackerType; + typedef itk::MLBSTrackingFilter<6,100> TrackerType; QmitkMLBTView(); virtual ~QmitkMLBTView(); virtual void CreateQtPartControl(QWidget *parent) override; virtual void StdMultiWidgetAvailable (QmitkStdMultiWidget &stdMultiWidget) override; virtual void StdMultiWidgetNotAvailable() override; virtual void Activated() override; protected slots: void StartTrackingThread(); void OnTrackingThreadStop(); void StartTrainingThread(); void OnTrainingThreadStop(); void SaveForest(); void LoadForest(); void BuildFibers(); void ChangeTimerInterval(int value); void ToggleDemoMode(int state); void PauseTracking(); void AbortTracking(); void AddTrainingWidget(); void RemoveTrainingWidget(); protected: void StartTracking(); void StartTraining(); void UpdateGui(); Ui::QmitkMLBTViewControls* m_Controls; QmitkStdMultiWidget* m_MultiWidget; - mitk::TrackingForestHandler<> m_ForestHandler; + mitk::TrackingForestHandler<6,100> m_ForestHandler; QFutureWatcher m_TrainingWatcher; QFutureWatcher m_TrackingWatcher; bool m_TrackingThreadIsRunning; TrackerType::Pointer tracker; std::shared_ptr m_TrackingTimer; mitk::DataNode::Pointer m_TractogramNode; mitk::DataNode::Pointer m_SamplingPointsNode; mitk::DataNode::Pointer m_AlternativePointsNode; std::vector< std::shared_ptr > m_TrainingWidgets; private: bool IsTrainingInputValid(void) const; std::string m_LastLoadedForestName; }; diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui index 23a0ba0b40..44ead02172 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui @@ -1,762 +1,785 @@ QmitkMLBTViewControls 0 0 696 907 Form Please load or train random forest! Qt::AlignCenter 1 0 0 678 - 804 + 813 Training Save Forest QFrame::NoFrame QFrame::Raised 0 0 0 0 Add additional training data pair ... :/org_mitk_icons/icons/tango/scalable/actions/list-add.svg:/org_mitk_icons/icons/tango/scalable/actions/list-add.svg Remove training data pair ... :/org_mitk_icons/icons/tango/scalable/actions/list-remove.svg:/org_mitk_icons/icons/tango/scalable/actions/list-remove.svg QFrame::NoFrame QFrame::Raised 0 0 0 0 0 6 QFrame::NoFrame QFrame::Raised 0 0 0 0 0 Input DWI: Qt::AlignCenter Reference Tractogram: Qt::AlignCenter Mask: Qt::AlignCenter WM: Qt::AlignCenter QFrame::NoFrame QFrame::Raised 0 0 0 0 Maximum tree depth. 1 999999999 - 30 + 25 Non-WM Sampling Points: Fiber Sampling: Number of tress in the final random forest. 1 999999999 - 50 + 30 Fiber sampling in mm. Determines the number of white-matter samples (-1 = auto). 3 -1.000000000000000 999.000000000000000 0.100000000000000 -1.000000000000000 Num. Trees: Number of sampling points outside of the white-matter (-1 = automatic estimation). -1 999999999 -1 Max. Depth: Sample Fraction: Fraction of samples used to train each tree. 3 1.000000000000000 0.100000000000000 - 1.000000000000000 + 0.700000000000000 Start Training Qt::Vertical 20 40 0 0 678 - 804 + 813 Tractography - - + + + + Random sampling + + + false + + + + + + + Secondary curvature check + + + false + + + + + QFrame::NoFrame QFrame::Raised - + 0 0 0 0 - - + + + + + + + + + + + - Demo Mode + Mask Image: - - - - 1 - - - 1000 + + + + Seed Image: - - 10 + + + + + + Stop Image: - - + + - Random sampling + Avoid premature termination - false + true - - + + QFrame::NoFrame QFrame::Raised - + 0 0 0 0 - - - - - - - - + + + + Pause tractography + + + ... + + + + :/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg + + - - + + + + Start tractography + - Mask Image: + ... + + + + :/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg - - + + + + Abort tractography + - Seed Image: + ... + + + + :/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg - - + + + + + + + Load Forest + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + - Stop Image: + Demo Mode + + + + + + + 1 + + + 1000 + + + 10 + + + + Use stop votes + + + true + + + QFrame::NoFrame QFrame::Raised 0 0 0 0 Number of seeds per voxel. 1 999 999999999.000000000000000 1.000000000000000 400.000000000000000 999999999.000000000000000 1.000000000000000 20.000000000000000 - - - - 999999999 - - - 50 - - - Sampling Distance: Min. Length Num. Seeds: 0.100000000000000 0.500000000000000 Max. Length Step Size: Input DWI: - Num. Samples: + Deflection distance modifier: 0.500000000000000 - - - - - - - Avoid premature termination - - - true - - - - - - - QFrame::NoFrame - - - QFrame::Raised - - - - 0 - - - 0 - - - 0 - - - 0 - - - - - Pause tractography - - - ... - - - - :/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg - - - - - - - Start tractography - - - ... - - - - :/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg - - - - - - - Abort tractography + + + + 1.000000000000000 - - ... + + 0.100000000000000 - - - :/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg + + 1.000000000000000 - - - - Qt::Vertical - - - - 20 - 40 - - - - - - + + - Secondary curvature check + Only frontal samples true - - - - Load Forest - - - QmitkDataStorageComboBox QComboBox
QmitkDataStorageComboBox.h
QmitkDataStorageComboBoxWithSelectNone QComboBox
QmitkDataStorageComboBoxWithSelectNone.h