diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp index 256952705b..1f43985c82 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -1,782 +1,796 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_txx #define __itkMLBSTrackingFilter_txx #include #include #include #include "itkMLBSTrackingFilter.h" #include #include #include #include #include #define _USE_MATH_DEFINES #include namespace itk { template< int NumImageFeatures > MLBSTrackingFilter< NumImageFeatures > ::MLBSTrackingFilter() : m_FiberPolyData(NULL) , m_Points(NULL) , m_Cells(NULL) , m_AngularThreshold(0.7) , m_StepSize(0) , m_MaxLength(10000) , m_MinTractLength(20.0) , m_MaxTractLength(400.0) , m_SeedsPerVoxel(1) , m_UseDirection(true) , m_NumberOfSamples(50) , m_SamplingDistance(-1) , m_SeedImage(NULL) , m_MaskImage(NULL) , m_DecisionForest(NULL) , m_StoppingRegions(NULL) , m_DemoMode(false) , m_PauseTracking(false) , m_AbortTracking(false) + , m_RemoveWmEndFibers(false) + , m_AposterioriCurvCheck(false) + , m_AvoidStop(true) { this->SetNumberOfRequiredInputs(1); } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures > ::RoundToNearest(double num) { return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::BeforeThreadedGenerateData() { m_InputImage = const_cast(this->GetInput(0)); PreprocessRawData(); m_FiberPolyData = PolyDataType::New(); m_Points = vtkSmartPointer< vtkPoints >::New(); m_Cells = vtkSmartPointer< vtkCellArray >::New(); m_ImageSize.resize(3); m_ImageSize[0] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[0]; m_ImageSize[1] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[1]; m_ImageSize[2] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[2]; m_ImageSpacing.resize(3); m_ImageSpacing[0] = m_FeatureImage->GetSpacing()[0]; m_ImageSpacing[1] = m_FeatureImage->GetSpacing()[1]; m_ImageSpacing[2] = m_FeatureImage->GetSpacing()[2]; double minSpacing; if(m_ImageSpacing[0]GetNumberOfThreads(); i++) { PolyDataType poly = PolyDataType::New(); m_PolyDataContainer.push_back(poly); } m_NotWmImage = ItkDoubleImgType::New(); m_NotWmImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_NotWmImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_NotWmImage->SetDirection( m_FeatureImage->GetDirection() ); m_NotWmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_NotWmImage->Allocate(); m_NotWmImage->FillBuffer(0); m_WmImage = ItkDoubleImgType::New(); m_WmImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_WmImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_WmImage->SetDirection( m_FeatureImage->GetDirection() ); m_WmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_WmImage->Allocate(); m_WmImage->FillBuffer(0); m_AvoidStopImage = ItkDoubleImgType::New(); m_AvoidStopImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_AvoidStopImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_AvoidStopImage->SetDirection( m_FeatureImage->GetDirection() ); m_AvoidStopImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_AvoidStopImage->Allocate(); m_AvoidStopImage->FillBuffer(0); if (m_StoppingRegions.IsNull()) { m_StoppingRegions = ItkUcharImgType::New(); m_StoppingRegions->SetSpacing( m_FeatureImage->GetSpacing() ); m_StoppingRegions->SetOrigin( m_FeatureImage->GetOrigin() ); m_StoppingRegions->SetDirection( m_FeatureImage->GetDirection() ); m_StoppingRegions->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_StoppingRegions->Allocate(); m_StoppingRegions->FillBuffer(0); } if (m_SeedImage.IsNull()) { m_SeedImage = ItkUcharImgType::New(); m_SeedImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_SeedImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_SeedImage->SetDirection( m_FeatureImage->GetDirection() ); m_SeedImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_SeedImage->Allocate(); m_SeedImage->FillBuffer(1); } if (m_MaskImage.IsNull()) { // initialize mask image m_MaskImage = ItkUcharImgType::New(); m_MaskImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_MaskImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_MaskImage->SetDirection( m_FeatureImage->GetDirection() ); m_MaskImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1); } else std::cout << "MLBSTrackingFilter: using mask image" << std::endl; if (m_AngularThreshold<0.0) m_AngularThreshold = 0.5*minSpacing; m_BuildFibersReady = 0; m_BuildFibersFinished = false; m_Threads = 0; m_Tractogram.clear(); std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Number of samples: " << m_NumberOfSamples << std::endl; std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::PreprocessRawData() { typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; std::cout << "MLBSTrackingFilter: Spherical signal interpolation and sampling ..." << std::endl; typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); filter->SetGradientImage( m_GradientDirections, m_InputImage ); filter->SetBValue( m_B_Value ); filter->SetLambda(0.006); filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); filter->Update(); // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); // featureImageVector.push_back(itkFeatureImage); std::cout << "MLBSTrackingFilter: Creating feature image ..." << std::endl; vnl_vector_fixed ref; ref.fill(0); ref[0]=1; itk::OrientationDistributionFunction< double, NumImageFeatures*2 > odf; m_DirectionIndices.clear(); for (unsigned int f=0; f0) // only used directions on one hemisphere m_DirectionIndices.push_back(f); } m_FeatureImage = FeatureImageType::New(); m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection()); m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->Allocate(); itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { typename FeatureImageType::PixelType pix; for (unsigned int f=0; fSetPixel(it.GetIndex(), pix); ++it; } } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) { // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); // dir = rot*dir; dir *= m_StepSize; pos[0] += dir[0]; pos[1] += dir[1]; pos[2] += dir[2]; } template< int NumImageFeatures > bool MLBSTrackingFilter< NumImageFeatures > ::IsValidPosition(itk::Point &pos) { typename FeatureImageType::IndexType idx; m_FeatureImage->TransformPhysicalPointToIndex(pos, idx); if (!m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) return false; return true; } template< int NumImageFeatures > typename MLBSTrackingFilter< NumImageFeatures >::FeatureImageType::PixelType MLBSTrackingFilter< NumImageFeatures >::GetImageValues(itk::Point itkP) { itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx); m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx); typename FeatureImageType::PixelType pix; pix.Fill(0.0); if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) ) pix = m_FeatureImage->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1) { vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = m_FeatureImage->GetPixel(idx) * interpWeights[0]; typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7]; } } template< int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop) { vnl_vector_fixed direction; direction.fill(0); vigra::MultiArray<2, double> featureData; if(m_UseDirection) featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumImageFeatures+3) ); else featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumImageFeatures) ); typename FeatureImageType::PixelType featurePixel = GetImageValues(pos); // pixel values for (unsigned int f=0; f ref; ref.fill(0); ref[0]=1; for (unsigned int f=NumImageFeatures; f probs(vigra::Shape2(1, m_DecisionForest->class_count())); m_DecisionForest->predictProbabilities(featureData, probs); double outProb = 0; prob = 0; candidates = 0; // directions with probability > 0 for (int i=0; iclass_count(); i++) { if (probs(0,i)>0) { int classLabel = 0; m_DecisionForest->ext_param_.to_classlabel(i, classLabel); if (classLabel d = m_ODF.GetDirection(m_DirectionIndices.at(classLabel)); double dot = dot_product(d, olddir); if (olddir.magnitude()>0) { if (fabs(dot)>angularThreshold) { if (dot<0) d *= -1; dot = fabs(dot); direction += probs(0,i)*dot*d; prob += probs(0,i)*dot; } } else { direction += probs(0,i)*d; prob += probs(0,i); } } else outProb += probs(0,i); } } ItkDoubleImgType::IndexType idx; m_NotWmImage->TransformPhysicalPointToIndex(pos, idx); if (m_NotWmImage->GetLargestPossibleRegion().IsInside(idx)) { m_NotWmImage->SetPixel(idx, m_NotWmImage->GetPixel(idx)+outProb); m_WmImage->SetPixel(idx, m_WmImage->GetPixel(idx)+prob); } if (outProb>prob && prob>0) { candidates = 0; prob = 0; direction.fill(0.0); } if (avoidStop && m_AvoidStopImage->GetLargestPossibleRegion().IsInside(idx) && candidates>0 && direction.magnitude()>0.001) m_AvoidStopImage->SetPixel(idx, m_AvoidStopImage->GetPixel(idx)+0.1); return direction; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::GetRandDouble(double min, double max) { return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; } template< int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::GetNewDirection(itk::Point &pos, vnl_vector_fixed& olddir) { vnl_vector_fixed direction; direction.fill(0); ItkUcharImgType::IndexType idx; m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); if (m_StoppingRegions->GetPixel(idx)>0) return direction; if (olddir.magnitude()>0) olddir.normalize(); int candidates = 0; // number of directions with probability > 0 double prob = 0; direction = Classify(pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood direction *= prob; + + itk::OrientationDistributionFunction< double, 50 > probeVecs; + for (int i=0; i probe; - probe[0] = GetRandDouble()*m_SamplingDistance; - probe[1] = GetRandDouble()*m_SamplingDistance; - probe[2] = GetRandDouble()*m_SamplingDistance; +// probe[0] = GetRandDouble()*m_SamplingDistance; +// probe[1] = GetRandDouble()*m_SamplingDistance; +// probe[2] = GetRandDouble()*m_SamplingDistance; + + probe = probeVecs.GetDirection(i)*m_SamplingDistance; itk::Point temp; temp[0] = pos[0] + probe[0]; temp[1] = pos[1] + probe[1]; temp[2] = pos[2] + probe[2]; candidates = 0; vnl_vector_fixed tempDir = Classify(temp, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) { direction += tempDir*prob; } - else if (candidates==0 && olddir.magnitude()>0) // out of white matter + else if (m_AvoidStop && candidates==0 && olddir.magnitude()>0) // out of white matter { vnl_vector_fixed normProbe = -probe; normProbe.normalize(); double dot = dot_product(normProbe, olddir); if (dot < 0.0) { probe = (normProbe - 2 * dot*olddir)*probe.magnitude(); // reflect } else { probe = -probe; // invert } // look a bit further into the other direction temp[0] = pos[0] + probe[0]*2; temp[1] = pos[1] + probe[1]*2; temp[2] = pos[2] + probe[2]*2; candidates = 0; vnl_vector_fixed tempDir = Classify(temp, candidates, olddir, m_AngularThreshold, prob, true); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? { direction += probe; // go into the direction of the white matter direction += tempDir*prob; // go into the direction of the white matter direction at this location } } } if (direction.magnitude()>0.001) { direction.normalize(); olddir[0] = direction[0]; olddir[1] = direction[1]; olddir[2] = direction[2]; } else direction.fill(0); return direction; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) { vnl_vector_fixed dirOld = dir; dirOld = dir; for (int step=0; step< m_MaxLength/2; step++) { while (m_PauseTracking){} if (m_DemoMode) { m_Mutex.Lock(); m_BuildFibersReady++; m_Tractogram.push_back(*fib); BuildFibers(true); m_Stop = true; m_Mutex.Unlock(); while (m_Stop){} } // get new position CalculateNewPosition(pos, dir); // is new position inside of image and mask if (!IsValidPosition(pos) || m_AbortTracking) // if not end streamline { return tractLength; } else // if yes, add new point to streamline { tractLength += m_StepSize; if (front) fib->push_front(pos); else fib->push_back(pos); - int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? - if (curv>0) + if (m_AposterioriCurvCheck) { - tractLength -= m_StepSize*curv; - while (curv>0) + int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? + if (curv>0) { - if (front) - fib->pop_front(); - else - fib->pop_back(); - curv--; + tractLength -= m_StepSize*curv; + while (curv>0) + { + if (front) + fib->pop_front(); + else + fib->pop_back(); + curv--; + } + return tractLength; } - return tractLength; } if (tractLength>m_MaxTractLength) return tractLength; } dir = GetNewDirection(pos, dirOld); if (dir.magnitude()<0.0001) return tractLength; } return tractLength; } template< int NumImageFeatures > int MLBSTrackingFilter::CheckCurvature(FiberType* fib, bool front) { double m_Distance = 5; if (fib->size()<3) return 0; double dist = 0; std::vector< vnl_vector_fixed< float, 3 > > vectors; vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); double dev = 0; if (front) { int c=0; while(distsize()-1) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c+1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==0) meanV += v; c++; } } else { int c=fib->size()-1; while(dist0) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c-1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==fib->size()-1) meanV += v; c--; } } meanV.normalize(); for (int c=0; c1.0) angle = 1.0; if (angle<-1.0) angle = -1.0; dev += acos(angle)*180/M_PI; } if (vectors.size()>0) dev /= vectors.size(); if (dev<30) return 0; else return vectors.size(); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) { m_Mutex.Lock(); m_Threads++; m_Mutex.Unlock(); typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; MaskIteratorType sit(m_SeedImage, regionForThread ); MaskIteratorType mit(m_MaskImage, regionForThread ); sit.GoToBegin(); mit.GoToBegin(); itk::Point worldPos; while( !sit.IsAtEnd() ) { if (sit.Value()==0 || mit.Value()==0) { ++sit; ++mit; continue; } for (int s=0; s start; unsigned int counter = 0; if (m_SeedsPerVoxel>1) { start[0] = index[0]+GetRandDouble(-0.5, 0.5); start[1] = index[1]+GetRandDouble(-0.5, 0.5); start[2] = index[2]+GetRandDouble(-0.5, 0.5); } else { start[0] = index[0]; start[1] = index[1]; start[2] = index[2]; } // get staring position m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); // get starting direction int candidates = 0; double prob = 0; vnl_vector_fixed dirOld; dirOld.fill(0.0); vnl_vector_fixed dir = Classify(worldPos, candidates, dirOld, 0, prob); if (dir.magnitude()<0.0001) continue; // forward tracking tractLength = FollowStreamline(threadId, worldPos, dir, &fib, 0, false); fib.push_front(worldPos); + + if (m_RemoveWmEndFibers) { itk::Point check = fib.back(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } // backward tracking tractLength = FollowStreamline(threadId, worldPos, -dir, &fib, tractLength, true); counter = fib.size(); + if (m_RemoveWmEndFibers) { itk::Point check = fib.front(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } if (tractLength void MLBSTrackingFilter< NumImageFeatures >::BuildFibers(bool check) { if (m_BuildFibersReady::New(); vtkSmartPointer vNewLines = vtkSmartPointer::New(); vtkSmartPointer vNewPoints = vtkSmartPointer::New(); for (int i=0; i container = vtkSmartPointer::New(); FiberType fib = m_Tractogram.at(i); for (FiberType::iterator it = fib.begin(); it!=fib.end(); it++) { vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); container->GetPointIds()->InsertNextId(id); } vNewLines->InsertNextCell(container); } if (check) for (int i=0; iSetPoints(vNewPoints); m_FiberPolyData->SetLines(vNewLines); m_BuildFibersFinished = true; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::AfterThreadedGenerateData() { MITK_INFO << "Generating polydata "; BuildFibers(false); MITK_INFO << "done"; } } #endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h index d3fc9fecf5..5e1a62d0e2 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h @@ -1,184 +1,191 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ /*=================================================================== This file is based heavily on a corresponding ITK filter. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_h_ #define __itkMLBSTrackingFilter_h_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include // classification includes #include #include #include namespace itk{ /** * \brief Performes deterministic streamline tracking on the input tensor image. */ template< int NumImageFeatures=100 > class MLBSTrackingFilter : public ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > { public: typedef MLBSTrackingFilter Self; typedef SmartPointer Pointer; typedef SmartPointer ConstPointer; typedef ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > Superclass; typedef vigra::RandomForest DecisionForestType; typedef typename Superclass::InputImageType InputImageType; typedef typename Superclass::InputImageRegionType InputImageRegionType; typedef Image< Vector< float, NumImageFeatures > , 3 > FeatureImageType; /** Method for creation through the object factory. */ itkFactorylessNewMacro(Self) itkCloneMacro(Self) /** Runtime information support. */ itkTypeMacro(MLBSTrackingFilter, ImageToImageFilter) typedef itk::Image ItkUcharImgType; typedef itk::Image ItkDoubleImgType; typedef itk::Image ItkFloatImgType; typedef vtkSmartPointer< vtkPolyData > PolyDataType; typedef std::deque< itk::Point > FiberType; typedef std::vector< FiberType > BundleType; bool m_PauseTracking; bool m_AbortTracking; bool m_BuildFibersFinished; int m_BuildFibersReady; bool m_Stop; // void RequestFibers(){ m_Stop=true; m_BuildFibersReady=0; m_BuildFibersFinished=false; } itkGetMacro( FiberPolyData, PolyDataType ) ///< Output fibers itkSetMacro( SeedImage, ItkUcharImgType::Pointer) ///< Seeds are only placed inside of this mask. itkSetMacro( MaskImage, ItkUcharImgType::Pointer) ///< Tracking is only performed inside of this mask image. itkSetMacro( SeedsPerVoxel, int) ///< One seed placed in the center of each voxel or multiple seeds randomly placed inside each voxel. itkSetMacro( StepSize, double) ///< Integration step size in mm itkSetMacro( MinTractLength, double ) ///< Shorter tracts are discarded. itkSetMacro( MaxTractLength, double ) itkSetMacro( AngularThreshold, double ) itkSetMacro( UseDirection, bool ) itkSetMacro( SamplingDistance, double ) itkSetMacro( NumberOfSamples, int ) itkSetMacro( StoppingRegions, ItkUcharImgType::Pointer) itkSetMacro( B_Value, float ) itkSetMacro( GradientDirections, mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer ) itkSetMacro( DemoMode, bool ) + itkSetMacro( RemoveWmEndFibers, bool ) + itkSetMacro( AposterioriCurvCheck, bool ) + itkSetMacro( AvoidStop, bool ) void SetDecisionForest( DecisionForestType* forest ) { m_DecisionForest = forest; } itkGetMacro( WmImage, ItkDoubleImgType::Pointer ) itkGetMacro( NotWmImage, ItkDoubleImgType::Pointer ) itkGetMacro( AvoidStopImage, ItkDoubleImgType::Pointer ) protected: MLBSTrackingFilter(); ~MLBSTrackingFilter() {} void CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir); ///< Calculate next integration step. double FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front); ///< Start streamline in one direction. bool IsValidPosition(itk::Point& pos); ///< Are we outside of the mask image? vnl_vector_fixed GetNewDirection(itk::Point& pos, vnl_vector_fixed& olddir); vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop=false); typename FeatureImageType::PixelType GetImageValues(itk::Point itkP); double GetRandDouble(double min=-1, double max=1); double RoundToNearest(double num); void BeforeThreadedGenerateData(); void PreprocessRawData(); void ThreadedGenerateData( const InputImageRegionType &outputRegionForThread, ThreadIdType threadId); void AfterThreadedGenerateData(); PolyDataType m_FiberPolyData; vtkSmartPointer m_Points; vtkSmartPointer m_Cells; BundleType m_Tractogram; double m_AngularThreshold; double m_StepSize; int m_MaxLength; double m_MinTractLength; double m_MaxTractLength; int m_SeedsPerVoxel; bool m_UseDirection; double m_SamplingDistance; int m_NumberOfSamples; std::vector< int > m_ImageSize; std::vector< double > m_ImageSpacing; SimpleFastMutexLock m_Mutex; ItkUcharImgType::Pointer m_StoppingRegions; ItkDoubleImgType::Pointer m_WmImage; ItkDoubleImgType::Pointer m_NotWmImage; ItkDoubleImgType::Pointer m_AvoidStopImage; ItkUcharImgType::Pointer m_SeedImage; ItkUcharImgType::Pointer m_MaskImage; typename FeatureImageType::Pointer m_FeatureImage; typename InputImageType::Pointer m_InputImage; mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer m_GradientDirections; float m_B_Value; + bool m_AposterioriCurvCheck; + bool m_RemoveWmEndFibers; + bool m_AvoidStop; + int m_Threads; bool m_DemoMode; void BuildFibers(bool check); int CheckCurvature(FiberType* fib, bool front); // decision forest DecisionForestType* m_DecisionForest; - itk::OrientationDistributionFunction< double, NumImageFeatures*2 > m_ODF; + itk::OrientationDistributionFunction< double, NumImageFeatures*2 > m_ODF; std::vector< int > m_DirectionIndices; std::vector< PolyDataType > m_PolyDataContainer; private: }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkMLBSTrackingFilter.cpp" #endif #endif //__itkMLBSTrackingFilter_h_ diff --git a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt index 0c801dc612..84dd09cf5b 100755 --- a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt +++ b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt @@ -1,117 +1,128 @@ option(BUILD_DiffusionMiniApps "Build commandline tools for diffusion" OFF) if(BUILD_DiffusionMiniApps OR MITK_BUILD_ALL_APPS) + + find_package(OpenMP) + if(NOT OPENMP_FOUND) + message("OpenMP is not available.") + endif() + if(OPENMP_FOUND) + message(STATUS "Found OpenMP.") + set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") + endif() + # needed include directories include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ) # list of diffusion miniapps # if an app requires additional dependencies # they are added after a "^^" and separated by "_" set( diffusionminiapps DwiDenoising^^ ImageResampler^^ NetworkCreation^^MitkFiberTracking_MitkConnectomics NetworkStatistics^^MitkConnectomics ExportShImage^^ Fiberfox^^MitkFiberTracking MultishellMethods^^MitkFiberTracking PeaksAngularError^^MitkFiberTracking PeakExtraction^^MitkFiberTracking FiberExtraction^^MitkFiberTracking FiberProcessing^^MitkFiberTracking FiberDirectionExtraction^^MitkFiberTracking LocalDirectionalFiberPlausibility^^MitkFiberTracking StreamlineTracking^^MitkFiberTracking GibbsTracking^^MitkFiberTracking CopyGeometry^^ DiffusionIndices^^ TractometerMetrics^^MitkFiberTracking QballReconstruction^^ Registration^^ FileFormatConverter^^MitkFiberTracking TensorReconstruction^^ TensorDerivedMapsExtraction^^ DICOMLoader^^ DFTraining^^MitkFiberTracking DFTracking^^MitkFiberTracking ) foreach(diffusionminiapp ${diffusionminiapps}) # extract mini app name and dependencies string(REPLACE "^^" "\\;" miniapp_info ${diffusionminiapp}) set(miniapp_info_list ${miniapp_info}) list(GET miniapp_info_list 0 appname) list(GET miniapp_info_list 1 raw_dependencies) string(REPLACE "_" "\\;" dependencies "${raw_dependencies}") set(dependencies_list ${dependencies}) mitk_create_executable(${appname} DEPENDS MitkCore MitkDiffusionCore ${dependencies_list} PACKAGE_DEPENDS ITK CPP_FILES ${appname}.cpp mitkCommandLineParser.cpp ) if(EXECUTABLE_IS_ENABLED) # On Linux, create a shell script to start a relocatable application if(UNIX AND NOT APPLE) install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) endif() get_target_property(_is_bundle ${EXECUTABLE_TARGET} MACOSX_BUNDLE) if(APPLE) if(_is_bundle) set(_target_locations ${EXECUTABLE_TARGET}.app) set(${_target_locations}_qt_plugins_install_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_bundle_dest_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_qt_plugins_for_current_bundle ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_qt_conf_install_dirs ${EXECUTABLE_TARGET}.app/Contents/Resources) install(TARGETS ${EXECUTABLE_TARGET} BUNDLE DESTINATION . ) else() if(NOT MACOSX_BUNDLE_NAMES) set(_qt_conf_install_dirs bin) set(_target_locations bin/${EXECUTABLE_TARGET}) set(${_target_locations}_qt_plugins_install_dir bin) install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) else() foreach(bundle_name ${MACOSX_BUNDLE_NAMES}) list(APPEND _qt_conf_install_dirs ${bundle_name}.app/Contents/Resources) set(_current_target_location ${bundle_name}.app/Contents/MacOS/${EXECUTABLE_TARGET}) list(APPEND _target_locations ${_current_target_location}) set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) message( " set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) ") install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION ${bundle_name}.app/Contents/MacOS/) endforeach() endif() endif() else() set(_target_locations bin/${EXECUTABLE_TARGET}${CMAKE_EXECUTABLE_SUFFIX}) set(${_target_locations}_qt_plugins_install_dir bin) set(_qt_conf_install_dirs bin) install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) endif() endif() endforeach() # This mini app does not depend on mitkDiffusionImaging at all mitk_create_executable(Dicom2Nrrd DEPENDS MitkCore CPP_FILES Dicom2Nrrd.cpp mitkCommandLineParser.cpp ) # On Linux, create a shell script to start a relocatable application if(UNIX AND NOT APPLE) install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) endif() if(EXECUTABLE_IS_ENABLED) MITK_INSTALL_TARGETS(EXECUTABLES ${EXECUTABLE_TARGET}) endif() endif() diff --git a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp index dc9db5d88e..b8ff87d2d4 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp @@ -1,196 +1,199 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include //#include #include #include #include #include #include #include //#include #include #include #include #include #define _USE_MATH_DEFINES #include const int numOdfSamples = 200; typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription(""); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("image", "i", mitkCommandLineParser::String, "DWIs:", "input diffusion-weighted image", us::Any(), false); parser.addArgument("forest", "f", mitkCommandLineParser::String, "Forest:", "input forest", us::Any(), false); parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output fiberbundle", us::Any(), false); parser.addArgument("stop", "st", mitkCommandLineParser::String, "Stop image:", "stop image", us::Any()); parser.addArgument("mask", "m", mitkCommandLineParser::String, "Mask image:", "mask image", us::Any()); parser.addArgument("seed", "s", mitkCommandLineParser::String, "Seed image:", "seed image", us::Any()); parser.addArgument("athres", "a", mitkCommandLineParser::Float, "Angular threshold:", "angular threshold (in radians)", us::Any()); parser.addArgument("stepsize", "se", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); parser.addArgument("samples", "ns", mitkCommandLineParser::Int, "Samples:", "samples", us::Any()); parser.addArgument("samplingdist", "sd", mitkCommandLineParser::Float, "Sampling distance:", "sampling distance (in voxels)", us::Any()); parser.addArgument("seeds", "nse", mitkCommandLineParser::Int, "Seeds per voxel:", "seeds per voxel", us::Any()); parser.addArgument("usedirection", "ud", mitkCommandLineParser::Bool, "Use previous direction:", "use previous direction as feature", us::Any()); parser.addArgument("verbose", "v", mitkCommandLineParser::Bool, "Verbose:", "output additional images", us::Any()); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; string imageFile = us::any_cast(parsedArgs["image"]); string forestFile = us::any_cast(parsedArgs["forest"]); string outFile = us::any_cast(parsedArgs["out"]); string maskFile = ""; if (parsedArgs.count("mask")) maskFile = us::any_cast(parsedArgs["mask"]); string seedFile = ""; if (parsedArgs.count("seed")) seedFile = us::any_cast(parsedArgs["seed"]); string stopFile = ""; if (parsedArgs.count("stop")) stopFile = us::any_cast(parsedArgs["stop"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); float athres = 0.7; if (parsedArgs.count("athres")) athres = us::any_cast(parsedArgs["athres"]); float samplingdist = 0.25; if (parsedArgs.count("samplingdist")) samplingdist = us::any_cast(parsedArgs["samplingdist"]); bool useDirection = false; if (parsedArgs.count("usedirection")) useDirection = true; bool verbose = false; if (parsedArgs.count("verbose")) verbose = true; int samples = 10; if (parsedArgs.count("samples")) samples = us::any_cast(parsedArgs["samples"]); int seeds = 1; if (parsedArgs.count("seeds")) seeds = us::any_cast(parsedArgs["seeds"]); typedef itk::Image ItkUcharImgType; MITK_INFO << "loading diffusion-weighted image"; mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::LoadImage(imageFile).GetPointer()); ItkUcharImgType::Pointer mask; if (!maskFile.empty()) { MITK_INFO << "loading mask image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(maskFile).GetPointer()); mask = ItkUcharImgType::New(); mitk::CastToItkImage(img, mask); } ItkUcharImgType::Pointer seed; if (!seedFile.empty()) { MITK_INFO << "loading seed image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(seedFile).GetPointer()); seed = ItkUcharImgType::New(); mitk::CastToItkImage(img, seed); } ItkUcharImgType::Pointer stop; if (!stopFile.empty()) { MITK_INFO << "loading stop image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(stopFile).GetPointer()); stop = ItkUcharImgType::New(); mitk::CastToItkImage(img, stop); } MITK_INFO << "loading forest"; vigra::RandomForest rf; vigra::rf_import_HDF5(rf, forestFile); typedef itk::MLBSTrackingFilter<100> TrackerType; TrackerType::Pointer tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi)); tracker->SetGradientDirections( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi) ); tracker->SetB_Value( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi) ); tracker->SetMaskImage(mask); tracker->SetSeedImage(seed); tracker->SetStoppingRegions(stop); tracker->SetSeedsPerVoxel(seeds); tracker->SetUseDirection(useDirection); tracker->SetStepSize(stepsize); tracker->SetAngularThreshold(athres); tracker->SetDecisionForest(&rf); tracker->SetSamplingDistance(samplingdist); tracker->SetNumberOfSamples(samples); + //tracker->SetAvoidStop(false); + tracker->SetAposterioriCurvCheck(true); + tracker->SetRemoveWmEndFibers(true); tracker->Update(); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); mitk::IOUtil::SaveBaseData(outFib, outFile); if (verbose) { MITK_INFO << "Writing images..."; string outName = itksys::SystemTools::GetFilenamePath(outFile)+"/"+itksys::SystemTools::GetFilenameWithoutLastExtension(outFile); itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::Pointer writer = itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::New(); writer->SetFileName(outName+"_WhiteMatter.nrrd"); writer->SetInput(tracker->GetWmImage()); writer->Update(); writer->SetFileName(outName+"_NotWhiteMatter.nrrd"); writer->SetInput(tracker->GetNotWmImage()); writer->Update(); writer->SetFileName(outName+"_AvoidStop.nrrd"); writer->SetInput(tracker->GetAvoidStopImage()); writer->Update(); } return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp index 4c2eec7c8d..d4c2bc3d40 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp @@ -1,532 +1,529 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include const int numOdfSamples = 200; // ODF is sampled in 200 directions but actuyll only 100 are used (symmetric) typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; void TrainForest( vigra::RandomForest &rf, vigra::MultiArray<2, double> &labelData, vigra::MultiArray<2, double> &featureData, int numTrees, int max_tree_depth, double sample_fraction ) { MITK_INFO << "Maximum tree depths: " << max_tree_depth; MITK_INFO << "Sample fraction per tree: " << sample_fraction; MITK_INFO << "Number of trees: " << numTrees; vigra::rf::visitors::OOB_Error oob_v; - rf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal - rf.set_options().sample_with_replacement(true); // if sampled with replacement or not - rf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree - rf.set_options().tree_count(1); // Number of trees that are calculated; - rf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node - // rf.set_options().features_per_node(10); - - rf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(oob_v)); - - // Prepare parallel VariableImportance Calculation - int numMod = featureData.shape(1); - const int numClass = 2 + 2; - - float** varImp = new float*[numMod]; - - for(int i = 0; i < numMod; i++) - varImp[i] = new float[numClass]; - - for (int i = 0; i < numMod; ++i) - for (int j = 0; j < numClass; ++j) - varImp[i][j] = 0.0; - +// rf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal +// rf.set_options().sample_with_replacement(true); // if sampled with replacement or not +// rf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree +// rf.set_options().tree_count(1); // Number of trees that are calculated; +// rf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node +// rf.ext_param_.max_tree_depth = max_tree_depth; +// // rf.set_options().features_per_node(10); +// rf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(oob_v)); + + std::vector< vigra::RandomForest > trees; + int count = 0; #pragma omp parallel for - for (int i = 0; i < numTrees - 1; ++i) + for (int i = 0; i < numTrees; ++i) { vigra::RandomForest lrf; vigra::rf::visitors::OOB_Error loob_v; lrf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal lrf.set_options().sample_with_replacement(true); // if sampled with replacement or not lrf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree lrf.set_options().tree_count(1); // Number of trees that are calculated; lrf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + lrf.ext_param_.max_tree_depth = max_tree_depth; // lrf.set_options().features_per_node(10); - lrf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(loob_v)); + lrf.learn(featureData, labelData);//, vigra::rf::visitors::create_visitor(loob_v)); #pragma omp critical { - rf.trees_.push_back(lrf.trees_[0]); + count++; + MITK_INFO << "Tree " << count << " finished training."; + trees.push_back(lrf); + //rf.trees_.push_back(lrf.trees_[0]); } } + for (int i = 1; i < numTrees; ++i) + trees.at(0).trees_.push_back(trees.at(i).trees_[0]); + + rf = trees.at(0); rf.options_.tree_count_ = numTrees; MITK_INFO << "Training finsihed"; - MITK_INFO << "The out-of-bag error is: " << oob_v.oob_breiman << std::endl; + //MITK_INFO << "The out-of-bag error is: " << oob_v.oob_breiman << std::endl; } SampledShImageType::PixelType GetImageValues(itk::Point itkP, SampledShImageType::Pointer image) { itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; image->TransformPhysicalPointToIndex(itkP, idx); image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); SampledShImageType::PixelType pix; pix.Fill(0.0); if ( image->GetLargestPossibleRegion().IsInside(idx) ) pix = image->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) { vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = image->GetPixel(idx) * interpWeights[0]; SampledShImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[7]; } return pix; } int main(int argc, char* argv[]) { MITK_INFO << "DFTraining"; mitkCommandLineParser parser; parser.setTitle("Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription(""); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("images", "i", mitkCommandLineParser::StringList, "DWIs:", "input diffusion-weighted images", us::Any(), false); parser.addArgument("wmmasks", "w", mitkCommandLineParser::StringList, "WM-Masks:", "white matter mask images", us::Any(), false); parser.addArgument("tractograms", "t", mitkCommandLineParser::StringList, "Tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("masks", "m", mitkCommandLineParser::StringList, "Masks:", "mask images", us::Any()); parser.addArgument("forest", "f", mitkCommandLineParser::OutputFile, "Forest:", "output forest", us::Any(), false); parser.addArgument("stepsize", "s", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); parser.addArgument("gmsamples", "g", mitkCommandLineParser::Int, "Number of gray matter samples per voxel:", "Number of gray matter samples per voxel", us::Any()); parser.addArgument("numtrees", "n", mitkCommandLineParser::Int, "Number of trees:", "number of trees", us::Any()); parser.addArgument("max_tree_depth", "d", mitkCommandLineParser::Int, "Max. tree depth:", "maximum tree depth", us::Any()); parser.addArgument("sample_fraction", "sf", mitkCommandLineParser::Float, "Sample fraction:", "fraction of samples used per tree", us::Any()); parser.addArgument("usedirection", "ud", mitkCommandLineParser::Bool, "bla:", "bla", us::Any()); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; mitkCommandLineParser::StringContainerType imageFiles = us::any_cast(parsedArgs["images"]); mitkCommandLineParser::StringContainerType wmMaskFiles = us::any_cast(parsedArgs["wmmasks"]); mitkCommandLineParser::StringContainerType maskFiles; if (parsedArgs.count("masks")) maskFiles = us::any_cast(parsedArgs["masks"]); string forestFile = us::any_cast(parsedArgs["forest"]); mitkCommandLineParser::StringContainerType tractogramFiles; if (parsedArgs.count("tractograms")) tractogramFiles = us::any_cast(parsedArgs["tractograms"]); int numTrees = 30; if (parsedArgs.count("numtrees")) numTrees = us::any_cast(parsedArgs["numtrees"]); int gmsamples = 50; if (parsedArgs.count("gmsamples")) gmsamples = us::any_cast(parsedArgs["gmsamples"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); int max_tree_depth = 50; if (parsedArgs.count("max_tree_depth")) max_tree_depth = us::any_cast(parsedArgs["max_tree_depth"]); double sample_fraction = 1.0; if (parsedArgs.count("sample_fraction")) sample_fraction = us::any_cast(parsedArgs["sample_fraction"]); // load DWI images if (imageFiles.size() QballFilterType; MITK_INFO << "loading diffusion-weighted images and reconstructing feature images"; std::vector< SampledShImageType::Pointer > sampledShImages; for (unsigned int i=0; i(mitk::IOUtil::LoadImage(imageFiles.at(i)).GetPointer()); QballFilterType::Pointer qballfilter = QballFilterType::New(); qballfilter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); qballfilter->SetLambda(0.006); qballfilter->SetNormalizationMethod(QballFilterType::QBAR_RAW_SIGNAL); qballfilter->Update(); // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); // featureImageVector.push_back(itkFeatureImage); sampledShImages.push_back(qballfilter->GetOutput()); } typedef itk::Image ItkUcharImgType; std::vector< ItkUcharImgType::Pointer > maskImageVector; std::vector< ItkUcharImgType::Pointer > wmMaskImageVector; MITK_INFO << "loading tractograms"; int numSamples = 0; std::vector< mitk::FiberBundle::Pointer > tractograms; for (unsigned int t=0; t(mitk::IOUtil::LoadImage(maskFiles.at(t)).GetPointer()); mask = ItkUcharImgType::New(); mitk::CastToItkImage(img, mask); maskImageVector.push_back(mask); } mitk::Image::Pointer img2 = dynamic_cast(mitk::IOUtil::LoadImage(wmMaskFiles.at(t)).GetPointer()); ItkUcharImgType::Pointer wmmask = ItkUcharImgType::New(); mitk::CastToItkImage(img2, wmmask); wmMaskImageVector.push_back(wmmask); itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); int OUTOFWM = 0; // count voxels outside of the white matter mask while(!it.IsAtEnd()) { if (it.Get()==0) if (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0)) OUTOFWM++; ++it; } numSamples += gmsamples*OUTOFWM; // for each of the non-white matter voxels we add a certain number of sampling points. these sampling points are used to tell the classifier where to recognize non-WM tissue MITK_INFO << "Samples outside of WM: " << numSamples << " (" << gmsamples << " per non-WM voxel)"; // load and resample training tractograms - mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::LoadDataNode(tractogramFiles.at(t))->GetData()); + mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(tractogramFiles.at(t)).at(0).GetPointer()); if (stepsize<0) { SampledShImageType::Pointer image = sampledShImages.at(t); float minSpacing = 1; if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) minSpacing = image->GetSpacing()[0]; else if (image->GetSpacing()[1] < image->GetSpacing()[2]) minSpacing = image->GetSpacing()[1]; else minSpacing = image->GetSpacing()[2]; stepsize = minSpacing*0.5; } fib->ResampleSpline(stepsize); tractograms.push_back(fib); numSamples += fib->GetNumberOfPoints(); // each point of the fiber gives us a training direction numSamples -= 2*fib->GetNumFibers(); // we don't use the first and last point because there we do not have a previous direction, which is needed as feature } MITK_INFO << "Number of samples: " << numSamples; // get ODF directions and number of features vnl_vector_fixed ref; ref.fill(0); ref[0]=1; itk::OrientationDistributionFunction< double, numOdfSamples > odf; std::vector< int > directionIndices; for (unsigned int f=0; f0) // we only use directions on one hemisphere (symmetric) directionIndices.push_back(f); // remember indices that are on the desired hemisphere } const int numSignalFeatures = numOdfSamples/2; int numDirectionFeatures = 0; if (useDirection) numDirectionFeatures = 3; vigra::MultiArray<2, double> featureData( vigra::Shape2(numSamples,numSignalFeatures+numDirectionFeatures) ); MITK_INFO << "Number of features: " << featureData.shape(1); vigra::MultiArray<2, double> labelData( vigra::Shape2(numSamples,1) ); itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); m_RandGen->SetSeed(); MITK_INFO << "Creating training data from tractograms and feature images"; int sampleCounter = 0; for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) { SampledShImageType::PixelType pix = image->GetPixel(it.GetIndex()); if (useDirection) { // null direction for (unsigned int f=0; f probe; probe[0] = m_RandGen->GetVariate()*2-1; probe[1] = m_RandGen->GetVariate()*2-1; probe[2] = m_RandGen->GetVariate()*2-1; probe.normalize(); if (dot_product(ref, probe)<0) probe *= -1; for (unsigned int f=numSignalFeatures; f idx; idx[0] = it.GetIndex()[0]; idx[1] = it.GetIndex()[1]; idx[2] = it.GetIndex()[2]; itk::Point itkP1; image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); SampledShImageType::PixelType pix = GetImageValues(itkP1, image);; for (unsigned int f=0; f idx; idx[0] = it.GetIndex()[0] + m_RandGen->GetVariate()-0.5; idx[1] = it.GetIndex()[1] + m_RandGen->GetVariate()-0.5; idx[2] = it.GetIndex()[2] + m_RandGen->GetVariate()-0.5; itk::Point itkP1; image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); SampledShImageType::PixelType pix = GetImageValues(itkP1, image);; for (unsigned int f=0; f polyData = fib->GetFiberPolyData(); for (int i=0; iGetNumFibers(); i++) { vtkCell* cell = polyData->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); vnl_vector_fixed dirOld; dirOld.fill(0.0); for (int j=0; jGetPoint(j); itk::Point itkP1; itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; vnl_vector_fixed dir; dir.fill(0.0); itk::Point itkP2; double* p2 = points->GetPoint(j+1); itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; dir[0]=itkP2[0]-itkP1[0]; dir[1]=itkP2[1]-itkP1[1]; dir[2]=itkP2[2]-itkP1[2]; if (dir.magnitude()<0.0001) { MITK_INFO << "streamline error!"; continue; } dir.normalize(); if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) { MITK_INFO << "ERROR: NaN direction!"; continue; } if (j==0) { dirOld = dir; continue; } // get voxel values SampledShImageType::PixelType pix = GetImageValues(itkP1, image); for (unsigned int f=0; f0.0001) { int label = 0; for (unsigned int f=0; fangle) { labelData(sampleCounter,0) = f; angle = a; label = f; } } } dirOld = dir; sampleCounter++; } } } MITK_INFO << "Training forest"; vigra::RandomForest rf; TrainForest( rf, labelData, featureData, numTrees, max_tree_depth, sample_fraction ); MITK_INFO << "Writing forest"; vigra::rf_export_HDF5( rf, forestFile, "" ); MITK_INFO << "Finished training"; return EXIT_SUCCESS; }