diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp index e7b01fcb32..6e66836263 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -1,812 +1,816 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_txx #define __itkMLBSTrackingFilter_txx #include #include #include #include "itkMLBSTrackingFilter.h" #include #include #include #include //#include #define _USE_MATH_DEFINES #include namespace itk { template< int NumImageFeatures > MLBSTrackingFilter< NumImageFeatures > ::MLBSTrackingFilter() : m_FiberPolyData(NULL) , m_Points(NULL) , m_Cells(NULL) , m_AngularThreshold(0.7) , m_StepSize(0) , m_MaxLength(10000) , m_MinTractLength(20.0) , m_MaxTractLength(400.0) , m_SeedsPerVoxel(1) , m_NumberOfSamples(50) , m_SamplingDistance(-1) , m_SeedImage(NULL) , m_MaskImage(NULL) , m_DecisionForest(NULL) , m_StoppingRegions(NULL) , m_DemoMode(false) , m_PauseTracking(false) , m_AbortTracking(false) , m_RemoveWmEndFibers(false) , m_AposterioriCurvCheck(false) , m_AvoidStop(true) , m_RandomSampling(false) + , m_Verbose(false) { this->SetNumberOfRequiredInputs(1); } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures > ::RoundToNearest(double num) { return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::BeforeThreadedGenerateData() { m_InputImage = const_cast(this->GetInput(0)); PreprocessRawData(); m_FiberPolyData = PolyDataType::New(); m_Points = vtkSmartPointer< vtkPoints >::New(); m_Cells = vtkSmartPointer< vtkCellArray >::New(); m_ImageSize.resize(3); m_ImageSize[0] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[0]; m_ImageSize[1] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[1]; m_ImageSize[2] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[2]; m_ImageSpacing.resize(3); m_ImageSpacing[0] = m_FeatureImage->GetSpacing()[0]; m_ImageSpacing[1] = m_FeatureImage->GetSpacing()[1]; m_ImageSpacing[2] = m_FeatureImage->GetSpacing()[2]; double minSpacing; if(m_ImageSpacing[0]GetNumberOfThreads(); i++) { PolyDataType poly = PolyDataType::New(); m_PolyDataContainer.push_back(poly); } m_NotWmImage = ItkDoubleImgType::New(); m_NotWmImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_NotWmImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_NotWmImage->SetDirection( m_FeatureImage->GetDirection() ); m_NotWmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_NotWmImage->Allocate(); m_NotWmImage->FillBuffer(0); m_WmImage = ItkDoubleImgType::New(); m_WmImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_WmImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_WmImage->SetDirection( m_FeatureImage->GetDirection() ); m_WmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_WmImage->Allocate(); m_WmImage->FillBuffer(0); m_AvoidStopImage = ItkDoubleImgType::New(); m_AvoidStopImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_AvoidStopImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_AvoidStopImage->SetDirection( m_FeatureImage->GetDirection() ); m_AvoidStopImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_AvoidStopImage->Allocate(); m_AvoidStopImage->FillBuffer(0); if (m_StoppingRegions.IsNull()) { m_StoppingRegions = ItkUcharImgType::New(); m_StoppingRegions->SetSpacing( m_FeatureImage->GetSpacing() ); m_StoppingRegions->SetOrigin( m_FeatureImage->GetOrigin() ); m_StoppingRegions->SetDirection( m_FeatureImage->GetDirection() ); m_StoppingRegions->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_StoppingRegions->Allocate(); m_StoppingRegions->FillBuffer(0); } if (m_SeedImage.IsNull()) { m_SeedImage = ItkUcharImgType::New(); m_SeedImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_SeedImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_SeedImage->SetDirection( m_FeatureImage->GetDirection() ); m_SeedImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_SeedImage->Allocate(); m_SeedImage->FillBuffer(1); } if (m_MaskImage.IsNull()) { // initialize mask image m_MaskImage = ItkUcharImgType::New(); m_MaskImage->SetSpacing( m_FeatureImage->GetSpacing() ); m_MaskImage->SetOrigin( m_FeatureImage->GetOrigin() ); m_MaskImage->SetDirection( m_FeatureImage->GetDirection() ); m_MaskImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1); } else std::cout << "MLBSTrackingFilter: using mask image" << std::endl; if (m_AngularThreshold<0.0) m_AngularThreshold = 0.5*minSpacing; m_BuildFibersReady = 0; m_BuildFibersFinished = false; m_Threads = 0; m_Tractogram.clear(); m_SamplingPointset = mitk::PointSet::New(); m_AlternativePointset = mitk::PointSet::New(); std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Number of samples: " << m_NumberOfSamples << std::endl; std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::PreprocessRawData() { typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; std::cout << "MLBSTrackingFilter: Spherical signal interpolation and sampling ..." << std::endl; typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); filter->SetGradientImage( m_GradientDirections, m_InputImage ); filter->SetBValue( m_B_Value ); filter->SetLambda(0.006); filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); filter->Update(); // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); // featureImageVector.push_back(itkFeatureImage); std::cout << "MLBSTrackingFilter: Creating feature image ..." << std::endl; vnl_vector_fixed ref; ref.fill(0); ref[0]=1; itk::OrientationDistributionFunction< double, NumImageFeatures*2 > odf; m_DirectionIndices.clear(); for (unsigned int f=0; f0) // only used directions on one hemisphere m_DirectionIndices.push_back(f); } m_FeatureImage = FeatureImageType::New(); m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection()); m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->Allocate(); itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { typename FeatureImageType::PixelType pix; for (unsigned int f=0; fSetPixel(it.GetIndex(), pix); ++it; } } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) { // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); // dir = rot*dir; dir *= m_StepSize; pos[0] += dir[0]; pos[1] += dir[1]; pos[2] += dir[2]; } template< int NumImageFeatures > bool MLBSTrackingFilter< NumImageFeatures > ::IsValidPosition(itk::Point &pos) { typename FeatureImageType::IndexType idx; m_FeatureImage->TransformPhysicalPointToIndex(pos, idx); if (!m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) return false; return true; } template< int NumImageFeatures > typename MLBSTrackingFilter< NumImageFeatures >::FeatureImageType::PixelType MLBSTrackingFilter< NumImageFeatures >::GetImageValues(itk::Point itkP) { itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx); m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx); typename FeatureImageType::PixelType pix; pix.Fill(0.0); if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) ) pix = m_FeatureImage->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1) { vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = m_FeatureImage->GetPixel(idx) * interpWeights[0]; typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7]; } return pix; } template< int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop) { vnl_vector_fixed direction; direction.fill(0); vigra::MultiArray<2, double> featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumImageFeatures+3) ); typename FeatureImageType::PixelType featurePixel = GetImageValues(pos); // pixel values for (unsigned int f=0; f ref; ref.fill(0); ref[0]=1; for (unsigned int f=NumImageFeatures; f probs(vigra::Shape2(1, m_DecisionForest->class_count())); m_DecisionForest->predictProbabilities(featureData, probs); double outProb = 0; prob = 0; candidates = 0; // directions with probability > 0 for (int i=0; iclass_count(); i++) { if (probs(0,i)>0) { int classLabel = 0; m_DecisionForest->ext_param_.to_classlabel(i, classLabel); if (classLabel d = m_ODF.GetDirection(m_DirectionIndices.at(classLabel)); double dot = dot_product(d, olddir); if (olddir.magnitude()>0) { if (fabs(dot)>angularThreshold) { if (dot<0) d *= -1; dot = fabs(dot); direction += probs(0,i)*dot*d; prob += probs(0,i)*dot; } } else { direction += probs(0,i)*d; prob += probs(0,i); } } else outProb += probs(0,i); } } ItkDoubleImgType::IndexType idx; - m_NotWmImage->TransformPhysicalPointToIndex(pos, idx); - if (m_NotWmImage->GetLargestPossibleRegion().IsInside(idx)) + if (m_Verbose) { - m_NotWmImage->SetPixel(idx, m_NotWmImage->GetPixel(idx)+outProb); - m_WmImage->SetPixel(idx, m_WmImage->GetPixel(idx)+prob); + m_NotWmImage->TransformPhysicalPointToIndex(pos, idx); + if (m_NotWmImage->GetLargestPossibleRegion().IsInside(idx)) + { + m_NotWmImage->SetPixel(idx, m_NotWmImage->GetPixel(idx)+outProb); + m_WmImage->SetPixel(idx, m_WmImage->GetPixel(idx)+prob); + } } if (outProb>prob && prob>0) { candidates = 0; prob = 0; direction.fill(0.0); } - if (avoidStop && m_AvoidStopImage->GetLargestPossibleRegion().IsInside(idx) && candidates>0 && direction.magnitude()>0.001) + if (m_Verbose && avoidStop && m_AvoidStopImage->GetLargestPossibleRegion().IsInside(idx) && candidates>0 && direction.magnitude()>0.001) m_AvoidStopImage->SetPixel(idx, m_AvoidStopImage->GetPixel(idx)+0.1); return direction; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::GetRandDouble(double min, double max) { return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; } template< int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::GetNewDirection(itk::Point &pos, vnl_vector_fixed& olddir) { if (m_DemoMode) { m_SamplingPointset->Clear(); m_AlternativePointset->Clear(); } vnl_vector_fixed direction; direction.fill(0); ItkUcharImgType::IndexType idx; m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); if (m_StoppingRegions->GetPixel(idx)>0) return direction; if (olddir.magnitude()>0) olddir.normalize(); int candidates = 0; // number of directions with probability > 0 double prob = 0; direction = Classify(pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood direction *= prob; itk::OrientationDistributionFunction< double, 50 > probeVecs; itk::Point sample_pos; int alternatives = 1; for (int i=0; i d; if (m_RandomSampling) { d[0] = GetRandDouble(); d[1] = GetRandDouble(); d[2] = GetRandDouble(); d.normalize(); d *= GetRandDouble(0,m_SamplingDistance); } else { d = probeVecs.GetDirection(i)*m_SamplingDistance; } sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_SamplingPointset->InsertPoint(i, sample_pos); candidates = 0; vnl_vector_fixed tempDir = Classify(sample_pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) { direction += tempDir*prob; } else if (m_AvoidStop && candidates==0 && olddir.magnitude()>0) // out of white matter { double dot = dot_product(d, olddir); if (dot >= 0.0) // in front of plane defined by pos and olddir d = -d + 2*dot*olddir; // reflect else d = -d; // invert // look a bit further into the other direction sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_AlternativePointset->InsertPoint(alternatives, sample_pos); alternatives++; candidates = 0; vnl_vector_fixed tempDir = Classify(sample_pos, candidates, olddir, m_AngularThreshold, prob, true); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? { direction += d; // go into the direction of the white matter direction += tempDir*prob; // go into the direction of the white matter direction at this location } } } if (direction.magnitude()>0.001) { direction.normalize(); olddir[0] = direction[0]; olddir[1] = direction[1]; olddir[2] = direction[2]; } else direction.fill(0); return direction; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) { vnl_vector_fixed dirOld = dir; dirOld = dir; for (int step=0; step< m_MaxLength/2; step++) { // get new position CalculateNewPosition(pos, dir); // is new position inside of image and mask if (!IsValidPosition(pos) || m_AbortTracking) // if not end streamline { return tractLength; } else // if yes, add new point to streamline { tractLength += m_StepSize; if (front) fib->push_front(pos); else fib->push_back(pos); if (m_AposterioriCurvCheck) { int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? if (curv>0) { tractLength -= m_StepSize*curv; while (curv>0) { if (front) fib->pop_front(); else fib->pop_back(); curv--; } return tractLength; } } if (tractLength>m_MaxTractLength) return tractLength; } if (m_DemoMode) // CHECK: warum sind die samplingpunkte der streamline in der visualisierung immer einen schritt voras? { m_Mutex.Lock(); m_BuildFibersReady++; m_Tractogram.push_back(*fib); BuildFibers(true); m_Stop = true; m_Mutex.Unlock(); while (m_Stop){ } } dir = GetNewDirection(pos, dirOld); while (m_PauseTracking){} if (dir.magnitude()<0.0001) return tractLength; } return tractLength; } template< int NumImageFeatures > int MLBSTrackingFilter::CheckCurvature(FiberType* fib, bool front) { double m_Distance = 5; if (fib->size()<3) return 0; double dist = 0; std::vector< vnl_vector_fixed< float, 3 > > vectors; vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); double dev = 0; if (front) { int c=0; while(distsize()-1) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c+1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==0) meanV += v; c++; } } else { int c=fib->size()-1; while(dist0) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c-1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==fib->size()-1) meanV += v; c--; } } meanV.normalize(); for (int c=0; c1.0) angle = 1.0; if (angle<-1.0) angle = -1.0; dev += acos(angle)*180/M_PI; } if (vectors.size()>0) dev /= vectors.size(); if (dev<30) return 0; else return vectors.size(); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) { m_Mutex.Lock(); m_Threads++; m_Mutex.Unlock(); typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; MaskIteratorType sit(m_SeedImage, regionForThread ); MaskIteratorType mit(m_MaskImage, regionForThread ); sit.GoToBegin(); mit.GoToBegin(); itk::Point worldPos; while( !sit.IsAtEnd() ) { if (sit.Value()==0 || mit.Value()==0) { ++sit; ++mit; continue; } for (int s=0; s start; unsigned int counter = 0; if (m_SeedsPerVoxel>1) { start[0] = index[0]+GetRandDouble(-0.5, 0.5); start[1] = index[1]+GetRandDouble(-0.5, 0.5); start[2] = index[2]+GetRandDouble(-0.5, 0.5); } else { start[0] = index[0]; start[1] = index[1]; start[2] = index[2]; } // get staring position m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); // get starting direction int candidates = 0; double prob = 0; vnl_vector_fixed dirOld; dirOld.fill(0.0); vnl_vector_fixed dir = Classify(worldPos, candidates, dirOld, 0, prob); if (dir.magnitude()<0.0001) continue; // forward tracking tractLength = FollowStreamline(threadId, worldPos, dir, &fib, 0, false); fib.push_front(worldPos); if (m_RemoveWmEndFibers) { itk::Point check = fib.back(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } // backward tracking tractLength = FollowStreamline(threadId, worldPos, -dir, &fib, tractLength, true); counter = fib.size(); if (m_RemoveWmEndFibers) { itk::Point check = fib.front(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } if (tractLength void MLBSTrackingFilter< NumImageFeatures >::BuildFibers(bool check) { if (m_BuildFibersReady::New(); vtkSmartPointer vNewLines = vtkSmartPointer::New(); vtkSmartPointer vNewPoints = vtkSmartPointer::New(); for (int i=0; i container = vtkSmartPointer::New(); FiberType fib = m_Tractogram.at(i); for (FiberType::iterator it = fib.begin(); it!=fib.end(); ++it) { vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); container->GetPointIds()->InsertNextId(id); } vNewLines->InsertNextCell(container); } if (check) for (int i=0; iSetPoints(vNewPoints); m_FiberPolyData->SetLines(vNewLines); m_BuildFibersFinished = true; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::AfterThreadedGenerateData() { MITK_INFO << "Generating polydata "; BuildFibers(false); MITK_INFO << "done"; } } #endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h index 14f612463e..eac0777107 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h @@ -1,194 +1,195 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ /*=================================================================== This file is based heavily on a corresponding ITK filter. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_h_ #define __itkMLBSTrackingFilter_h_ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include // classification includes #include #include #include namespace itk{ /** * \brief Performes deterministic streamline tracking on the input tensor image. */ template< int NumImageFeatures=100 > class MLBSTrackingFilter : public ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > { public: typedef MLBSTrackingFilter Self; typedef SmartPointer Pointer; typedef SmartPointer ConstPointer; typedef ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > Superclass; typedef vigra::RandomForest DecisionForestType; typedef typename Superclass::InputImageType InputImageType; typedef typename Superclass::InputImageRegionType InputImageRegionType; typedef Image< Vector< float, NumImageFeatures > , 3 > FeatureImageType; /** Method for creation through the object factory. */ itkFactorylessNewMacro(Self) itkCloneMacro(Self) /** Runtime information support. */ itkTypeMacro(MLBSTrackingFilter, ImageToImageFilter) typedef itk::Image ItkUcharImgType; typedef itk::Image ItkDoubleImgType; typedef itk::Image ItkFloatImgType; typedef vtkSmartPointer< vtkPolyData > PolyDataType; typedef std::deque< itk::Point > FiberType; typedef std::vector< FiberType > BundleType; volatile bool m_PauseTracking; bool m_AbortTracking; bool m_BuildFibersFinished; int m_BuildFibersReady; volatile bool m_Stop; mitk::PointSet::Pointer m_SamplingPointset; mitk::PointSet::Pointer m_AlternativePointset; // void RequestFibers(){ m_Stop=true; m_BuildFibersReady=0; m_BuildFibersFinished=false; } itkGetMacro( FiberPolyData, PolyDataType ) ///< Output fibers itkSetMacro( SeedImage, ItkUcharImgType::Pointer) ///< Seeds are only placed inside of this mask. itkSetMacro( MaskImage, ItkUcharImgType::Pointer) ///< Tracking is only performed inside of this mask image. itkSetMacro( SeedsPerVoxel, int) ///< One seed placed in the center of each voxel or multiple seeds randomly placed inside each voxel. itkSetMacro( StepSize, double) ///< Integration step size in mm itkSetMacro( MinTractLength, double ) ///< Shorter tracts are discarded. itkSetMacro( MaxTractLength, double ) itkSetMacro( AngularThreshold, double ) itkSetMacro( SamplingDistance, double ) itkSetMacro( NumberOfSamples, int ) itkSetMacro( StoppingRegions, ItkUcharImgType::Pointer) itkSetMacro( B_Value, float ) itkSetMacro( GradientDirections, mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer ) itkSetMacro( DemoMode, bool ) itkSetMacro( RemoveWmEndFibers, bool ) itkSetMacro( AposterioriCurvCheck, bool ) itkSetMacro( AvoidStop, bool ) itkSetMacro( RandomSampling, bool ) + itkSetMacro( Verbose, bool ) void SetDecisionForest( DecisionForestType* forest ) { m_DecisionForest = forest; } itkGetMacro( WmImage, ItkDoubleImgType::Pointer ) itkGetMacro( NotWmImage, ItkDoubleImgType::Pointer ) itkGetMacro( AvoidStopImage, ItkDoubleImgType::Pointer ) protected: MLBSTrackingFilter(); ~MLBSTrackingFilter() {} void CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir); ///< Calculate next integration step. double FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front); ///< Start streamline in one direction. bool IsValidPosition(itk::Point& pos); ///< Are we outside of the mask image? vnl_vector_fixed GetNewDirection(itk::Point& pos, vnl_vector_fixed& olddir); vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop=false); typename FeatureImageType::PixelType GetImageValues(itk::Point itkP); double GetRandDouble(double min=-1, double max=1); double RoundToNearest(double num); void BeforeThreadedGenerateData(); void PreprocessRawData(); void ThreadedGenerateData( const InputImageRegionType &outputRegionForThread, ThreadIdType threadId); void AfterThreadedGenerateData(); PolyDataType m_FiberPolyData; vtkSmartPointer m_Points; vtkSmartPointer m_Cells; BundleType m_Tractogram; double m_AngularThreshold; double m_StepSize; int m_MaxLength; double m_MinTractLength; double m_MaxTractLength; int m_SeedsPerVoxel; bool m_RandomSampling; double m_SamplingDistance; int m_NumberOfSamples; std::vector< int > m_ImageSize; std::vector< double > m_ImageSpacing; SimpleFastMutexLock m_Mutex; ItkUcharImgType::Pointer m_StoppingRegions; ItkDoubleImgType::Pointer m_WmImage; ItkDoubleImgType::Pointer m_NotWmImage; ItkDoubleImgType::Pointer m_AvoidStopImage; ItkUcharImgType::Pointer m_SeedImage; ItkUcharImgType::Pointer m_MaskImage; typename FeatureImageType::Pointer m_FeatureImage; typename InputImageType::Pointer m_InputImage; mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer m_GradientDirections; float m_B_Value; bool m_AposterioriCurvCheck; bool m_RemoveWmEndFibers; bool m_AvoidStop; - + bool m_Verbose; int m_Threads; bool m_DemoMode; void BuildFibers(bool check); int CheckCurvature(FiberType* fib, bool front); // decision forest DecisionForestType* m_DecisionForest; itk::OrientationDistributionFunction< double, NumImageFeatures*2 > m_ODF; std::vector< int > m_DirectionIndices; std::vector< PolyDataType > m_PolyDataContainer; private: }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkMLBSTrackingFilter.cpp" #endif #endif //__itkMLBSTrackingFilter_h_ diff --git a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp index cedda2c683..093bed8d30 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp @@ -1,193 +1,194 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include //#include #include #include #include #include #include #include //#include #include #include #include #include #define _USE_MATH_DEFINES #include const int numOdfSamples = 200; typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription(""); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("image", "i", mitkCommandLineParser::String, "DWIs:", "input diffusion-weighted image", us::Any(), false); parser.addArgument("forest", "f", mitkCommandLineParser::String, "Forest:", "input forest", us::Any(), false); parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output fiberbundle", us::Any(), false); parser.addArgument("stop", "st", mitkCommandLineParser::String, "Stop image:", "stop image", us::Any()); parser.addArgument("mask", "m", mitkCommandLineParser::String, "Mask image:", "mask image", us::Any()); parser.addArgument("seed", "s", mitkCommandLineParser::String, "Seed image:", "seed image", us::Any()); parser.addArgument("athres", "a", mitkCommandLineParser::Float, "Angular threshold:", "angular threshold (in radians)", us::Any()); parser.addArgument("stepsize", "se", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); parser.addArgument("samples", "ns", mitkCommandLineParser::Int, "Samples:", "samples", us::Any()); parser.addArgument("samplingdist", "sd", mitkCommandLineParser::Float, "Sampling distance:", "sampling distance (in voxels)", us::Any()); parser.addArgument("seeds", "nse", mitkCommandLineParser::Int, "Seeds per voxel:", "seeds per voxel", us::Any()); parser.addArgument("verbose", "v", mitkCommandLineParser::Bool, "Verbose:", "output additional images", us::Any()); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; string imageFile = us::any_cast(parsedArgs["image"]); string forestFile = us::any_cast(parsedArgs["forest"]); string outFile = us::any_cast(parsedArgs["out"]); string maskFile = ""; if (parsedArgs.count("mask")) maskFile = us::any_cast(parsedArgs["mask"]); string seedFile = ""; if (parsedArgs.count("seed")) seedFile = us::any_cast(parsedArgs["seed"]); string stopFile = ""; if (parsedArgs.count("stop")) stopFile = us::any_cast(parsedArgs["stop"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); float athres = 0.7; if (parsedArgs.count("athres")) athres = us::any_cast(parsedArgs["athres"]); float samplingdist = 0.25; if (parsedArgs.count("samplingdist")) samplingdist = us::any_cast(parsedArgs["samplingdist"]); bool verbose = false; if (parsedArgs.count("verbose")) verbose = true; int samples = 10; if (parsedArgs.count("samples")) samples = us::any_cast(parsedArgs["samples"]); int seeds = 1; if (parsedArgs.count("seeds")) seeds = us::any_cast(parsedArgs["seeds"]); typedef itk::Image ItkUcharImgType; MITK_INFO << "loading diffusion-weighted image"; mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::LoadImage(imageFile).GetPointer()); ItkUcharImgType::Pointer mask; if (!maskFile.empty()) { MITK_INFO << "loading mask image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(maskFile).GetPointer()); mask = ItkUcharImgType::New(); mitk::CastToItkImage(img, mask); } ItkUcharImgType::Pointer seed; if (!seedFile.empty()) { MITK_INFO << "loading seed image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(seedFile).GetPointer()); seed = ItkUcharImgType::New(); mitk::CastToItkImage(img, seed); } ItkUcharImgType::Pointer stop; if (!stopFile.empty()) { MITK_INFO << "loading stop image"; mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(stopFile).GetPointer()); stop = ItkUcharImgType::New(); mitk::CastToItkImage(img, stop); } MITK_INFO << "loading forest"; vigra::RandomForest rf; vigra::rf_import_HDF5(rf, forestFile); typedef itk::MLBSTrackingFilter<100> TrackerType; TrackerType::Pointer tracker = TrackerType::New(); tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi)); tracker->SetGradientDirections( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi) ); tracker->SetB_Value( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi) ); tracker->SetMaskImage(mask); tracker->SetSeedImage(seed); tracker->SetStoppingRegions(stop); tracker->SetSeedsPerVoxel(seeds); tracker->SetStepSize(stepsize); tracker->SetAngularThreshold(athres); tracker->SetDecisionForest(&rf); tracker->SetSamplingDistance(samplingdist); tracker->SetNumberOfSamples(samples); //tracker->SetAvoidStop(false); + tracker->SetVerbose(verbose); tracker->SetAposterioriCurvCheck(false); tracker->SetRemoveWmEndFibers(false); tracker->Update(); vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); mitk::IOUtil::SaveBaseData(outFib, outFile); if (verbose) { MITK_INFO << "Writing images..."; string outName = itksys::SystemTools::GetFilenamePath(outFile)+"/"+itksys::SystemTools::GetFilenameWithoutLastExtension(outFile); itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::Pointer writer = itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::New(); writer->SetFileName(outName+"_WhiteMatter.nrrd"); writer->SetInput(tracker->GetWmImage()); writer->Update(); writer->SetFileName(outName+"_NotWhiteMatter.nrrd"); writer->SetInput(tracker->GetNotWmImage()); writer->Update(); writer->SetFileName(outName+"_AvoidStop.nrrd"); writer->SetInput(tracker->GetAvoidStopImage()); writer->Update(); } return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp index 03d9f9107e..a5b2e34103 100755 --- a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp +++ b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp @@ -1,483 +1,146 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include #include #include #include #include -#include -#include -#include -#include -#include - -#include -#include -#include +#include #define _USE_MATH_DEFINES #include -const int numOdfSamples = 200; // ODF is sampled in 200 directions but actuyll only 100 are used (symmetric) -typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; - -void TrainForest( vigra::RandomForest &rf, vigra::MultiArray<2, double> &labelData, vigra::MultiArray<2, double> &featureData, int numTrees, int max_tree_depth, double sample_fraction ) -{ - MITK_INFO << "Maximum tree depths: " << max_tree_depth; - MITK_INFO << "Sample fraction per tree: " << sample_fraction; - MITK_INFO << "Number of trees: " << numTrees; - vigra::rf::visitors::OOB_Error oob_v; - -// rf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal -// rf.set_options().sample_with_replacement(true); // if sampled with replacement or not -// rf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree -// rf.set_options().tree_count(1); // Number of trees that are calculated; -// rf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node -// rf.ext_param_.max_tree_depth = max_tree_depth; -// // rf.set_options().features_per_node(10); -// rf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(oob_v)); - - std::vector< vigra::RandomForest > trees; - int count = 0; -#pragma omp parallel for - for (int i = 0; i < numTrees; ++i) - { - vigra::RandomForest lrf; - vigra::rf::visitors::OOB_Error loob_v; - - lrf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal - lrf.set_options().sample_with_replacement(true); // if sampled with replacement or not - lrf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree - lrf.set_options().tree_count(1); // Number of trees that are calculated; - lrf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node - lrf.ext_param_.max_tree_depth = max_tree_depth; - // lrf.set_options().features_per_node(10); - - lrf.learn(featureData, labelData);//, vigra::rf::visitors::create_visitor(loob_v)); -#pragma omp critical - { - count++; - MITK_INFO << "Tree " << count << " finished training."; - trees.push_back(lrf); - //rf.trees_.push_back(lrf.trees_[0]); - } - } - - for (int i = 1; i < numTrees; ++i) - trees.at(0).trees_.push_back(trees.at(i).trees_[0]); - - rf = trees.at(0); - rf.options_.tree_count_ = numTrees; - MITK_INFO << "Training finsihed"; - //MITK_INFO << "The out-of-bag error is: " << oob_v.oob_breiman << std::endl; -} - -SampledShImageType::PixelType GetImageValues(itk::Point itkP, SampledShImageType::Pointer image) -{ - itk::Index<3> idx; - itk::ContinuousIndex< double, 3> cIdx; - image->TransformPhysicalPointToIndex(itkP, idx); - image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); - - SampledShImageType::PixelType pix; pix.Fill(0.0); - if ( image->GetLargestPossibleRegion().IsInside(idx) ) - pix = image->GetPixel(idx); - else - return pix; - - double frac_x = cIdx[0] - idx[0]; - double frac_y = cIdx[1] - idx[1]; - double frac_z = cIdx[2] - idx[2]; - if (frac_x<0) - { - idx[0] -= 1; - frac_x += 1; - } - if (frac_y<0) - { - idx[1] -= 1; - frac_y += 1; - } - if (frac_z<0) - { - idx[2] -= 1; - frac_z += 1; - } - frac_x = 1-frac_x; - frac_y = 1-frac_y; - frac_z = 1-frac_z; - - // int coordinates inside image? - if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && - idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && - idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) - { - vnl_vector_fixed interpWeights; - interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); - interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); - interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); - interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); - interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); - interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); - interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); - interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); - - - pix = image->GetPixel(idx) * interpWeights[0]; - SampledShImageType::IndexType tmpIdx = idx; tmpIdx[0]++; - pix += image->GetPixel(tmpIdx) * interpWeights[1]; - tmpIdx = idx; tmpIdx[1]++; - pix += image->GetPixel(tmpIdx) * interpWeights[2]; - tmpIdx = idx; tmpIdx[2]++; - pix += image->GetPixel(tmpIdx) * interpWeights[3]; - tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; - pix += image->GetPixel(tmpIdx) * interpWeights[4]; - tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; - pix += image->GetPixel(tmpIdx) * interpWeights[5]; - tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; - pix += image->GetPixel(tmpIdx) * interpWeights[6]; - tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; - pix += image->GetPixel(tmpIdx) * interpWeights[7]; - } - - return pix; -} - int main(int argc, char* argv[]) { MITK_INFO << "DFTraining"; mitkCommandLineParser parser; parser.setTitle("Machine Learning Based Streamline Tractography"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription(""); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("images", "i", mitkCommandLineParser::StringList, "DWIs:", "input diffusion-weighted images", us::Any(), false); parser.addArgument("wmmasks", "w", mitkCommandLineParser::StringList, "WM-Masks:", "white matter mask images", us::Any(), false); parser.addArgument("tractograms", "t", mitkCommandLineParser::StringList, "Tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("masks", "m", mitkCommandLineParser::StringList, "Masks:", "mask images", us::Any()); parser.addArgument("forest", "f", mitkCommandLineParser::OutputFile, "Forest:", "output forest", us::Any(), false); parser.addArgument("stepsize", "s", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); parser.addArgument("gmsamples", "g", mitkCommandLineParser::Int, "Number of gray matter samples per voxel:", "Number of gray matter samples per voxel", us::Any()); parser.addArgument("numtrees", "n", mitkCommandLineParser::Int, "Number of trees:", "number of trees", us::Any()); parser.addArgument("max_tree_depth", "d", mitkCommandLineParser::Int, "Max. tree depth:", "maximum tree depth", us::Any()); parser.addArgument("sample_fraction", "sf", mitkCommandLineParser::Float, "Sample fraction:", "fraction of samples used per tree", us::Any()); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; mitkCommandLineParser::StringContainerType imageFiles = us::any_cast(parsedArgs["images"]); mitkCommandLineParser::StringContainerType wmMaskFiles = us::any_cast(parsedArgs["wmmasks"]); mitkCommandLineParser::StringContainerType maskFiles; if (parsedArgs.count("masks")) maskFiles = us::any_cast(parsedArgs["masks"]); string forestFile = us::any_cast(parsedArgs["forest"]); mitkCommandLineParser::StringContainerType tractogramFiles; if (parsedArgs.count("tractograms")) tractogramFiles = us::any_cast(parsedArgs["tractograms"]); int numTrees = 30; if (parsedArgs.count("numtrees")) numTrees = us::any_cast(parsedArgs["numtrees"]); int gmsamples = 50; if (parsedArgs.count("gmsamples")) gmsamples = us::any_cast(parsedArgs["gmsamples"]); float stepsize = -1; if (parsedArgs.count("stepsize")) stepsize = us::any_cast(parsedArgs["stepsize"]); int max_tree_depth = 50; if (parsedArgs.count("max_tree_depth")) max_tree_depth = us::any_cast(parsedArgs["max_tree_depth"]); double sample_fraction = 1.0; if (parsedArgs.count("sample_fraction")) sample_fraction = us::any_cast(parsedArgs["sample_fraction"]); - // load DWI images - if (imageFiles.size() QballFilterType; - MITK_INFO << "loading diffusion-weighted images and reconstructing feature images"; - std::vector< SampledShImageType::Pointer > sampledShImages; + MITK_INFO << "loading diffusion-weighted images"; + std::vector< mitk::Image::Pointer > rawData; for (unsigned int i=0; i(mitk::IOUtil::LoadImage(imageFiles.at(i)).GetPointer()); - - QballFilterType::Pointer qballfilter = QballFilterType::New(); - qballfilter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); - qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - qballfilter->SetLambda(0.006); - qballfilter->SetNormalizationMethod(QballFilterType::QBAR_RAW_SIGNAL); - qballfilter->Update(); - // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); - // featureImageVector.push_back(itkFeatureImage); - sampledShImages.push_back(qballfilter->GetOutput()); + rawData.push_back(dwi); } - typedef itk::Image ItkUcharImgType; + MITK_INFO << "loading mask images"; std::vector< ItkUcharImgType::Pointer > maskImageVector; + for (unsigned int i=0; i(mitk::IOUtil::LoadImage(maskFiles.at(i)).GetPointer()); + ItkUcharImgType::Pointer mask = ItkUcharImgType::New(); + mitk::CastToItkImage(img, mask); + maskImageVector.push_back(mask); + } + + MITK_INFO << "loading white matter mask images"; std::vector< ItkUcharImgType::Pointer > wmMaskImageVector; + for (unsigned int i=0; i(mitk::IOUtil::LoadImage(wmMaskFiles.at(i)).GetPointer()); + ItkUcharImgType::Pointer wmmask = ItkUcharImgType::New(); + mitk::CastToItkImage(img, wmmask); + wmMaskImageVector.push_back(wmmask); + } MITK_INFO << "loading tractograms"; - int numSamples = 0; std::vector< mitk::FiberBundle::Pointer > tractograms; for (unsigned int t=0; t(mitk::IOUtil::LoadImage(maskFiles.at(t)).GetPointer()); - mask = ItkUcharImgType::New(); - mitk::CastToItkImage(img, mask); - maskImageVector.push_back(mask); - } - mitk::Image::Pointer img2 = dynamic_cast(mitk::IOUtil::LoadImage(wmMaskFiles.at(t)).GetPointer()); - ItkUcharImgType::Pointer wmmask = ItkUcharImgType::New(); - mitk::CastToItkImage(img2, wmmask); - wmMaskImageVector.push_back(wmmask); - - itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); - int OUTOFWM = 0; // count voxels outside of the white matter mask - while(!it.IsAtEnd()) - { - if (it.Get()==0) - if (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0)) - OUTOFWM++; - ++it; - } - numSamples += gmsamples*OUTOFWM; // for each of the non-white matter voxels we add a certain number of sampling points. these sampling points are used to tell the classifier where to recognize non-WM tissue - - MITK_INFO << "Samples outside of WM: " << numSamples << " (" << gmsamples << " per non-WM voxel)"; - - // load and resample training tractograms mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(tractogramFiles.at(t)).at(0).GetPointer()); - if (stepsize<0) - { - SampledShImageType::Pointer image = sampledShImages.at(t); - float minSpacing = 1; - if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) - minSpacing = image->GetSpacing()[0]; - else if (image->GetSpacing()[1] < image->GetSpacing()[2]) - minSpacing = image->GetSpacing()[1]; - else - minSpacing = image->GetSpacing()[2]; - stepsize = minSpacing*0.5; - } - fib->ResampleSpline(stepsize); tractograms.push_back(fib); - numSamples += fib->GetNumberOfPoints(); // each point of the fiber gives us a training direction - numSamples -= 2*fib->GetNumFibers(); // we don't use the first and last point because there we do not have a previous direction, which is needed as feature - } - MITK_INFO << "Number of samples: " << numSamples; - - // get ODF directions and number of features - vnl_vector_fixed ref; ref.fill(0); ref[0]=1; - itk::OrientationDistributionFunction< double, numOdfSamples > odf; - std::vector< int > directionIndices; - for (unsigned int f=0; f0) // we only use directions on one hemisphere (symmetric) - directionIndices.push_back(f); // remember indices that are on the desired hemisphere - } - const int numSignalFeatures = numOdfSamples/2; - int numDirectionFeatures = 3; - - vigra::MultiArray<2, double> featureData( vigra::Shape2(numSamples,numSignalFeatures+numDirectionFeatures) ); - MITK_INFO << "Number of features: " << featureData.shape(1); - vigra::MultiArray<2, double> labelData( vigra::Shape2(numSamples,1) ); - - itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); - m_RandGen->SetSeed(); - MITK_INFO << "Creating training data from tractograms and feature images"; - int sampleCounter = 0; - for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); - while(!it.IsAtEnd()) - { - if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) - { - SampledShImageType::PixelType pix = image->GetPixel(it.GetIndex()); - - // null direction - for (unsigned int f=0; f probe; - probe[0] = m_RandGen->GetVariate()*2-1; - probe[1] = m_RandGen->GetVariate()*2-1; - probe[2] = m_RandGen->GetVariate()*2-1; - probe.normalize(); - if (dot_product(ref, probe)<0) - probe *= -1; - for (unsigned int f=numSignalFeatures; f polyData = fib->GetFiberPolyData(); - for (int i=0; iGetNumFibers(); i++) - { - vtkCell* cell = polyData->GetCell(i); - int numPoints = cell->GetNumberOfPoints(); - vtkPoints* points = cell->GetPoints(); - - vnl_vector_fixed dirOld; dirOld.fill(0.0); - - for (int j=0; jGetPoint(j); - itk::Point itkP1; - itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; - - vnl_vector_fixed dir; dir.fill(0.0); - - itk::Point itkP2; - double* p2 = points->GetPoint(j+1); - itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; - dir[0]=itkP2[0]-itkP1[0]; - dir[1]=itkP2[1]-itkP1[1]; - dir[2]=itkP2[2]-itkP1[2]; - - if (dir.magnitude()<0.0001) - { - MITK_INFO << "streamline error!"; - continue; - } - dir.normalize(); - if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) - { - MITK_INFO << "ERROR: NaN direction!"; - continue; - } - - if (j==0) - { - dirOld = dir; - continue; - } - - // get voxel values - SampledShImageType::PixelType pix = GetImageValues(itkP1, image); - for (unsigned int f=0; f0.0001) - { - int label = 0; - for (unsigned int f=0; fangle) - { - labelData(sampleCounter,0) = f; - angle = a; - label = f; - } - } - } - - dirOld = dir; - sampleCounter++; - } - } } - MITK_INFO << "Training forest"; - vigra::RandomForest rf; - TrainForest( rf, labelData, featureData, numTrees, max_tree_depth, sample_fraction ); - MITK_INFO << "Writing forest"; - vigra::rf_export_HDF5( rf, forestFile, "" ); - MITK_INFO << "Finished training"; + mitk::TrackingForestHandler<> forestHandler; + forestHandler.SetRawData(rawData); + forestHandler.SetTractograms(tractograms); + forestHandler.SetNumTrees(numTrees); + forestHandler.SetMaxTreeDepth(max_tree_depth); + forestHandler.SetGrayMatterSamplesPerVoxel(gmsamples); + forestHandler.SetSampleFraction(sample_fraction); + forestHandler.SetStepSize(stepsize); + forestHandler.StartTraining(); + forestHandler.SaveForest(forestFile); return EXIT_SUCCESS; } diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui index a641f8762c..bf1799b4e8 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui @@ -1,640 +1,623 @@ QmitkMLBTViewControls 0 0 359 1127 Form 0 0 0 0 1 0 0 - 290 - 376 + 359 + 1065 Training Load Forest Qt::Vertical 20 40 QFrame::NoFrame QFrame::Raised 0 0 0 0 - - - - Step size: - - - - - - - Num. Trees: - - - - - + + - Max. Depth: + GM Sampling Points: - - + + 3 - - -1.000000000000000 - - 999.000000000000000 + 1.000000000000000 0.100000000000000 - -1.000000000000000 + 1.000000000000000 1 999999999 - - + + - Sample Fraction: + Step size: 1 999999999 10 + + + + Num. Trees: + + + + + + + Max. Depth: + + + + + + + Sample Fraction: + + + 1 999999999 10 - - + + 3 + + -1.000000000000000 + - 1.000000000000000 + 999.000000000000000 0.100000000000000 - 1.000000000000000 - - - - - - - GM Sampling Points: - - - - - - - Use Previous Direction: - - - - - - - - - - true + -1.000000000000000 Start Training Save Forest 0 0 359 1065 Tractography Avoid premature termination true Qt::Vertical 20 40 Secondary curvature check true QFrame::NoFrame QFrame::Raised 0 0 0 0 Demo Mode 1 1000 10 QFrame::NoFrame QFrame::Raised 0 0 0 0 Use Stop Image: Use Seed Image: false Use Mask Image: - true + false QFrame::NoFrame QFrame::Raised 0 0 0 0 0.500000000000000 Max. Length 999999999 50 Num. Samples: Input DWI: Step Size: Sampling Distance: 1 999 Min. Length 0.100000000000000 0.500000000000000 999999999.000000000000000 1.000000000000000 20.000000000000000 Angular Threshold: Num. Seeds: 999999999.000000000000000 1.000000000000000 400.000000000000000 1 90.000000000000000 45.000000000000000 Num. Threads: 1 30 QFrame::NoFrame QFrame::Raised 0 0 0 0 ... :/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-pause.svg ... :/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg ... :/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg:/org_mitk_icons/icons/tango/scalable/actions/media-playback-stop.svg Random sampling false QmitkDataStorageComboBox QComboBox
QmitkDataStorageComboBox.h