diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp index 9369e58ea1..867930cd12 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -1,591 +1,591 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef __itkMLBSTrackingFilter_txx #define __itkMLBSTrackingFilter_txx #include #include #include #include "itkMLBSTrackingFilter.h" #include #include #include #include #define _USE_MATH_DEFINES #include namespace itk { template< int NumImageFeatures > MLBSTrackingFilter< NumImageFeatures > ::MLBSTrackingFilter() : m_PauseTracking(false) , m_AbortTracking(false) , m_FiberPolyData(NULL) , m_Points(NULL) , m_Cells(NULL) , m_AngularThreshold(0.7) , m_StepSize(0) , m_MaxLength(10000) , m_MinTractLength(20.0) , m_MaxTractLength(400.0) , m_SeedsPerVoxel(1) , m_RandomSampling(false) , m_SamplingDistance(-1) , m_NumberOfSamples(50) , m_StoppingRegions(NULL) , m_SeedImage(NULL) , m_MaskImage(NULL) , m_AposterioriCurvCheck(false) , m_RemoveWmEndFibers(false) , m_AvoidStop(true) , m_DemoMode(false) { this->SetNumberOfRequiredInputs(1); } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures > ::RoundToNearest(double num) { return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::BeforeThreadedGenerateData() { m_InputImage = const_cast(this->GetInput(0)); m_ForestHandler.InitForTracking(); m_FiberPolyData = PolyDataType::New(); m_Points = vtkSmartPointer< vtkPoints >::New(); m_Cells = vtkSmartPointer< vtkCellArray >::New(); std::vector< double > m_ImageSpacing; m_ImageSpacing.resize(3); m_ImageSpacing[0] = m_InputImage->GetSpacing()[0]; m_ImageSpacing[1] = m_InputImage->GetSpacing()[1]; m_ImageSpacing[2] = m_InputImage->GetSpacing()[2]; double minSpacing; if(m_ImageSpacing[0]GetNumberOfThreads(); i++) { PolyDataType poly = PolyDataType::New(); m_PolyDataContainer.push_back(poly); } if (m_StoppingRegions.IsNull()) { m_StoppingRegions = ItkUcharImgType::New(); m_StoppingRegions->SetSpacing( m_InputImage->GetSpacing() ); m_StoppingRegions->SetOrigin( m_InputImage->GetOrigin() ); m_StoppingRegions->SetDirection( m_InputImage->GetDirection() ); m_StoppingRegions->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_StoppingRegions->Allocate(); m_StoppingRegions->FillBuffer(0); } if (m_SeedImage.IsNull()) { m_SeedImage = ItkUcharImgType::New(); m_SeedImage->SetSpacing( m_InputImage->GetSpacing() ); m_SeedImage->SetOrigin( m_InputImage->GetOrigin() ); m_SeedImage->SetDirection( m_InputImage->GetDirection() ); m_SeedImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_SeedImage->Allocate(); m_SeedImage->FillBuffer(1); } if (m_MaskImage.IsNull()) { // initialize mask image m_MaskImage = ItkUcharImgType::New(); m_MaskImage->SetSpacing( m_InputImage->GetSpacing() ); m_MaskImage->SetOrigin( m_InputImage->GetOrigin() ); m_MaskImage->SetDirection( m_InputImage->GetDirection() ); m_MaskImage->SetRegions( m_InputImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1); } else std::cout << "MLBSTrackingFilter: using mask image" << std::endl; if (m_AngularThreshold<0.0) m_AngularThreshold = 0.5*minSpacing; m_BuildFibersReady = 0; m_BuildFibersFinished = false; m_Threads = 0; m_Tractogram.clear(); m_SamplingPointset = mitk::PointSet::New(); m_AlternativePointset = mitk::PointSet::New(); m_StartTime = std::chrono::system_clock::now(); std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Number of samples: " << m_NumberOfSamples << std::endl; std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) { // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); // dir = rot*dir; dir *= m_StepSize; pos[0] += dir[0]; pos[1] += dir[1]; pos[2] += dir[2]; } template< int NumImageFeatures > bool MLBSTrackingFilter< NumImageFeatures > ::IsValidPosition(itk::Point &pos) { typename FeatureImageType::IndexType idx; m_InputImage->TransformPhysicalPointToIndex(pos, idx); if (!m_InputImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) return false; return true; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::GetRandDouble(double min, double max) { return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; } template< int NumImageFeatures > vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::GetNewDirection(itk::Point &pos, vnl_vector_fixed& olddir) { if (m_DemoMode) { m_SamplingPointset->Clear(); m_AlternativePointset->Clear(); } vnl_vector_fixed direction; direction.fill(0); ItkUcharImgType::IndexType idx; m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); if (m_StoppingRegions->GetLargestPossibleRegion().IsInside(idx) && m_StoppingRegions->GetPixel(idx)>0) return direction; if (olddir.magnitude()>0) olddir.normalize(); int candidates = 0; // number of directions with probability > 0 - double prob = 0; + double w = 0; // weight of the direction predicted at each sampling point if (IsValidPosition(pos)) { - direction = m_ForestHandler.Classify(pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood - direction *= prob; + direction = m_ForestHandler.Classify(pos, candidates, olddir, m_AngularThreshold, w); // sample neighborhood + direction *= w; // HERE WE ARE WEIGHTING AGAIN!!! THE EFFECT OF THIS HAS YET TO BE EVALUATED QUANTITATIVELY. } itk::OrientationDistributionFunction< double, 50 > probeVecs; itk::Point sample_pos; int alternatives = 1; for (int i=0; i d; if (m_RandomSampling) { d[0] = GetRandDouble(); d[1] = GetRandDouble(); d[2] = GetRandDouble(); d.normalize(); d *= GetRandDouble(0,m_SamplingDistance); } else { d = probeVecs.GetDirection(i)*m_SamplingDistance; } sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_SamplingPointset->InsertPoint(i, sample_pos); candidates = 0; vnl_vector_fixed tempDir; tempDir.fill(0.0); if (IsValidPosition(sample_pos)) - tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood + tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, w); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) { - direction += tempDir*prob; + direction += tempDir*w; } else if (m_AvoidStop && candidates==0 && olddir.magnitude()>0) // out of white matter { double dot = dot_product(d, olddir); if (dot >= 0.0) // in front of plane defined by pos and olddir d = -d + 2*dot*olddir; // reflect else d = -d; // invert // look a bit further into the other direction sample_pos[0] = pos[0] + d[0]; sample_pos[1] = pos[1] + d[1]; sample_pos[2] = pos[2] + d[2]; if(m_DemoMode) m_AlternativePointset->InsertPoint(alternatives, sample_pos); alternatives++; candidates = 0; vnl_vector_fixed tempDir; tempDir.fill(0.0); if (IsValidPosition(sample_pos)) - tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood + tempDir = m_ForestHandler.Classify(sample_pos, candidates, olddir, m_AngularThreshold, w); // sample neighborhood if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? { direction += d; // go into the direction of the white matter - direction += tempDir*prob; // go into the direction of the white matter direction at this location + direction += tempDir*w; // go into the direction of the white matter direction at this location } } } if (direction.magnitude()>0.001) { direction.normalize(); olddir[0] = direction[0]; olddir[1] = direction[1]; olddir[2] = direction[2]; } else direction.fill(0); return direction; } template< int NumImageFeatures > double MLBSTrackingFilter< NumImageFeatures >::FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) { vnl_vector_fixed dirOld = dir; dirOld = dir; for (int step=0; step< m_MaxLength/2; step++) { // get new position CalculateNewPosition(pos, dir); // is new position inside of image and mask if (m_AbortTracking) // if not end streamline { return tractLength; } else // if yes, add new point to streamline { tractLength += m_StepSize; if (front) fib->push_front(pos); else fib->push_back(pos); if (m_AposterioriCurvCheck) { int curv = CheckCurvature(fib, front); // TODO: Move into classification ??? if (curv>0) { tractLength -= m_StepSize*curv; while (curv>0) { if (front) fib->pop_front(); else fib->pop_back(); curv--; } return tractLength; } } if (tractLength>m_MaxTractLength) return tractLength; } if (m_DemoMode) // CHECK: warum sind die samplingpunkte der streamline in der visualisierung immer einen schritt voras? { m_Mutex.Lock(); m_BuildFibersReady++; m_Tractogram.push_back(*fib); BuildFibers(true); m_Stop = true; m_Mutex.Unlock(); while (m_Stop){ } } dir = GetNewDirection(pos, dirOld); while (m_PauseTracking){} if (dir.magnitude()<0.0001) return tractLength; } return tractLength; } template< int NumImageFeatures > int MLBSTrackingFilter::CheckCurvature(FiberType* fib, bool front) { double m_Distance = 5; if (fib->size()<3) return 0; double dist = 0; std::vector< vnl_vector_fixed< float, 3 > > vectors; vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); double dev = 0; if (front) { int c=0; while(distsize()-1) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c+1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==0) meanV += v; c++; } } else { int c=fib->size()-1; while(dist0) { itk::Point p1 = fib->at(c); itk::Point p2 = fib->at(c-1); vnl_vector_fixed< float, 3 > v; v[0] = p2[0]-p1[0]; v[1] = p2[1]-p1[1]; v[2] = p2[2]-p1[2]; dist += v.magnitude(); v.normalize(); vectors.push_back(v); if (c==fib->size()-1) meanV += v; c--; } } meanV.normalize(); for (int c=0; c1.0) angle = 1.0; if (angle<-1.0) angle = -1.0; dev += acos(angle)*180/M_PI; } if (vectors.size()>0) dev /= vectors.size(); if (dev<30) return 0; else return vectors.size(); } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) { m_Mutex.Lock(); m_Threads++; m_Mutex.Unlock(); typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; MaskIteratorType sit(m_SeedImage, regionForThread ); MaskIteratorType mit(m_MaskImage, regionForThread ); sit.GoToBegin(); mit.GoToBegin(); itk::Point worldPos; while( !sit.IsAtEnd() ) { if (sit.Value()==0 || mit.Value()==0) { ++sit; ++mit; continue; } for (int s=0; s start; unsigned int counter = 0; if (m_SeedsPerVoxel>1) { start[0] = index[0]+GetRandDouble(-0.5, 0.5); start[1] = index[1]+GetRandDouble(-0.5, 0.5); start[2] = index[2]+GetRandDouble(-0.5, 0.5); } else { start[0] = index[0]; start[1] = index[1]; start[2] = index[2]; } // get staring position m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); // get starting direction int candidates = 0; double prob = 0; vnl_vector_fixed dirOld; dirOld.fill(0.0); vnl_vector_fixed dir; dir.fill(0.0); if (IsValidPosition(worldPos)) dir = m_ForestHandler.Classify(worldPos, candidates, dirOld, 0, prob); if (dir.magnitude()<0.0001) continue; // forward tracking tractLength = FollowStreamline(threadId, worldPos, dir, &fib, 0, false); fib.push_front(worldPos); if (m_RemoveWmEndFibers) { itk::Point check = fib.back(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } // backward tracking tractLength = FollowStreamline(threadId, worldPos, -dir, &fib, tractLength, true); counter = fib.size(); if (m_RemoveWmEndFibers) { itk::Point check = fib.front(); dirOld.fill(0.0); vnl_vector_fixed check2 = GetNewDirection(check, dirOld); if (check2.magnitude()>0.001) { MITK_INFO << "Detected WM ending. Discarding fiber."; continue; } } if (tractLength void MLBSTrackingFilter< NumImageFeatures >::BuildFibers(bool check) { if (m_BuildFibersReady::New(); vtkSmartPointer vNewLines = vtkSmartPointer::New(); vtkSmartPointer vNewPoints = vtkSmartPointer::New(); for (int i=0; i container = vtkSmartPointer::New(); FiberType fib = m_Tractogram.at(i); for (FiberType::iterator it = fib.begin(); it!=fib.end(); ++it) { vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); container->GetPointIds()->InsertNextId(id); } vNewLines->InsertNextCell(container); } if (check) for (int i=0; iSetPoints(vNewPoints); m_FiberPolyData->SetLines(vNewLines); m_BuildFibersFinished = true; } template< int NumImageFeatures > void MLBSTrackingFilter< NumImageFeatures >::AfterThreadedGenerateData() { MITK_INFO << "Generating polydata "; BuildFibers(false); MITK_INFO << "done"; m_EndTime = std::chrono::system_clock::now(); std::chrono::hours hh = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::minutes mm = std::chrono::duration_cast((m_EndTime - m_StartTime) % std::chrono::hours(1)); MITK_INFO << "Tracking took " << hh.count() << "h and " << mm.count() << "m"; } } #endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp index 56f0f524ab..1707d6ffe7 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp @@ -1,697 +1,699 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _TrackingForestHandler_cpp #define _TrackingForestHandler_cpp #include "mitkTrackingForestHandler.h" #include #include namespace mitk { template< int ShOrder, int NumberOfSignalFeatures > TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::TrackingForestHandler() : m_WmSampleDistance(-1) , m_NumTrees(30) , m_MaxTreeDepth(50) , m_SampleFraction(1.0) , m_NumberOfSamples(0) , m_GmSamplesPerVoxel(50) { } template< int ShOrder, int NumberOfSignalFeatures > TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::~TrackingForestHandler() { } template< int ShOrder, int NumberOfSignalFeatures > typename TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InterpolatedRawImageType::PixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image) { + // transform physical point to index coordinates itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; image->TransformPhysicalPointToIndex(itkP, idx); image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); typename InterpolatedRawImageType::PixelType pix; pix.Fill(0.0); if ( image->GetLargestPossibleRegion().IsInside(idx) ) pix = image->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) { + // trilinear interpolation vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = image->GetPixel(idx) * interpWeights[0]; typename InterpolatedRawImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += image->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += image->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += image->GetPixel(tmpIdx) * interpWeights[7]; } return pix; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InputDataValidForTracking() { if (m_RawData.empty()) mitkThrow() << "No diffusion-weighted images set!"; if (!IsForestValid()) mitkThrow() << "No or invalid random forest detected!"; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InitForTracking() { InputDataValidForTracking(); - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; - MITK_INFO << "Spherically interpolating raw data and creating feature image ..."; + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(0)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(0)) ); filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(0))); filter->SetLambda(0.006); filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); filter->Update(); vnl_vector_fixed ref; ref.fill(0); ref[0]=1; itk::OrientationDistributionFunction< double, NumberOfSignalFeatures*2 > odf; m_DirectionIndices.clear(); for (unsigned int f=0; f0) // only used directions on one hemisphere - m_DirectionIndices.push_back(f); + m_DirectionIndices.push_back(f); // store indices for later mapping the classifier output to the actual direction } m_FeatureImage = FeatureImageType::New(); m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection()); m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); m_FeatureImage->Allocate(); + // get signal values and store them in the feature image itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { typename FeatureImageType::PixelType pix; for (unsigned int f=0; fSetPixel(it.GetIndex(), pix); ++it; } } template< int ShOrder, int NumberOfSignalFeatures > typename TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::FeatureImageType::PixelType TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::GetFeatureValues(itk::Point itkP) { + // transform physical point to index coordinates itk::Index<3> idx; itk::ContinuousIndex< double, 3> cIdx; m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx); m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx); typename FeatureImageType::PixelType pix; pix.Fill(0.0); if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) ) pix = m_FeatureImage->GetPixel(idx); else return pix; double frac_x = cIdx[0] - idx[0]; double frac_y = cIdx[1] - idx[1]; double frac_z = cIdx[2] - idx[2]; if (frac_x<0) { idx[0] -= 1; frac_x += 1; } if (frac_y<0) { idx[1] -= 1; frac_y += 1; } if (frac_z<0) { idx[2] -= 1; frac_z += 1; } frac_x = 1-frac_x; frac_y = 1-frac_y; frac_z = 1-frac_z; // int coordinates inside image? if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 && idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 && idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1) { + // trilinear interpolation vnl_vector_fixed interpWeights; interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); pix = m_FeatureImage->GetPixel(idx) * interpWeights[0]; typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1]; tmpIdx = idx; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2]; tmpIdx = idx; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4]; tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5]; tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6]; tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7]; } return pix; } template< int ShOrder, int NumberOfSignalFeatures > -vnl_vector_fixed TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob) +vnl_vector_fixed TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& w) { vnl_vector_fixed direction; direction.fill(0); + // store feature pixel values in a vigra data type vigra::MultiArray<2, double> featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumberOfSignalFeatures+3) ); - typename FeatureImageType::PixelType featurePixel = GetFeatureValues(pos); - - // pixel values for (unsigned int f=0; f ref; ref.fill(0); ref[0]=1; + for (unsigned int f=NumberOfSignalFeatures; f ref; ref.fill(0); ref[0]=1; - for (unsigned int f=NumberOfSignalFeatures; f probs(vigra::Shape2(1, m_Forest->class_count())); m_Forest->predictProbabilities(featureData, probs); - double outProb = 0; - prob = 0; - candidates = 0; // directions with probability > 0 - for (int i=0; iclass_count(); i++) + double pNonFib = 0; // probability that we left the white matter + w = 0; // weight of the predicted direction + candidates = 0; // directions with probability > 0 + for (int i=0; iclass_count(); i++) // for each class (number of possible directions + out-of-wm class) { - if (probs(0,i)>0) + if (probs(0,i)>0) // if probability of respective class is 0, do nothing { + // get label of class (does not correspond to the loop variable i) int classLabel = 0; m_Forest->ext_param_.to_classlabel(i, classLabel); - if (classLabel d = m_DirContainer.GetDirection(m_DirectionIndices.at(classLabel)); - double dot = dot_product(d, olddir); + candidates++; // now we have one direction more with probability > 0 (DO WE NEED THIS???) + vnl_vector_fixed d = m_DirContainer.GetDirection(m_DirectionIndices.at(classLabel)); // get direction vector assiciated with the respective direction index - if (olddir.magnitude()>0) + if (olddir.magnitude()>0) // do we have a previous streamline direction or did we just start? { - if (fabs(dot)>angularThreshold) + double dot = dot_product(d, olddir); // claculate angle between the candidate direction vector and the previous streamline direction + if (fabs(dot)>angularThreshold) // is angle between the directions smaller than our hard threshold? { - if (dot<0) + if (dot<0) // make sure we don't walk backwards d *= -1; - dot = fabs(dot); - direction += probs(0,i)*dot*d; - prob += probs(0,i)*dot; + double w_i = probs(0,i)*fabs(dot); + direction += w_i*d; // weight contribution to output direction with its probability and the angular deviation from the previous direction + w += w_i; // increase output weight of the final direction } } else { direction += probs(0,i)*d; - prob += probs(0,i); + w += probs(0,i); } } else - outProb += probs(0,i); + pNonFib += probs(0,i); // probability that we are not in the whte matter anymore } } - if (outProb>prob && prob>0) + // if we did not find a suitable direction, make sure that we return (0,0,0) + if (pNonFib>w && w>0) { candidates = 0; - prob = 0; + w = 0; direction.fill(0.0); } return direction; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::StartTraining() { m_StartTime = std::chrono::system_clock::now(); InputDataValidForTraining(); PreprocessInputDataForTraining(); CalculateFeaturesForTraining(); TrainForest(); m_EndTime = std::chrono::system_clock::now(); std::chrono::hours hh = std::chrono::duration_cast(m_EndTime - m_StartTime); std::chrono::minutes mm = std::chrono::duration_cast((m_EndTime - m_StartTime) % std::chrono::hours(1)); MITK_INFO << "Training took " << hh.count() << "h and " << mm.count() << "m"; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::InputDataValidForTraining() { if (m_RawData.empty()) mitkThrow() << "No diffusion-weighted images set!"; if (m_Tractograms.empty()) mitkThrow() << "No tractograms set!"; if (m_RawData.size()!=m_Tractograms.size()) mitkThrow() << "Unequal number of diffusion-weighted images and tractograms detected!"; } template< int ShOrder, int NumberOfSignalFeatures > bool TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::IsForestValid() { if(m_Forest && m_Forest->tree_count()>0 && m_Forest->feature_count()==(NumberOfSignalFeatures+3)) return true; return false; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::PreprocessInputDataForTraining() { typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; MITK_INFO << "Spherical signal interpolation and sampling ..."; for (unsigned int i=0; iSetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(i)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(i)) ); qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(i))); qballfilter->SetLambda(0.006); qballfilter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); qballfilter->Update(); // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); m_InterpolatedRawImages.push_back(qballfilter->GetOutput()); if (i>=m_MaskImages.size()) { ItkUcharImgType::Pointer newMask = ItkUcharImgType::New(); newMask->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); newMask->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); newMask->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); newMask->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); newMask->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); newMask->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); newMask->Allocate(); newMask->FillBuffer(1); m_MaskImages.push_back(newMask); } if (m_MaskImages.at(i)==nullptr) { m_MaskImages.at(i) = ItkUcharImgType::New(); m_MaskImages.at(i)->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); m_MaskImages.at(i)->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); m_MaskImages.at(i)->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); m_MaskImages.at(i)->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); m_MaskImages.at(i)->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); m_MaskImages.at(i)->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); m_MaskImages.at(i)->Allocate(); m_MaskImages.at(i)->FillBuffer(1); } } MITK_INFO << "Resampling fibers and calculating number of samples ..."; m_NumberOfSamples = 0; for (unsigned int t=0; t::Pointer env = itk::TractDensityImageFilter< ItkUcharImgType >::New(); env->SetFiberBundle(m_Tractograms.at(t)); env->SetInputImage(mask); env->SetBinaryOutput(true); env->SetUseImageGeometry(true); env->Update(); wmmask = env->GetOutput(); if (t>=m_WhiteMatterImages.size()) m_WhiteMatterImages.push_back(wmmask); else m_WhiteMatterImages.at(t) = wmmask; } // Calculate white-matter samples if (m_WmSampleDistance<0) { typename InterpolatedRawImageType::Pointer image = m_InterpolatedRawImages.at(t); float minSpacing = 1; if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) minSpacing = image->GetSpacing()[0]; else if (image->GetSpacing()[1] < image->GetSpacing()[2]) minSpacing = image->GetSpacing()[1]; else minSpacing = image->GetSpacing()[2]; m_WmSampleDistance = minSpacing*0.5; } m_Tractograms.at(t)->ResampleSpline(m_WmSampleDistance); unsigned int wmSamples = m_Tractograms.at(t)->GetNumberOfPoints()-2*m_Tractograms.at(t)->GetNumFibers(); m_NumberOfSamples += wmSamples; MITK_INFO << "Samples inside of WM: " << wmSamples; // calculate gray-matter samples itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); int OUTOFWM = 0; while(!it.IsAtEnd()) { if (it.Get()==0 && mask->GetPixel(it.GetIndex())>0) OUTOFWM++; ++it; } MITK_INFO << "Non-white matter voxels: " << OUTOFWM; if (m_GmSamplesPerVoxel>0) { m_GmSamples.push_back(m_GmSamplesPerVoxel); m_NumberOfSamples += m_GmSamplesPerVoxel*OUTOFWM; } else if (OUTOFWM>0) { m_GmSamples.push_back(0.5+(double)wmSamples/(double)OUTOFWM); m_NumberOfSamples += m_GmSamples.back()*OUTOFWM; MITK_INFO << "Non-white matter samples per voxel: " << m_GmSamples.back(); } else { m_GmSamples.push_back(0); } MITK_INFO << "Samples outside of WM: " << m_GmSamples.back()*OUTOFWM; } MITK_INFO << "Number of samples: " << m_NumberOfSamples; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::CalculateFeaturesForTraining() { vnl_vector_fixed ref; ref.fill(0); ref[0]=1; itk::OrientationDistributionFunction< double, 2*NumberOfSignalFeatures > directions; std::vector< int > directionIndices; for (unsigned int f=0; f<2*NumberOfSignalFeatures; f++) { if (dot_product(ref, directions.GetDirection(f))>0) directionIndices.push_back(f); } int numDirectionFeatures = 3; m_FeatureData.reshape( vigra::Shape2(m_NumberOfSamples, NumberOfSignalFeatures+numDirectionFeatures) ); m_LabelData.reshape( vigra::Shape2(m_NumberOfSamples,1) ); MITK_INFO << "Number of features: " << m_FeatureData.shape(1); itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); m_RandGen->SetSeed(); MITK_INFO << "Creating training data ..."; int sampleCounter = 0; for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) { typename InterpolatedRawImageType::PixelType pix = image->GetPixel(it.GetIndex()); // null direction for (unsigned int f=0; f probe; probe[0] = m_RandGen->GetVariate()*2-1; probe[1] = m_RandGen->GetVariate()*2-1; probe[2] = m_RandGen->GetVariate()*2-1; probe.normalize(); if (dot_product(ref, probe)<0) probe *= -1; for (unsigned int f=NumberOfSignalFeatures; f polyData = fib->GetFiberPolyData(); for (int i=0; iGetNumFibers(); i++) { vtkCell* cell = polyData->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); vnl_vector_fixed dirOld; dirOld.fill(0.0); for (int j=0; jGetPoint(j); itk::Point itkP1; itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; vnl_vector_fixed dir; dir.fill(0.0); itk::Point itkP2; double* p2 = points->GetPoint(j+1); itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; dir[0]=itkP2[0]-itkP1[0]; dir[1]=itkP2[1]-itkP1[1]; dir[2]=itkP2[2]-itkP1[2]; if (dir.magnitude()<0.0001) { MITK_INFO << "streamline error!"; continue; } dir.normalize(); if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) { MITK_INFO << "ERROR: NaN direction!"; continue; } if (j==0) { dirOld = dir; continue; } // get voxel values typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image); for (unsigned int f=0; f0.0001) { for (unsigned int f=0; fangle) { m_LabelData(sampleCounter,0) = f; angle = a; } } } dirOld = dir; sampleCounter++; } } } } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::TrainForest() { MITK_INFO << "Maximum tree depths: " << m_MaxTreeDepth; MITK_INFO << "Sample fraction per tree: " << m_SampleFraction; MITK_INFO << "Number of trees: " << m_NumTrees; std::vector< std::shared_ptr< vigra::RandomForest > > trees; int count = 0; #pragma omp parallel for for (int i = 0; i < m_NumTrees; ++i) { std::shared_ptr< vigra::RandomForest > lrf = std::make_shared< vigra::RandomForest >(); lrf->set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal lrf->set_options().sample_with_replacement(true); // if sampled with replacement or not lrf->set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree lrf->set_options().tree_count(1); // Number of trees that are calculated; lrf->set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node lrf->ext_param_.max_tree_depth = m_MaxTreeDepth; lrf->learn(m_FeatureData, m_LabelData); #pragma omp critical { count++; MITK_INFO << "Tree " << count << " finished training."; trees.push_back(lrf); } } for (int i = 1; i < m_NumTrees; ++i) trees.at(0)->trees_.push_back(trees.at(i)->trees_[0]); m_Forest = trees.at(0); m_Forest->options_.tree_count_ = m_NumTrees; MITK_INFO << "Training finsihed"; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::SaveForest(std::string forestFile) { MITK_INFO << "Saving forest to " << forestFile; if (IsForestValid()) vigra::rf_export_HDF5( *m_Forest, forestFile, "" ); else MITK_INFO << "Forest invalid! Could not be saved."; } template< int ShOrder, int NumberOfSignalFeatures > void TrackingForestHandler< ShOrder, NumberOfSignalFeatures >::LoadForest(std::string forestFile) { MITK_INFO << "Loading forest from " << forestFile; m_Forest = std::make_shared< vigra::RandomForest >(); vigra::rf_import_HDF5( *m_Forest, forestFile); } } #endif diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h index 16d9b98da9..1a958f1650 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h @@ -1,132 +1,132 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _TrackingForestHandler #define _TrackingForestHandler #include "mitkBaseData.h" #include #include #include #include #include #include #undef DIFFERENCE #define VIGRA_STATIC_LIB #include #include #include #define _USE_MATH_DEFINES #include namespace mitk { /** * \brief Manages random forests for fiber tractography. The preparation of the features from the inputa data and the training process are handled here. The data preprocessing and actual prediction for the tracking process is also performed here. The tracking itself is performed in MLBSTrackingFilter. */ template< int ShOrder=6, int NumberOfSignalFeatures=100 > class TrackingForestHandler { public: TrackingForestHandler(); ~TrackingForestHandler(); typedef itk::Image ItkUcharImgType; typedef itk::Image< itk::Vector< float, NumberOfSignalFeatures*2 > , 3 > InterpolatedRawImageType; typedef itk::Image< Vector< float, NumberOfSignalFeatures > , 3 > FeatureImageType; void SetRawData( std::vector< Image::Pointer > images ){ m_RawData = images; } void AddRawData( Image::Pointer img ){ m_RawData.push_back(img); } void SetTractograms( std::vector< FiberBundle::Pointer > tractograms ) { m_Tractograms.clear(); for (auto fib : tractograms) { m_Tractograms.push_back(fib->GetDeepCopy()); } } void SetMaskImages( std::vector< ItkUcharImgType::Pointer > images ){ m_MaskImages = images; } void SetWhiteMatterImages( std::vector< ItkUcharImgType::Pointer > images ){ m_WhiteMatterImages = images; } void StartTraining(); void SaveForest(std::string forestFile); void LoadForest(std::string forestFile); // training parameters void SetNumTrees(int num){ m_NumTrees = num; } void SetMaxTreeDepth(int depth){ m_MaxTreeDepth = depth; } void SetStepSize(double step){ m_WmSampleDistance = step; } void SetGrayMatterSamplesPerVoxel(int samples){ m_GmSamplesPerVoxel = samples; } void SetSampleFraction(double fraction){ m_SampleFraction = fraction; } std::shared_ptr< vigra::RandomForest > GetForest(){ return m_Forest; } void InitForTracking(); ///< calls InputDataValidForTracking() and creates feature images from the war input DWI - vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob); ///< predicts next progression direction at the given position + vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& w); ///< predicts next progression direction at the given position bool IsForestValid(); ///< true is forest is not null, has more than 0 trees and the correct number of features (NumberOfSignalFeatures + 3) protected: // tracking void InputDataValidForTracking(); ///< check if raw data is set and tracking forest is valid typename FeatureImageType::PixelType GetFeatureValues(itk::Point itkP); ///< get trilinearly interpolated feature values at given world position // training void InputDataValidForTraining(); ///< Check if everything is tehere for training (raw datasets, fiber tracts) void PreprocessInputDataForTraining(); ///< Generate masks if necessary, resample fibers, spherically interpolate raw DWIs void CalculateFeaturesForTraining(); ///< Calculate GM and WM features using the interpolated raw data, the WM masks and the fibers void TrainForest(); ///< start training process typename InterpolatedRawImageType::PixelType GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image); ///< get trilinearly interpolated raw image values at given world position std::vector< Image::Pointer > m_RawData; ///< original input DWI data std::shared_ptr< vigra::RandomForest > m_Forest; ///< random forest classifier std::chrono::time_point m_StartTime; std::chrono::time_point m_EndTime; // only for training std::vector< FiberBundle::Pointer > m_Tractograms; ///< training tractograms std::vector< ItkUcharImgType::Pointer > m_MaskImages; ///< binary mask images to constrain training to a certain area (e.g. brain mask) std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; ///< defines white matter voxels. if not set, theses mask images are automatically generated from the input tractograms std::vector< typename InterpolatedRawImageType::Pointer > m_InterpolatedRawImages; ///< spherically interpolated and resampled raw datasets double m_WmSampleDistance; ///< deterines the number of white matter samples (distance of sampling points on each fiber). int m_NumTrees; ///< number of trees in random forest int m_MaxTreeDepth; ///< limits the tree depth double m_SampleFraction; ///< fraction of samples used to train each tree unsigned int m_NumberOfSamples; ///< stores overall number of samples used for training std::vector< unsigned int > m_GmSamples; ///< number of gray matter samples int m_GmSamplesPerVoxel; ///< number of gray matter samplees per voxel. if -1, then the number is automatically chosen to gain an overall number of GM samples close to the number of WM samples. vigra::MultiArray<2, double> m_FeatureData; ///< vigra container for training features // only for tracking typename FeatureImageType::Pointer m_FeatureImage; ///< feature image used for tracking vigra::MultiArray<2, double> m_LabelData; ///< vigra container for training labels std::vector< int > m_DirectionIndices; ///< maps each of the NumberOfSignalFeatures possible output directions to one of the 2*NumberOfSignalFeatures ODF directions. itk::OrientationDistributionFunction< double, NumberOfSignalFeatures*2 > m_DirContainer; ///< direction container }; } #include "mitkTrackingForestHandler.cpp" #endif