diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp new file mode 100644 index 0000000000..94bfc6ffd5 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp @@ -0,0 +1,425 @@ +#include "itkFitFibersToImageFilter.h" + +#include + +namespace itk{ + +FitFibersToImageFilter::FitFibersToImageFilter() + : m_FitIndividualFibers(true) + , m_GradientTolerance(1e-5) + , m_Lambda(0.1) + , m_MaxIterations(20) + , m_FiberSampling(10) + , m_Coverage(0) + , m_Overshoot(0) + , m_FilterOutliers(true) + , m_MeanWeight(1.0) + , m_MedianWeight(1.0) + , m_MinWeight(1.0) + , m_MaxWeight(1.0) + , m_Verbose(true) + , m_DeepCopy(true) +{ + this->SetNumberOfRequiredOutputs(3); +} + +FitFibersToImageFilter::~FitFibersToImageFilter() +{ + +} + +void FitFibersToImageFilter::GenerateData() +{ + int sz_x = m_PeakImage->GetLargestPossibleRegion().GetSize(0); + int sz_y = m_PeakImage->GetLargestPossibleRegion().GetSize(1); + int sz_z = m_PeakImage->GetLargestPossibleRegion().GetSize(2); + int sz_peaks = m_PeakImage->GetLargestPossibleRegion().GetSize(3)/3 + 1; // +1 for zero - peak + int num_voxels = sz_x*sz_y*sz_z; + + float minSpacing = 1; + if(m_PeakImage->GetSpacing()[0]GetSpacing()[1] && m_PeakImage->GetSpacing()[0]GetSpacing()[2]) + minSpacing = m_PeakImage->GetSpacing()[0]; + else if (m_PeakImage->GetSpacing()[1] < m_PeakImage->GetSpacing()[2]) + minSpacing = m_PeakImage->GetSpacing()[1]; + else + minSpacing = m_PeakImage->GetSpacing()[2]; + + unsigned int num_unknowns = m_Tractograms.size(); + if (m_FitIndividualFibers) + { + num_unknowns = 0; + for (unsigned int bundle=0; bundleGetNumFibers(); + } + + for (unsigned int bundle=0; bundleGetDeepCopy(); + m_Tractograms.at(bundle)->ResampleLinear(minSpacing/m_FiberSampling); + } + + unsigned int number_of_residuals = num_voxels * sz_peaks; + + MITK_INFO << "Num. unknowns: " << num_unknowns; + MITK_INFO << "Num. residuals: " << number_of_residuals; + MITK_INFO << "Creating system ..."; + + vnl_sparse_matrix A; + vnl_vector b; + A.set_size(number_of_residuals, num_unknowns); + b.set_size(number_of_residuals); b.fill(0.0); + + double TD = 0; + double FD = 0; + unsigned int dir_count = 0; + unsigned int fiber_count = 0; + + for (unsigned int bundle=0; bundle polydata = m_Tractograms.at(bundle)->GetFiberPolyData(); + + for (int i=0; iGetNumFibers(); ++i) + { + vtkCell* cell = polydata->GetCell(i); + int numPoints = cell->GetNumberOfPoints(); + vtkPoints* points = cell->GetPoints(); + + if (numPoints<2) + MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; + + for (int j=0; jGetPoint(j); + PointType4 p; + p[0]=p1[0]; + p[1]=p1[1]; + p[2]=p1[2]; + p[3]=0; + + itk::Index<4> idx4; + m_PeakImage->TransformPhysicalPointToIndex(p, idx4); + if (!m_PeakImage->GetLargestPossibleRegion().IsInside(idx4)) + continue; + + double* p2 = points->GetPoint(j+1); + vnl_vector_fixed fiber_dir; + fiber_dir[0] = p[0]-p2[0]; + fiber_dir[1] = p[1]-p2[1]; + fiber_dir[2] = p[2]-p2[2]; + fiber_dir.normalize(); + + double w = 1; + int peak_id = sz_peaks-1; + vnl_vector_fixed odf_peak = GetClosestPeak(idx4, m_PeakImage, fiber_dir, peak_id, w); + float peak_mag = odf_peak.magnitude(); + + int x = idx4[0]; + int y = idx4[1]; + int z = idx4[2]; + + unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z + sz_x*sz_y*sz_z*peak_id; + + if (b[linear_index] == 0 && peak_id<3) + { + dir_count++; + FD += peak_mag; + } + TD += w; + + if (m_FitIndividualFibers) + { + b[linear_index] = peak_mag; + A.put(linear_index, fiber_count, A.get(linear_index, fiber_count) + w); + } + else + { + b[linear_index] = peak_mag; + A.put(linear_index, bundle, A.get(linear_index, bundle) + w); + } + } + + ++fiber_count; + } + } + + TD /= (dir_count*fiber_count); + FD /= dir_count; + A /= TD; + b *= 100.0/FD; // times 100 because we want to avoid too small values for computational reasons + + double init_lambda = 1e5; // initialization for lambda estimation + + itk::TimeProbe clock; + clock.Start(); + + VnlCostFunction cost(num_unknowns); + cost.SetProblem(A, b, init_lambda); + m_Weights.set_size(num_unknowns); m_Weights.fill( TD/100.0 * FD/2.0 ); + vnl_lbfgsb minimizer(cost); + vnl_vector l; l.set_size(num_unknowns); l.fill(0); + vnl_vector bound_selection; bound_selection.set_size(num_unknowns); bound_selection.fill(1); + minimizer.set_bound_selection(bound_selection); + minimizer.set_lower_bound(l); + minimizer.set_projected_gradient_tolerance(m_GradientTolerance); + + MITK_INFO << "Estimating regularization"; + minimizer.set_trace(false); + minimizer.set_max_function_evals(1); + minimizer.minimize(m_Weights); + vnl_vector dx; dx.set_size(num_unknowns); dx.fill(0.0); + cost.grad_regu_localMSE(m_Weights, dx); + double r = dx.magnitude()/m_Weights.magnitude(); + cost.m_Lambda *= m_Lambda*55.0/r; + if (cost.m_Lambda>10e7) + { + MITK_INFO << "Regularization estimation failed. Using default value."; + cost.m_Lambda = fiber_count; + } + MITK_INFO << "Using regularization factor of " << cost.m_Lambda << " (λ: " << m_Lambda << ")"; + + MITK_INFO << "Fitting fibers"; + minimizer.set_trace(m_Verbose); + + minimizer.set_max_function_evals(m_MaxIterations); + minimizer.minimize(m_Weights); + + std::vector< double > weights; + if (m_FilterOutliers) + { + for (auto w : m_Weights) + weights.push_back(w); + std::sort(weights.begin(), weights.end()); + MITK_INFO << "Setting upper weight bound to " << weights.at(num_unknowns*0.95); + vnl_vector u; u.set_size(num_unknowns); u.fill(weights.at(num_unknowns*0.95)); + minimizer.set_upper_bound(u); + bound_selection.fill(2); + minimizer.set_bound_selection(bound_selection); + minimizer.minimize(m_Weights); + weights.clear(); + } + + for (auto w : m_Weights) + weights.push_back(w); + std::sort(weights.begin(), weights.end()); + + m_MeanWeight = m_Weights.mean(); + m_MedianWeight = weights.at(num_unknowns*0.5); + m_MinWeight = weights.at(0); + m_MaxWeight = weights.at(num_unknowns-1); + + MITK_INFO << "*************************"; + MITK_INFO << "Weight statistics"; + MITK_INFO << "Mean: " << m_MeanWeight; + MITK_INFO << "Median: " << m_MedianWeight; + MITK_INFO << "75% quantile: " << weights.at(num_unknowns*0.75); + MITK_INFO << "95% quantile: " << weights.at(num_unknowns*0.95); + MITK_INFO << "99% quantile: " << weights.at(num_unknowns*0.99); + MITK_INFO << "Min: " << m_MinWeight; + MITK_INFO << "Max: " << m_MaxWeight; + MITK_INFO << "*************************"; + MITK_INFO << "NumEvals: " << minimizer.get_num_evaluations(); + MITK_INFO << "NumIterations: " << minimizer.get_num_iterations(); + MITK_INFO << "Residual cost: " << minimizer.get_end_error(); + MITK_INFO << "Final RMS: " << cost.S->get_rms_error(m_Weights); + + clock.Stop(); + int h = clock.GetTotal()/3600; + int m = ((int)clock.GetTotal()%3600)/60; + int s = (int)clock.GetTotal()%60; + MITK_INFO << "Optimization took " << h << "h, " << m << "m and " << s << "s"; + + // transform back for peak image creation + A *= FD/100.0; + b *= FD/100.0; + + MITK_INFO << "Weighting fibers"; + if (m_FitIndividualFibers) + { + unsigned int fiber_count = 0; + for (unsigned int bundle=0; bundleGetNumFibers(); i++) + { + m_Tractograms.at(bundle)->SetFiberWeight(i, m_Weights[fiber_count]); + ++fiber_count; + } + m_Tractograms.at(bundle)->Compress(0.1); + m_Tractograms.at(bundle)->ColorFibersByFiberWeights(false, true); + } + } + else + { + for (unsigned int i=0; iSetFiberWeights(m_Weights[i]); + m_Tractograms.at(i)->Compress(0.1); + m_Tractograms.at(i)->ColorFibersByFiberWeights(false, true); + } + } + + MITK_INFO << "Generating output images ..."; + + itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); + duplicator->SetInputImage(m_PeakImage); + duplicator->Update(); + m_UnderexplainedImage = duplicator->GetOutput(); + m_UnderexplainedImage->FillBuffer(0.0); + + duplicator->SetInputImage(m_UnderexplainedImage); + duplicator->Update(); + m_OverexplainedImage = duplicator->GetOutput(); + m_OverexplainedImage->FillBuffer(0.0); + + duplicator->SetInputImage(m_OverexplainedImage); + duplicator->Update(); + m_ResidualImage = duplicator->GetOutput(); + m_ResidualImage->FillBuffer(0.0); + + duplicator->SetInputImage(m_ResidualImage); + duplicator->Update(); + m_FittedImage = duplicator->GetOutput(); + m_FittedImage->FillBuffer(0.0); + + vnl_vector fitted_b; fitted_b.set_size(b.size()); + cost.S->multiply(m_Weights, fitted_b); + + for (unsigned int r=0; r idx; + unsigned int linear_index = r; + idx[0] = linear_index % sz_x; linear_index /= sz_x; + idx[1] = linear_index % sz_y; linear_index /= sz_y; + idx[2] = linear_index % sz_z; linear_index /= sz_z; + int peak_id = linear_index % sz_peaks; + + if (peak_id peak_dir; + + idx[3] = peak_id*3; + peak_dir[0] = m_PeakImage->GetPixel(idx); + idx[3] += 1; + peak_dir[1] = m_PeakImage->GetPixel(idx); + idx[3] += 1; + peak_dir[2] = m_PeakImage->GetPixel(idx); + + peak_dir.normalize(); + peak_dir *= fitted_b[r]; + + idx[3] = peak_id*3; + m_FittedImage->SetPixel(idx, peak_dir[0]); + + idx[3] += 1; + m_FittedImage->SetPixel(idx, peak_dir[1]); + + idx[3] += 1; + m_FittedImage->SetPixel(idx, peak_dir[2]); + } + } + + FD = 0; + m_Coverage = 0; + m_Overshoot = 0; + + itk::Index<4> idx; + for (idx[0]=0; idx[0] peak_dir; + vnl_vector_fixed fitted_dir; + vnl_vector_fixed overshoot_dir; + for (idx[3]=0; idx[3]<(itk::IndexValueType)m_PeakImage->GetLargestPossibleRegion().GetSize(3); ++idx[3]) + { + peak_dir[idx[3]%3] = m_PeakImage->GetPixel(idx); + fitted_dir[idx[3]%3] = m_FittedImage->GetPixel(idx); + m_ResidualImage->SetPixel(idx, m_PeakImage->GetPixel(idx) - m_FittedImage->GetPixel(idx)); + + if (idx[3]%3==2) + { + FD += peak_dir.magnitude(); + + itk::Index<4> tidx= idx; + if (peak_dir.magnitude()>fitted_dir.magnitude()) + { + m_Coverage += fitted_dir.magnitude(); + m_UnderexplainedImage->SetPixel(tidx, peak_dir[2]-fitted_dir[2]); tidx[3]--; + m_UnderexplainedImage->SetPixel(tidx, peak_dir[1]-fitted_dir[1]); tidx[3]--; + m_UnderexplainedImage->SetPixel(tidx, peak_dir[0]-fitted_dir[0]); + } + else + { + overshoot_dir[0] = fitted_dir[0]-peak_dir[0]; + overshoot_dir[1] = fitted_dir[1]-peak_dir[1]; + overshoot_dir[2] = fitted_dir[2]-peak_dir[2]; + m_Coverage += peak_dir.magnitude(); + m_Overshoot += overshoot_dir.magnitude(); + m_OverexplainedImage->SetPixel(tidx, overshoot_dir[2]); tidx[3]--; + m_OverexplainedImage->SetPixel(tidx, overshoot_dir[1]); tidx[3]--; + m_OverexplainedImage->SetPixel(tidx, overshoot_dir[0]); + } + } + } + } + + m_Coverage = m_Coverage/FD; + m_Overshoot = m_Overshoot/FD; + + MITK_INFO << std::fixed << "Coverage: " << setprecision(1) << 100.0*m_Coverage << "%"; + MITK_INFO << std::fixed << "Overshoot: " << setprecision(1) << 100.0*m_Overshoot << "%"; +} + +vnl_vector_fixed FitFibersToImageFilter::GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed fiber_dir, int& id, double& w ) +{ + int m_NumDirs = peak_image->GetLargestPossibleRegion().GetSize()[3]/3; + vnl_vector_fixed out_dir; out_dir.fill(0); + float angle = 0.8; + + for (int i=0; i dir; + idx[3] = i*3; + dir[0] = peak_image->GetPixel(idx); + idx[3] += 1; + dir[1] = peak_image->GetPixel(idx); + idx[3] += 1; + dir[2] = peak_image->GetPixel(idx); + + float mag = dir.magnitude(); + if (magangle) + { + angle = fabs(a); + w = angle; + if (a<0) + out_dir = -dir; + else + out_dir = dir; + out_dir *= mag; + id = i; + } + } + + return out_dir; +} + +std::vector FitFibersToImageFilter::GetTractograms() const +{ + return m_Tractograms; +} + +void FitFibersToImageFilter::SetTractograms(const std::vector &tractograms) +{ + m_Tractograms = tractograms; +} + +} + + + diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h new file mode 100644 index 0000000000..d784f3b2b7 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h @@ -0,0 +1,252 @@ +#ifndef __itkFitFibersToImageFilter_h__ +#define __itkFitFibersToImageFilter_h__ + +// MITK +#include +#include +#include +#include +#include +#include +#include +#include +#include + +namespace itk{ + +/** +* \brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). */ + +class FitFibersToImageFilter : public ImageSource< mitk::PeakImage::ItkPeakImageType > +{ + +public: + + typedef FitFibersToImageFilter Self; + typedef ProcessObject Superclass; + typedef SmartPointer< Self > Pointer; + typedef SmartPointer< const Self > ConstPointer; + + typedef itk::Point PointType4; + typedef mitk::PeakImage::ItkPeakImageType PeakImgType; + + itkFactorylessNewMacro(Self) + itkCloneMacro(Self) + itkTypeMacro( FitFibersToImageFilter, ImageSource ) + + itkSetMacro( PeakImage, PeakImgType::Pointer) + itkGetMacro( PeakImage, PeakImgType::Pointer) + itkSetMacro( FitIndividualFibers, bool) + itkGetMacro( FitIndividualFibers, bool) + itkSetMacro( GradientTolerance, double) + itkGetMacro( GradientTolerance, double) + itkSetMacro( Lambda, double) + itkGetMacro( Lambda, double) + itkSetMacro( MaxIterations, int) + itkGetMacro( MaxIterations, int) + itkSetMacro( FiberSampling, float) + itkGetMacro( FiberSampling, float) + itkSetMacro( FilterOutliers, bool) + itkGetMacro( FilterOutliers, bool) + itkSetMacro( Verbose, bool) + itkGetMacro( Verbose, bool) + itkSetMacro( DeepCopy, bool) + itkGetMacro( DeepCopy, bool) + + itkGetMacro( Weights, vnl_vector) + itkGetMacro( FittedImage, PeakImgType::Pointer) + itkGetMacro( ResidualImage, PeakImgType::Pointer) + itkGetMacro( OverexplainedImage, PeakImgType::Pointer) + itkGetMacro( UnderexplainedImage, PeakImgType::Pointer) + itkGetMacro( Coverage, double) + itkGetMacro( Overshoot, double) + itkGetMacro( MeanWeight, double) + itkGetMacro( MedianWeight, double) + itkGetMacro( MinWeight, double) + itkGetMacro( MaxWeight, double) + + void SetTractograms(const std::vector &tractograms); + + void GenerateData() override; + + std::vector GetTractograms() const; + +protected: + + FitFibersToImageFilter(); + virtual ~FitFibersToImageFilter(); + + vnl_vector_fixed GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer m_PeakImage , vnl_vector_fixed fiber_dir, int& id, double& w ); + + std::vector< mitk::FiberBundle::Pointer > m_Tractograms; + PeakImgType::Pointer m_PeakImage; + bool m_FitIndividualFibers; + double m_GradientTolerance; + double m_Lambda; + int m_MaxIterations; + float m_FiberSampling; + double m_Coverage; + double m_Overshoot; + bool m_FilterOutliers; + double m_MeanWeight; + double m_MedianWeight; + double m_MinWeight; + double m_MaxWeight; + bool m_Verbose; + bool m_DeepCopy; + + // output + vnl_vector m_Weights; + PeakImgType::Pointer m_UnderexplainedImage; + PeakImgType::Pointer m_OverexplainedImage; + PeakImgType::Pointer m_ResidualImage; + PeakImgType::Pointer m_FittedImage; +}; + +} + + +class VnlCostFunction : public vnl_cost_function +{ +public: + + vnl_sparse_matrix_linear_system< double >* S; + vnl_sparse_matrix< double > m_A; + vnl_sparse_matrix< double > m_A_Ones; // matrix indicating active weights with 1 + vnl_vector< double > m_b; + double m_Lambda; // regularization factor + + vnl_vector row_sums; // number of active weights per row + vnl_vector local_weight_means; // mean weight of each row + + void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector& b, double lambda) + { + S = new vnl_sparse_matrix_linear_system(A, b); + m_A = A; + m_b = b; + m_Lambda = lambda; + + m_A_Ones.set_size(m_A.rows(), m_A.cols()); + m_A.reset(); + while (m_A.next()) + m_A_Ones.put(m_A.getrow(), m_A.getcolumn(), 1); + + unsigned int N = m_b.size(); + vnl_vector ones; ones.set_size(dim); ones.fill(1.0); + row_sums.set_size(N); + m_A_Ones.mult(ones, row_sums); + local_weight_means.set_size(N); + } + + VnlCostFunction(const int NumVars) : vnl_cost_function(NumVars) + { + } + + void regu_MSE(vnl_vector const &x, double& cost) + { + double mean = x.mean(); + vnl_vector tx = x-mean; + cost += m_Lambda*1e8*tx.squared_magnitude()/x.size(); + } + + void regu_MSM(vnl_vector const &x, double& cost) + { + cost += m_Lambda*1e8*x.squared_magnitude()/x.size(); + } + + void regu_localMSE(vnl_vector const &x, double& cost) + { + m_A_Ones.mult(x, local_weight_means); + local_weight_means = element_quotient(local_weight_means, row_sums); + + m_A_Ones.reset(); + double regu = 0; + while (m_A_Ones.next()) + { + double d = 0; + if (x[m_A_Ones.getcolumn()]>local_weight_means[m_A_Ones.getrow()]) + d = std::exp(x[m_A_Ones.getcolumn()]) - std::exp(local_weight_means[m_A_Ones.getrow()]); + else + d = x[m_A_Ones.getcolumn()] - local_weight_means[m_A_Ones.getrow()]; + regu += d*d; + } + cost += m_Lambda*regu/dim; + } + + void grad_regu_MSE(vnl_vector const &x, vnl_vector &dx) + { + double mean = x.mean(); + vnl_vector tx = x-mean; + + vnl_vector tx2(dim, 0.0); + vnl_vector h(dim, 1.0); + for (int c=0; c const &x, vnl_vector &dx) + { + dx += m_Lambda*1e8*2.0*x/dim; + } + + void grad_regu_localMSE(vnl_vector const &x, vnl_vector &dx) + { + m_A_Ones.mult(x, local_weight_means); + local_weight_means = element_quotient(local_weight_means, row_sums); + + vnl_vector exp_x = x.apply(std::exp); + vnl_vector exp_means = local_weight_means.apply(std::exp); + + vnl_vector tdx(dim, 0); + m_A_Ones.reset(); + while (m_A_Ones.next()) + { + int c = m_A_Ones.getcolumn(); + int r = m_A_Ones.getrow(); + + if (x[c]>local_weight_means[r]) + tdx[c] += exp_x[c] * ( exp_x[c] - exp_means[r] ); + else + tdx[c] += x[c] - local_weight_means[r]; + } + dx += tdx*2.0*m_Lambda/dim; + } + + + double f(vnl_vector const &x) + { + double cost = S->get_rms_error(x); + cost *= cost; + + regu_localMSE(x, cost); + + return cost; + } + + void gradf(vnl_vector const &x, vnl_vector &dx) + { + dx.fill(0.0); + unsigned int N = m_b.size(); + + vnl_vector d; d.set_size(N); + S->multiply(x,d); + d -= m_b; + + S->transpose_multiply(d, dx); + dx *= 2.0/N; + + grad_regu_localMSE(x,dx); + } +}; + +#ifndef ITK_MANUAL_INSTANTIATION +#include "itkFitFibersToImageFilter.cpp" +#endif + +#endif // __itkFitFibersToImageFilter_h__ diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/CMakeLists.txt b/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/CMakeLists.txt index e65098bb44..b636b424e6 100755 --- a/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/CMakeLists.txt +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/CMakeLists.txt @@ -1,43 +1,44 @@ option(BUILD_DiffusionFiberProcessingCmdApps "Build commandline tools for diffusion fiber processing" OFF) if(BUILD_DiffusionFiberProcessingCmdApps OR MITK_BUILD_ALL_APPS) # needed include directories include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ) # list of diffusion cmdapps # if an app requires additional dependencies # they are added after a "^^" and separated by "_" set( diffusionFiberProcessingcmdapps TractDensity^^MitkFiberTracking Sift2WeightCopy^^MitkFiberTracking FiberExtraction^^MitkFiberTracking FiberProcessing^^MitkFiberTracking + FitFibersToImage^^MitkFiberTracking FiberDirectionExtraction^^MitkFiberTracking FiberJoin^^MitkFiberTracking ) foreach(diffusionFiberProcessingcmdapp ${diffusionFiberProcessingcmdapps}) # extract cmd app name and dependencies string(REPLACE "^^" "\\;" cmdapp_info ${diffusionFiberProcessingcmdapp}) set(cmdapp_info_list ${cmdapp_info}) list(GET cmdapp_info_list 0 appname) list(GET cmdapp_info_list 1 raw_dependencies) string(REPLACE "_" "\\;" dependencies "${raw_dependencies}") set(dependencies_list ${dependencies}) mitkFunctionCreateCommandLineApp( NAME ${appname} DEPENDS MitkCore MitkDiffusionCore ${dependencies_list} PACKAGE_DEPENDS ITK ) endforeach() if(EXECUTABLE_IS_ENABLED) MITK_INSTALL_TARGETS(EXECUTABLES ${EXECUTABLE_TARGET}) endif() endif() diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/FitFibersToImage.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/FitFibersToImage.cpp new file mode 100755 index 0000000000..a8a0539732 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/FiberProcessing/FitFibersToImage.cpp @@ -0,0 +1,203 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +typedef itksys::SystemTools ist; +typedef itk::Point PointType4; +typedef itk::Image< float, 4 > PeakImgType; + +std::vector< string > get_file_list(const std::string& path) +{ + std::vector< string > file_list; + itk::Directory::Pointer dir = itk::Directory::New(); + + if (dir->Load(path.c_str())) + { + int n = dir->GetNumberOfFiles(); + for (int r = 0; r < n; r++) + { + const char *filename = dir->GetFile(r); + std::string ext = ist::GetFilenameExtension(filename); + if (ext==".fib" || ext==".trk") + file_list.push_back(path + '/' + filename); + } + } + return file_list; +} + +/*! +\brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). +*/ +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + + parser.setTitle("Fit Fibers To Image"); + parser.setCategory("Fiber Tracking and Processing Methods"); + parser.setDescription("Assigns a weight to each fiber in order to optimally explain the input peak image"); + parser.setContributor("MIC"); + + parser.setArgumentPrefix("--", "-"); + parser.addArgument("", "i1", mitkCommandLineParser::StringList, "Input tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); + parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); + parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); + + parser.addArgument("max_iter", "", mitkCommandLineParser::Int, "Max. iterations:", "maximum number of optimizer iterations", 20); + parser.addArgument("bundle_based", "", mitkCommandLineParser::Bool, "Bundle based fit:", "fit one weight per input tractogram/bundle, not for each fiber", false); + parser.addArgument("min_g", "", mitkCommandLineParser::Float, "Min. g:", "lower termination threshold for gradient magnitude", 1e-5); + parser.addArgument("lambda", "", mitkCommandLineParser::Float, "Lambda:", "modifier for regularization", 0.1); + parser.addArgument("save_res", "", mitkCommandLineParser::Bool, "Residuals:", "save residual images", false); + parser.addArgument("filter_outliers", "", mitkCommandLineParser::Bool, "Filter outliers:", "perform second optimization run with an upper weight bound based on the first weight estimation (95% quantile)", true); + + map parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size()==0) + return EXIT_FAILURE; + + mitkCommandLineParser::StringContainerType fib_files = us::any_cast(parsedArgs["i1"]); + string peak_file_name = us::any_cast(parsedArgs["i2"]); + string outRoot = us::any_cast(parsedArgs["o"]); + + bool single_fib = true; + if (parsedArgs.count("bundle_based")) + single_fib = !us::any_cast(parsedArgs["bundle_based"]); + + bool save_residuals = false; + if (parsedArgs.count("save_res")) + save_residuals = us::any_cast(parsedArgs["save_res"]); + + int max_iter = 20; + if (parsedArgs.count("max_iter")) + max_iter = us::any_cast(parsedArgs["max_iter"]); + + float g_tol = 1e-5; + if (parsedArgs.count("min_g")) + g_tol = us::any_cast(parsedArgs["min_g"]); + + float lambda = 0.1; + if (parsedArgs.count("lambda")) + lambda = us::any_cast(parsedArgs["lambda"]); + + bool filter_outliers = true; + if (parsedArgs.count("filter_outliers")) + filter_outliers = us::any_cast(parsedArgs["filter_outliers"]); + + try + { + std::vector< mitk::FiberBundle::Pointer > input_tracts; + + mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); + mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(peak_file_name, &functor)[0].GetPointer()); + + typedef mitk::ImageToItk< PeakImgType > CasterType; + CasterType::Pointer caster = CasterType::New(); + caster->SetInput(inputImage); + caster->Update(); + PeakImgType::Pointer peak_image = caster->GetOutput(); + + std::vector< std::string > fib_names; + for (auto item : fib_files) + { + if ( ist::FileIsDirectory(item) ) + { + for ( auto fibFile : get_file_list(item) ) + { + mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(fibFile)[0].GetPointer()); + if (inputTractogram.IsNull()) + continue; + input_tracts.push_back(inputTractogram); + fib_names.push_back(fibFile); + } + } + else + { + mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(item)[0].GetPointer()); + if (inputTractogram.IsNull()) + continue; + input_tracts.push_back(inputTractogram); + fib_names.push_back(item); + } + } + + itk::FitFibersToImageFilter::Pointer fitter = itk::FitFibersToImageFilter::New(); + fitter->SetPeakImage(peak_image); + fitter->SetTractograms(input_tracts); + fitter->SetFitIndividualFibers(single_fib); + fitter->SetMaxIterations(max_iter); + fitter->SetGradientTolerance(g_tol); + fitter->SetLambda(lambda); + fitter->SetFilterOutliers(filter_outliers); + fitter->Update(); + + if (save_residuals) + { + itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); + writer->SetInput(fitter->GetFittedImage()); + writer->SetFileName(outRoot + "fitted_image.nrrd"); + writer->Update(); + + writer->SetInput(fitter->GetResidualImage()); + writer->SetFileName(outRoot + "residual_image.nrrd"); + writer->Update(); + + writer->SetInput(fitter->GetOverexplainedImage()); + writer->SetFileName(outRoot + "overexplained_image.nrrd"); + writer->Update(); + + writer->SetInput(fitter->GetUnderexplainedImage()); + writer->SetFileName(outRoot + "underexplained_image.nrrd"); + writer->Update(); + } + + std::vector< mitk::FiberBundle::Pointer > output_tracts = fitter->GetTractograms(); + for (unsigned int bundle=0; bundle -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include - -using namespace std; -typedef itksys::SystemTools ist; -typedef itk::Point PointType4; -typedef itk::Image< float, 4 > PeakImgType; - -vnl_vector_fixed GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed fiber_dir, int& id, double& w ) -{ - int m_NumDirs = peak_image->GetLargestPossibleRegion().GetSize()[3]/3; - vnl_vector_fixed out_dir; out_dir.fill(0); - float angle = 0.8; - - for (int i=0; i dir; - idx[3] = i*3; - dir[0] = peak_image->GetPixel(idx); - idx[3] += 1; - dir[1] = peak_image->GetPixel(idx); - idx[3] += 1; - dir[2] = peak_image->GetPixel(idx); - - float mag = dir.magnitude(); - if (magangle) - { - angle = fabs(a); - w = angle; - if (a<0) - out_dir = -dir; - else - out_dir = dir; - out_dir *= mag; - id = i; - } - } - - return out_dir; -} - -class VnlCostFunction : public vnl_cost_function -{ -public: - - vnl_sparse_matrix_linear_system< double >* S; - vnl_sparse_matrix< double > m_A; - vnl_sparse_matrix< double > m_A_Ones; // matrix indicating active weights with 1 - vnl_vector< double > m_b; - double m_Lambda; // regularization factor - - vnl_vector row_sums; // number of active weights per row - vnl_vector local_weight_means; // mean weight of each row - - void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector& b, double lambda) - { - S = new vnl_sparse_matrix_linear_system(A, b); - m_A = A; - m_b = b; - m_Lambda = lambda; - - m_A_Ones.set_size(m_A.rows(), m_A.cols()); - m_A.reset(); - while (m_A.next()) - m_A_Ones.put(m_A.getrow(), m_A.getcolumn(), 1); - - unsigned int N = m_b.size(); - vnl_vector ones; ones.set_size(dim); ones.fill(1.0); - row_sums.set_size(N); - m_A_Ones.mult(ones, row_sums); - local_weight_means.set_size(N); - } - - VnlCostFunction(const int NumVars) : vnl_cost_function(NumVars) - { - } - - void regu_MSE(vnl_vector const &x, double& cost) - { - double mean = x.mean(); - vnl_vector tx = x-mean; - cost += m_Lambda*1e8*tx.squared_magnitude()/x.size(); - } - - void regu_MSM(vnl_vector const &x, double& cost) - { - cost += m_Lambda*1e8*x.squared_magnitude()/x.size(); - } - - void regu_localMSE(vnl_vector const &x, double& cost) - { - m_A_Ones.mult(x, local_weight_means); - local_weight_means = element_quotient(local_weight_means, row_sums); - - m_A_Ones.reset(); - unsigned int num_elements = 0; - double regu = 0; - while (m_A_Ones.next()) - { - double d = 0; - if (x[m_A_Ones.getcolumn()]>local_weight_means[m_A_Ones.getrow()]) - d = std::exp(x[m_A_Ones.getcolumn()]) - std::exp(local_weight_means[m_A_Ones.getrow()]); - else - d = x[m_A_Ones.getcolumn()] - local_weight_means[m_A_Ones.getrow()]; - regu += d*d; - ++num_elements; - } - cost += m_Lambda*1e3*regu/num_elements; - } - - void grad_regu_MSE(vnl_vector const &x, vnl_vector &dx) - { - double mean = x.mean(); - vnl_vector tx = x-mean; - - vnl_vector tx2(dim, 0.0); - vnl_vector h(dim, 1.0); - for (int c=0; c const &x, vnl_vector &dx) - { - dx += m_Lambda*1e8*2.0*x/dim; - } - - void grad_regu_localMSE(vnl_vector const &x, vnl_vector &dx) - { - m_A_Ones.mult(x, local_weight_means); - local_weight_means = element_quotient(local_weight_means, row_sums); - - vnl_vector exp_x = x.apply(std::exp); - vnl_vector exp_means = local_weight_means.apply(std::exp); - - vnl_vector tdx(dim, 0); - m_A_Ones.reset(); - while (m_A_Ones.next()) - { - int c = m_A_Ones.getcolumn(); - int r = m_A_Ones.getrow(); - if (row_sums[r]==0) - continue; - - if (x[c]>local_weight_means[r]) - tdx[c] += (exp_x[c] * ( exp_x[c] - exp_means[r] ))/row_sums[r]; - else - tdx[c] += (x[c] - local_weight_means[r])/row_sums[r]; - } - dx += tdx*1e3*2.0*m_Lambda; - - // vnl_vector dr; dr.set_size(dim); dr.fill(0); - // for (unsigned int r=0; r weights; weights.set_size(n); - // vnl_matrix temp(n,n,1); temp.fill_diagonal(n-1); - - // int i=0; - // for (auto w : m_A_Ones.get_row(r)) - // { - // weights[i]=w.second; - // ++i; - // } - - // weights -= local_weight_means[r]; - // weights = temp*weights; - - // i=0; - // for (auto w : m_A_Ones.get_row(r)) - // { - // dr[w.second] += weights[i]; - // ++i; - // } - // } - - // dx += dr*2.0*m_Lambda; - } - - - double f(vnl_vector const &x) - { - double cost = S->get_rms_error(x); - cost *= cost; - - // cost for e^x - // vnl_vector x_exp; x_exp.set_size(x.size()); - // for (unsigned int c=0; cget_rms_error(x_exp); - // cost *= cost; - - regu_localMSE(x, cost); - // regu_MSM(x, cost); - - return cost; - } - - void gradf(vnl_vector const &x, vnl_vector &dx) - { - dx.fill(0.0); - unsigned int N = m_b.size(); - - // vnl_vector x_exp; x_exp.set_size(x.size()); - // for (unsigned int c=0; c d; d.set_size(N); - S->multiply(x,d); - d -= m_b; - - S->transpose_multiply(d, dx); - dx *= 2.0/N; - - // for (unsigned int c=0; c FitFibers( std::string , std::vector< mitk::FiberBundle::Pointer > input_tracts, mitk::Image::Pointer inputImage, vnl_sparse_matrix< double >& A, vnl_vector& b, bool single_fiber_fit, int max_iter, float g_tol, float lambda ) -{ - typedef mitk::ImageToItk< PeakImgType > CasterType; - CasterType::Pointer caster = CasterType::New(); - caster->SetInput(inputImage); - caster->Update(); - PeakImgType::Pointer itkImage = caster->GetOutput(); - - unsigned int* image_size = inputImage->GetDimensions(); - int sz_x = image_size[0]; - int sz_y = image_size[1]; - int sz_z = image_size[2]; - int sz_peaks = image_size[3]/3 + 1; // +1 for zero - peak - int num_voxels = sz_x*sz_y*sz_z; - - unsigned int num_unknowns = input_tracts.size(); - if (single_fiber_fit) - { - num_unknowns = 0; - for (unsigned int bundle=0; bundleGetNumFibers(); - } - - unsigned int number_of_residuals = num_voxels * sz_peaks; - - // create linear system - MITK_INFO << "Num. unknowns: " << num_unknowns; - MITK_INFO << "Num. residuals: " << number_of_residuals; - - MITK_INFO << "Creating system ..."; - A.set_size(number_of_residuals, num_unknowns); - b.set_size(number_of_residuals); b.fill(0.0); - - double TD = 0; - double FD = 0; - unsigned int dir_count = 0; - unsigned int fiber_count = 0; - - for (unsigned int bundle=0; bundle polydata = input_tracts.at(bundle)->GetFiberPolyData(); - - for (int i=0; iGetNumFibers(); ++i) - { - vtkCell* cell = polydata->GetCell(i); - int numPoints = cell->GetNumberOfPoints(); - vtkPoints* points = cell->GetPoints(); - - if (numPoints<2) - MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; - - for (int j=0; jGetPoint(j); - PointType4 p; - p[0]=p1[0]; - p[1]=p1[1]; - p[2]=p1[2]; - p[3]=0; - - itk::Index<4> idx4; - itkImage->TransformPhysicalPointToIndex(p, idx4); - if (!itkImage->GetLargestPossibleRegion().IsInside(idx4)) - continue; - - double* p2 = points->GetPoint(j+1); - vnl_vector_fixed fiber_dir; - fiber_dir[0] = p[0]-p2[0]; - fiber_dir[1] = p[1]-p2[1]; - fiber_dir[2] = p[2]-p2[2]; - fiber_dir.normalize(); - - double w = 1; - int peak_id = sz_peaks-1; - vnl_vector_fixed odf_peak = GetClosestPeak(idx4, itkImage, fiber_dir, peak_id, w); - float peak_mag = odf_peak.magnitude(); - - int x = idx4[0]; - int y = idx4[1]; - int z = idx4[2]; - - unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z + sz_x*sz_y*sz_z*peak_id; - - if (b[linear_index] == 0 && peak_id<3) - { - dir_count++; - FD += peak_mag; - } - TD += w; - - if (single_fiber_fit) - { - b[linear_index] = peak_mag; - A.put(linear_index, fiber_count, A.get(linear_index, fiber_count) + w); - } - else - { - b[linear_index] = peak_mag; - A.put(linear_index, bundle, A.get(linear_index, bundle) + w); - } - } - - ++fiber_count; - } - } - - TD /= (dir_count*fiber_count); - FD /= dir_count; - - A /= TD; - b *= 100.0/FD; // times 100 because we want to avoid too small values for computational reasons - -// MITK_INFO << "TD: " << TD; -// MITK_INFO << "FD: " << FD; -// MITK_INFO << "Regularization: " << lambda; - - itk::TimeProbe clock; - clock.Start(); - - MITK_INFO << "Fitting fibers"; - VnlCostFunction cost(num_unknowns); - cost.SetProblem(A, b, lambda); - - vnl_vector x; x.set_size(num_unknowns); x.fill( TD/100.0 * FD/2.0 ); - - vnl_lbfgsb minimizer(cost); - vnl_vector l; l.set_size(num_unknowns); l.fill(0); - - vnl_vector bound_selection; bound_selection.set_size(num_unknowns); bound_selection.fill(1); - minimizer.set_bound_selection(bound_selection); - minimizer.set_lower_bound(l); - minimizer.set_trace(true); - minimizer.set_projected_gradient_tolerance(g_tol); - if (max_iter>0) - minimizer.set_max_function_evals(max_iter); - minimizer.minimize(x); - - // SECOND RUN - std::vector< double > weights; - for (auto w : x) - weights.push_back(w); - sort(weights.begin(), weights.end()); - MITK_INFO << "Setting upper weight bound to " << weights.at(num_unknowns*0.95); - vnl_vector u; u.set_size(num_unknowns); u.fill(weights.at(num_unknowns*0.95)); - minimizer.set_upper_bound(u); - bound_selection.fill(2); - minimizer.set_bound_selection(bound_selection); - minimizer.minimize(x); - - weights.clear(); - for (auto w : x) - weights.push_back(w); - sort(weights.begin(), weights.end()); - - MITK_INFO << "*************************"; - MITK_INFO << "Weight statistics"; - MITK_INFO << "Mean: " << x.mean(); - MITK_INFO << "Median: " << weights.at(num_unknowns*0.5); - MITK_INFO << "75% quantile: " << weights.at(num_unknowns*0.75); - MITK_INFO << "95% quantile: " << weights.at(num_unknowns*0.95); - MITK_INFO << "99% quantile: " << weights.at(num_unknowns*0.99); - MITK_INFO << "Min: " << weights.at(0); - MITK_INFO << "Max: " << weights.at(num_unknowns-1); - MITK_INFO << "*************************"; - MITK_INFO << "NumEvals: " << minimizer.get_num_evaluations(); - MITK_INFO << "NumIterations: " << minimizer.get_num_iterations(); - MITK_INFO << "Residual cost: " << minimizer.get_end_error(); - MITK_INFO << "Final RMS: " << cost.S->get_rms_error(x); - - clock.Stop(); - int h = clock.GetTotal()/3600; - int m = ((int)clock.GetTotal()%3600)/60; - int s = (int)clock.GetTotal()%60; - MITK_INFO << "Optimization took " << h << "h, " << m << "m and " << s << "s"; - - // transform back for peak image creation - A *= FD/100.0; - b *= FD/100.0; - - return x; -} - -std::vector< string > get_file_list(const std::string& path) -{ - std::vector< string > file_list; - itk::Directory::Pointer dir = itk::Directory::New(); - - if (dir->Load(path.c_str())) - { - int n = dir->GetNumberOfFiles(); - for (int r = 0; r < n; r++) - { - const char *filename = dir->GetFile(r); - std::string ext = ist::GetFilenameExtension(filename); - if (ext==".fib" || ext==".trk") - file_list.push_back(path + '/' + filename); - } - } - return file_list; -} - -/*! -\brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). -*/ -int main(int argc, char* argv[]) -{ - mitkCommandLineParser parser; - - parser.setTitle("Fit Fibers To Image"); - parser.setCategory("Fiber Tracking Evaluation"); - parser.setDescription(""); - parser.setContributor("MIC"); - - parser.setArgumentPrefix("--", "-"); - parser.addArgument("", "i1", mitkCommandLineParser::StringList, "Input tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); - parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); - parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); - - parser.addArgument("max_iter", "", mitkCommandLineParser::Int, "Max. iterations:", "maximum number of optimizer iterations", 20); - parser.addArgument("bundle_based", "", mitkCommandLineParser::Bool, "Bundle based fit:", "fit one weight per input tractogram/bundle, not for each fiber", false); - parser.addArgument("min_g", "", mitkCommandLineParser::Float, "Min. g:", "lower termination threshold for gradient magnitude", 1e-5); - parser.addArgument("lambda", "", mitkCommandLineParser::Float, "Lambda:", "weighting factor for regularization", 1.0); - parser.addArgument("save_res", "", mitkCommandLineParser::Bool, "Residuals:", "save residual images", false); - - map parsedArgs = parser.parseArguments(argc, argv); - if (parsedArgs.size()==0) - return EXIT_FAILURE; - - mitkCommandLineParser::StringContainerType fib_files = us::any_cast(parsedArgs["i1"]); - string peak_file_name = us::any_cast(parsedArgs["i2"]); - string outRoot = us::any_cast(parsedArgs["o"]); - - bool single_fib = true; - if (parsedArgs.count("bundle_based")) - single_fib = !us::any_cast(parsedArgs["bundle_based"]); - - bool save_residuals = false; - if (parsedArgs.count("save_res")) - save_residuals = us::any_cast(parsedArgs["residuals"]); - - int max_iter = 20; - if (parsedArgs.count("it")) - max_iter = us::any_cast(parsedArgs["it"]); - - float g_tol = 1e-5; - if (parsedArgs.count("min_g")) - g_tol = us::any_cast(parsedArgs["min_g"]); - - float lambda = 1.0; - if (parsedArgs.count("lambda")) - lambda = us::any_cast(parsedArgs["lambda"]); - - try - { - std::vector< mitk::FiberBundle::Pointer > input_tracts; - - mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); - mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(peak_file_name, &functor)[0].GetPointer()); - - float minSpacing = 1; - if(inputImage->GetGeometry()->GetSpacing()[0]GetGeometry()->GetSpacing()[1] && inputImage->GetGeometry()->GetSpacing()[0]GetGeometry()->GetSpacing()[2]) - minSpacing = inputImage->GetGeometry()->GetSpacing()[0]; - else if (inputImage->GetGeometry()->GetSpacing()[1] < inputImage->GetGeometry()->GetSpacing()[2]) - minSpacing = inputImage->GetGeometry()->GetSpacing()[1]; - else - minSpacing = inputImage->GetGeometry()->GetSpacing()[2]; - - std::vector< std::string > fib_names; - for (auto item : fib_files) - { - if ( ist::FileIsDirectory(item) ) - { - for ( auto fibFile : get_file_list(item) ) - { - mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(fibFile)[0].GetPointer()); - if (inputTractogram.IsNull()) - continue; - inputTractogram->ResampleLinear(minSpacing/10); - input_tracts.push_back(inputTractogram); - fib_names.push_back(fibFile); - } - } - else - { - mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(item)[0].GetPointer()); - if (inputTractogram.IsNull()) - continue; - inputTractogram->ResampleLinear(minSpacing/10); - input_tracts.push_back(inputTractogram); - fib_names.push_back(item); - } - } - - vnl_sparse_matrix A; - vnl_vector b; - vnl_vector x = FitFibers(outRoot, input_tracts, inputImage, A, b, single_fib, max_iter, g_tol, lambda); - - MITK_INFO << "Weighting fibers"; - if (single_fib) - { - unsigned int fiber_count = 0; - for (unsigned int bundle=0; bundleGetNumFibers(); i++) - { - input_tracts.at(bundle)->SetFiberWeight(i, x[fiber_count]); - ++fiber_count; - } - } - } - else - { - for (unsigned int i=0; iSetFiberWeights(x[i]); - } - - if (save_residuals) - { - // OUTPUT IMAGES - MITK_INFO << "Generating output images ..."; - typedef mitk::ImageToItk< PeakImgType > CasterType; - CasterType::Pointer caster = CasterType::New(); - caster->SetInput(inputImage); - caster->Update(); - PeakImgType::Pointer peak_image = caster->GetOutput(); - - itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); - duplicator->SetInputImage(peak_image); - duplicator->Update(); - PeakImgType::Pointer underexplained_image = duplicator->GetOutput(); - underexplained_image->FillBuffer(0.0); - - duplicator->SetInputImage(underexplained_image); - duplicator->Update(); - PeakImgType::Pointer overexplained_image = duplicator->GetOutput(); - overexplained_image->FillBuffer(0.0); - - duplicator->SetInputImage(overexplained_image); - duplicator->Update(); - PeakImgType::Pointer residual_image = duplicator->GetOutput(); - residual_image->FillBuffer(0.0); - - duplicator->SetInputImage(residual_image); - duplicator->Update(); - PeakImgType::Pointer fitted_image = duplicator->GetOutput(); - fitted_image->FillBuffer(0.0); - - vnl_sparse_matrix_linear_system S(A, b); - vnl_vector fitted_b; fitted_b.set_size(b.size()); - S.multiply(x, fitted_b); - - unsigned int* image_size = inputImage->GetDimensions(); - int sz_x = image_size[0]; - int sz_y = image_size[1]; - int sz_z = image_size[2]; - int sz_peaks = image_size[3]/3 + 1; // +1 for zero - peak - for (unsigned int r=0; r idx; - unsigned int linear_index = r; - idx[0] = linear_index % sz_x; linear_index /= sz_x; - idx[1] = linear_index % sz_y; linear_index /= sz_y; - idx[2] = linear_index % sz_z; linear_index /= sz_z; - int peak_id = linear_index % sz_peaks; - - if (peak_id peak_dir; - - idx[3] = peak_id*3; - peak_dir[0] = peak_image->GetPixel(idx); - idx[3] += 1; - peak_dir[1] = peak_image->GetPixel(idx); - idx[3] += 1; - peak_dir[2] = peak_image->GetPixel(idx); - - peak_dir.normalize(); - peak_dir *= fitted_b[r]; - - idx[3] = peak_id*3; - fitted_image->SetPixel(idx, peak_dir[0]); - - idx[3] += 1; - fitted_image->SetPixel(idx, peak_dir[1]); - - idx[3] += 1; - fitted_image->SetPixel(idx, peak_dir[2]); - } - } - - itk::Index<4> idx; - for (idx[0]=0; idx[0] peak_dir; - vnl_vector_fixed fitted_dir; - for (idx[3]=0; idx[3]GetPixel(idx); - fitted_dir[idx[3]%3] = fitted_image->GetPixel(idx); - residual_image->SetPixel(idx, peak_image->GetPixel(idx) - fitted_image->GetPixel(idx)); - - if (idx[3]%3==2) - { - itk::Index<4> tidx= idx; - if (peak_dir.magnitude()>fitted_dir.magnitude()) - { - underexplained_image->SetPixel(tidx, peak_dir[2]-fitted_dir[2]); tidx[3]--; - underexplained_image->SetPixel(tidx, peak_dir[1]-fitted_dir[1]); tidx[3]--; - underexplained_image->SetPixel(tidx, peak_dir[0]-fitted_dir[0]); - } - else - { - overexplained_image->SetPixel(tidx, fitted_dir[2]-peak_dir[2]); tidx[3]--; - overexplained_image->SetPixel(tidx, fitted_dir[1]-peak_dir[1]); tidx[3]--; - overexplained_image->SetPixel(tidx, fitted_dir[0]-peak_dir[0]); - } - } - } - } - - itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); - writer->SetInput(fitted_image); - writer->SetFileName(outRoot + "fitted_image.nrrd"); - writer->Update(); - - writer->SetInput(residual_image); - writer->SetFileName(outRoot + "residual_image.nrrd"); - writer->Update(); - - writer->SetInput(overexplained_image); - writer->SetFileName(outRoot + "overexplained_image.nrrd"); - writer->Update(); - - writer->SetInput(underexplained_image); - writer->SetFileName(outRoot + "underexplained_image.nrrd"); - writer->Update(); - } - - for (unsigned int bundle=0; bundleCompress(0.1); - std::string name = fib_names.at(bundle); - name = ist::GetFilenameWithoutExtension(name); - mitk::IOUtil::Save(input_tracts.at(bundle), outRoot + name + "_fitted.fib"); - } - } - catch (itk::ExceptionObject e) - { - std::cout << e; - return EXIT_FAILURE; - } - catch (std::exception e) - { - std::cout << e.what(); - return EXIT_FAILURE; - } - catch (...) - { - std::cout << "ERROR!?!"; - return EXIT_FAILURE; - } - return EXIT_SUCCESS; -} diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp new file mode 100755 index 0000000000..b1eecb53d6 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp @@ -0,0 +1,271 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +using namespace std; +typedef itksys::SystemTools ist; +typedef itk::Point PointType4; +typedef itk::Image< float, 4 > PeakImgType; + +std::vector< string > get_file_list(const std::string& path) +{ + std::vector< string > file_list; + itk::Directory::Pointer dir = itk::Directory::New(); + + if (dir->Load(path.c_str())) + { + int n = dir->GetNumberOfFiles(); + for (int r = 0; r < n; r++) + { + const char *filename = dir->GetFile(r); + std::string ext = ist::GetFilenameExtension(filename); + if (ext==".fib" || ext==".trk") + file_list.push_back(path + '/' + filename); + } + } + return file_list; +} + +/*! +\brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). +*/ +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + + parser.setTitle(""); + parser.setCategory("Fiber Tracking Evaluation"); + parser.setDescription(""); + parser.setContributor("MIC"); + + parser.setArgumentPrefix("--", "-"); + parser.addArgument("", "i1", mitkCommandLineParser::InputFile, "Input tractogram:", "input tractogram (.fib, vtk ascii file format)", us::Any(), false); + parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); + parser.addArgument("", "i3", mitkCommandLineParser::InputFile, "", "", us::Any(), false); + parser.addArgument("min_gain", "", mitkCommandLineParser::Float, "Min. gain:", "process stops if remaining bundles don't contribute enough", 0.05); + + parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); + + parser.addArgument("max_iter", "", mitkCommandLineParser::Int, "Max. iterations:", "maximum number of optimizer iterations", 20); + parser.addArgument("bundle_based", "", mitkCommandLineParser::Bool, "Bundle based fit:", "fit one weight per input tractogram/bundle, not for each fiber", false); + parser.addArgument("min_g", "", mitkCommandLineParser::Float, "Min. gradient:", "lower termination threshold for gradient magnitude", 1e-5); + parser.addArgument("lambda", "", mitkCommandLineParser::Float, "Lambda:", "modifier for regularization", 0.1); + parser.addArgument("filter_outliers", "", mitkCommandLineParser::Bool, "Filter outliers:", "perform second optimization run with an upper weight bound based on the first weight estimation (95% quantile)", true); + + map parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size()==0) + return EXIT_FAILURE; + + string fib_file = us::any_cast(parsedArgs["i1"]); + string peak_file_name = us::any_cast(parsedArgs["i2"]); + string candidate_folder = us::any_cast(parsedArgs["i3"]); + string outRoot = us::any_cast(parsedArgs["o"]); + + bool single_fib = true; + if (parsedArgs.count("bundle_based")) + single_fib = !us::any_cast(parsedArgs["bundle_based"]); + + int max_iter = 20; + if (parsedArgs.count("max_iter")) + max_iter = us::any_cast(parsedArgs["max_iter"]); + + float g_tol = 1e-5; + if (parsedArgs.count("min_g")) + g_tol = us::any_cast(parsedArgs["min_g"]); + + float min_gain = 0.05; + if (parsedArgs.count("min_gain")) + min_gain = us::any_cast(parsedArgs["min_gain"]); + + float lambda = 0.1; + if (parsedArgs.count("lambda")) + lambda = us::any_cast(parsedArgs["lambda"]); + + bool filter_outliers = true; + if (parsedArgs.count("filter_outliers")) + filter_outliers = us::any_cast(parsedArgs["filter_outliers"]); + + try + { + + mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); + mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(peak_file_name, &functor)[0].GetPointer()); + + typedef mitk::ImageToItk< PeakImgType > CasterType; + CasterType::Pointer caster = CasterType::New(); + caster->SetInput(inputImage); + caster->Update(); + PeakImgType::Pointer peak_image = caster->GetOutput(); + + std::vector< mitk::FiberBundle::Pointer > input_reference; + mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(fib_file)[0].GetPointer()); + if (fib.IsNull()) + return EXIT_FAILURE; + input_reference.push_back(fib); + + std::vector< mitk::FiberBundle::Pointer > input_candidates; + std::vector< string > candidate_tract_files = get_file_list(candidate_folder); + for (string f : candidate_tract_files) + { + mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(f)[0].GetPointer()); + if (fib.IsNull()) + continue; + input_candidates.push_back(fib); + } + + int iteration = 0; + std::string name = ist::GetFilenameWithoutExtension(fib_file); + + itk::FitFibersToImageFilter::Pointer fitter = itk::FitFibersToImageFilter::New(); + fitter->SetTractograms(input_reference); + fitter->SetFitIndividualFibers(single_fib); + fitter->SetMaxIterations(max_iter); + fitter->SetGradientTolerance(g_tol); + fitter->SetLambda(lambda); + fitter->SetFilterOutliers(filter_outliers); + fitter->SetPeakImage(peak_image); + fitter->SetVerbose(false); + fitter->SetDeepCopy(false); + fitter->Update(); + + + fitter->GetTractograms().at(0)->SetFiberWeights(fitter->GetCoverage()); + fitter->GetTractograms().at(0)->ColorFibersByFiberWeights(false, false); + + mitk::IOUtil::Save(fitter->GetTractograms().at(0), outRoot + "0_" + name + ".fib"); + peak_image = fitter->GetUnderexplainedImage(); + + itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); + writer->SetInput(peak_image); + writer->SetFileName(outRoot + boost::lexical_cast(iteration) + "_" + name + ".nrrd"); + writer->Update(); + + double coverage = fitter->GetCoverage(); + MITK_INFO << "Iteration: " << iteration; + MITK_INFO << std::fixed << "Coverage: " << setprecision(1) << 100.0*coverage << "%"; +// fitter->SetPeakImage(peak_image); + + while (!input_candidates.empty()) + { + streambuf *old = cout.rdbuf(); // <-- save + stringstream ss; + std::cout.rdbuf (ss.rdbuf()); // <-- redirect + + double next_coverage = 0; + mitk::FiberBundle::Pointer best_candidate = nullptr; + for (auto fib : input_candidates) + { + + // WHY NECESSARY AGAIN?? + itk::FitFibersToImageFilter::Pointer fitter = itk::FitFibersToImageFilter::New(); + fitter->SetFitIndividualFibers(single_fib); + fitter->SetMaxIterations(max_iter); + fitter->SetGradientTolerance(g_tol); + fitter->SetLambda(lambda); + fitter->SetFilterOutliers(filter_outliers); + fitter->SetVerbose(false); + fitter->SetPeakImage(peak_image); + fitter->SetDeepCopy(false); + // ****************************** + fitter->SetTractograms({fib}); + fitter->Update(); + + double candidate_coverage = fitter->GetCoverage(); + + if (candidate_coverage>next_coverage) + { + next_coverage = candidate_coverage; + if ((1.0-coverage) * next_coverage >= min_gain) + { + best_candidate = fitter->GetTractograms().at(0); + peak_image = fitter->GetUnderexplainedImage(); + } + } + } + + if (best_candidate.IsNull()) + { + std::cout.rdbuf (old); // <-- restore + break; + } + +// fitter->SetPeakImage(peak_image); + best_candidate->SetFiberWeights((1.0-coverage) * next_coverage); + best_candidate->ColorFibersByFiberWeights(false, false); + + coverage += (1.0-coverage) * next_coverage; + + int i=0; + std::vector< mitk::FiberBundle::Pointer > remaining_candidates; + std::vector< string > remaining_candidate_files; + for (auto fib : input_candidates) + { + if (fib!=best_candidate) + { + remaining_candidates.push_back(fib); + remaining_candidate_files.push_back(candidate_tract_files.at(i)); + } + else + name = ist::GetFilenameWithoutExtension(candidate_tract_files.at(i)); + ++i; + } + input_candidates = remaining_candidates; + candidate_tract_files = remaining_candidate_files; + + iteration++; + mitk::IOUtil::Save(best_candidate, outRoot + boost::lexical_cast(iteration) + "_" + name + ".fib"); + writer->SetInput(peak_image); + writer->SetFileName(outRoot + boost::lexical_cast(iteration) + "_" + name + ".nrrd"); + writer->Update(); + std::cout.rdbuf (old); // <-- restore + + MITK_INFO << "Iteration: " << iteration; + MITK_INFO << std::fixed << "Coverage: " << setprecision(1) << 100.0*coverage << "% (+" << 100*(1.0-coverage) * next_coverage << "%)"; + } + MITK_INFO << "DONE"; + } + catch (itk::ExceptionObject e) + { + std::cout << e; + return EXIT_FAILURE; + } + catch (std::exception e) + { + std::cout << e.what(); + return EXIT_FAILURE; + } + catch (...) + { + std::cout << "ERROR!?!"; + return EXIT_FAILURE; + } + return EXIT_SUCCESS; +} diff --git a/Modules/DiffusionImaging/FiberTracking/files.cmake b/Modules/DiffusionImaging/FiberTracking/files.cmake index 4649890eb3..0ae54b7a4a 100644 --- a/Modules/DiffusionImaging/FiberTracking/files.cmake +++ b/Modules/DiffusionImaging/FiberTracking/files.cmake @@ -1,91 +1,92 @@ set(CPP_FILES mitkFiberTrackingModuleActivator.cpp ## IO datastructures IODataStructures/FiberBundle/mitkFiberBundle.cpp IODataStructures/FiberBundle/mitkTrackvis.cpp IODataStructures/PlanarFigureComposite/mitkPlanarFigureComposite.cpp IODataStructures/mitkTractographyForest.cpp # Interactions # Tractography Algorithms/GibbsTracking/mitkParticleGrid.cpp Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.cpp Algorithms/GibbsTracking/mitkEnergyComputer.cpp Algorithms/GibbsTracking/mitkGibbsEnergyComputer.cpp Algorithms/GibbsTracking/mitkFiberBuilder.cpp Algorithms/GibbsTracking/mitkSphereInterpolator.cpp Algorithms/itkStreamlineTrackingFilter.cpp Algorithms/TrackingHandlers/mitkTrackingDataHandler.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerTensor.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerPeaks.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerOdf.cpp ) set(H_FILES # DataStructures -> FiberBundle IODataStructures/FiberBundle/mitkFiberBundle.h IODataStructures/FiberBundle/mitkTrackvis.h IODataStructures/mitkFiberfoxParameters.h IODataStructures/mitkTractographyForest.h # Algorithms Algorithms/itkTractDensityImageFilter.h Algorithms/itkTractsToFiberEndingsImageFilter.h Algorithms/itkTractsToRgbaImageFilter.h Algorithms/itkTractsToVectorImageFilter.h Algorithms/itkEvaluateDirectionImagesFilter.h Algorithms/itkEvaluateTractogramDirectionsFilter.h Algorithms/itkFiberCurvatureFilter.h + Algorithms/itkFitFibersToImageFilter.h # Tractography Algorithms/TrackingHandlers/mitkTrackingDataHandler.h Algorithms/TrackingHandlers/mitkTrackingHandlerRandomForest.h Algorithms/TrackingHandlers/mitkTrackingHandlerTensor.h Algorithms/TrackingHandlers/mitkTrackingHandlerPeaks.h Algorithms/TrackingHandlers/mitkTrackingHandlerOdf.h Algorithms/itkGibbsTrackingFilter.h Algorithms/itkStochasticTractographyFilter.h Algorithms/GibbsTracking/mitkParticle.h Algorithms/GibbsTracking/mitkParticleGrid.h Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.h Algorithms/GibbsTracking/mitkSimpSamp.h Algorithms/GibbsTracking/mitkEnergyComputer.h Algorithms/GibbsTracking/mitkGibbsEnergyComputer.h Algorithms/GibbsTracking/mitkSphereInterpolator.h Algorithms/GibbsTracking/mitkFiberBuilder.h Algorithms/itkStreamlineTrackingFilter.h # Fiberfox Fiberfox/itkFibersFromPlanarFiguresFilter.h Fiberfox/itkTractsToDWIImageFilter.h Fiberfox/itkKspaceImageFilter.h Fiberfox/itkDftImageFilter.h Fiberfox/itkFieldmapGeneratorFilter.h Fiberfox/SignalModels/mitkDiffusionSignalModel.h Fiberfox/SignalModels/mitkTensorModel.h Fiberfox/SignalModels/mitkBallModel.h Fiberfox/SignalModels/mitkDotModel.h Fiberfox/SignalModels/mitkAstroStickModel.h Fiberfox/SignalModels/mitkStickModel.h Fiberfox/SignalModels/mitkRawShModel.h Fiberfox/SignalModels/mitkDiffusionNoiseModel.h Fiberfox/SignalModels/mitkRicianNoiseModel.h Fiberfox/SignalModels/mitkChiSquareNoiseModel.h Fiberfox/Sequences/mitkAcquisitionType.h Fiberfox/Sequences/mitkSingleShotEpi.h Fiberfox/Sequences/mitkCartesianReadout.h ) set(RESOURCE_FILES # Binary directory resources FiberTrackingLUTBaryCoords.bin FiberTrackingLUTIndices.bin )