diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp new file mode 100644 index 0000000000..a45ecb0df7 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.cpp @@ -0,0 +1,762 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#ifndef __itkMLBSTrackingFilter_txx +#define __itkMLBSTrackingFilter_txx + +#include +#include +#include + +#include "itkMLBSTrackingFilter.h" +#include +#include +#include +#include +#include + +#define _USE_MATH_DEFINES +#include + +namespace itk { + +template< int NumImageFeatures > +MLBSTrackingFilter< NumImageFeatures > +::MLBSTrackingFilter() + : m_FiberPolyData(NULL) + , m_Points(NULL) + , m_Cells(NULL) + , m_AngularThreshold(0.7) + , m_StepSize(0) + , m_MaxLength(10000) + , m_MinTractLength(20.0) + , m_MaxTractLength(400.0) + , m_SeedsPerVoxel(1) + , m_UseDirection(true) + , m_NumberOfSamples(50) + , m_SamplingDistance(-1) + , m_SeedImage(NULL) + , m_MaskImage(NULL) + , m_DecisionForest(NULL) + , m_StoppingRegions(NULL) + , m_DemoMode(false) + , m_PauseTracking(false) + , m_AbortTracking(false) +{ + this->SetNumberOfRequiredInputs(1); +} + +template< int NumImageFeatures > +double MLBSTrackingFilter< NumImageFeatures > +::RoundToNearest(double num) { + return (num > 0.0) ? floor(num + 0.5) : ceil(num - 0.5); +} + +template< int NumImageFeatures > +void MLBSTrackingFilter< NumImageFeatures >::BeforeThreadedGenerateData() +{ + m_InputImage = const_cast(this->GetInput(0)); + PreprocessRawData(); + + m_FiberPolyData = PolyDataType::New(); + m_Points = vtkSmartPointer< vtkPoints >::New(); + m_Cells = vtkSmartPointer< vtkCellArray >::New(); + + m_ImageSize.resize(3); + m_ImageSize[0] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[0]; + m_ImageSize[1] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[1]; + m_ImageSize[2] = m_FeatureImage->GetLargestPossibleRegion().GetSize()[2]; + + m_ImageSpacing.resize(3); + m_ImageSpacing[0] = m_FeatureImage->GetSpacing()[0]; + m_ImageSpacing[1] = m_FeatureImage->GetSpacing()[1]; + m_ImageSpacing[2] = m_FeatureImage->GetSpacing()[2]; + + double minSpacing; + if(m_ImageSpacing[0]GetNumberOfThreads(); i++) + { + PolyDataType poly = PolyDataType::New(); + m_PolyDataContainer.push_back(poly); + } + + m_NotWmImage = ItkDoubleImgType::New(); + m_NotWmImage->SetSpacing( m_FeatureImage->GetSpacing() ); + m_NotWmImage->SetOrigin( m_FeatureImage->GetOrigin() ); + m_NotWmImage->SetDirection( m_FeatureImage->GetDirection() ); + m_NotWmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_NotWmImage->Allocate(); + m_NotWmImage->FillBuffer(0); + + m_WmImage = ItkDoubleImgType::New(); + m_WmImage->SetSpacing( m_FeatureImage->GetSpacing() ); + m_WmImage->SetOrigin( m_FeatureImage->GetOrigin() ); + m_WmImage->SetDirection( m_FeatureImage->GetDirection() ); + m_WmImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_WmImage->Allocate(); + m_WmImage->FillBuffer(0); + + m_AvoidStopImage = ItkDoubleImgType::New(); + m_AvoidStopImage->SetSpacing( m_FeatureImage->GetSpacing() ); + m_AvoidStopImage->SetOrigin( m_FeatureImage->GetOrigin() ); + m_AvoidStopImage->SetDirection( m_FeatureImage->GetDirection() ); + m_AvoidStopImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_AvoidStopImage->Allocate(); + m_AvoidStopImage->FillBuffer(0); + + if (m_StoppingRegions.IsNull()) + { + m_StoppingRegions = ItkUcharImgType::New(); + m_StoppingRegions->SetSpacing( m_FeatureImage->GetSpacing() ); + m_StoppingRegions->SetOrigin( m_FeatureImage->GetOrigin() ); + m_StoppingRegions->SetDirection( m_FeatureImage->GetDirection() ); + m_StoppingRegions->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_StoppingRegions->Allocate(); + m_StoppingRegions->FillBuffer(0); + } + + if (m_SeedImage.IsNull()) + { + m_SeedImage = ItkUcharImgType::New(); + m_SeedImage->SetSpacing( m_FeatureImage->GetSpacing() ); + m_SeedImage->SetOrigin( m_FeatureImage->GetOrigin() ); + m_SeedImage->SetDirection( m_FeatureImage->GetDirection() ); + m_SeedImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_SeedImage->Allocate(); + m_SeedImage->FillBuffer(1); + } + + if (m_MaskImage.IsNull()) + { + // initialize mask image + m_MaskImage = ItkUcharImgType::New(); + m_MaskImage->SetSpacing( m_FeatureImage->GetSpacing() ); + m_MaskImage->SetOrigin( m_FeatureImage->GetOrigin() ); + m_MaskImage->SetDirection( m_FeatureImage->GetDirection() ); + m_MaskImage->SetRegions( m_FeatureImage->GetLargestPossibleRegion() ); + m_MaskImage->Allocate(); + m_MaskImage->FillBuffer(1); + } + else + std::cout << "MLBSTrackingFilter: using mask image" << std::endl; + + if (m_AngularThreshold<0.0) + m_AngularThreshold = 0.5*minSpacing; + + m_BuildFibersReady = 0; + m_BuildFibersFinished = false; + m_Threads = 0; + m_Tractogram.clear(); + + std::cout << "MLBSTrackingFilter: Angular threshold: " << m_AngularThreshold << std::endl; + std::cout << "MLBSTrackingFilter: Stepsize: " << m_StepSize << " mm" << std::endl; + std::cout << "MLBSTrackingFilter: Seeds per voxel: " << m_SeedsPerVoxel << std::endl; + std::cout << "MLBSTrackingFilter: Max. sampling distance: " << m_SamplingDistance << " mm" << std::endl; + std::cout << "MLBSTrackingFilter: Number of samples: " << m_NumberOfSamples << std::endl; + std::cout << "MLBSTrackingFilter: Max. tract length: " << m_MaxTractLength << " mm" << std::endl; + std::cout << "MLBSTrackingFilter: Min. tract length: " << m_MinTractLength << " mm" << std::endl; + std::cout << "MLBSTrackingFilter: Starting streamline tracking using " << this->GetNumberOfThreads() << " threads." << std::endl; +} + +template< int NumImageFeatures > +void MLBSTrackingFilter< NumImageFeatures >::PreprocessRawData() +{ + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; + + std::cout << "MLBSTrackingFilter: Spherical signal interpolation and sampling ..." << std::endl; + typename InterpolationFilterType::Pointer filter = InterpolationFilterType::New(); + filter->SetGradientImage( m_GradientDirections, m_InputImage ); + filter->SetBValue( m_B_Value ); + filter->SetLambda(0.006); + filter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); + filter->Update(); + // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); + // featureImageVector.push_back(itkFeatureImage); + + std::cout << "MLBSTrackingFilter: Creating feature image ..." << std::endl; + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + itk::OrientationDistributionFunction< double, NumImageFeatures*2 > odf; + m_DirectionIndices.clear(); + for (unsigned int f=0; f0) // only used directions on one hemisphere + m_DirectionIndices.push_back(f); + } + + m_FeatureImage = FeatureImageType::New(); + m_FeatureImage->SetSpacing(filter->GetOutput()->GetSpacing()); + m_FeatureImage->SetOrigin(filter->GetOutput()->GetOrigin()); + m_FeatureImage->SetDirection(filter->GetOutput()->GetDirection()); + m_FeatureImage->SetLargestPossibleRegion(filter->GetOutput()->GetLargestPossibleRegion()); + m_FeatureImage->SetBufferedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + m_FeatureImage->SetRequestedRegion(filter->GetOutput()->GetLargestPossibleRegion()); + m_FeatureImage->Allocate(); + + itk::ImageRegionIterator< typename InterpolationFilterType::OutputImageType > it(filter->GetOutput(), filter->GetOutput()->GetLargestPossibleRegion()); + while(!it.IsAtEnd()) + { + typename FeatureImageType::PixelType pix; + for (unsigned int f=0; fSetPixel(it.GetIndex(), pix); + ++it; + } +} + +template< int NumImageFeatures > +void MLBSTrackingFilter< NumImageFeatures >::CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir) +{ + // vnl_matrix_fixed< double, 3, 3 > rot = m_FeatureImage->GetDirection().GetTranspose(); + // dir = rot*dir; + + dir *= m_StepSize; + pos[0] += dir[0]; + pos[1] += dir[1]; + pos[2] += dir[2]; +} + +template< int NumImageFeatures > +bool MLBSTrackingFilter< NumImageFeatures > +::IsValidPosition(itk::Point &pos) +{ + typename FeatureImageType::IndexType idx; + m_FeatureImage->TransformPhysicalPointToIndex(pos, idx); + if (!m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) || m_MaskImage->GetPixel(idx)==0) + return false; + + return true; +} + +template< int NumImageFeatures > +typename MLBSTrackingFilter< NumImageFeatures >::FeatureImageType::PixelType MLBSTrackingFilter< NumImageFeatures >::GetImageValues(itk::Point itkP) +{ + itk::Index<3> idx; + itk::ContinuousIndex< double, 3> cIdx; + m_FeatureImage->TransformPhysicalPointToIndex(itkP, idx); + m_FeatureImage->TransformPhysicalPointToContinuousIndex(itkP, cIdx); + + typename FeatureImageType::PixelType pix; pix.Fill(0.0); + if ( m_FeatureImage->GetLargestPossibleRegion().IsInside(idx) ) + pix = m_FeatureImage->GetPixel(idx); + else + return pix; + + double frac_x = cIdx[0] - idx[0]; + double frac_y = cIdx[1] - idx[1]; + double frac_z = cIdx[2] - idx[2]; + if (frac_x<0) + { + idx[0] -= 1; + frac_x += 1; + } + if (frac_y<0) + { + idx[1] -= 1; + frac_y += 1; + } + if (frac_z<0) + { + idx[2] -= 1; + frac_z += 1; + } + frac_x = 1-frac_x; + frac_y = 1-frac_y; + frac_z = 1-frac_z; + + // int coordinates inside image? + if (idx[0] >= 0 && idx[0] < m_FeatureImage->GetLargestPossibleRegion().GetSize(0)-1 && + idx[1] >= 0 && idx[1] < m_FeatureImage->GetLargestPossibleRegion().GetSize(1)-1 && + idx[2] >= 0 && idx[2] < m_FeatureImage->GetLargestPossibleRegion().GetSize(2)-1) + { + vnl_vector_fixed interpWeights; + interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); + interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); + interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); + interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); + interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); + interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); + interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); + interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); + + pix = m_FeatureImage->GetPixel(idx) * interpWeights[0]; + typename FeatureImageType::IndexType tmpIdx = idx; tmpIdx[0]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[1]; + tmpIdx = idx; tmpIdx[1]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[2]; + tmpIdx = idx; tmpIdx[2]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[3]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[4]; + tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[5]; + tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[6]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; + pix += m_FeatureImage->GetPixel(tmpIdx) * interpWeights[7]; + } +} + +template< int NumImageFeatures > +vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop) +{ + vnl_vector_fixed direction; direction.fill(0); + + vigra::MultiArray<2, double> featureData; + if(m_UseDirection) + featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumImageFeatures+3) ); + else + featureData = vigra::MultiArray<2, double>( vigra::Shape2(1,NumImageFeatures) ); + + typename FeatureImageType::PixelType featurePixel = GetImageValues(pos); + + // pixel values + for (unsigned int f=0; f ref; ref.fill(0); ref[0]=1; + for (unsigned int f=NumImageFeatures; f probs(vigra::Shape2(1, m_DecisionForest->class_count())); + m_DecisionForest->predictProbabilities(featureData, probs); + + double outProb = 0; + prob = 0; + candidates = 0; // directions with probability > 0 + for (int i=0; iclass_count(); i++) + { + if (probs(0,i)>0) + { + int classLabel = 0; + m_DecisionForest->ext_param_.to_classlabel(i, classLabel); + + if (classLabel d = m_ODF.GetDirection(m_DirectionIndices.at(classLabel)); + double dot = dot_product(d, olddir); + + if (olddir.magnitude()>0) + { + if (fabs(dot)>angularThreshold) + { + if (dot<0) + d *= -1; + dot = fabs(dot); + direction += probs(0,i)*dot*d; + prob += probs(0,i)*dot; + } + } + else + { + direction += probs(0,i)*d; + prob += probs(0,i); + } + } + else + outProb += probs(0,i); + } + } + + ItkDoubleImgType::IndexType idx; + m_NotWmImage->TransformPhysicalPointToIndex(pos, idx); + if (m_NotWmImage->GetLargestPossibleRegion().IsInside(idx)) + { + m_NotWmImage->SetPixel(idx, m_NotWmImage->GetPixel(idx)+outProb); + m_WmImage->SetPixel(idx, m_WmImage->GetPixel(idx)+prob); + } + if (outProb>prob && prob>0) + { + candidates = 0; + prob = 0; + direction.fill(0.0); + } + if (avoidStop && m_AvoidStopImage->GetLargestPossibleRegion().IsInside(idx) && candidates>0 && direction.magnitude()>0.001) + m_AvoidStopImage->SetPixel(idx, m_AvoidStopImage->GetPixel(idx)+0.1); + + return direction; +} + + +template< int NumImageFeatures > +double MLBSTrackingFilter< NumImageFeatures >::GetRandDouble(double min, double max) +{ + return (double)(rand()%((int)(10000*(max-min))) + 10000*min)/10000; +} + +template< int NumImageFeatures > +vnl_vector_fixed MLBSTrackingFilter< NumImageFeatures >::GetNewDirection(itk::Point &pos, vnl_vector_fixed& olddir) +{ + vnl_vector_fixed direction; direction.fill(0); + + ItkUcharImgType::IndexType idx; + m_StoppingRegions->TransformPhysicalPointToIndex(pos, idx); + if (m_StoppingRegions->GetPixel(idx)>0) + return direction; + + if (olddir.magnitude()>0) + olddir.normalize(); + + int candidates = 0; // number of directions with probability > 0 + double prob = 0; + direction = Classify(pos, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood + direction *= prob; + + for (int i=0; i probe; + probe[0] = GetRandDouble()*m_SamplingDistance; + probe[1] = GetRandDouble()*m_SamplingDistance; + probe[2] = GetRandDouble()*m_SamplingDistance; + + itk::Point temp; + temp[0] = pos[0] + probe[0]; + temp[1] = pos[1] + probe[1]; + temp[2] = pos[2] + probe[2]; + + candidates = 0; + vnl_vector_fixed tempDir = Classify(temp, candidates, olddir, m_AngularThreshold, prob); // sample neighborhood + if (candidates>0 && tempDir.magnitude()>0.001) + { + direction += tempDir*prob; + } + else if (candidates==0) // out of white matter + { + vnl_vector_fixed normProbe = -probe; normProbe.normalize(); + double dot = dot_product(normProbe, olddir); + if (dot < 0.0) + { + probe = (normProbe - 2 * dot*olddir)*probe.magnitude(); // reflect + } + else + { + probe = -probe; // invert + } + + // look a bit further into the other direction + temp[0] = pos[0] + probe[0]*2; + temp[1] = pos[1] + probe[1]*2; + temp[2] = pos[2] + probe[2]*2; + candidates = 0; + vnl_vector_fixed tempDir = Classify(temp, candidates, olddir, m_AngularThreshold, prob, true); // sample neighborhood + + if (candidates>0 && tempDir.magnitude()>0.001) // are we back in the white matter? + { + direction += probe; // go into the direction of the white matter + direction += tempDir*prob; // go into the direction of the white matter direction at this location + } + } + } + + if (direction.magnitude()>0.001) + { + direction.normalize(); + olddir[0] = direction[0]; + olddir[1] = direction[1]; + olddir[2] = direction[2]; + } + else + direction.fill(0); + + return direction; +} + +template< int NumImageFeatures > +double MLBSTrackingFilter< NumImageFeatures >::FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front) +{ + vnl_vector_fixed dirOld = dir; + dirOld = dir; + + for (int step=0; step< m_MaxLength/2; step++) + { + while (m_PauseTracking){} + if (m_DemoMode) + { + m_Mutex.Lock(); + m_BuildFibersReady++; + m_Tractogram.push_back(*fib); + BuildFibers(true); + m_Stop = true; + m_Mutex.Unlock(); + while (m_Stop){} + } + + // get new position + CalculateNewPosition(pos, dir); + + // is new position inside of image and mask + if (!IsValidPosition(pos) || m_AbortTracking) // if not end streamline + { + return tractLength; + } + else // if yes, add new point to streamline + { + tractLength += m_StepSize; + if (front) + fib->push_front(pos); + else + fib->push_back(pos); + + int curv = CheckCurvature(fib, front); // TODO: Move into classification + if (curv>0) + { + MITK_INFO << "CURVTHRES!!!"; + tractLength -= m_StepSize*curv; + while (curv>0) + { + if (front) + fib->pop_front(); + else + fib->pop_back(); + curv--; + } + return tractLength; + } + + if (tractLength>m_MaxTractLength) + return tractLength; + } + + dir = GetNewDirection(pos, dirOld); + + if (dir.magnitude()<0.0001) + return tractLength; + } + return tractLength; +} + +template< int NumImageFeatures > +int MLBSTrackingFilter::CheckCurvature(FiberType* fib, bool front) +{ + double m_Distance = 5; + if (fib->size()<3) + return 0; + + double dist = 0; + std::vector< vnl_vector_fixed< float, 3 > > vectors; + vnl_vector_fixed< float, 3 > meanV; meanV.fill(0); + double dev = 0; + + if (front) + { + int c=0; + while(distsize()-1) + { + itk::Point p1 = fib->at(c); + itk::Point p2 = fib->at(c+1); + + vnl_vector_fixed< float, 3 > v; + v[0] = p2[0]-p1[0]; + v[1] = p2[1]-p1[1]; + v[2] = p2[2]-p1[2]; + dist += v.magnitude(); + v.normalize(); + vectors.push_back(v); + if (c==0) + meanV += v; + c++; + } + } + else + { + int c=fib->size()-1; + while(dist0) + { + itk::Point p1 = fib->at(c); + itk::Point p2 = fib->at(c-1); + + vnl_vector_fixed< float, 3 > v; + v[0] = p2[0]-p1[0]; + v[1] = p2[1]-p1[1]; + v[2] = p2[2]-p1[2]; + dist += v.magnitude(); + v.normalize(); + vectors.push_back(v); + if (c==fib->size()-1) + meanV += v; + c--; + } + } + meanV.normalize(); + + for (int c=0; c1.0) + angle = 1.0; + if (angle<-1.0) + angle = -1.0; + dev += acos(angle)*180/M_PI; + } + if (vectors.size()>0) + dev /= vectors.size(); + + if (dev<30) + return 0; + else + return vectors.size(); +} + +template< int NumImageFeatures > +void MLBSTrackingFilter< NumImageFeatures >::ThreadedGenerateData(const InputImageRegionType ®ionForThread, ThreadIdType threadId) +{ + m_Mutex.Lock(); + m_Threads++; + m_Mutex.Unlock(); + typedef ImageRegionConstIterator< ItkUcharImgType > MaskIteratorType; + MaskIteratorType sit(m_SeedImage, regionForThread ); + MaskIteratorType mit(m_MaskImage, regionForThread ); + + sit.GoToBegin(); + mit.GoToBegin(); + itk::Point worldPos; + while( !sit.IsAtEnd() ) + { + if (sit.Value()==0 || mit.Value()==0) + { + ++sit; + ++mit; + continue; + } + + for (int s=0; s start; + unsigned int counter = 0; + + if (m_SeedsPerVoxel>1) + { + start[0] = index[0]+GetRandDouble(-0.5, 0.5); + start[1] = index[1]+GetRandDouble(-0.5, 0.5); + start[2] = index[2]+GetRandDouble(-0.5, 0.5); + } + else + { + start[0] = index[0]; + start[1] = index[1]; + start[2] = index[2]; + } + + // get staring position + m_SeedImage->TransformContinuousIndexToPhysicalPoint( start, worldPos ); + + // get starting direction + int candidates = 0; + double prob = 0; + vnl_vector_fixed dirOld; dirOld.fill(0.0); + vnl_vector_fixed dir = Classify(worldPos, candidates, dirOld, 0, prob); + if (dir.magnitude()<0.0001) + continue; + + // forward tracking + tractLength = FollowStreamline(threadId, worldPos, dir, &fib, 0, false); + fib.push_front(worldPos); + + // backward tracking + tractLength = FollowStreamline(threadId, worldPos, -dir, &fib, tractLength, true); + counter = fib.size(); + + if (tractLength +void MLBSTrackingFilter< NumImageFeatures >::BuildFibers(bool check) +{ + if (m_BuildFibersReady::New(); + vtkSmartPointer vNewLines = vtkSmartPointer::New(); + vtkSmartPointer vNewPoints = vtkSmartPointer::New(); + + for (int i=0; i container = vtkSmartPointer::New(); + FiberType fib = m_Tractogram.at(i); + for (FiberType::iterator it = fib.begin(); it!=fib.end(); it++) + { + vtkIdType id = vNewPoints->InsertNextPoint((*it).GetDataPointer()); + container->GetPointIds()->InsertNextId(id); + } + vNewLines->InsertNextCell(container); + } + if (check) + for (int i=0; iSetPoints(vNewPoints); + m_FiberPolyData->SetLines(vNewLines); + m_BuildFibersFinished = true; +} + +template< int NumImageFeatures > +void MLBSTrackingFilter< NumImageFeatures >::AfterThreadedGenerateData() +{ + MITK_INFO << "Generating polydata "; + BuildFibers(false); + MITK_INFO << "done"; +} + +} + +#endif // __itkDiffusionQballPrincipleDirectionsImageFilter_txx diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h new file mode 100644 index 0000000000..d3fc9fecf5 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/itkMLBSTrackingFilter.h @@ -0,0 +1,184 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +/*=================================================================== + +This file is based heavily on a corresponding ITK filter. + +===================================================================*/ +#ifndef __itkMLBSTrackingFilter_h_ +#define __itkMLBSTrackingFilter_h_ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// classification includes +#include +#include +#include + +namespace itk{ + +/** +* \brief Performes deterministic streamline tracking on the input tensor image. */ + +template< int NumImageFeatures=100 > +class MLBSTrackingFilter : public ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > +{ + +public: + + typedef MLBSTrackingFilter Self; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + typedef ImageToImageFilter< VectorImage< short, 3 >, Image< double, 3 > > Superclass; + + typedef vigra::RandomForest DecisionForestType; + typedef typename Superclass::InputImageType InputImageType; + typedef typename Superclass::InputImageRegionType InputImageRegionType; + typedef Image< Vector< float, NumImageFeatures > , 3 > FeatureImageType; + + /** Method for creation through the object factory. */ + itkFactorylessNewMacro(Self) + itkCloneMacro(Self) + + /** Runtime information support. */ + itkTypeMacro(MLBSTrackingFilter, ImageToImageFilter) + + typedef itk::Image ItkUcharImgType; + typedef itk::Image ItkDoubleImgType; + typedef itk::Image ItkFloatImgType; + typedef vtkSmartPointer< vtkPolyData > PolyDataType; + + typedef std::deque< itk::Point > FiberType; + typedef std::vector< FiberType > BundleType; + + bool m_PauseTracking; + bool m_AbortTracking; + bool m_BuildFibersFinished; + int m_BuildFibersReady; + bool m_Stop; +// void RequestFibers(){ m_Stop=true; m_BuildFibersReady=0; m_BuildFibersFinished=false; } + + itkGetMacro( FiberPolyData, PolyDataType ) ///< Output fibers + itkSetMacro( SeedImage, ItkUcharImgType::Pointer) ///< Seeds are only placed inside of this mask. + itkSetMacro( MaskImage, ItkUcharImgType::Pointer) ///< Tracking is only performed inside of this mask image. + itkSetMacro( SeedsPerVoxel, int) ///< One seed placed in the center of each voxel or multiple seeds randomly placed inside each voxel. + itkSetMacro( StepSize, double) ///< Integration step size in mm + itkSetMacro( MinTractLength, double ) ///< Shorter tracts are discarded. + itkSetMacro( MaxTractLength, double ) + itkSetMacro( AngularThreshold, double ) + itkSetMacro( UseDirection, bool ) + itkSetMacro( SamplingDistance, double ) + itkSetMacro( NumberOfSamples, int ) + itkSetMacro( StoppingRegions, ItkUcharImgType::Pointer) + itkSetMacro( B_Value, float ) + itkSetMacro( GradientDirections, mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer ) + itkSetMacro( DemoMode, bool ) + + void SetDecisionForest( DecisionForestType* forest ) + { + m_DecisionForest = forest; + } + + itkGetMacro( WmImage, ItkDoubleImgType::Pointer ) + itkGetMacro( NotWmImage, ItkDoubleImgType::Pointer ) + itkGetMacro( AvoidStopImage, ItkDoubleImgType::Pointer ) + + protected: + MLBSTrackingFilter(); + ~MLBSTrackingFilter() {} + + void CalculateNewPosition(itk::Point& pos, vnl_vector_fixed& dir); ///< Calculate next integration step. + double FollowStreamline(ThreadIdType threadId, itk::Point pos, vnl_vector_fixed dir, FiberType* fib, double tractLength, bool front); ///< Start streamline in one direction. + bool IsValidPosition(itk::Point& pos); ///< Are we outside of the mask image? + vnl_vector_fixed GetNewDirection(itk::Point& pos, vnl_vector_fixed& olddir); + vnl_vector_fixed Classify(itk::Point& pos, int& candidates, vnl_vector_fixed& olddir, double angularThreshold, double& prob, bool avoidStop=false); + + typename FeatureImageType::PixelType GetImageValues(itk::Point itkP); + double GetRandDouble(double min=-1, double max=1); + double RoundToNearest(double num); + + void BeforeThreadedGenerateData(); + void PreprocessRawData(); + void ThreadedGenerateData( const InputImageRegionType &outputRegionForThread, ThreadIdType threadId); + void AfterThreadedGenerateData(); + + PolyDataType m_FiberPolyData; + vtkSmartPointer m_Points; + vtkSmartPointer m_Cells; + BundleType m_Tractogram; + + double m_AngularThreshold; + double m_StepSize; + int m_MaxLength; + double m_MinTractLength; + double m_MaxTractLength; + int m_SeedsPerVoxel; + bool m_UseDirection; + double m_SamplingDistance; + int m_NumberOfSamples; + std::vector< int > m_ImageSize; + std::vector< double > m_ImageSpacing; + + SimpleFastMutexLock m_Mutex; + ItkUcharImgType::Pointer m_StoppingRegions; + ItkDoubleImgType::Pointer m_WmImage; + ItkDoubleImgType::Pointer m_NotWmImage; + ItkDoubleImgType::Pointer m_AvoidStopImage; + ItkUcharImgType::Pointer m_SeedImage; + ItkUcharImgType::Pointer m_MaskImage; + typename FeatureImageType::Pointer m_FeatureImage; + typename InputImageType::Pointer m_InputImage; + mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer m_GradientDirections; + float m_B_Value; + + int m_Threads; + bool m_DemoMode; + void BuildFibers(bool check); + int CheckCurvature(FiberType* fib, bool front); + + // decision forest + DecisionForestType* m_DecisionForest; + itk::OrientationDistributionFunction< double, NumImageFeatures*2 > m_ODF; + std::vector< int > m_DirectionIndices; + + std::vector< PolyDataType > m_PolyDataContainer; + +private: + +}; + +} + +#ifndef ITK_MANUAL_INSTANTIATION +#include "itkMLBSTrackingFilter.cpp" +#endif + +#endif //__itkMLBSTrackingFilter_h_ + diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp new file mode 100644 index 0000000000..aa50256bdb --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.cpp @@ -0,0 +1,540 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#ifndef _TrackingForestHandler_cpp +#define _TrackingForestHandler_cpp + +#include "mitkTrackingForestHandler.h" +#include +#include + +namespace mitk +{ + +template< int NumberOfSignalFeatures > +TrackingForestHandler< NumberOfSignalFeatures >::TrackingForestHandler() + : m_GrayMatterSamplesPerVoxel(50) + , m_StepSize(-1) + , m_UsePreviousDirection(true) + , m_NumTrees(30) + , m_MaxTreeDepth(50) + , m_SampleFraction(1.0) +{ + +} + +template< int NumberOfSignalFeatures > +TrackingForestHandler< NumberOfSignalFeatures >::~TrackingForestHandler() +{ + +} + +template< int NumberOfSignalFeatures > +typename TrackingForestHandler< NumberOfSignalFeatures >::InterpolatedRawImageType::PixelType TrackingForestHandler< NumberOfSignalFeatures >::GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image) +{ + itk::Index<3> idx; + itk::ContinuousIndex< double, 3> cIdx; + image->TransformPhysicalPointToIndex(itkP, idx); + image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); + + typename InterpolatedRawImageType::PixelType pix; pix.Fill(0.0); + if ( image->GetLargestPossibleRegion().IsInside(idx) ) + pix = image->GetPixel(idx); + else + return pix; + + double frac_x = cIdx[0] - idx[0]; + double frac_y = cIdx[1] - idx[1]; + double frac_z = cIdx[2] - idx[2]; + if (frac_x<0) + { + idx[0] -= 1; + frac_x += 1; + } + if (frac_y<0) + { + idx[1] -= 1; + frac_y += 1; + } + if (frac_z<0) + { + idx[2] -= 1; + frac_z += 1; + } + frac_x = 1-frac_x; + frac_y = 1-frac_y; + frac_z = 1-frac_z; + + // int coordinates inside image? + if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && + idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && + idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) + { + vnl_vector_fixed interpWeights; + interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); + interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); + interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); + interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); + interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); + interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); + interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); + interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); + + + pix = image->GetPixel(idx) * interpWeights[0]; + typename InterpolatedRawImageType::IndexType tmpIdx = idx; tmpIdx[0]++; + pix += image->GetPixel(tmpIdx) * interpWeights[1]; + tmpIdx = idx; tmpIdx[1]++; + pix += image->GetPixel(tmpIdx) * interpWeights[2]; + tmpIdx = idx; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[3]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; + pix += image->GetPixel(tmpIdx) * interpWeights[4]; + tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[5]; + tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; + pix += image->GetPixel(tmpIdx) * interpWeights[6]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[7]; + } + + return pix; +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::InputDataValidForTracking() +{ + if (m_RawData.empty()) + mitkThrow() << "No diffusion-weighted images set!"; +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::StartTraining() +{ + InputDataValidForTraining(); + PreprocessInputData(); + CalculateFeatures(); + TrainForest(); +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::InputDataValidForTraining() +{ + if (m_RawData.empty()) + mitkThrow() << "No diffusion-weighted images set!"; + if (m_Tractograms.empty()) + mitkThrow() << "No tractograms set!"; + if (m_RawData.size()!=m_Tractograms.size()) + mitkThrow() << "Unequal number of diffusion-weighted images and tractograms detected!"; +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::PreprocessInputData() +{ + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter InterpolationFilterType; + + MITK_INFO << "Spherical signal interpolation and sampling ..."; + for (unsigned int i=0; iSetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(m_RawData.at(i)), mitk::DiffusionPropertyHelper::GetItkVectorImage(m_RawData.at(i)) ); + qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(m_RawData.at(i))); + qballfilter->SetLambda(0.006); + qballfilter->SetNormalizationMethod(InterpolationFilterType::QBAR_RAW_SIGNAL); + qballfilter->Update(); + // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); + // featureImageVector.push_back(itkFeatureImage); + m_InterpolatedRawImages.push_back(qballfilter->GetOutput()); + + if (i>=m_MaskImages.size()) + { + ItkUcharImgType::Pointer newMask = ItkUcharImgType::New(); + newMask->SetSpacing( m_InterpolatedRawImages.at(i)->GetSpacing() ); + newMask->SetOrigin( m_InterpolatedRawImages.at(i)->GetOrigin() ); + newMask->SetDirection( m_InterpolatedRawImages.at(i)->GetDirection() ); + newMask->SetLargestPossibleRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + newMask->SetBufferedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + newMask->SetRequestedRegion( m_InterpolatedRawImages.at(i)->GetLargestPossibleRegion() ); + newMask->Allocate(); + newMask->FillBuffer(1); + m_MaskImages.push_back(newMask); + } + } + + MITK_INFO << "Resampling fibers and calculating number of samples ..."; + m_NumberOfSamples = 0; + for (unsigned int t=0; t::Pointer env = itk::TractDensityImageFilter< ItkUcharImgType >::New(); + env->SetFiberBundle(m_Tractograms.at(t)); + env->SetInputImage(mask); + env->SetBinaryOutput(true); + env->SetUseImageGeometry(true); + env->Update(); + wmmask = env->GetOutput(); + m_WhiteMatterImages.push_back(wmmask); + } + + itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); + int OUTOFWM = 0; + while(!it.IsAtEnd()) + { + if (it.Get()==0 && mask->GetPixel(it.GetIndex())>0) + OUTOFWM++; + ++it; + } + m_NumberOfSamples += m_GrayMatterSamplesPerVoxel*OUTOFWM; + MITK_INFO << "Samples outside of WM: " << m_NumberOfSamples; + + if (m_StepSize<0) + { + typename InterpolatedRawImageType::Pointer image = m_InterpolatedRawImages.at(t); + float minSpacing = 1; + if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) + minSpacing = image->GetSpacing()[0]; + else if (image->GetSpacing()[1] < image->GetSpacing()[2]) + minSpacing = image->GetSpacing()[1]; + else + minSpacing = image->GetSpacing()[2]; + m_StepSize = minSpacing*0.5; + } + + m_Tractograms.at(t)->ResampleSpline(m_StepSize); + m_NumberOfSamples += m_Tractograms.at(t)->GetNumberOfPoints(); + m_NumberOfSamples -= 2*m_Tractograms.at(t)->GetNumFibers(); + } + MITK_INFO << "Number of samples: " << m_NumberOfSamples; +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::CalculateFeatures() +{ + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + itk::OrientationDistributionFunction< double, 2*NumberOfSignalFeatures > directions; + std::vector< int > directionIndices; + for (unsigned int f=0; f<2*NumberOfSignalFeatures; f++) + { + if (dot_product(ref, directions.GetDirection(f))>0) + directionIndices.push_back(f); + } + + int numDirectionFeatures = 0; + if (m_UsePreviousDirection) + numDirectionFeatures = 3; + + m_FeatureData.reshape( vigra::Shape2(m_NumberOfSamples, NumberOfSignalFeatures+numDirectionFeatures) ); + m_LabelData.reshape( vigra::Shape2(m_NumberOfSamples,1) ); + MITK_INFO << "Number of features: " << m_FeatureData.shape(1); + + itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); + m_RandGen->SetSeed(); + MITK_INFO << "Creating training data ..."; + int sampleCounter = 0; + for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); + while(!it.IsAtEnd()) + { + if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) + { + typename InterpolatedRawImageType::PixelType pix = image->GetPixel(it.GetIndex()); + if (m_UsePreviousDirection) + { + // null direction + for (unsigned int f=0; f probe; + probe[0] = m_RandGen->GetVariate()*2-1; + probe[1] = m_RandGen->GetVariate()*2-1; + probe[2] = m_RandGen->GetVariate()*2-1; + probe.normalize(); + if (dot_product(ref, probe)<0) + probe *= -1; + for (unsigned int f=NumberOfSignalFeatures; f idx; + idx[0] = it.GetIndex()[0]; + idx[1] = it.GetIndex()[1]; + idx[2] = it.GetIndex()[2]; + itk::Point itkP1; + image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); + typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image);; + for (unsigned int f=0; f idx; + idx[0] = it.GetIndex()[0] + m_RandGen->GetVariate()-0.5; + idx[1] = it.GetIndex()[1] + m_RandGen->GetVariate()-0.5; + idx[2] = it.GetIndex()[2] + m_RandGen->GetVariate()-0.5; + itk::Point itkP1; + image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); + typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image);; + for (unsigned int f=0; f polyData = fib->GetFiberPolyData(); + for (int i=0; iGetNumFibers(); i++) + { + vtkCell* cell = polyData->GetCell(i); + int numPoints = cell->GetNumberOfPoints(); + vtkPoints* points = cell->GetPoints(); + + vnl_vector_fixed dirOld; dirOld.fill(0.0); + + for (int j=0; jGetPoint(j); + itk::Point itkP1; + itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; + + vnl_vector_fixed dir; dir.fill(0.0); + + itk::Point itkP2; + double* p2 = points->GetPoint(j+1); + itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; + dir[0]=itkP2[0]-itkP1[0]; + dir[1]=itkP2[1]-itkP1[1]; + dir[2]=itkP2[2]-itkP1[2]; + + if (dir.magnitude()<0.0001) + { + MITK_INFO << "streamline error!"; + continue; + } + dir.normalize(); + if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) + { + MITK_INFO << "ERROR: NaN direction!"; + continue; + } + + if (j==0) + { + dirOld = dir; + continue; + } + + // get voxel values + typename InterpolatedRawImageType::PixelType pix = GetImageValues(itkP1, image); + for (unsigned int f=0; f0.0001) + { + int label = 0; + for (unsigned int f=0; fangle) + { + m_LabelData(sampleCounter,0) = f; + angle = a; + label = f; + } + } + } + + dirOld = dir; + sampleCounter++; + } + } + } +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::TrainForest() +{ + MITK_INFO << "Maximum tree depths: " << m_MaxTreeDepth; + MITK_INFO << "Sample fraction per tree: " << m_SampleFraction; + MITK_INFO << "Number of trees: " << m_NumTrees; + bool random_split = false; + vigra::rf::visitors::OOB_Error oob_v; + MITK_INFO << "Create Split Function"; + // typedef ThresholdSplit, vigra::ClassificationTag> DefaultSplitType; + + m_Forest.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal + m_Forest.set_options().sample_with_replacement(true); // if sampled with replacement or not + m_Forest.set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree + m_Forest.set_options().tree_count(1); // Number of trees that are calculated; + m_Forest.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + // rf.set_options().features_per_node(10); + + m_Forest.learn(m_FeatureData, m_LabelData, vigra::rf::visitors::create_visitor(oob_v)); + + // Prepare parallel VariableImportance Calculation + int numMod = m_FeatureData.shape(1); + const int numClass = 2 + 2; + + float** varImp = new float*[numMod]; + + for(int i = 0; i < numMod; i++) + varImp[i] = new float[numClass]; + + for (int i = 0; i < numMod; ++i) + for (int j = 0; j < numClass; ++j) + varImp[i][j] = 0.0; + +#pragma omp parallel for + for (int i = 0; i < m_NumTrees - 1; ++i) + { + vigra::RandomForest lrf; + vigra::rf::visitors::OOB_Error loob_v; + + lrf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal + lrf.set_options().sample_with_replacement(true); // if sampled with replacement or not + lrf.set_options().samples_per_tree(m_SampleFraction); // Fraction of samples that are used to train a tree + lrf.set_options().tree_count(1); // Number of trees that are calculated; + lrf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + // lrf.set_options().features_per_node(10); + + vigra::rf::visitors::VariableImportanceVisitor lvariableImportance; + lrf.learn(m_FeatureData, m_LabelData, vigra::rf::visitors::create_visitor(loob_v)); + +#pragma omp critical + { + m_Forest.trees_.push_back(lrf.trees_[0]); + } + } + + m_Forest.options_.tree_count_ = m_NumTrees; + MITK_INFO << "Training finsihed"; + MITK_INFO << "The out-of-bag error is: " << oob_v.oob_breiman << std::endl; +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::SaveForest(std::string forestFile) +{ + MITK_INFO << "Saving forest to " << forestFile; + vigra::rf_export_HDF5( m_Forest, forestFile, "" ); +} + +template< int NumberOfSignalFeatures > +void TrackingForestHandler< NumberOfSignalFeatures >::LoadForest(std::string forestFile) +{ + MITK_INFO << "Loading forest from " << forestFile; + vigra::rf_import_HDF5(m_Forest, forestFile); +} + +//// superclass implementations +//template< int NumberOfSignalFeatures > +//void TrackingForestHandler< NumberOfSignalFeatures >::UpdateOutputInformation() +//{ + +//} +//template< int NumberOfSignalFeatures > +//void TrackingForestHandler< NumberOfSignalFeatures >::SetRequestedRegionToLargestPossibleRegion() +//{ + +//} +//template< int NumberOfSignalFeatures > +//bool TrackingForestHandler< NumberOfSignalFeatures >::RequestedRegionIsOutsideOfTheBufferedRegion() +//{ +// return false; +//} +//template< int NumberOfSignalFeatures > +//bool TrackingForestHandler< NumberOfSignalFeatures >::VerifyRequestedRegion() +//{ +// return true; +//} +//template< int NumberOfSignalFeatures > +//void TrackingForestHandler< NumberOfSignalFeatures >::SetRequestedRegion(const itk::DataObject* ) +//{ + +//} + +} + +#endif diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h new file mode 100644 index 0000000000..4aecb1ac41 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/MLTracking/mitkTrackingForestHandler.h @@ -0,0 +1,115 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#ifndef _TrackingForestHandler +#define _TrackingForestHandler + +#include "mitkBaseData.h" + +#include +#include +#include +#include + +// classification includes +//#include "RegressionForestClasses.hxx" +#undef DIFFERENCE +#define VIGRA_STATIC_LIB +#include +#include +#include + +#define _USE_MATH_DEFINES +#include + +namespace mitk +{ + +/** +* \brief */ + +template< int NumberOfSignalFeatures=100 > +class TrackingForestHandler +{ + +public: + + TrackingForestHandler(); + ~TrackingForestHandler(); + + typedef itk::Image ItkUcharImgType; + typedef itk::Image< itk::Vector< float, NumberOfSignalFeatures*2 > , 3 > InterpolatedRawImageType; + + void SetRawData( std::vector< Image::Pointer > images ){ m_RawData = images; } + void SetTractograms( std::vector< FiberBundle::Pointer > tractograms ) + { + m_Tractograms.clear(); + for (int i=0; iGetDeepCopy()); + } + } + void SetMaskImages( std::vector< ItkUcharImgType::Pointer > images ){ m_MaskImages = images; } + void SetWhiteMatterImages( std::vector< ItkUcharImgType::Pointer > images ){ m_WhiteMatterImages = images; } + + void StartTraining(); + void SaveForest(std::string forestFile); + void LoadForest(std::string forestFile); + + void SetNumTrees(int num){ m_NumTrees = num; } + void SetMaxTreeDepth(int depth){ m_MaxTreeDepth = depth; } + void SetUsePreviousDirection(bool use){ m_UsePreviousDirection = use; } + void SetStepSize(double step){ m_StepSize = step; } + void SetGrayMatterSamplesPerVoxel(int samples){ m_GrayMatterSamplesPerVoxel = samples; } + void SetSampleFraction(double fraction){ m_SampleFraction = fraction; } + vigra::RandomForest GetForest(){ return m_Forest; } + +protected: + + void InputDataValidForTracking(); + void InputDataValidForTraining(); + void PreprocessInputData(); + void CalculateFeatures(); + void TrainForest(); + + int m_GrayMatterSamplesPerVoxel; + double m_StepSize; + bool m_UsePreviousDirection; + int m_NumTrees; + int m_MaxTreeDepth; + double m_SampleFraction; + + std::vector< Image::Pointer > m_RawData; + std::vector< FiberBundle::Pointer > m_Tractograms; + std::vector< ItkUcharImgType::Pointer > m_MaskImages; + std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; + std::vector< ItkUcharImgType::Pointer > m_SeedImages; + std::vector< ItkUcharImgType::Pointer > m_StopImages; + + int m_NumberOfSamples; + vigra::RandomForest m_Forest; + vigra::MultiArray<2, double> m_FeatureData; + vigra::MultiArray<2, double> m_LabelData; + std::vector< typename InterpolatedRawImageType::Pointer > m_InterpolatedRawImages; + + typename InterpolatedRawImageType::PixelType GetImageValues(itk::Point itkP, typename InterpolatedRawImageType::Pointer image); +}; + +} + +#include "mitkTrackingForestHandler.cpp" + +#endif diff --git a/Modules/DiffusionImaging/FiberTracking/CMakeLists.txt b/Modules/DiffusionImaging/FiberTracking/CMakeLists.txt index 78a0983713..53a22b6883 100644 --- a/Modules/DiffusionImaging/FiberTracking/CMakeLists.txt +++ b/Modules/DiffusionImaging/FiberTracking/CMakeLists.txt @@ -1,65 +1,65 @@ set(_module_deps MitkDiffusionCore MitkGraphAlgorithms) mitk_check_module_dependencies( MODULES ${_module_deps} MISSING_DEPENDENCIES_VAR _missing_deps ) # Enable OpenMP support find_package(OpenMP) if(NOT OPENMP_FOUND) message("OpenMP is not available.") endif() if(OPENMP_FOUND) message(STATUS "Found OpenMP.") set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} ${OpenMP_C_FLAGS}") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${OpenMP_CXX_FLAGS}") endif() if(NOT _missing_deps) set(lut_url http://mitk.org/download/data/FibertrackingLUT.tar.gz) set(lut_tarball ${CMAKE_CURRENT_BINARY_DIR}/FibertrackingLUT.tar.gz) message("Downloading FiberTracking LUT ${lut_url}...") file(DOWNLOAD ${lut_url} ${lut_tarball} EXPECTED_MD5 38ecb6d4a826c9ebb0f4965eb9aeee44 TIMEOUT 60 STATUS status SHOW_PROGRESS ) list(GET status 0 status_code) list(GET status 1 status_msg) if(NOT status_code EQUAL 0) message(SEND_ERROR "${status_msg} (error code ${status_code})") else() message("done.") endif() file(MAKE_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/Resources) message("Unpacking FiberTracking LUT tarball...") execute_process(COMMAND ${CMAKE_COMMAND} -E tar xzf ../FibertrackingLUT.tar.gz WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR}/Resources RESULT_VARIABLE result ERROR_VARIABLE err_msg) if(result) message(SEND_ERROR "Unpacking FibertrackingLUT.tar.gz failed: ${err_msg}") else() message("done.") endif() endif() MITK_CREATE_MODULE( SUBPROJECTS MITK-DTI - INCLUDE_DIRS Algorithms Algorithms/GibbsTracking Algorithms/StochasticTracking IODataStructures IODataStructures/FiberBundle IODataStructures/PlanarFigureComposite Interactions SignalModels Rendering ${CMAKE_CURRENT_BINARY_DIR} + INCLUDE_DIRS Algorithms Algorithms/MLTracking Algorithms/GibbsTracking Algorithms/StochasticTracking IODataStructures IODataStructures/FiberBundle IODataStructures/PlanarFigureComposite Interactions SignalModels Rendering ${CMAKE_CURRENT_BINARY_DIR} DEPENDS ${_module_deps} - PACKAGE_DEPENDS PUBLIC ITK|ITKFFT ITK|ITKDiffusionTensorImage + PACKAGE_DEPENDS PUBLIC ITK|ITKFFT ITK|ITKDiffusionTensorImage Vigra HDF5 #WARNINGS_AS_ERRORS ) if(MODULE_IS_ENABLED) add_subdirectory(Testing) endif() diff --git a/Modules/DiffusionImaging/FiberTracking/files.cmake b/Modules/DiffusionImaging/FiberTracking/files.cmake index adc2368cf2..7a0b4dd4c8 100644 --- a/Modules/DiffusionImaging/FiberTracking/files.cmake +++ b/Modules/DiffusionImaging/FiberTracking/files.cmake @@ -1,78 +1,80 @@ set(CPP_FILES mitkFiberTrackingModuleActivator.cpp ## IO datastructures IODataStructures/FiberBundle/mitkFiberBundle.cpp IODataStructures/FiberBundle/mitkTrackvis.cpp IODataStructures/PlanarFigureComposite/mitkPlanarFigureComposite.cpp # Interactions Interactions/mitkFiberBundleInteractor.cpp # Tractography Algorithms/GibbsTracking/mitkParticleGrid.cpp Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.cpp Algorithms/GibbsTracking/mitkEnergyComputer.cpp Algorithms/GibbsTracking/mitkGibbsEnergyComputer.cpp Algorithms/GibbsTracking/mitkFiberBuilder.cpp Algorithms/GibbsTracking/mitkSphereInterpolator.cpp ) set(H_FILES # DataStructures -> FiberBundle IODataStructures/FiberBundle/mitkFiberBundle.h IODataStructures/FiberBundle/mitkTrackvis.h IODataStructures/mitkFiberfoxParameters.h # Algorithms Algorithms/itkTractDensityImageFilter.h Algorithms/itkTractsToFiberEndingsImageFilter.h Algorithms/itkTractsToRgbaImageFilter.h # moved to DiffusionCore #Algorithms/itkElectrostaticRepulsionDiffusionGradientReductionFilter.h Algorithms/itkFibersFromPlanarFiguresFilter.h Algorithms/itkTractsToDWIImageFilter.h Algorithms/itkTractsToVectorImageFilter.h Algorithms/itkKspaceImageFilter.h Algorithms/itkDftImageFilter.h Algorithms/itkAddArtifactsToDwiImageFilter.h Algorithms/itkFieldmapGeneratorFilter.h Algorithms/itkEvaluateDirectionImagesFilter.h Algorithms/itkEvaluateTractogramDirectionsFilter.h Algorithms/itkFiberCurvatureFilter.h - # (old) Tractography + # Tractography Algorithms/itkGibbsTrackingFilter.h Algorithms/itkStochasticTractographyFilter.h Algorithms/itkStreamlineTrackingFilter.h Algorithms/GibbsTracking/mitkParticle.h Algorithms/GibbsTracking/mitkParticleGrid.h Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.h Algorithms/GibbsTracking/mitkSimpSamp.h Algorithms/GibbsTracking/mitkEnergyComputer.h Algorithms/GibbsTracking/mitkGibbsEnergyComputer.h Algorithms/GibbsTracking/mitkSphereInterpolator.h Algorithms/GibbsTracking/mitkFiberBuilder.h + Algorithms/MLTracking/mitkTrackingForestHandler.h + Algorithms/MLTracking/itkMLBSTrackingFilter.h # Signal Models SignalModels/mitkDiffusionSignalModel.h SignalModels/mitkTensorModel.h SignalModels/mitkBallModel.h SignalModels/mitkDotModel.h SignalModels/mitkAstroStickModel.h SignalModels/mitkStickModel.h SignalModels/mitkRawShModel.h SignalModels/mitkDiffusionNoiseModel.h SignalModels/mitkRicianNoiseModel.h SignalModels/mitkChiSquareNoiseModel.h ) set(RESOURCE_FILES # Binary directory resources FiberTrackingLUTBaryCoords.bin FiberTrackingLUTIndices.bin # Shaders Shaders/mitkShaderFiberClipping.xml ) diff --git a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt index 3325706db6..0c801dc612 100755 --- a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt +++ b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt @@ -1,115 +1,117 @@ option(BUILD_DiffusionMiniApps "Build commandline tools for diffusion" OFF) if(BUILD_DiffusionMiniApps OR MITK_BUILD_ALL_APPS) # needed include directories include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ) # list of diffusion miniapps # if an app requires additional dependencies # they are added after a "^^" and separated by "_" set( diffusionminiapps DwiDenoising^^ ImageResampler^^ NetworkCreation^^MitkFiberTracking_MitkConnectomics NetworkStatistics^^MitkConnectomics ExportShImage^^ Fiberfox^^MitkFiberTracking MultishellMethods^^MitkFiberTracking PeaksAngularError^^MitkFiberTracking PeakExtraction^^MitkFiberTracking FiberExtraction^^MitkFiberTracking FiberProcessing^^MitkFiberTracking FiberDirectionExtraction^^MitkFiberTracking LocalDirectionalFiberPlausibility^^MitkFiberTracking StreamlineTracking^^MitkFiberTracking GibbsTracking^^MitkFiberTracking CopyGeometry^^ DiffusionIndices^^ TractometerMetrics^^MitkFiberTracking QballReconstruction^^ Registration^^ FileFormatConverter^^MitkFiberTracking TensorReconstruction^^ TensorDerivedMapsExtraction^^ DICOMLoader^^ + DFTraining^^MitkFiberTracking + DFTracking^^MitkFiberTracking ) foreach(diffusionminiapp ${diffusionminiapps}) # extract mini app name and dependencies string(REPLACE "^^" "\\;" miniapp_info ${diffusionminiapp}) set(miniapp_info_list ${miniapp_info}) list(GET miniapp_info_list 0 appname) list(GET miniapp_info_list 1 raw_dependencies) string(REPLACE "_" "\\;" dependencies "${raw_dependencies}") set(dependencies_list ${dependencies}) mitk_create_executable(${appname} DEPENDS MitkCore MitkDiffusionCore ${dependencies_list} PACKAGE_DEPENDS ITK CPP_FILES ${appname}.cpp mitkCommandLineParser.cpp ) if(EXECUTABLE_IS_ENABLED) # On Linux, create a shell script to start a relocatable application if(UNIX AND NOT APPLE) install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) endif() get_target_property(_is_bundle ${EXECUTABLE_TARGET} MACOSX_BUNDLE) if(APPLE) if(_is_bundle) set(_target_locations ${EXECUTABLE_TARGET}.app) set(${_target_locations}_qt_plugins_install_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_bundle_dest_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_qt_plugins_for_current_bundle ${EXECUTABLE_TARGET}.app/Contents/MacOS) set(_qt_conf_install_dirs ${EXECUTABLE_TARGET}.app/Contents/Resources) install(TARGETS ${EXECUTABLE_TARGET} BUNDLE DESTINATION . ) else() if(NOT MACOSX_BUNDLE_NAMES) set(_qt_conf_install_dirs bin) set(_target_locations bin/${EXECUTABLE_TARGET}) set(${_target_locations}_qt_plugins_install_dir bin) install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) else() foreach(bundle_name ${MACOSX_BUNDLE_NAMES}) list(APPEND _qt_conf_install_dirs ${bundle_name}.app/Contents/Resources) set(_current_target_location ${bundle_name}.app/Contents/MacOS/${EXECUTABLE_TARGET}) list(APPEND _target_locations ${_current_target_location}) set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) message( " set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) ") install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION ${bundle_name}.app/Contents/MacOS/) endforeach() endif() endif() else() set(_target_locations bin/${EXECUTABLE_TARGET}${CMAKE_EXECUTABLE_SUFFIX}) set(${_target_locations}_qt_plugins_install_dir bin) set(_qt_conf_install_dirs bin) install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) endif() endif() endforeach() # This mini app does not depend on mitkDiffusionImaging at all mitk_create_executable(Dicom2Nrrd DEPENDS MitkCore CPP_FILES Dicom2Nrrd.cpp mitkCommandLineParser.cpp ) # On Linux, create a shell script to start a relocatable application if(UNIX AND NOT APPLE) install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) endif() if(EXECUTABLE_IS_ENABLED) MITK_INSTALL_TARGETS(EXECUTABLES ${EXECUTABLE_TARGET}) endif() endif() diff --git a/Modules/DiffusionImaging/MiniApps/DFTracking.cpp b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp new file mode 100755 index 0000000000..dc9db5d88e --- /dev/null +++ b/Modules/DiffusionImaging/MiniApps/DFTracking.cpp @@ -0,0 +1,196 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include +#include +#include +#include +#include "mitkCommandLineParser.h" +#include +#include +#include +//#include +#include +#include +#include +#include + +#include +#include +//#include +#include +#include +#include +#include + +#define _USE_MATH_DEFINES +#include + +const int numOdfSamples = 200; +typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; + +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + + parser.setTitle("Machine Learning Based Streamline Tractography"); + parser.setCategory("Fiber Tracking and Processing Methods"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + parser.addArgument("image", "i", mitkCommandLineParser::String, "DWIs:", "input diffusion-weighted image", us::Any(), false); + parser.addArgument("forest", "f", mitkCommandLineParser::String, "Forest:", "input forest", us::Any(), false); + parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output fiberbundle", us::Any(), false); + + parser.addArgument("stop", "st", mitkCommandLineParser::String, "Stop image:", "stop image", us::Any()); + parser.addArgument("mask", "m", mitkCommandLineParser::String, "Mask image:", "mask image", us::Any()); + parser.addArgument("seed", "s", mitkCommandLineParser::String, "Seed image:", "seed image", us::Any()); + + parser.addArgument("athres", "a", mitkCommandLineParser::Float, "Angular threshold:", "angular threshold (in radians)", us::Any()); + parser.addArgument("stepsize", "se", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); + parser.addArgument("samples", "ns", mitkCommandLineParser::Int, "Samples:", "samples", us::Any()); + parser.addArgument("samplingdist", "sd", mitkCommandLineParser::Float, "Sampling distance:", "sampling distance (in voxels)", us::Any()); + parser.addArgument("seeds", "nse", mitkCommandLineParser::Int, "Seeds per voxel:", "seeds per voxel", us::Any()); + + parser.addArgument("usedirection", "ud", mitkCommandLineParser::Bool, "Use previous direction:", "use previous direction as feature", us::Any()); + parser.addArgument("verbose", "v", mitkCommandLineParser::Bool, "Verbose:", "output additional images", us::Any()); + + map parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size()==0) + return EXIT_FAILURE; + + string imageFile = us::any_cast(parsedArgs["image"]); + string forestFile = us::any_cast(parsedArgs["forest"]); + string outFile = us::any_cast(parsedArgs["out"]); + + string maskFile = ""; + if (parsedArgs.count("mask")) + maskFile = us::any_cast(parsedArgs["mask"]); + + string seedFile = ""; + if (parsedArgs.count("seed")) + seedFile = us::any_cast(parsedArgs["seed"]); + + string stopFile = ""; + if (parsedArgs.count("stop")) + stopFile = us::any_cast(parsedArgs["stop"]); + + float stepsize = -1; + if (parsedArgs.count("stepsize")) + stepsize = us::any_cast(parsedArgs["stepsize"]); + + float athres = 0.7; + if (parsedArgs.count("athres")) + athres = us::any_cast(parsedArgs["athres"]); + + float samplingdist = 0.25; + if (parsedArgs.count("samplingdist")) + samplingdist = us::any_cast(parsedArgs["samplingdist"]); + + bool useDirection = false; + if (parsedArgs.count("usedirection")) + useDirection = true; + + bool verbose = false; + if (parsedArgs.count("verbose")) + verbose = true; + + int samples = 10; + if (parsedArgs.count("samples")) + samples = us::any_cast(parsedArgs["samples"]); + + int seeds = 1; + if (parsedArgs.count("seeds")) + seeds = us::any_cast(parsedArgs["seeds"]); + + typedef itk::Image ItkUcharImgType; + + MITK_INFO << "loading diffusion-weighted image"; + mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::LoadImage(imageFile).GetPointer()); + + ItkUcharImgType::Pointer mask; + if (!maskFile.empty()) + { + MITK_INFO << "loading mask image"; + mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(maskFile).GetPointer()); + mask = ItkUcharImgType::New(); + mitk::CastToItkImage(img, mask); + } + + ItkUcharImgType::Pointer seed; + if (!seedFile.empty()) + { + MITK_INFO << "loading seed image"; + mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(seedFile).GetPointer()); + seed = ItkUcharImgType::New(); + mitk::CastToItkImage(img, seed); + } + + ItkUcharImgType::Pointer stop; + if (!stopFile.empty()) + { + MITK_INFO << "loading stop image"; + mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadImage(stopFile).GetPointer()); + stop = ItkUcharImgType::New(); + mitk::CastToItkImage(img, stop); + } + + MITK_INFO << "loading forest"; + vigra::RandomForest rf; + vigra::rf_import_HDF5(rf, forestFile); + + typedef itk::MLBSTrackingFilter<100> TrackerType; + TrackerType::Pointer tracker = TrackerType::New(); + tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi)); + tracker->SetGradientDirections( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi) ); + tracker->SetB_Value( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi) ); + tracker->SetMaskImage(mask); + tracker->SetSeedImage(seed); + tracker->SetStoppingRegions(stop); + tracker->SetSeedsPerVoxel(seeds); + tracker->SetUseDirection(useDirection); + tracker->SetStepSize(stepsize); + tracker->SetAngularThreshold(athres); + tracker->SetDecisionForest(&rf); + tracker->SetSamplingDistance(samplingdist); + tracker->SetNumberOfSamples(samples); + tracker->Update(); + vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); + mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); + + mitk::IOUtil::SaveBaseData(outFib, outFile); + + if (verbose) + { + MITK_INFO << "Writing images..."; + string outName = itksys::SystemTools::GetFilenamePath(outFile)+"/"+itksys::SystemTools::GetFilenameWithoutLastExtension(outFile); + itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::Pointer writer = itk::ImageFileWriter< TrackerType::ItkDoubleImgType >::New(); + writer->SetFileName(outName+"_WhiteMatter.nrrd"); + writer->SetInput(tracker->GetWmImage()); + writer->Update(); + + writer->SetFileName(outName+"_NotWhiteMatter.nrrd"); + writer->SetInput(tracker->GetNotWmImage()); + writer->Update(); + + writer->SetFileName(outName+"_AvoidStop.nrrd"); + writer->SetInput(tracker->GetAvoidStopImage()); + writer->Update(); + } + + return EXIT_SUCCESS; +} diff --git a/Modules/DiffusionImaging/MiniApps/DFTraining.cpp b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp new file mode 100755 index 0000000000..4c2eec7c8d --- /dev/null +++ b/Modules/DiffusionImaging/MiniApps/DFTraining.cpp @@ -0,0 +1,532 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include +#include +#include +#include +#include "mitkCommandLineParser.h" +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#define _USE_MATH_DEFINES +#include + +const int numOdfSamples = 200; // ODF is sampled in 200 directions but actuyll only 100 are used (symmetric) +typedef itk::Image< itk::Vector< float, numOdfSamples > , 3 > SampledShImageType; + +void TrainForest( vigra::RandomForest &rf, vigra::MultiArray<2, double> &labelData, vigra::MultiArray<2, double> &featureData, int numTrees, int max_tree_depth, double sample_fraction ) +{ + MITK_INFO << "Maximum tree depths: " << max_tree_depth; + MITK_INFO << "Sample fraction per tree: " << sample_fraction; + MITK_INFO << "Number of trees: " << numTrees; + vigra::rf::visitors::OOB_Error oob_v; + + rf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal + rf.set_options().sample_with_replacement(true); // if sampled with replacement or not + rf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree + rf.set_options().tree_count(1); // Number of trees that are calculated; + rf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + // rf.set_options().features_per_node(10); + + rf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(oob_v)); + + // Prepare parallel VariableImportance Calculation + int numMod = featureData.shape(1); + const int numClass = 2 + 2; + + float** varImp = new float*[numMod]; + + for(int i = 0; i < numMod; i++) + varImp[i] = new float[numClass]; + + for (int i = 0; i < numMod; ++i) + for (int j = 0; j < numClass; ++j) + varImp[i][j] = 0.0; + +#pragma omp parallel for + for (int i = 0; i < numTrees - 1; ++i) + { + vigra::RandomForest lrf; + vigra::rf::visitors::OOB_Error loob_v; + + lrf.set_options().use_stratification(vigra::RF_NONE); // How the data should be made equal + lrf.set_options().sample_with_replacement(true); // if sampled with replacement or not + lrf.set_options().samples_per_tree(sample_fraction); // Fraction of samples that are used to train a tree + lrf.set_options().tree_count(1); // Number of trees that are calculated; + lrf.set_options().min_split_node_size(5); // Minimum number of datapoints that must be in a node + // lrf.set_options().features_per_node(10); + + lrf.learn(featureData, labelData, vigra::rf::visitors::create_visitor(loob_v)); +#pragma omp critical + { + rf.trees_.push_back(lrf.trees_[0]); + } + } + + rf.options_.tree_count_ = numTrees; + MITK_INFO << "Training finsihed"; + MITK_INFO << "The out-of-bag error is: " << oob_v.oob_breiman << std::endl; +} + +SampledShImageType::PixelType GetImageValues(itk::Point itkP, SampledShImageType::Pointer image) +{ + itk::Index<3> idx; + itk::ContinuousIndex< double, 3> cIdx; + image->TransformPhysicalPointToIndex(itkP, idx); + image->TransformPhysicalPointToContinuousIndex(itkP, cIdx); + + SampledShImageType::PixelType pix; pix.Fill(0.0); + if ( image->GetLargestPossibleRegion().IsInside(idx) ) + pix = image->GetPixel(idx); + else + return pix; + + double frac_x = cIdx[0] - idx[0]; + double frac_y = cIdx[1] - idx[1]; + double frac_z = cIdx[2] - idx[2]; + if (frac_x<0) + { + idx[0] -= 1; + frac_x += 1; + } + if (frac_y<0) + { + idx[1] -= 1; + frac_y += 1; + } + if (frac_z<0) + { + idx[2] -= 1; + frac_z += 1; + } + frac_x = 1-frac_x; + frac_y = 1-frac_y; + frac_z = 1-frac_z; + + // int coordinates inside image? + if (idx[0] >= 0 && idx[0] < image->GetLargestPossibleRegion().GetSize(0)-1 && + idx[1] >= 0 && idx[1] < image->GetLargestPossibleRegion().GetSize(1)-1 && + idx[2] >= 0 && idx[2] < image->GetLargestPossibleRegion().GetSize(2)-1) + { + vnl_vector_fixed interpWeights; + interpWeights[0] = ( frac_x)*( frac_y)*( frac_z); + interpWeights[1] = (1-frac_x)*( frac_y)*( frac_z); + interpWeights[2] = ( frac_x)*(1-frac_y)*( frac_z); + interpWeights[3] = ( frac_x)*( frac_y)*(1-frac_z); + interpWeights[4] = (1-frac_x)*(1-frac_y)*( frac_z); + interpWeights[5] = ( frac_x)*(1-frac_y)*(1-frac_z); + interpWeights[6] = (1-frac_x)*( frac_y)*(1-frac_z); + interpWeights[7] = (1-frac_x)*(1-frac_y)*(1-frac_z); + + + pix = image->GetPixel(idx) * interpWeights[0]; + SampledShImageType::IndexType tmpIdx = idx; tmpIdx[0]++; + pix += image->GetPixel(tmpIdx) * interpWeights[1]; + tmpIdx = idx; tmpIdx[1]++; + pix += image->GetPixel(tmpIdx) * interpWeights[2]; + tmpIdx = idx; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[3]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; + pix += image->GetPixel(tmpIdx) * interpWeights[4]; + tmpIdx = idx; tmpIdx[1]++; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[5]; + tmpIdx = idx; tmpIdx[2]++; tmpIdx[0]++; + pix += image->GetPixel(tmpIdx) * interpWeights[6]; + tmpIdx = idx; tmpIdx[0]++; tmpIdx[1]++; tmpIdx[2]++; + pix += image->GetPixel(tmpIdx) * interpWeights[7]; + } + + return pix; +} + +int main(int argc, char* argv[]) +{ + MITK_INFO << "DFTraining"; + mitkCommandLineParser parser; + + parser.setTitle("Machine Learning Based Streamline Tractography"); + parser.setCategory("Fiber Tracking and Processing Methods"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + parser.addArgument("images", "i", mitkCommandLineParser::StringList, "DWIs:", "input diffusion-weighted images", us::Any(), false); + parser.addArgument("wmmasks", "w", mitkCommandLineParser::StringList, "WM-Masks:", "white matter mask images", us::Any(), false); + parser.addArgument("tractograms", "t", mitkCommandLineParser::StringList, "Tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); + parser.addArgument("masks", "m", mitkCommandLineParser::StringList, "Masks:", "mask images", us::Any()); + parser.addArgument("forest", "f", mitkCommandLineParser::OutputFile, "Forest:", "output forest", us::Any(), false); + + parser.addArgument("stepsize", "s", mitkCommandLineParser::Float, "Stepsize:", "stepsize", us::Any()); + parser.addArgument("gmsamples", "g", mitkCommandLineParser::Int, "Number of gray matter samples per voxel:", "Number of gray matter samples per voxel", us::Any()); + parser.addArgument("numtrees", "n", mitkCommandLineParser::Int, "Number of trees:", "number of trees", us::Any()); + parser.addArgument("max_tree_depth", "d", mitkCommandLineParser::Int, "Max. tree depth:", "maximum tree depth", us::Any()); + parser.addArgument("sample_fraction", "sf", mitkCommandLineParser::Float, "Sample fraction:", "fraction of samples used per tree", us::Any()); + parser.addArgument("usedirection", "ud", mitkCommandLineParser::Bool, "bla:", "bla", us::Any()); + + map parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size()==0) + return EXIT_FAILURE; + + mitkCommandLineParser::StringContainerType imageFiles = us::any_cast(parsedArgs["images"]); + mitkCommandLineParser::StringContainerType wmMaskFiles = us::any_cast(parsedArgs["wmmasks"]); + + mitkCommandLineParser::StringContainerType maskFiles; + if (parsedArgs.count("masks")) + maskFiles = us::any_cast(parsedArgs["masks"]); + + string forestFile = us::any_cast(parsedArgs["forest"]); + + mitkCommandLineParser::StringContainerType tractogramFiles; + if (parsedArgs.count("tractograms")) + tractogramFiles = us::any_cast(parsedArgs["tractograms"]); + + int numTrees = 30; + if (parsedArgs.count("numtrees")) + numTrees = us::any_cast(parsedArgs["numtrees"]); + + int gmsamples = 50; + if (parsedArgs.count("gmsamples")) + gmsamples = us::any_cast(parsedArgs["gmsamples"]); + + float stepsize = -1; + if (parsedArgs.count("stepsize")) + stepsize = us::any_cast(parsedArgs["stepsize"]); + + int max_tree_depth = 50; + if (parsedArgs.count("max_tree_depth")) + max_tree_depth = us::any_cast(parsedArgs["max_tree_depth"]); + + double sample_fraction = 1.0; + if (parsedArgs.count("sample_fraction")) + sample_fraction = us::any_cast(parsedArgs["sample_fraction"]); + + // load DWI images + if (imageFiles.size() QballFilterType; + + MITK_INFO << "loading diffusion-weighted images and reconstructing feature images"; + std::vector< SampledShImageType::Pointer > sampledShImages; + for (unsigned int i=0; i(mitk::IOUtil::LoadImage(imageFiles.at(i)).GetPointer()); + + QballFilterType::Pointer qballfilter = QballFilterType::New(); + qballfilter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); + qballfilter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + qballfilter->SetLambda(0.006); + qballfilter->SetNormalizationMethod(QballFilterType::QBAR_RAW_SIGNAL); + qballfilter->Update(); + // FeatureImageType::Pointer itkFeatureImage = qballfilter->GetCoefficientImage(); + // featureImageVector.push_back(itkFeatureImage); + sampledShImages.push_back(qballfilter->GetOutput()); + } + + + typedef itk::Image ItkUcharImgType; + std::vector< ItkUcharImgType::Pointer > maskImageVector; + std::vector< ItkUcharImgType::Pointer > wmMaskImageVector; + + MITK_INFO << "loading tractograms"; + int numSamples = 0; + std::vector< mitk::FiberBundle::Pointer > tractograms; + for (unsigned int t=0; t(mitk::IOUtil::LoadImage(maskFiles.at(t)).GetPointer()); + mask = ItkUcharImgType::New(); + mitk::CastToItkImage(img, mask); + maskImageVector.push_back(mask); + } + mitk::Image::Pointer img2 = dynamic_cast(mitk::IOUtil::LoadImage(wmMaskFiles.at(t)).GetPointer()); + ItkUcharImgType::Pointer wmmask = ItkUcharImgType::New(); + mitk::CastToItkImage(img2, wmmask); + wmMaskImageVector.push_back(wmmask); + + itk::ImageRegionConstIterator it(wmmask, wmmask->GetLargestPossibleRegion()); + int OUTOFWM = 0; // count voxels outside of the white matter mask + while(!it.IsAtEnd()) + { + if (it.Get()==0) + if (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0)) + OUTOFWM++; + ++it; + } + numSamples += gmsamples*OUTOFWM; // for each of the non-white matter voxels we add a certain number of sampling points. these sampling points are used to tell the classifier where to recognize non-WM tissue + + MITK_INFO << "Samples outside of WM: " << numSamples << " (" << gmsamples << " per non-WM voxel)"; + + // load and resample training tractograms + mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::LoadDataNode(tractogramFiles.at(t))->GetData()); + if (stepsize<0) + { + SampledShImageType::Pointer image = sampledShImages.at(t); + float minSpacing = 1; + if(image->GetSpacing()[0]GetSpacing()[1] && image->GetSpacing()[0]GetSpacing()[2]) + minSpacing = image->GetSpacing()[0]; + else if (image->GetSpacing()[1] < image->GetSpacing()[2]) + minSpacing = image->GetSpacing()[1]; + else + minSpacing = image->GetSpacing()[2]; + stepsize = minSpacing*0.5; + } + fib->ResampleSpline(stepsize); + tractograms.push_back(fib); + numSamples += fib->GetNumberOfPoints(); // each point of the fiber gives us a training direction + numSamples -= 2*fib->GetNumFibers(); // we don't use the first and last point because there we do not have a previous direction, which is needed as feature + } + MITK_INFO << "Number of samples: " << numSamples; + + // get ODF directions and number of features + vnl_vector_fixed ref; ref.fill(0); ref[0]=1; + itk::OrientationDistributionFunction< double, numOdfSamples > odf; + std::vector< int > directionIndices; + for (unsigned int f=0; f0) // we only use directions on one hemisphere (symmetric) + directionIndices.push_back(f); // remember indices that are on the desired hemisphere + } + const int numSignalFeatures = numOdfSamples/2; + int numDirectionFeatures = 0; + if (useDirection) + numDirectionFeatures = 3; + + vigra::MultiArray<2, double> featureData( vigra::Shape2(numSamples,numSignalFeatures+numDirectionFeatures) ); + MITK_INFO << "Number of features: " << featureData.shape(1); + vigra::MultiArray<2, double> labelData( vigra::Shape2(numSamples,1) ); + + itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer m_RandGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); + m_RandGen->SetSeed(); + MITK_INFO << "Creating training data from tractograms and feature images"; + int sampleCounter = 0; + for (unsigned int t=0; t it(wmMask, wmMask->GetLargestPossibleRegion()); + while(!it.IsAtEnd()) + { + if (it.Get()==0 && (mask.IsNull() || (mask.IsNotNull() && mask->GetPixel(it.GetIndex())>0))) + { + SampledShImageType::PixelType pix = image->GetPixel(it.GetIndex()); + if (useDirection) + { + // null direction + for (unsigned int f=0; f probe; + probe[0] = m_RandGen->GetVariate()*2-1; + probe[1] = m_RandGen->GetVariate()*2-1; + probe[2] = m_RandGen->GetVariate()*2-1; + probe.normalize(); + if (dot_product(ref, probe)<0) + probe *= -1; + for (unsigned int f=numSignalFeatures; f idx; + idx[0] = it.GetIndex()[0]; + idx[1] = it.GetIndex()[1]; + idx[2] = it.GetIndex()[2]; + itk::Point itkP1; + image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); + SampledShImageType::PixelType pix = GetImageValues(itkP1, image);; + for (unsigned int f=0; f idx; + idx[0] = it.GetIndex()[0] + m_RandGen->GetVariate()-0.5; + idx[1] = it.GetIndex()[1] + m_RandGen->GetVariate()-0.5; + idx[2] = it.GetIndex()[2] + m_RandGen->GetVariate()-0.5; + itk::Point itkP1; + image->TransformContinuousIndexToPhysicalPoint(idx, itkP1); + SampledShImageType::PixelType pix = GetImageValues(itkP1, image);; + for (unsigned int f=0; f polyData = fib->GetFiberPolyData(); + for (int i=0; iGetNumFibers(); i++) + { + vtkCell* cell = polyData->GetCell(i); + int numPoints = cell->GetNumberOfPoints(); + vtkPoints* points = cell->GetPoints(); + + vnl_vector_fixed dirOld; dirOld.fill(0.0); + + for (int j=0; jGetPoint(j); + itk::Point itkP1; + itkP1[0] = p1[0]; itkP1[1] = p1[1]; itkP1[2] = p1[2]; + + vnl_vector_fixed dir; dir.fill(0.0); + + itk::Point itkP2; + double* p2 = points->GetPoint(j+1); + itkP2[0] = p2[0]; itkP2[1] = p2[1]; itkP2[2] = p2[2]; + dir[0]=itkP2[0]-itkP1[0]; + dir[1]=itkP2[1]-itkP1[1]; + dir[2]=itkP2[2]-itkP1[2]; + + if (dir.magnitude()<0.0001) + { + MITK_INFO << "streamline error!"; + continue; + } + dir.normalize(); + if (dir[0]!=dir[0] || dir[1]!=dir[1] || dir[2]!=dir[2]) + { + MITK_INFO << "ERROR: NaN direction!"; + continue; + } + + if (j==0) + { + dirOld = dir; + continue; + } + + // get voxel values + SampledShImageType::PixelType pix = GetImageValues(itkP1, image); + for (unsigned int f=0; f0.0001) + { + int label = 0; + for (unsigned int f=0; fangle) + { + labelData(sampleCounter,0) = f; + angle = a; + label = f; + } + } + } + + dirOld = dir; + sampleCounter++; + } + } + } + + MITK_INFO << "Training forest"; + vigra::RandomForest rf; + TrainForest( rf, labelData, featureData, numTrees, max_tree_depth, sample_fraction ); + MITK_INFO << "Writing forest"; + vigra::rf_export_HDF5( rf, forestFile, "" ); + MITK_INFO << "Finished training"; + + return EXIT_SUCCESS; +} diff --git a/Modules/DiffusionImaging/MiniApps/TractometerMetrics.cpp b/Modules/DiffusionImaging/MiniApps/TractometerMetrics.cpp index d37e7c3980..38795b6bf5 100755 --- a/Modules/DiffusionImaging/MiniApps/TractometerMetrics.cpp +++ b/Modules/DiffusionImaging/MiniApps/TractometerMetrics.cpp @@ -1,414 +1,413 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include #include #include #include #include #include #define _USE_MATH_DEFINES #include int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Tractometer Metrics"); parser.setCategory("Fiber Tracking and Processing Methods"); parser.setDescription(""); parser.setContributor("MBI"); parser.setArgumentPrefix("--", "-"); parser.addArgument("input", "i", mitkCommandLineParser::InputFile, "Input:", "input tractogram (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); parser.addArgument("labels", "l", mitkCommandLineParser::StringList, "Label pairs:", "label pairs", false); parser.addArgument("labelimage", "li", mitkCommandLineParser::String, "Label image:", "label image", false); parser.addArgument("verbose", "v", mitkCommandLineParser::Bool, "Verbose:", "output valid, invalid and no connections as fiber bundles"); - parser.addArgument("fileID", "id", mitkCommandLineParser::String, "ID:", "optional ID field"); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; mitkCommandLineParser::StringContainerType labelpairs = us::any_cast(parsedArgs["labels"]); string fibFile = us::any_cast(parsedArgs["input"]); string labelImageFile = us::any_cast(parsedArgs["labelimage"]); string outRoot = us::any_cast(parsedArgs["out"]); string fileID = ""; if (parsedArgs.count("fileID")) fileID = us::any_cast(parsedArgs["fileID"]); bool verbose = false; if (parsedArgs.count("verbose")) verbose = us::any_cast(parsedArgs["verbose"]); try { typedef itk::Image ItkShortImgType; typedef itk::Image ItkUcharImgType; // load fiber bundle mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::LoadDataNode(fibFile)->GetData()); mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadDataNode(labelImageFile)->GetData()); typedef mitk::ImageToItk< ItkShortImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(img); caster->Update(); ItkShortImgType::Pointer labelImage = caster->GetOutput(); string path = itksys::SystemTools::GetFilenamePath(labelImageFile); std::vector< bool > detected; std::vector< std::pair< int, int > > labelsvector; std::vector< ItkUcharImgType::Pointer > bundleMasks; std::vector< ItkUcharImgType::Pointer > bundleMasksCoverage; short max = 0; for (unsigned int i=0; i l; l.first = boost::lexical_cast(labelpairs.at(i)); l.second = boost::lexical_cast(labelpairs.at(i+1)); std::cout << labelpairs.at(i); std::cout << labelpairs.at(i+1); if (l.first>max) max=l.first; if (l.second>max) max=l.second; labelsvector.push_back(l); detected.push_back(false); { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadDataNode(path+"/Bundle"+boost::lexical_cast(labelsvector.size())+"_MASK.nrrd")->GetData()); typedef mitk::ImageToItk< ItkUcharImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(img); caster->Update(); ItkUcharImgType::Pointer bundle = caster->GetOutput(); bundleMasks.push_back(bundle); } { mitk::Image::Pointer img = dynamic_cast(mitk::IOUtil::LoadDataNode(path+"/Bundle"+boost::lexical_cast(labelsvector.size())+"_MASK_COVERAGE.nrrd")->GetData()); typedef mitk::ImageToItk< ItkUcharImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(img); caster->Update(); ItkUcharImgType::Pointer bundle = caster->GetOutput(); bundleMasksCoverage.push_back(bundle); } } vnl_matrix< unsigned char > matrix; matrix.set_size(max, max); matrix.fill(0); vtkSmartPointer polyData = inputTractogram->GetFiberPolyData(); int validConnections = 0; int noConnection = 0; int validBundles = 0; int invalidBundles = 0; int invalidConnections = 0; ItkUcharImgType::Pointer coverage = ItkUcharImgType::New(); coverage->SetSpacing(labelImage->GetSpacing()); coverage->SetOrigin(labelImage->GetOrigin()); coverage->SetDirection(labelImage->GetDirection()); coverage->SetLargestPossibleRegion(labelImage->GetLargestPossibleRegion()); coverage->SetBufferedRegion( labelImage->GetLargestPossibleRegion() ); coverage->SetRequestedRegion( labelImage->GetLargestPossibleRegion() ); coverage->Allocate(); coverage->FillBuffer(0); vtkSmartPointer noConnPoints = vtkSmartPointer::New(); vtkSmartPointer noConnCells = vtkSmartPointer::New(); vtkSmartPointer invalidPoints = vtkSmartPointer::New(); vtkSmartPointer invalidCells = vtkSmartPointer::New(); vtkSmartPointer validPoints = vtkSmartPointer::New(); vtkSmartPointer validCells = vtkSmartPointer::New(); boost::progress_display disp(inputTractogram->GetNumFibers()); for (int i=0; iGetNumFibers(); i++) { ++disp; vtkCell* cell = polyData->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); if (numPoints>1) { double* start = points->GetPoint(0); itk::Point itkStart; itkStart[0] = start[0]; itkStart[1] = start[1]; itkStart[2] = start[2]; itk::Index<3> idxStart; labelImage->TransformPhysicalPointToIndex(itkStart, idxStart); double* end = points->GetPoint(numPoints-1); itk::Point itkEnd; itkEnd[0] = end[0]; itkEnd[1] = end[1]; itkEnd[2] = end[2]; itk::Index<3> idxEnd; labelImage->TransformPhysicalPointToIndex(itkEnd, idxEnd); if ( labelImage->GetPixel(idxStart)==0 || labelImage->GetPixel(idxEnd)==0 ) { noConnection++; if (verbose) { vtkSmartPointer container = vtkSmartPointer::New(); for (int j=0; jGetPoint(j); vtkIdType id = noConnPoints->InsertNextPoint(p); container->GetPointIds()->InsertNextId(id); } noConnCells->InsertNextCell(container); } } else { bool invalid = true; for (unsigned int i=0; i l = labelsvector.at(i); if ( (labelImage->GetPixel(idxStart)==l.first && labelImage->GetPixel(idxEnd)==l.second) || (labelImage->GetPixel(idxStart)==l.second && labelImage->GetPixel(idxEnd)==l.first) ) { for (int j=0; jGetPoint(j); itk::Point itkP; itkP[0] = p[0]; itkP[1] = p[1]; itkP[2] = p[2]; itk::Index<3> idx; bundle->TransformPhysicalPointToIndex(itkP, idx); if ( !bundle->GetPixel(idx)>0 && bundle->GetLargestPossibleRegion().IsInside(idx) ) { outside=true; } } if (!outside) { validConnections++; if (detected.at(i)==false) validBundles++; detected.at(i) = true; invalid = false; vtkSmartPointer container = vtkSmartPointer::New(); for (int j=0; jGetPoint(j); vtkIdType id = validPoints->InsertNextPoint(p); container->GetPointIds()->InsertNextId(id); itk::Point itkP; itkP[0] = p[0]; itkP[1] = p[1]; itkP[2] = p[2]; itk::Index<3> idx; coverage->TransformPhysicalPointToIndex(itkP, idx); if ( coverage->GetLargestPossibleRegion().IsInside(idx) ) coverage->SetPixel(idx, 1); } validCells->InsertNextCell(container); } break; } } if (invalid==true) { invalidConnections++; int x = labelImage->GetPixel(idxStart)-1; int y = labelImage->GetPixel(idxEnd)-1; if (x>=0 && y>0 && x container = vtkSmartPointer::New(); for (int j=0; jGetPoint(j); vtkIdType id = invalidPoints->InsertNextPoint(p); container->GetPointIds()->InsertNextId(id); } invalidCells->InsertNextCell(container); } } } } } if (verbose) { mitk::CoreObjectFactory::FileWriterList fileWriters = mitk::CoreObjectFactory::GetInstance()->GetFileWriters(); vtkSmartPointer noConnPolyData = vtkSmartPointer::New(); noConnPolyData->SetPoints(noConnPoints); noConnPolyData->SetLines(noConnCells); mitk::FiberBundle::Pointer noConnFib = mitk::FiberBundle::New(noConnPolyData); string ncfilename = outRoot; ncfilename.append("_NC.fib"); mitk::IOUtil::SaveBaseData(noConnFib.GetPointer(), ncfilename ); vtkSmartPointer invalidPolyData = vtkSmartPointer::New(); invalidPolyData->SetPoints(invalidPoints); invalidPolyData->SetLines(invalidCells); mitk::FiberBundle::Pointer invalidFib = mitk::FiberBundle::New(invalidPolyData); string icfilename = outRoot; icfilename.append("_IC.fib"); mitk::IOUtil::SaveBaseData(invalidFib.GetPointer(), icfilename ); vtkSmartPointer validPolyData = vtkSmartPointer::New(); validPolyData->SetPoints(validPoints); validPolyData->SetLines(validCells); mitk::FiberBundle::Pointer validFib = mitk::FiberBundle::New(validPolyData); string vcfilename = outRoot; vcfilename.append("_VC.fib"); mitk::IOUtil::SaveBaseData(validFib.GetPointer(), vcfilename ); { typedef itk::ImageFileWriter< ItkUcharImgType > WriterType; WriterType::Pointer writer = WriterType::New(); writer->SetFileName(outRoot+"_ABC.nrrd"); writer->SetInput(coverage); writer->Update(); } } // calculate coverage int wmVoxels = 0; int coveredVoxels = 0; itk::ImageRegionIterator it (coverage, coverage->GetLargestPossibleRegion()); while(!it.IsAtEnd()) { bool wm = false; for (unsigned int i=0; iGetPixel(it.GetIndex())>0) { wm = true; wmVoxels++; break; } } if (wm && it.Get()>0) coveredVoxels++; ++it; } int numFibers = inputTractogram->GetNumFibers(); double nc = (double)noConnection/numFibers; double vc = (double)validConnections/numFibers; double ic = (double)invalidConnections/numFibers; if (numFibers==0) { nc = 0.0; vc = 0.0; ic = 0.0; } int vb = validBundles; int ib = invalidBundles; double abc = (double)coveredVoxels/wmVoxels; std::cout << "NC: " << nc; std::cout << "VC: " << vc; std::cout << "IC: " << ic; std::cout << "VB: " << vb; std::cout << "IB: " << ib; std::cout << "ABC: " << abc; string logFile = outRoot; logFile.append("_TRACTOMETER.csv"); ofstream file; file.open (logFile.c_str()); { string sens = itksys::SystemTools::GetFilenameWithoutLastExtension(fibFile); if (!fileID.empty()) sens = fileID; sens.append(","); sens.append(boost::lexical_cast(nc)); sens.append(","); sens.append(boost::lexical_cast(vc)); sens.append(","); sens.append(boost::lexical_cast(ic)); sens.append(","); sens.append(boost::lexical_cast(validBundles)); sens.append(","); sens.append(boost::lexical_cast(invalidBundles)); sens.append(","); sens.append(boost::lexical_cast(abc)); sens.append(";\n"); file << sens; } file.close(); } catch (itk::ExceptionObject e) { std::cout << e; return EXIT_FAILURE; } catch (std::exception e) { std::cout << e.what(); return EXIT_FAILURE; } catch (...) { std::cout << "ERROR!?!"; return EXIT_FAILURE; } return EXIT_SUCCESS; } diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/files.cmake b/Plugins/org.mitk.gui.qt.diffusionimaging/files.cmake index a6f512c7f6..fa0b37ba15 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/files.cmake +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/files.cmake @@ -1,208 +1,211 @@ set(SRC_CPP_FILES QmitkODFDetailsWidget.cpp QmitkODFRenderWidget.cpp QmitkPartialVolumeAnalysisWidget.cpp QmitkIVIMWidget.cpp QmitkTbssRoiAnalysisWidget.cpp QmitkResidualAnalysisWidget.cpp QmitkResidualViewWidget.cpp QmitkTensorModelParametersWidget.cpp QmitkZeppelinModelParametersWidget.cpp QmitkStickModelParametersWidget.cpp QmitkDotModelParametersWidget.cpp QmitkBallModelParametersWidget.cpp QmitkAstrosticksModelParametersWidget.cpp QmitkPrototypeSignalParametersWidget.cpp ) set(INTERNAL_CPP_FILES mitkPluginActivator.cpp QmitkQBallReconstructionView.cpp QmitkPreprocessingView.cpp QmitkDiffusionDicomImportView.cpp QmitkDiffusionQuantificationView.cpp QmitkTensorReconstructionView.cpp QmitkControlVisualizationPropertiesView.cpp QmitkODFDetailsView.cpp QmitkGibbsTrackingView.cpp QmitkStochasticFiberTrackingView.cpp QmitkStreamlineTrackingView.cpp QmitkFiberQuantificationView.cpp -# QmitkFiberBundleDeveloperView.cpp QmitkPartialVolumeAnalysisView.cpp QmitkIVIMView.cpp QmitkTractbasedSpatialStatisticsView.cpp QmitkTbssTableModel.cpp QmitkTbssMetaTableModel.cpp QmitkTbssSkeletonizationView.cpp Connectomics/QmitkConnectomicsDataView.cpp Connectomics/QmitkConnectomicsNetworkOperationsView.cpp Connectomics/QmitkConnectomicsStatisticsView.cpp Connectomics/QmitkNetworkHistogramCanvas.cpp Connectomics/QmitkRandomParcellationView.cpp QmitkDwiSoftwarePhantomView.cpp QmitkOdfMaximaExtractionView.cpp QmitkFiberfoxView.cpp QmitkFiberProcessingView.cpp QmitkFieldmapGeneratorView.cpp QmitkDiffusionRegistrationView.cpp QmitkDenoisingView.cpp + QmitkMLBTView.cpp Perspectives/QmitkFiberProcessingPerspective.cpp Perspectives/QmitkDiffusionImagingAppPerspective.cpp Perspectives/QmitkGibbsTractographyPerspective.cpp Perspectives/QmitkStreamlineTractographyPerspective.cpp Perspectives/QmitkProbabilisticTractographyPerspective.cpp Perspectives/QmitkDIAppSyntheticDataGenerationPerspective.cpp Perspectives/QmitkDIAppIVIMPerspective.cpp Perspectives/QmitkDiffusionDefaultPerspective.cpp ) set(UI_FILES src/internal/QmitkQBallReconstructionViewControls.ui src/internal/QmitkPreprocessingViewControls.ui src/internal/QmitkDiffusionDicomImportViewControls.ui src/internal/QmitkDiffusionQuantificationViewControls.ui src/internal/QmitkTensorReconstructionViewControls.ui src/internal/QmitkControlVisualizationPropertiesViewControls.ui src/internal/QmitkODFDetailsViewControls.ui src/internal/QmitkGibbsTrackingViewControls.ui src/internal/QmitkStochasticFiberTrackingViewControls.ui src/internal/QmitkStreamlineTrackingViewControls.ui src/internal/QmitkFiberQuantificationViewControls.ui # src/internal/QmitkFiberBundleDeveloperViewControls.ui src/internal/QmitkPartialVolumeAnalysisViewControls.ui src/internal/QmitkIVIMViewControls.ui src/internal/QmitkTractbasedSpatialStatisticsViewControls.ui src/internal/QmitkTbssSkeletonizationViewControls.ui src/internal/Connectomics/QmitkConnectomicsDataViewControls.ui src/internal/Connectomics/QmitkConnectomicsNetworkOperationsViewControls.ui src/internal/Connectomics/QmitkConnectomicsStatisticsViewControls.ui src/internal/Connectomics/QmitkRandomParcellationViewControls.ui src/internal/QmitkDwiSoftwarePhantomViewControls.ui src/internal/QmitkOdfMaximaExtractionViewControls.ui src/internal/QmitkFiberfoxViewControls.ui src/internal/QmitkFiberProcessingViewControls.ui src/QmitkTensorModelParametersWidgetControls.ui src/QmitkZeppelinModelParametersWidgetControls.ui src/QmitkStickModelParametersWidgetControls.ui src/QmitkDotModelParametersWidgetControls.ui src/QmitkBallModelParametersWidgetControls.ui src/QmitkAstrosticksModelParametersWidgetControls.ui src/QmitkPrototypeSignalParametersWidgetControls.ui src/internal/QmitkFieldmapGeneratorViewControls.ui src/internal/QmitkDiffusionRegistrationViewControls.ui src/internal/QmitkDenoisingViewControls.ui + src/internal/QmitkMLBTViewControls.ui ) set(MOC_H_FILES src/internal/mitkPluginActivator.h src/internal/QmitkQBallReconstructionView.h src/internal/QmitkPreprocessingView.h src/internal/QmitkDiffusionDicomImportView.h src/internal/QmitkDiffusionQuantificationView.h src/internal/QmitkTensorReconstructionView.h src/internal/QmitkControlVisualizationPropertiesView.h src/internal/QmitkODFDetailsView.h src/QmitkODFRenderWidget.h src/QmitkODFDetailsWidget.h src/internal/QmitkGibbsTrackingView.h src/internal/QmitkStochasticFiberTrackingView.h src/internal/QmitkStreamlineTrackingView.h src/internal/QmitkFiberQuantificationView.h # src/internal/QmitkFiberBundleDeveloperView.h src/internal/QmitkPartialVolumeAnalysisView.h src/QmitkPartialVolumeAnalysisWidget.h src/internal/QmitkIVIMView.h src/internal/QmitkTractbasedSpatialStatisticsView.h src/internal/QmitkTbssSkeletonizationView.h src/QmitkTbssRoiAnalysisWidget.h src/QmitkResidualAnalysisWidget.h src/QmitkResidualViewWidget.h src/internal/Connectomics/QmitkConnectomicsDataView.h src/internal/Connectomics/QmitkConnectomicsNetworkOperationsView.h src/internal/Connectomics/QmitkConnectomicsStatisticsView.h src/internal/Connectomics/QmitkNetworkHistogramCanvas.h src/internal/Connectomics/QmitkRandomParcellationView.h src/internal/QmitkDwiSoftwarePhantomView.h src/internal/QmitkOdfMaximaExtractionView.h src/internal/QmitkFiberfoxView.h src/internal/QmitkFiberProcessingView.h src/QmitkTensorModelParametersWidget.h src/QmitkZeppelinModelParametersWidget.h src/QmitkStickModelParametersWidget.h src/QmitkDotModelParametersWidget.h src/QmitkBallModelParametersWidget.h src/QmitkAstrosticksModelParametersWidget.h src/QmitkPrototypeSignalParametersWidget.h src/internal/QmitkFieldmapGeneratorView.h src/internal/QmitkDiffusionRegistrationView.h src/internal/QmitkDenoisingView.h + src/internal/QmitkMLBTView.h src/internal/Perspectives/QmitkFiberProcessingPerspective.h src/internal/Perspectives/QmitkDiffusionImagingAppPerspective.h src/internal/Perspectives/QmitkGibbsTractographyPerspective.h src/internal/Perspectives/QmitkStreamlineTractographyPerspective.h src/internal/Perspectives/QmitkProbabilisticTractographyPerspective.h src/internal/Perspectives/QmitkDIAppSyntheticDataGenerationPerspective.h src/internal/Perspectives/QmitkDIAppIVIMPerspective.h src/internal/Perspectives/QmitkDiffusionDefaultPerspective.h ) set(CACHED_RESOURCE_FILES # list of resource files which can be used by the plug-in # system without loading the plug-ins shared library, # for example the icon used in the menu and tabs for the # plug-in views in the workbench plugin.xml resources/preprocessing.png resources/dwiimport.png resources/quantification.png resources/reconodf.png resources/recontensor.png resources/vizControls.png resources/OdfDetails.png resources/GibbsTracking.png resources/FiberBundleOperations.png resources/PartialVolumeAnalysis_24.png resources/IVIM_48.png resources/stochFB.png resources/tbss.png resources/connectomics/QmitkConnectomicsDataViewIcon_48.png resources/connectomics/QmitkConnectomicsNetworkOperationsViewIcon_48.png resources/connectomics/QmitkConnectomicsStatisticsViewIcon_48.png resources/connectomics/QmitkRandomParcellationIcon.png resources/arrow.png resources/qball_peaks.png resources/phantom.png resources/tensor.png resources/qball.png resources/StreamlineTracking.png resources/dwi2.png resources/dwi.png resources/odf.png resources/refresh.xpm resources/diffusionregistration.png resources/denoisingicon.png resources/syntheticdata.png resources/ivim.png resources/tractography.png + resources/fiberTracking1.png ) set(QRC_FILES # uncomment the following line if you want to use Qt resources resources/QmitkDiffusionImaging.qrc #resources/QmitkTractbasedSpatialStatisticsView.qrc ) set(CPP_FILES ) foreach(file ${SRC_CPP_FILES}) set(CPP_FILES ${CPP_FILES} src/${file}) endforeach(file ${SRC_CPP_FILES}) foreach(file ${INTERNAL_CPP_FILES}) set(CPP_FILES ${CPP_FILES} src/internal/${file}) endforeach(file ${INTERNAL_CPP_FILES}) diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/plugin.xml b/Plugins/org.mitk.gui.qt.diffusionimaging/plugin.xml index 99206fa33f..162bc32673 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/plugin.xml +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/plugin.xml @@ -1,428 +1,436 @@ Q-Ball reconstruction view Diffusion DICOM data import Calculation of tensor and Q-ball derived measures. Diffusion weighted MRI data simulation tool. + + + + This perspective contains all views necessary to post process fibers. diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/resources/fiberTracking1.png b/Plugins/org.mitk.gui.qt.diffusionimaging/resources/fiberTracking1.png new file mode 100644 index 0000000000..be82a2eeb2 Binary files /dev/null and b/Plugins/org.mitk.gui.qt.diffusionimaging/resources/fiberTracking1.png differ diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp new file mode 100644 index 0000000000..992e0c577e --- /dev/null +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.cpp @@ -0,0 +1,329 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + + +// Blueberry +#include +#include + +// Qmitk +#include "QmitkMLBTView.h" +#include "QmitkStdMultiWidget.h" + +// Qt +#include +#include +#include + +#include "mitkNodePredicateDataType.h" +#include +#include +#include +#include +#include + +#define _USE_MATH_DEFINES +#include + +const std::string QmitkMLBTView::VIEW_ID = "org.mitk.views.mlbtview"; +using namespace berry; + +QmitkMLBTView::QmitkMLBTView() + : QmitkFunctionality() + , m_Controls( 0 ) + , m_MultiWidget( NULL ) +{ + m_TrackingTimer = new QTimer(this); +} + +// Destructor +QmitkMLBTView::~QmitkMLBTView() +{ + +} + +void QmitkMLBTView::CreateQtPartControl( QWidget *parent ) +{ + // build up qt view, unless already done + if ( !m_Controls ) + { + // create GUI widgets from the Qt Designer's .ui file + m_Controls = new Ui::QmitkMLBTViewControls; + m_Controls->setupUi( parent ); + + connect( m_Controls->m_StartTrainingButton, SIGNAL ( clicked() ), this, SLOT( StartTrainingThread() ) ); + connect( &m_TrainingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrainingThreadStop() ) ); + connect( m_Controls->m_StartTrackingButton, SIGNAL ( clicked() ), this, SLOT( StartTrackingThread() ) ); + connect( &m_TrackingWatcher, SIGNAL ( finished() ), this, SLOT( OnTrackingThreadStop() ) ); + connect( m_Controls->m_SaveForestButton, SIGNAL ( clicked() ), this, SLOT( SaveForest() ) ); + connect( m_Controls->m_LoadForestButton, SIGNAL ( clicked() ), this, SLOT( LoadForest() ) ); + connect( m_TrackingTimer, SIGNAL(timeout()), this, SLOT(BuildFibers()) ); + connect( m_Controls->m_TimerIntervalBox, SIGNAL(valueChanged(int)), this, SLOT( ChangeTimerInterval(int) )); + connect( m_Controls->m_DemoModeBox, SIGNAL(stateChanged(int)), this, SLOT( ToggleDemoMode(int) )); + connect( m_Controls->m_PauseTrackingButton, SIGNAL ( clicked() ), this, SLOT( PauseTracking() ) ); + connect( m_Controls->m_AbortTrackingButton, SIGNAL ( clicked() ), this, SLOT( AbortTracking() ) ); + + int numThread = itk::MultiThreader::GetGlobalDefaultNumberOfThreads(); + m_Controls->m_NumberOfThreadsBox->setMaximum(numThread); + m_Controls->m_NumberOfThreadsBox->setValue(numThread); + + m_Controls->m_TrackingMaskImageBox->SetDataStorage(this->GetDataStorage()); + m_Controls->m_TrackingSeedImageBox->SetDataStorage(this->GetDataStorage()); + m_Controls->m_TrackingStopImageBox->SetDataStorage(this->GetDataStorage()); + m_Controls->m_TrackingRawImageBox->SetDataStorage(this->GetDataStorage()); + + mitk::TNodePredicateDataType::Pointer isMitkImage = mitk::TNodePredicateDataType::New(); + mitk::NodePredicateDataType::Pointer isDwi = mitk::NodePredicateDataType::New("DiffusionImage"); + mitk::NodePredicateDataType::Pointer isDti = mitk::NodePredicateDataType::New("TensorImage"); + mitk::NodePredicateDataType::Pointer isQbi = mitk::NodePredicateDataType::New("QBallImage"); + mitk::NodePredicateOr::Pointer isDiffusionImage = mitk::NodePredicateOr::New(isDwi, isDti); + isDiffusionImage = mitk::NodePredicateOr::New(isDiffusionImage, isQbi); + mitk::NodePredicateNot::Pointer noDiffusionImage = mitk::NodePredicateNot::New(isDiffusionImage); + mitk::NodePredicateAnd::Pointer finalPredicate = mitk::NodePredicateAnd::New(isMitkImage, noDiffusionImage); + mitk::NodePredicateProperty::Pointer isBinaryPredicate = mitk::NodePredicateProperty::New("binary", mitk::BoolProperty::New(true)); + finalPredicate = mitk::NodePredicateAnd::New(finalPredicate, isBinaryPredicate); + m_Controls->m_TrackingMaskImageBox->SetPredicate(finalPredicate); + m_Controls->m_TrackingSeedImageBox->SetPredicate(finalPredicate); + m_Controls->m_TrackingStopImageBox->SetPredicate(finalPredicate); + m_Controls->m_TrackingRawImageBox->SetPredicate(isDwi); + } +} + +void QmitkMLBTView::AbortTracking() +{ + if (tracker.IsNotNull()) + { + tracker->m_AbortTracking = true; + } +} + +void QmitkMLBTView::PauseTracking() +{ + if (tracker.IsNotNull()) + { + tracker->m_PauseTracking = !tracker->m_PauseTracking; + } +} + +void QmitkMLBTView::ChangeTimerInterval(int value) +{ + m_TrackingTimer->setInterval(value); +} + +void QmitkMLBTView::ToggleDemoMode(int state) +{ + if (tracker.IsNotNull()) + { + tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); + tracker->m_Stop = false; + } +} + +void QmitkMLBTView::BuildFibers() +{ + if (m_Controls->m_DemoModeBox->isChecked() && tracker.IsNotNull() && tracker->m_BuildFibersFinished) + { + vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); + mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); + outFib->SetFiberColors(255,0,0); + m_TractogramNode->SetData(outFib); + mitk::RenderingManager::GetInstance()->RequestUpdateAll(); + tracker->m_BuildFibersFinished = false; + tracker->m_Stop = false; + tracker->m_BuildFibersReady = 0; + } +} + +void QmitkMLBTView::LoadForest() +{ + + QString filename = QFileDialog::getOpenFileName(0, tr("Load Forest"), QDir::currentPath(), tr("HDF5 random forest file (*.rf)") ); + if(filename.isEmpty() || filename.isNull()) + return; + + m_ForestHandler.LoadForest( filename.toStdString() ); +} + +void QmitkMLBTView::StartTrackingThread() +{ + m_TractogramNode = mitk::DataNode::New(); + m_TractogramNode->SetName("MLBT Result"); + m_TractogramNode->SetProperty("Fiber2DSliceThickness", mitk::FloatProperty::New(20)); + m_TractogramNode->SetProperty("Fiber2DfadeEFX", mitk::BoolProperty::New(false)); + m_TractogramNode->SetProperty("LineWidth", mitk::IntProperty::New(2)); + m_TractogramNode->SetProperty("color",mitk::ColorProperty::New(0, 1, 1)); + this->GetDataStorage()->Add(m_TractogramNode); + + QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTracking ); + m_TrackingWatcher.setFuture(future); + m_TrackingThreadIsRunning = true; + m_Controls->m_StartTrackingButton->setEnabled(false); + m_TrackingTimer->start(10); +} + +void QmitkMLBTView::OnTrackingThreadStop() +{ + m_TrackingThreadIsRunning = false; + m_Controls->m_StartTrackingButton->setEnabled(true); + + + vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); + mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); + outFib->SetFiberColors(255,0,0); +// mitk::DataNode::Pointer node = mitk::DataNode::New(); + m_TractogramNode->SetData(outFib); + + tracker = NULL; + m_TrackingTimer->stop(); +} + +void QmitkMLBTView::StartTracking() +{ + if ( m_Controls->m_TrackingRawImageBox->GetSelectedNode().IsNull() ) + return; + + mitk::Image::Pointer dwi = dynamic_cast(m_Controls->m_TrackingRawImageBox->GetSelectedNode()->GetData()); + tracker = TrackerType::New(); + tracker->SetNumberOfThreads(m_Controls->m_NumberOfThreadsBox->value()); + tracker->SetInput(0, mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi) ); + tracker->SetGradientDirections( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi) ); + tracker->SetB_Value( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi) ); + tracker->SetDemoMode(m_Controls->m_DemoModeBox->isChecked()); + if (m_Controls->m_TrackingUseMaskImageBox->isChecked() && m_Controls->m_TrackingMaskImageBox->GetSelectedNode().IsNotNull()) + { + mitk::Image::Pointer mask = dynamic_cast(m_Controls->m_TrackingMaskImageBox->GetSelectedNode()->GetData()); + ItkUcharImgType::Pointer itkMask = ItkUcharImgType::New(); + mitk::CastToItkImage(mask, itkMask); + tracker->SetMaskImage(itkMask); + } + if (m_Controls->m_TrackingUseSeedImageBox->isChecked() && m_Controls->m_TrackingSeedImageBox->GetSelectedNode().IsNotNull()) + { + mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingSeedImageBox->GetSelectedNode()->GetData()); + ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); + mitk::CastToItkImage(img, itkImg); + tracker->SetSeedImage(itkImg); + } + if (m_Controls->m_TrackingUseStopImageBox->isChecked() && m_Controls->m_TrackingStopImageBox->GetSelectedNode().IsNotNull()) + { + mitk::Image::Pointer img = dynamic_cast(m_Controls->m_TrackingStopImageBox->GetSelectedNode()->GetData()); + ItkUcharImgType::Pointer itkImg = ItkUcharImgType::New(); + mitk::CastToItkImage(img, itkImg); + tracker->SetStoppingRegions(itkImg); + } + tracker->SetSeedsPerVoxel(m_Controls->m_NumberOfSeedsBox->value()); + tracker->SetUseDirection(true); + tracker->SetStepSize(m_Controls->m_TrackingStepSizeBox->value()); + tracker->SetAngularThreshold(cos((float)m_Controls->m_AngularThresholdBox->value()*M_PI/180)); + tracker->SetMinTractLength(m_Controls->m_MinLengthBox->value()); + tracker->SetMaxTractLength(m_Controls->m_MaxLengthBox->value()); + + vigra::RandomForest forest = m_ForestHandler.GetForest(); + tracker->SetDecisionForest(&forest); + tracker->SetSamplingDistance(m_Controls->m_SamplingDistanceBox->value()); + tracker->SetNumberOfSamples(m_Controls->m_NumSamplesBox->value()); + tracker->Update(); +// vtkSmartPointer< vtkPolyData > poly = tracker->GetFiberPolyData(); +// mitk::FiberBundle::Pointer outFib = mitk::FiberBundle::New(poly); +// outFib->SetColorCoding(mitk::FiberBundle::COLORCODING_CUSTOM); +// mitk::DataNode::Pointer node = mitk::DataNode::New(); +// m_TractogramNode->SetData(outFib); +// node->SetData(outFib); +// node->SetName("MLBT Result"); +// this->GetDataStorage()->Add(node); +// mitk::RenderingManager::GetInstance()->RequestUpdateAll(); +} + +void QmitkMLBTView::SaveForest() +{ + QString filename = QFileDialog::getSaveFileName(0, tr("Save Forest"), QDir::currentPath()+"/forest.rf", tr("HDF5 random forest file (*.rf)") ); + + if(filename.isEmpty() || filename.isNull()) + return; + if(!filename.endsWith(".rf")) + filename += ".rf"; + + m_ForestHandler.SaveForest( filename.toStdString() ); +} + +void QmitkMLBTView::StartTrainingThread() +{ + QFuture future = QtConcurrent::run( this, &QmitkMLBTView::StartTraining ); + m_TrainingWatcher.setFuture(future); + m_Controls->m_StartTrainingButton->setEnabled(false); + m_Controls->m_SaveForestButton->setEnabled(false); + m_Controls->m_LoadForestButton->setEnabled(false); +} + +void QmitkMLBTView::OnTrainingThreadStop() +{ + m_Controls->m_StartTrainingButton->setEnabled(true); + m_Controls->m_SaveForestButton->setEnabled(true); + m_Controls->m_LoadForestButton->setEnabled(true); +} + +void QmitkMLBTView::StartTraining() +{ + m_ForestHandler.SetRawData(m_SelectedDiffImages); + m_ForestHandler.SetTractograms(m_SelectedFB); + m_ForestHandler.SetNumTrees(m_Controls->m_NumTreesBox->value()); + m_ForestHandler.SetMaxTreeDepth(m_Controls->m_MaxDepthBox->value()); + m_ForestHandler.SetGrayMatterSamplesPerVoxel(m_Controls->m_GmSamplingBox->value()); + m_ForestHandler.SetSampleFraction(m_Controls->m_SampleFractionBox->value()); + m_ForestHandler.SetStepSize(m_Controls->m_TrainingStepSizeBox->value()); + m_ForestHandler.SetUsePreviousDirection(m_Controls->m_TrackingUsePreviousDirectionBox->isChecked()); + m_ForestHandler.StartTraining(); +} + +void QmitkMLBTView::StdMultiWidgetAvailable (QmitkStdMultiWidget &stdMultiWidget) +{ + m_MultiWidget = &stdMultiWidget; +} + + +void QmitkMLBTView::StdMultiWidgetNotAvailable() +{ + m_MultiWidget = NULL; +} + +void QmitkMLBTView::OnSelectionChanged( std::vector nodes ) +{ + if ( !this->IsVisible() ) + { + // do nothing if nobody wants to see me :-( + return; + } + + m_SelectedFB.clear(); + m_SelectedDiffImages.clear(); + m_MaskImages.clear(); + m_WhiteMatterImages.clear(); + + for( std::vector::iterator it = nodes.begin(); it != nodes.end(); ++it ) + { + mitk::DataNode::Pointer node = *it; + if ( dynamic_cast(node->GetData()) ) + m_SelectedFB.push_back( dynamic_cast(node->GetData()) ); + else if (mitk::DiffusionPropertyHelper::IsDiffusionWeightedImage(node)) + m_SelectedDiffImages.push_back( dynamic_cast(node->GetData()) ); + } +} + +void QmitkMLBTView::Activated() +{ + +} + + diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h new file mode 100644 index 0000000000..98517f2782 --- /dev/null +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTView.h @@ -0,0 +1,107 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + + + +#include +#include + +#include +#include "ui_QmitkMLBTViewControls.h" + +#include "mitkDataStorage.h" +#include "mitkDataStorageSelection.h" +#include +#include +#include +#include +#include + +/*! +\brief +*/ + +// Forward Qt class declarations + + +class QmitkMLBTView : public QmitkFunctionality +{ + + + // this is needed for all Qt objects that should have a Qt meta-object + // (everything that derives from QObject and wants to have signal/slots) + Q_OBJECT + +public: + + static const std::string VIEW_ID; + + typedef itk::Image ItkUcharImgType; + typedef itk::MLBSTrackingFilter<100> TrackerType; + + QmitkMLBTView(); + virtual ~QmitkMLBTView(); + + virtual void CreateQtPartControl(QWidget *parent); + + virtual void StdMultiWidgetAvailable (QmitkStdMultiWidget &stdMultiWidget); + virtual void StdMultiWidgetNotAvailable(); + virtual void Activated(); + + protected slots: + + void StartTrackingThread(); + void OnTrackingThreadStop(); + void StartTrainingThread(); + void OnTrainingThreadStop(); + void SaveForest(); + void LoadForest(); + void BuildFibers(); + void ChangeTimerInterval(int value); + void ToggleDemoMode(int state); + void PauseTracking(); + void AbortTracking(); + +protected: + + void StartTracking(); + void StartTraining(); + + /// \brief called by QmitkFunctionality when DataManager's selection has changed + virtual void OnSelectionChanged( std::vector nodes ); + + Ui::QmitkMLBTViewControls* m_Controls; + QmitkStdMultiWidget* m_MultiWidget; + + mitk::TrackingForestHandler<> m_ForestHandler; + std::vector< mitk::Image::Pointer > m_SelectedDiffImages; + std::vector< mitk::FiberBundle::Pointer > m_SelectedFB; + std::vector< ItkUcharImgType::Pointer > m_MaskImages; + std::vector< ItkUcharImgType::Pointer > m_WhiteMatterImages; + + QFutureWatcher m_TrainingWatcher; + QFutureWatcher m_TrackingWatcher; + bool m_TrackingThreadIsRunning; + TrackerType::Pointer tracker; + QTimer* m_TrackingTimer; + mitk::DataNode::Pointer m_TractogramNode; + +private: + + }; + + + diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui new file mode 100644 index 0000000000..cf19c52391 --- /dev/null +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/QmitkMLBTViewControls.ui @@ -0,0 +1,620 @@ + + + QmitkMLBTViewControls + + + + 0 + 0 + 359 + 1127 + + + + Form + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + 1 + + + + + 0 + 0 + 359 + 1065 + + + + Training + + + + + + Load Forest + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Step size: + + + + + + + Num. Trees: + + + + + + + Max. Depth: + + + + + + + 3 + + + -1.000000000000000 + + + 999.000000000000000 + + + 0.100000000000000 + + + -1.000000000000000 + + + + + + + 1 + + + 999999999 + + + + + + + Sample Fraction: + + + + + + + 1 + + + 999999999 + + + 10 + + + + + + + 1 + + + 999999999 + + + 10 + + + + + + + 3 + + + 1.000000000000000 + + + 0.100000000000000 + + + 1.000000000000000 + + + + + + + GM Sampling Points: + + + + + + + Use Previous Direction: + + + + + + + + + + true + + + + + + + + + + Start Training + + + + + + + Save Forest + + + + + + + + + 0 + 0 + 359 + 1065 + + + + Tractography + + + + + + Use Previous Direction + + + true + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + ... + + + + :/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-pause.svg:/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-pause.svg + + + + + + + ... + + + + :/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-start.svg:/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-start.svg + + + + + + + ... + + + + :/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-stop.svg:/org.mbi.gui.qt.diffusionimaginginternal/resources/Media-playback-stop.svg + + + + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Demo Mode + + + + + + + 1 + + + 1000 + + + 10 + + + + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + 0.500000000000000 + + + + + + + Max. Length + + + + + + + 999999999 + + + 50 + + + + + + + Num. Samples: + + + + + + + Input DWI: + + + + + + + Step Size: + + + + + + + Sampling Distance: + + + + + + + 1 + + + 999 + + + + + + + Min. Length + + + + + + + 0.100000000000000 + + + 0.500000000000000 + + + + + + + 999999999.000000000000000 + + + 1.000000000000000 + + + 20.000000000000000 + + + + + + + Angular Threshold: + + + + + + + Num. Seeds: + + + + + + + 999999999.000000000000000 + + + 1.000000000000000 + + + 400.000000000000000 + + + + + + + 1 + + + 90.000000000000000 + + + 45.000000000000000 + + + + + + + + + + Num. Threads: + + + + + + + 1 + + + 30 + + + + + + + + + + Qt::Vertical + + + + 20 + 40 + + + + + + + + QFrame::NoFrame + + + QFrame::Raised + + + + 0 + + + 0 + + + 0 + + + 0 + + + + + Use Stop Image: + + + + + + + Use Seed Image: + + + false + + + + + + + Use Mask Image: + + + true + + + + + + + + + + + + + + + + + + + + + + + + QmitkDataStorageComboBox + QComboBox +
QmitkDataStorageComboBox.h
+
+
+ + + + +
diff --git a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/mitkPluginActivator.cpp b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/mitkPluginActivator.cpp index 502929a7bf..f8c0de6e51 100644 --- a/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/mitkPluginActivator.cpp +++ b/Plugins/org.mitk.gui.qt.diffusionimaging/src/internal/mitkPluginActivator.cpp @@ -1,106 +1,108 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkPluginActivator.h" #include #include "src/internal/Perspectives/QmitkDiffusionImagingAppPerspective.h" #include "src/internal/Perspectives/QmitkDIAppIVIMPerspective.h" #include "src/internal/Perspectives/QmitkDIAppSyntheticDataGenerationPerspective.h" #include "src/internal/Perspectives/QmitkGibbsTractographyPerspective.h" #include "src/internal/Perspectives/QmitkStreamlineTractographyPerspective.h" #include "src/internal/Perspectives/QmitkProbabilisticTractographyPerspective.h" #include "src/internal/Perspectives/QmitkFiberProcessingPerspective.h" #include "src/internal/Perspectives/QmitkDiffusionDefaultPerspective.h" #include "src/internal/QmitkQBallReconstructionView.h" #include "src/internal/QmitkPreprocessingView.h" #include "src/internal/QmitkDiffusionDicomImportView.h" #include "src/internal/QmitkDiffusionQuantificationView.h" #include "src/internal/QmitkTensorReconstructionView.h" #include "src/internal/QmitkControlVisualizationPropertiesView.h" #include "src/internal/QmitkODFDetailsView.h" #include "src/internal/QmitkGibbsTrackingView.h" #include "src/internal/QmitkStochasticFiberTrackingView.h" #include "src/internal/QmitkFiberQuantificationView.h" #include "src/internal/QmitkPartialVolumeAnalysisView.h" #include "src/internal/QmitkIVIMView.h" #include "src/internal/QmitkTractbasedSpatialStatisticsView.h" #include "src/internal/QmitkTbssSkeletonizationView.h" #include "src/internal/QmitkStreamlineTrackingView.h" #include "src/internal/Connectomics/QmitkConnectomicsDataView.h" #include "src/internal/Connectomics/QmitkConnectomicsNetworkOperationsView.h" #include "src/internal/Connectomics/QmitkConnectomicsStatisticsView.h" #include "src/internal/Connectomics/QmitkRandomParcellationView.h" #include "src/internal/QmitkOdfMaximaExtractionView.h" #include "src/internal/QmitkFiberfoxView.h" #include "src/internal/QmitkFiberProcessingView.h" #include "src/internal/QmitkFieldmapGeneratorView.h" #include "src/internal/QmitkDiffusionRegistrationView.h" #include "src/internal/QmitkDenoisingView.h" +#include "src/internal/QmitkMLBTView.h" namespace mitk { void PluginActivator::start(ctkPluginContext* context) { BERRY_REGISTER_EXTENSION_CLASS(QmitkDiffusionImagingAppPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkGibbsTractographyPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkStreamlineTractographyPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkProbabilisticTractographyPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDIAppSyntheticDataGenerationPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDIAppIVIMPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkFiberProcessingPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDiffusionDefaultPerspective, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkQBallReconstructionView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkPreprocessingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDiffusionDicomImport, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDiffusionQuantificationView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkTensorReconstructionView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkControlVisualizationPropertiesView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkODFDetailsView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkGibbsTrackingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkStochasticFiberTrackingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkFiberQuantificationView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkPartialVolumeAnalysisView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkIVIMView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkTractbasedSpatialStatisticsView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkTbssSkeletonizationView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkConnectomicsDataView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkConnectomicsNetworkOperationsView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkConnectomicsStatisticsView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkStreamlineTrackingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkOdfMaximaExtractionView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkFiberfoxView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkFiberProcessingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkFieldmapGeneratorView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDiffusionRegistrationView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkDenoisingView, context) BERRY_REGISTER_EXTENSION_CLASS(QmitkRandomParcellationView, context) + BERRY_REGISTER_EXTENSION_CLASS(QmitkMLBTView, context) } void PluginActivator::stop(ctkPluginContext* context) { Q_UNUSED(context) } } #if QT_VERSION < QT_VERSION_CHECK(5, 0, 0) Q_EXPORT_PLUGIN2(org_mitk_gui_qt_diffusionimaging, mitk::PluginActivator) #endif