diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp index 6c9dfe8ba8..eedda228b1 100755 --- a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp @@ -1,752 +1,745 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include <mitkBaseData.h> #include <mitkImageCast.h> #include <mitkImageToItk.h> #include <metaCommand.h> #include <mitkCommandLineParser.h> #include <usAny.h> #include <mitkIOUtil.h> #include <boost/lexical_cast.hpp> #include <itksys/SystemTools.hxx> #include <itkDirectory.h> #include <mitkFiberBundle.h> #include <mitkPreferenceListReaderOptionsFunctor.h> #include <mitkDiffusionPropertyHelper.h> #include <vnl/vnl_linear_system.h> #include <Eigen/Dense> #include <mitkStickModel.h> #include <mitkBallModel.h> #include <vigra/regression.hxx> #include <itkImageFileWriter.h> #include <itkImageDuplicator.h> #include <itkMersenneTwisterRandomVariateGenerator.h> #include <mitkPeakImage.h> #include <vnl/algo/vnl_lbfgsb.h> #include <vnl/vnl_sparse_matrix.h> #include <vnl/vnl_sparse_matrix_linear_system.h> #include <vnl/algo/vnl_lsqr.h> #include <itkImageDuplicator.h> #include <itkTimeProbe.h> #include <random> #include <itkParticleSwarmOptimizer.h> #include <itkOnePlusOneEvolutionaryOptimizer.h> #include <itkGradientDescentOptimizer.h> #include <itkSPSAOptimizer.h> using namespace std; typedef itksys::SystemTools ist; typedef itk::Point<float, 4> PointType4; typedef itk::Image< float, 4 > PeakImgType; -const float UPSCALE = 1000.0; -vnl_vector_fixed<float,3> GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed<float,3> fiber_dir, int& id ) +vnl_vector_fixed<float,3> GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed<float,3> fiber_dir, int& id, double& w ) { int m_NumDirs = peak_image->GetLargestPossibleRegion().GetSize()[3]/3; vnl_vector_fixed<float,3> out_dir; out_dir.fill(0); float angle = 0.8; for (int i=0; i<m_NumDirs; i++) { vnl_vector_fixed<float,3> dir; idx[3] = i*3; dir[0] = peak_image->GetPixel(idx); idx[3] += 1; dir[1] = peak_image->GetPixel(idx); idx[3] += 1; dir[2] = peak_image->GetPixel(idx); float mag = dir.magnitude(); if (mag<mitk::eps) continue; dir.normalize(); float a = dot_product(dir, fiber_dir); if (fabs(a)>angle) { angle = fabs(a); + w = angle; if (a<0) out_dir = -dir; else out_dir = dir; out_dir *= mag; - out_dir *= UPSCALE; // for the fit it's better if the magnitude is larger since many fibers pass each voxel and otherwise the individual contributions would be very small id = i; } } return out_dir; } class VnlCostFunction : public vnl_cost_function { public: vnl_sparse_matrix_linear_system< double >* S; vnl_sparse_matrix< double > m_A; + vnl_sparse_matrix< double > m_A_Ones; // matrix indicating active weights with 1 vnl_vector< double > m_b; - itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen; + double m_Lambda; // regularization factor - void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector<double>& b) + vnl_vector<double> row_sums; // number of active weights per row + vnl_vector<double> local_weight_means; // mean weight of each row + + void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector<double>& b, double lambda) { S = new vnl_sparse_matrix_linear_system<double>(A, b); m_A = A; m_b = b; - randGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); + m_Lambda = lambda; + + m_A_Ones.set_size(m_A.rows(), m_A.cols()); + m_A.reset(); + while (m_A.next()) + m_A_Ones.put(m_A.getrow(), m_A.getcolumn(), 1); + + unsigned int N = m_b.size(); + vnl_vector<double> ones; ones.set_size(dim); ones.fill(1.0); + row_sums.set_size(N); + m_A_Ones.mult(ones, row_sums); + local_weight_means.set_size(N); } VnlCostFunction(const int NumVars) : vnl_cost_function(NumVars) { } - double f_itk(vnl_vector<double> const &x) const + void regu_MSE(vnl_vector<double> const &x, double& cost) { - double min = x.min_value(); - if( min<0 ) - return 10000 * -min; - - double cost = S->get_rms_error(x); - double regu = x.squared_magnitude()/x.size(); - cost += regu; - return cost; + double mean = x.mean(); + vnl_vector<double> tx = x-mean; + cost += m_Lambda*1e8*tx.squared_magnitude()/x.size(); } - double f(vnl_vector<double> const &x) + void regu_MSM(vnl_vector<double> const &x, double& cost) { - double cost = S->get_rms_error(x); - - double regu = x.squared_magnitude()/x.size(); + cost += m_Lambda*1e8*x.squared_magnitude()/x.size(); + } -// unsigned int norm = 0; -// for (unsigned int i=0; i<m_b.size(); ++i) -// { -// if (m_A.get_row(i).empty()) -// continue; + void regu_localMSE(vnl_vector<double> const &x, double& cost) + { + m_A_Ones.mult(x, local_weight_means); + local_weight_means = element_quotient(local_weight_means, row_sums); -// float mean = 0; -// for (auto el : m_A.get_row(i)) -// mean += x[el.first]; + m_A_Ones.reset(); + unsigned int num_elements = 0; + double regu = 0; + while (m_A_Ones.next()) + { + double d = 0; + if (x[m_A_Ones.getcolumn()]>local_weight_means[m_A_Ones.getrow()]) + d = std::exp(x[m_A_Ones.getcolumn()]) - std::exp(local_weight_means[m_A_Ones.getrow()]); + else + d = x[m_A_Ones.getcolumn()] - local_weight_means[m_A_Ones.getrow()]; + regu += d*d; + ++num_elements; + } + cost += m_Lambda*1e3*regu/num_elements; + } -// mean /= m_A.get_row(i).size(); + void grad_regu_MSE(vnl_vector<double> const &x, vnl_vector<double> &dx) + { + double mean = x.mean(); + vnl_vector<double> tx = x-mean; -// for (auto el : m_A.get_row(i)) -// { -// float d = x[el.first] - mean; -// { -// regu += d*d; -// norm++; -// } -// } -// } -// regu /= norm; + vnl_vector<double> tx2(dim, 0.0); + vnl_vector<double> h(dim, 1.0); + for (int c=0; c<dim; c++) + { + h[c] = dim-1; + tx2[c] += dot_product(h,tx); + h[c] = 1; + } + dx += tx2*m_Lambda*1e8*2.0/(dim*dim); - cost += regu; - return cost; } - // Finite differences gradient (SLOW) -// void gradf(vnl_vector<double> const &x, vnl_vector<double> &dx) -// { -// fdgradf(x, dx); -// } + void grad_regu_MSM(vnl_vector<double> const &x, vnl_vector<double> &dx) + { + dx += m_Lambda*1e8*2.0*x/dim; + } - void gradf(vnl_vector<double> const &x, vnl_vector<double> &dx) + void grad_regu_localMSE(vnl_vector<double> const &x, vnl_vector<double> &dx) { - dx.fill(0.0); - double mag = x.magnitude(); - unsigned int N = m_b.size(); + m_A_Ones.mult(x, local_weight_means); + local_weight_means = element_quotient(local_weight_means, row_sums); - vnl_vector<double> d; d.set_size(N); - S->multiply(x,d); - d -= m_b; + vnl_vector<double> exp_x = x.apply(std::exp); + vnl_vector<double> exp_means = local_weight_means.apply(std::exp); - vnl_vector<double> numerator; numerator.set_size(x.size()); - S->transpose_multiply(d, dx); - dx *= 2.0/N; + vnl_vector<double> tdx(dim, 0); + m_A_Ones.reset(); + while (m_A_Ones.next()) + { + int c = m_A_Ones.getcolumn(); + int r = m_A_Ones.getrow(); + if (row_sums[r]==0) + continue; - if (mag>0) - dx += x/mag; + if (x[c]>local_weight_means[r]) + tdx[c] += (exp_x[c] * ( exp_x[c] - exp_means[r] ))/row_sums[r]; + else + tdx[c] += (x[c] - local_weight_means[r])/row_sums[r]; + } + dx += tdx*1e3*2.0*m_Lambda; + + // vnl_vector<double> dr; dr.set_size(dim); dr.fill(0); + // for (unsigned int r=0; r<m_A_Ones.rows(); r++) + // { + // int n = row_sums[r]; + // vnl_vector<double> weights; weights.set_size(n); + // vnl_matrix<double> temp(n,n,1); temp.fill_diagonal(n-1); + + // int i=0; + // for (auto w : m_A_Ones.get_row(r)) + // { + // weights[i]=w.second; + // ++i; + // } + + // weights -= local_weight_means[r]; + // weights = temp*weights; + + // i=0; + // for (auto w : m_A_Ones.get_row(r)) + // { + // dr[w.second] += weights[i]; + // ++i; + // } + // } + + // dx += dr*2.0*m_Lambda; } -}; - -class ItkCostFunction : public itk::SingleValuedCostFunction -{ -public: - /** Standard class typedefs. */ - typedef ItkCostFunction Self; - typedef SingleValuedCostFunction Superclass; - typedef itk::SmartPointer<Self> Pointer; - typedef itk::SmartPointer<const Self> ConstPointer; - /** Method for creation through the object factory. */ - itkNewMacro(Self) + double f(vnl_vector<double> const &x) + { + double cost = S->get_rms_error(x); + cost *= cost; - /** Run-time type information (and related methods). */ - itkTypeMacro(ItkCostFunction, SingleValuedCostfunction) + // cost for e^x + // vnl_vector<double> x_exp; x_exp.set_size(x.size()); + // for (unsigned int c=0; c<x.size(); c++) + // x_exp[c] = std::exp(x[c]); + // double cost = S->get_rms_error(x_exp); + // cost *= cost; - void SetVnlCostFunction(VnlCostFunction& cf) { m_VnlCostFunction = cf; } - unsigned int GetNumberOfParameters(void) const { return m_VnlCostFunction.get_number_of_unknowns(); } // itk::CostFunction + regu_localMSE(x, cost); + // regu_MSM(x, cost); - MeasureType GetValue(const ParametersType & parameters) const - { - return m_VnlCostFunction.f_itk(parameters); + return cost; } - void GetDerivative(const ParametersType &, DerivativeType & ) const { - throw itk::ExceptionObject( __FILE__, __LINE__, "No derivative is available for this cost function."); - } + void gradf(vnl_vector<double> const &x, vnl_vector<double> &dx) + { + dx.fill(0.0); + unsigned int N = m_b.size(); -protected: - ItkCostFunction(){} - ~ItkCostFunction(){} + // vnl_vector<double> x_exp; x_exp.set_size(x.size()); + // for (unsigned int c=0; c<x.size(); c++) + // x_exp[c] = std::exp(x[c]); - VnlCostFunction m_VnlCostFunction = VnlCostFunction(1); + vnl_vector<double> d; d.set_size(N); + S->multiply(x,d); + d -= m_b; -private: - ItkCostFunction(const Self &); //purposely not implemented - void operator = (const Self &); //purposely not implemented -}; + S->transpose_multiply(d, dx); + dx *= 2.0/N; -void OptimizeItk(VnlCostFunction& cf, vnl_vector<double>& x, int iter, double lb, double ub) -{ - ItkCostFunction::ParametersType p; p.SetData(x.data_block(), x.size()); - - ItkCostFunction::Pointer itk_cf = ItkCostFunction::New(); - itk_cf->SetVnlCostFunction(cf); - - std::pair< double, double > bounds; bounds.first = lb; bounds.second = ub; - MITK_INFO << bounds; - -// itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); -// itk::OnePlusOneEvolutionaryOptimizer::Pointer opt = itk::OnePlusOneEvolutionaryOptimizer::New(); -// opt->SetCostFunction(itk_cf); -// opt->MinimizeOn(); -// opt->SetInitialPosition(p); -// opt->SetNormalVariateGenerator(randGen); -// opt->SetMaximumIteration(iter); -// opt->StartOptimization(); - -// itk::ParticleSwarmOptimizer::Pointer opt = itk::ParticleSwarmOptimizer::New(); -// opt->SetCostFunction(itk_cf); -// opt->SetInitialPosition(p); -// opt->SetParameterBounds(bounds, x.size()); -// opt->SetMaximalNumberOfIterations(iter); -// opt->SetNumberOfParticles(100); -// opt->SetParametersConvergenceTolerance(0.01, x.size()); -// opt->SetNumberOfGenerationsWithMinimalImprovement(3); -// opt->StartOptimization(); - -// itk::GradientDescentOptimizer::Pointer opt = itk::GradientDescentOptimizer::New(); -// opt->SetCostFunction(itk_cf); -// opt->SetInitialPosition(p); -// opt->SetMinimize(true); -// opt->SetNumberOfIterations(iter); -// opt->StartOptimization(); - - itk::SPSAOptimizer::Pointer opt = itk::SPSAOptimizer::New(); - opt->SetCostFunction(itk_cf); - opt->SetInitialPosition(p); - opt->SetMinimize(true); - opt->SetNumberOfPerturbations(iter); - opt->StartOptimization(); - - x.copy_in(opt->GetCurrentPosition().data_block()); - for (unsigned int i=0; i<x.size(); i++) - MITK_INFO << opt->GetCurrentPosition()[i]; -// MITK_INFO << "Cost: " << opt->GetCurrentCost(); + // for (unsigned int c=0; c<x.size(); c++) + // dx[c] *= x_exp[c]; // only for e^x weights -} + grad_regu_localMSE(x,dx); + // grad_regu_MSM(x,dx); + } +}; -std::vector<float> FitFibers( std::string , std::vector< mitk::FiberBundle::Pointer > input_tracts, mitk::Image::Pointer inputImage, bool single_fiber_fit, int max_iter, float g_tol, bool lb ) +vnl_vector<double> FitFibers( std::string , std::vector< mitk::FiberBundle::Pointer > input_tracts, mitk::Image::Pointer inputImage, vnl_sparse_matrix< double >& A, vnl_vector<double>& b, bool single_fiber_fit, int max_iter, float g_tol, float lambda ) { typedef mitk::ImageToItk< PeakImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(inputImage); caster->Update(); PeakImgType::Pointer itkImage = caster->GetOutput(); unsigned int* image_size = inputImage->GetDimensions(); int sz_x = image_size[0]; int sz_y = image_size[1]; int sz_z = image_size[2]; - int sz_peaks = image_size[3]; + int sz_peaks = image_size[3]/3 + 1; // +1 for zero - peak int num_voxels = sz_x*sz_y*sz_z; - unsigned int num_unknowns = input_tracts.size(); //inputTractogram->GetNumFibers(); + unsigned int num_unknowns = input_tracts.size(); if (single_fiber_fit) { num_unknowns = 0; for (unsigned int bundle=0; bundle<input_tracts.size(); bundle++) num_unknowns += input_tracts.at(bundle)->GetNumFibers(); } unsigned int number_of_residuals = num_voxels * sz_peaks; // create linear system MITK_INFO << "Num. unknowns: " << num_unknowns; MITK_INFO << "Num. residuals: " << number_of_residuals; MITK_INFO << "Creating system ..."; - vnl_sparse_matrix< double > A; A.set_size(number_of_residuals, num_unknowns); - vnl_vector<double> b; b.set_size(number_of_residuals); b.fill(0.0); - - float max_peak_mag = 0; - int max_peak_idx = -1; - - float min_peak_mag = 999999999; - int min_peak_idx = -1; + A.set_size(number_of_residuals, num_unknowns); + b.set_size(number_of_residuals); b.fill(0.0); + double TD = 0; + double FD = 0; + unsigned int dir_count = 0; unsigned int fiber_count = 0; + for (unsigned int bundle=0; bundle<input_tracts.size(); bundle++) { vtkSmartPointer<vtkPolyData> polydata = input_tracts.at(bundle)->GetFiberPolyData(); for (int i=0; i<input_tracts.at(bundle)->GetNumFibers(); ++i) { vtkCell* cell = polydata->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); if (numPoints<2) MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; for (int j=0; j<numPoints-1; ++j) { double* p1 = points->GetPoint(j); PointType4 p; p[0]=p1[0]; p[1]=p1[1]; p[2]=p1[2]; p[3]=0; itk::Index<4> idx4; itkImage->TransformPhysicalPointToIndex(p, idx4); if (!itkImage->GetLargestPossibleRegion().IsInside(idx4)) continue; double* p2 = points->GetPoint(j+1); vnl_vector_fixed<float,3> fiber_dir; fiber_dir[0] = p[0]-p2[0]; fiber_dir[1] = p[1]-p2[1]; fiber_dir[2] = p[2]-p2[2]; fiber_dir.normalize(); - int peak_id = -1; - vnl_vector_fixed<float,3> odf_peak = GetClosestPeak(idx4, itkImage, fiber_dir, peak_id); + double w = 1; + int peak_id = sz_peaks-1; + vnl_vector_fixed<float,3> odf_peak = GetClosestPeak(idx4, itkImage, fiber_dir, peak_id, w); float peak_mag = odf_peak.magnitude(); - if (peak_id<0) - continue; int x = idx4[0]; int y = idx4[1]; int z = idx4[2]; - unsigned int linear_index = sz_peaks*(x + sz_x*y + sz_x*sz_y*z); + unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z + sz_x*sz_y*sz_z*peak_id; - if (peak_mag>max_peak_mag) + if (b[linear_index] == 0 && peak_id<3) { - max_peak_mag = peak_mag; - max_peak_idx = linear_index + 3*peak_id; + dir_count++; + FD += peak_mag; } + TD += w; - if (peak_mag<min_peak_mag) + if (single_fiber_fit) { - min_peak_mag = peak_mag; - min_peak_idx = linear_index + 3*peak_id; + b[linear_index] = peak_mag; + A.put(linear_index, fiber_count, A.get(linear_index, fiber_count) + w); } - - for (unsigned int k=0; k<3; ++k) + else { - if (single_fiber_fit) - { - b[linear_index + 3*peak_id + k] = (double)odf_peak[k]; - A.put(linear_index + 3*peak_id + k, fiber_count, A.get(linear_index + 3*peak_id + k, fiber_count) + (double)fiber_dir[k]); - } - else - { - b[linear_index + 3*peak_id + k] = (double)odf_peak[k]; - A.put(linear_index + 3*peak_id + k, bundle, A.get(linear_index + 3*peak_id + k, bundle) + (double)fiber_dir[k]); - } + b[linear_index] = peak_mag; + A.put(linear_index, bundle, A.get(linear_index, bundle) + w); } - } ++fiber_count; } } - vnl_vector_fixed<float,3> max_corr_fiber_dir; max_corr_fiber_dir.fill(0.0); - vnl_vector_fixed<float,3> min_corr_fiber_dir; min_corr_fiber_dir.fill(0.0); - for (unsigned int i=0; i<fiber_count; ++i) - { - max_corr_fiber_dir[0] += A.get(max_peak_idx, i); - max_corr_fiber_dir[1] += A.get(max_peak_idx+1, i); - max_corr_fiber_dir[2] += A.get(max_peak_idx+2, i); - - min_corr_fiber_dir[0] += A.get(min_peak_idx, i); - min_corr_fiber_dir[1] += A.get(min_peak_idx+1, i); - min_corr_fiber_dir[2] += A.get(min_peak_idx+2, i); - } + TD /= (dir_count*fiber_count); + FD /= dir_count; - float upper_bound = max_peak_mag/max_corr_fiber_dir.magnitude(); - float lower_bound = min_peak_mag/min_corr_fiber_dir.magnitude(); + A /= TD; + b *= 100.0/FD; // times 100 because we want to avoid too small values for computational reasons - if (!lb || lower_bound>=upper_bound) - lower_bound = 0; - - MITK_INFO << "Lower bound: " << lower_bound; - MITK_INFO << "Upper bound: " << upper_bound; +// MITK_INFO << "TD: " << TD; +// MITK_INFO << "FD: " << FD; +// MITK_INFO << "Regularization: " << lambda; itk::TimeProbe clock; clock.Start(); MITK_INFO << "Fitting fibers"; VnlCostFunction cost(num_unknowns); - cost.SetProblem(A, b); + cost.SetProblem(A, b, lambda); - MITK_INFO << g_tol << " " << max_iter; - vnl_vector<double> x; x.set_size(num_unknowns); x.fill( (upper_bound-lower_bound)/2 ); -// OptimizeItk(cost, x, max_iter, lower_bound, upper_bound); + vnl_vector<double> x; x.set_size(num_unknowns); x.fill( TD/100.0 * FD/2.0 ); vnl_lbfgsb minimizer(cost); - vnl_vector<double> l; l.set_size(num_unknowns); l.fill(lower_bound); - vnl_vector<double> u; u.set_size(num_unknowns); u.fill(upper_bound); - vnl_vector<long> bound_selection; bound_selection.set_size(num_unknowns); bound_selection.fill(2); + vnl_vector<double> l; l.set_size(num_unknowns); l.fill(0); + + vnl_vector<long> bound_selection; bound_selection.set_size(num_unknowns); bound_selection.fill(1); minimizer.set_bound_selection(bound_selection); minimizer.set_lower_bound(l); - minimizer.set_upper_bound(u); minimizer.set_trace(true); minimizer.set_projected_gradient_tolerance(g_tol); if (max_iter>0) minimizer.set_max_function_evals(max_iter); minimizer.minimize(x); - MITK_INFO << "Residual error: " << minimizer.get_end_error(); + + // SECOND RUN + std::vector< double > weights; + for (auto w : x) + weights.push_back(w); + sort(weights.begin(), weights.end()); + MITK_INFO << "Setting upper weight bound to " << weights.at(num_unknowns*0.95); + vnl_vector<double> u; u.set_size(num_unknowns); u.fill(weights.at(num_unknowns*0.95)); + minimizer.set_upper_bound(u); + bound_selection.fill(2); + minimizer.set_bound_selection(bound_selection); + minimizer.minimize(x); + + weights.clear(); + for (auto w : x) + weights.push_back(w); + sort(weights.begin(), weights.end()); + + MITK_INFO << "*************************"; + MITK_INFO << "Weight statistics"; + MITK_INFO << "Mean: " << x.mean(); + MITK_INFO << "Median: " << weights.at(num_unknowns*0.5); + MITK_INFO << "75% quantile: " << weights.at(num_unknowns*0.75); + MITK_INFO << "95% quantile: " << weights.at(num_unknowns*0.95); + MITK_INFO << "99% quantile: " << weights.at(num_unknowns*0.99); + MITK_INFO << "Min: " << weights.at(0); + MITK_INFO << "Max: " << weights.at(num_unknowns-1); + MITK_INFO << "*************************"; MITK_INFO << "NumEvals: " << minimizer.get_num_evaluations(); MITK_INFO << "NumIterations: " << minimizer.get_num_iterations(); - -// vnl_sparse_matrix_linear_system<double> S(A, b); -// vnl_lsqr linear_solver( S ); -// linear_solver.set_max_iterations(max_iter); -// linear_solver.minimize(x); + MITK_INFO << "Residual cost: " << minimizer.get_end_error(); + MITK_INFO << "Final RMS: " << cost.S->get_rms_error(x); clock.Stop(); int h = clock.GetTotal()/3600; int m = ((int)clock.GetTotal()%3600)/60; int s = (int)clock.GetTotal()%60; MITK_INFO << "Optimization took " << h << "h, " << m << "m and " << s << "s"; - std::vector<float> weights; - float max_w = 0; - for (unsigned int i=0; i<num_unknowns; ++i) - { -// MITK_INFO << x[i]; - if (x[i]>max_w) - max_w = x[i]; - weights.push_back(x[i]); - } - MITK_INFO << "Max w: " << max_w; + // transform back for peak image creation + A *= FD/100.0; + b *= FD/100.0; - return weights; + return x; } std::vector< string > get_file_list(const std::string& path) { std::vector< string > file_list; itk::Directory::Pointer dir = itk::Directory::New(); if (dir->Load(path.c_str())) { int n = dir->GetNumberOfFiles(); for (int r = 0; r < n; r++) { const char *filename = dir->GetFile(r); std::string ext = ist::GetFilenameExtension(filename); if (ext==".fib" || ext==".trk") file_list.push_back(path + '/' + filename); } } return file_list; } +/*! +\brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). +*/ int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Fit Fibers To Image"); parser.setCategory("Fiber Tracking Evaluation"); parser.setDescription(""); parser.setContributor("MIC"); parser.setArgumentPrefix("--", "-"); parser.addArgument("", "i1", mitkCommandLineParser::StringList, "Input tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); - parser.addArgument("", "it", mitkCommandLineParser::Int, "", ""); - parser.addArgument("", "s", mitkCommandLineParser::Bool, "", ""); - parser.addArgument("", "lb", mitkCommandLineParser::Bool, "", ""); - parser.addArgument("", "g", mitkCommandLineParser::Float, "", ""); parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); + parser.addArgument("max_iter", "", mitkCommandLineParser::Int, "Max. iterations:", "maximum number of optimizer iterations", 20); + parser.addArgument("bundle_based", "", mitkCommandLineParser::Bool, "Bundle based fit:", "fit one weight per input tractogram/bundle, not for each fiber", false); + parser.addArgument("min_g", "", mitkCommandLineParser::Float, "Min. g:", "lower termination threshold for gradient magnitude", 1e-5); + parser.addArgument("lambda", "", mitkCommandLineParser::Float, "Lambda:", "weighting factor for regularization", 1.0); + parser.addArgument("save_res", "", mitkCommandLineParser::Bool, "Residuals:", "save residual images", false); + map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; mitkCommandLineParser::StringContainerType fib_files = us::any_cast<mitkCommandLineParser::StringContainerType>(parsedArgs["i1"]); - string dwiFile = us::any_cast<string>(parsedArgs["i2"]); + string peak_file_name = us::any_cast<string>(parsedArgs["i2"]); string outRoot = us::any_cast<string>(parsedArgs["o"]); - bool single_fib = false; - if (parsedArgs.count("s")) - single_fib = us::any_cast<bool>(parsedArgs["s"]); + bool single_fib = true; + if (parsedArgs.count("bundle_based")) + single_fib = !us::any_cast<bool>(parsedArgs["bundle_based"]); - int max_iter = 0; + bool save_residuals = false; + if (parsedArgs.count("save_res")) + save_residuals = us::any_cast<bool>(parsedArgs["residuals"]); + + int max_iter = 20; if (parsedArgs.count("it")) max_iter = us::any_cast<int>(parsedArgs["it"]); float g_tol = 1e-5; - if (parsedArgs.count("g")) - g_tol = us::any_cast<float>(parsedArgs["g"]); + if (parsedArgs.count("min_g")) + g_tol = us::any_cast<float>(parsedArgs["min_g"]); - bool lb = false; - if (parsedArgs.count("lb")) - lb = us::any_cast<bool>(parsedArgs["lb"]); + float lambda = 1.0; + if (parsedArgs.count("lambda")) + lambda = us::any_cast<float>(parsedArgs["lambda"]); try { std::vector< mitk::FiberBundle::Pointer > input_tracts; mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); - mitk::Image::Pointer inputImage = dynamic_cast<mitk::PeakImage*>(mitk::IOUtil::Load(dwiFile, &functor)[0].GetPointer()); + mitk::Image::Pointer inputImage = dynamic_cast<mitk::PeakImage*>(mitk::IOUtil::Load(peak_file_name, &functor)[0].GetPointer()); float minSpacing = 1; if(inputImage->GetGeometry()->GetSpacing()[0]<inputImage->GetGeometry()->GetSpacing()[1] && inputImage->GetGeometry()->GetSpacing()[0]<inputImage->GetGeometry()->GetSpacing()[2]) minSpacing = inputImage->GetGeometry()->GetSpacing()[0]; else if (inputImage->GetGeometry()->GetSpacing()[1] < inputImage->GetGeometry()->GetSpacing()[2]) minSpacing = inputImage->GetGeometry()->GetSpacing()[1]; else minSpacing = inputImage->GetGeometry()->GetSpacing()[2]; std::vector< std::string > fib_names; for (auto item : fib_files) { if ( ist::FileIsDirectory(item) ) { for ( auto fibFile : get_file_list(item) ) { mitk::FiberBundle::Pointer inputTractogram = dynamic_cast<mitk::FiberBundle*>(mitk::IOUtil::Load(fibFile)[0].GetPointer()); if (inputTractogram.IsNull()) continue; inputTractogram->ResampleLinear(minSpacing/10); input_tracts.push_back(inputTractogram); fib_names.push_back(fibFile); } } else { mitk::FiberBundle::Pointer inputTractogram = dynamic_cast<mitk::FiberBundle*>(mitk::IOUtil::Load(item)[0].GetPointer()); if (inputTractogram.IsNull()) continue; inputTractogram->ResampleLinear(minSpacing/10); input_tracts.push_back(inputTractogram); fib_names.push_back(item); } } - std::vector<float> weights = FitFibers(outRoot, input_tracts, inputImage, single_fib, max_iter, g_tol, lb); + vnl_sparse_matrix<double> A; + vnl_vector<double> b; + vnl_vector<double> x = FitFibers(outRoot, input_tracts, inputImage, A, b, single_fib, max_iter, g_tol, lambda); + MITK_INFO << "Weighting fibers"; if (single_fib) { unsigned int fiber_count = 0; - for (unsigned int bundle=0; bundle<input_tracts.size(); bundle++) { - mitk::FiberBundle::Pointer fib = input_tracts.at(bundle); - for (int i=0; i<fib->GetNumFibers(); i++) + for (int i=0; i<input_tracts.at(bundle)->GetNumFibers(); i++) { - fib->SetFiberWeight(i, weights.at(fiber_count)); + input_tracts.at(bundle)->SetFiberWeight(i, x[fiber_count]); ++fiber_count; } - - std::string name = fib_names.at(bundle); - name = ist::GetFilenameWithoutExtension(name); - mitk::IOUtil::Save(fib, outRoot + name + "_fitted.fib"); } } else { for (unsigned int i=0; i<fib_names.size(); ++i) - { - std::string name = fib_names.at(i); - name = ist::GetFilenameWithoutExtension(name); - MITK_INFO << name << ": " << weights.at(i); - mitk::FiberBundle::Pointer bundle = input_tracts.at(i); - bundle->SetFiberWeights(weights.at(i)); - mitk::IOUtil::Save(bundle, outRoot + name + "_fitted.fib"); - } + input_tracts.at(i)->SetFiberWeights(x[i]); } - - // OUTPUT IMAGES - MITK_INFO << "Generating output images ..."; - typedef mitk::ImageToItk< PeakImgType > CasterType; - CasterType::Pointer caster = CasterType::New(); - caster->SetInput(inputImage); - caster->Update(); - PeakImgType::Pointer itkImage = caster->GetOutput(); - - itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); - duplicator->SetInputImage(itkImage); - duplicator->Update(); - PeakImgType::Pointer unexplained_image = duplicator->GetOutput(); - - duplicator->SetInputImage(unexplained_image); - duplicator->Update(); - PeakImgType::Pointer residual_image = duplicator->GetOutput(); - - duplicator->SetInputImage(residual_image); - duplicator->Update(); - PeakImgType::Pointer explained_image = duplicator->GetOutput(); - explained_image->FillBuffer(0.0); - - for (unsigned int bundle=0; bundle<input_tracts.size(); bundle++) + if (save_residuals) { - vtkSmartPointer<vtkPolyData> polydata = input_tracts.at(bundle)->GetFiberPolyData(); - - for (int i=0; i<input_tracts.at(bundle)->GetNumFibers(); ++i) + // OUTPUT IMAGES + MITK_INFO << "Generating output images ..."; + typedef mitk::ImageToItk< PeakImgType > CasterType; + CasterType::Pointer caster = CasterType::New(); + caster->SetInput(inputImage); + caster->Update(); + PeakImgType::Pointer peak_image = caster->GetOutput(); + + itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); + duplicator->SetInputImage(peak_image); + duplicator->Update(); + PeakImgType::Pointer underexplained_image = duplicator->GetOutput(); + underexplained_image->FillBuffer(0.0); + + duplicator->SetInputImage(underexplained_image); + duplicator->Update(); + PeakImgType::Pointer overexplained_image = duplicator->GetOutput(); + overexplained_image->FillBuffer(0.0); + + duplicator->SetInputImage(overexplained_image); + duplicator->Update(); + PeakImgType::Pointer residual_image = duplicator->GetOutput(); + residual_image->FillBuffer(0.0); + + duplicator->SetInputImage(residual_image); + duplicator->Update(); + PeakImgType::Pointer fitted_image = duplicator->GetOutput(); + fitted_image->FillBuffer(0.0); + + vnl_sparse_matrix_linear_system<double> S(A, b); + vnl_vector<double> fitted_b; fitted_b.set_size(b.size()); + S.multiply(x, fitted_b); + + unsigned int* image_size = inputImage->GetDimensions(); + int sz_x = image_size[0]; + int sz_y = image_size[1]; + int sz_z = image_size[2]; + int sz_peaks = image_size[3]/3 + 1; // +1 for zero - peak + for (unsigned int r=0; r<b.size(); r++) { - vtkCell* cell = polydata->GetCell(i); - int numPoints = cell->GetNumberOfPoints(); - vtkPoints* points = cell->GetPoints(); - - if (numPoints<2) - MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; - - float w = input_tracts.at(bundle)->GetFiberWeight(i)/UPSCALE; - - for (int j=0; j<numPoints-1; ++j) + itk::Index<4> idx; + unsigned int linear_index = r; + idx[0] = linear_index % sz_x; linear_index /= sz_x; + idx[1] = linear_index % sz_y; linear_index /= sz_y; + idx[2] = linear_index % sz_z; linear_index /= sz_z; + int peak_id = linear_index % sz_peaks; + + if (peak_id<sz_peaks-1) { - double* p1 = points->GetPoint(j); - PointType4 p; - p[0]=p1[0]; - p[1]=p1[1]; - p[2]=p1[2]; - p[3]=0; - - itk::Index<4> idx4; - itkImage->TransformPhysicalPointToIndex(p, idx4); - if (!itkImage->GetLargestPossibleRegion().IsInside(idx4)) - continue; - - double* p2 = points->GetPoint(j+1); - vnl_vector_fixed<float,3> fiber_dir; - fiber_dir[0] = p[0]-p2[0]; - fiber_dir[1] = p[1]-p2[1]; - fiber_dir[2] = p[2]-p2[2]; - fiber_dir.normalize(); - - int peak_id = -1; - GetClosestPeak(idx4, itkImage, fiber_dir, peak_id); - if (peak_id<0) - continue; - - vnl_vector_fixed<float,3> unexplained_dir; - vnl_vector_fixed<float,3> explained_dir; - vnl_vector_fixed<float,3> res_dir; vnl_vector_fixed<float,3> peak_dir; - idx4[3] = peak_id*3; - unexplained_dir[0] = unexplained_image->GetPixel(idx4); - explained_dir[0] = explained_image->GetPixel(idx4); - peak_dir[0] = itkImage->GetPixel(idx4); - res_dir[0] = residual_image->GetPixel(idx4); + idx[3] = peak_id*3; + peak_dir[0] = peak_image->GetPixel(idx); + idx[3] += 1; + peak_dir[1] = peak_image->GetPixel(idx); + idx[3] += 1; + peak_dir[2] = peak_image->GetPixel(idx); - idx4[3] += 1; - unexplained_dir[1] = unexplained_image->GetPixel(idx4); - explained_dir[1] = explained_image->GetPixel(idx4); - peak_dir[1] = itkImage->GetPixel(idx4); - res_dir[1] = residual_image->GetPixel(idx4); + peak_dir.normalize(); + peak_dir *= fitted_b[r]; - idx4[3] += 1; - unexplained_dir[2] = unexplained_image->GetPixel(idx4); - explained_dir[2] = explained_image->GetPixel(idx4); - peak_dir[2] = itkImage->GetPixel(idx4); - res_dir[2] = residual_image->GetPixel(idx4); + idx[3] = peak_id*3; + fitted_image->SetPixel(idx, peak_dir[0]); - if (dot_product(peak_dir, fiber_dir)<0) - fiber_dir *= -1; - fiber_dir *= w; - - idx4[3] = peak_id*3; - residual_image->SetPixel(idx4, res_dir[0] - fiber_dir[0]); - - idx4[3] += 1; - residual_image->SetPixel(idx4, res_dir[1] - fiber_dir[1]); - - idx4[3] += 1; - residual_image->SetPixel(idx4, res_dir[2] - fiber_dir[2]); + idx[3] += 1; + fitted_image->SetPixel(idx, peak_dir[1]); + idx[3] += 1; + fitted_image->SetPixel(idx, peak_dir[2]); + } + } - if ( fabs(unexplained_dir[0]) - fabs(fiber_dir[0]) < 0 ) // did we "overexplain" stuff? - fiber_dir = unexplained_dir; + itk::Index<4> idx; + for (idx[0]=0; idx[0]<sz_x; ++idx[0]) + for (idx[1]=0; idx[1]<sz_y; ++idx[1]) + for (idx[2]=0; idx[2]<sz_z; ++idx[2]) + { + vnl_vector_fixed<float,3> peak_dir; + vnl_vector_fixed<float,3> fitted_dir; + for (idx[3]=0; idx[3]<image_size[3]; ++idx[3]) + { + peak_dir[idx[3]%3] = peak_image->GetPixel(idx); + fitted_dir[idx[3]%3] = fitted_image->GetPixel(idx); + residual_image->SetPixel(idx, peak_image->GetPixel(idx) - fitted_image->GetPixel(idx)); + + if (idx[3]%3==2) + { + itk::Index<4> tidx= idx; + if (peak_dir.magnitude()>fitted_dir.magnitude()) + { + underexplained_image->SetPixel(tidx, peak_dir[2]-fitted_dir[2]); tidx[3]--; + underexplained_image->SetPixel(tidx, peak_dir[1]-fitted_dir[1]); tidx[3]--; + underexplained_image->SetPixel(tidx, peak_dir[0]-fitted_dir[0]); + } + else + { + overexplained_image->SetPixel(tidx, fitted_dir[2]-peak_dir[2]); tidx[3]--; + overexplained_image->SetPixel(tidx, fitted_dir[1]-peak_dir[1]); tidx[3]--; + overexplained_image->SetPixel(tidx, fitted_dir[0]-peak_dir[0]); + } + } + } + } - idx4[3] = peak_id*3; - unexplained_image->SetPixel(idx4, unexplained_dir[0] - fiber_dir[0]); - explained_image->SetPixel(idx4, explained_dir[0] + fiber_dir[0]); + itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); + writer->SetInput(fitted_image); + writer->SetFileName(outRoot + "fitted_image.nrrd"); + writer->Update(); - idx4[3] += 1; - unexplained_image->SetPixel(idx4, unexplained_dir[1] - fiber_dir[1]); - explained_image->SetPixel(idx4, explained_dir[1] + fiber_dir[1]); + writer->SetInput(residual_image); + writer->SetFileName(outRoot + "residual_image.nrrd"); + writer->Update(); - idx4[3] += 1; - unexplained_image->SetPixel(idx4, unexplained_dir[2] - fiber_dir[2]); - explained_image->SetPixel(idx4, explained_dir[2] + fiber_dir[2]); - } + writer->SetInput(overexplained_image); + writer->SetFileName(outRoot + "overexplained_image.nrrd"); + writer->Update(); - } + writer->SetInput(underexplained_image); + writer->SetFileName(outRoot + "underexplained_image.nrrd"); + writer->Update(); } - itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); - writer->SetInput(unexplained_image); - writer->SetFileName(outRoot + "unexplained_image.nrrd"); - writer->Update(); - - writer->SetInput(explained_image); - writer->SetFileName(outRoot + "explained_image.nrrd"); - writer->Update(); - - writer->SetInput(residual_image); - writer->SetFileName(outRoot + "residual_image.nrrd"); - writer->Update(); + for (unsigned int bundle=0; bundle<input_tracts.size(); bundle++) + { + input_tracts.at(bundle)->Compress(0.1); + std::string name = fib_names.at(bundle); + name = ist::GetFilenameWithoutExtension(name); + mitk::IOUtil::Save(input_tracts.at(bundle), outRoot + name + "_fitted.fib"); + } } catch (itk::ExceptionObject e) { std::cout << e; return EXIT_FAILURE; } catch (std::exception e) { std::cout << e.what(); return EXIT_FAILURE; } catch (...) { std::cout << "ERROR!?!"; return EXIT_FAILURE; } return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/Quantification/cmdapps/QballReconstruction.cpp b/Modules/DiffusionImaging/Quantification/cmdapps/QballReconstruction.cpp index 761a18dcd7..22ffb67b83 100644 --- a/Modules/DiffusionImaging/Quantification/cmdapps/QballReconstruction.cpp +++ b/Modules/DiffusionImaging/Quantification/cmdapps/QballReconstruction.cpp @@ -1,263 +1,266 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include <mitkCoreObjectFactory.h> #include "mitkImage.h" #include "itkAnalyticalDiffusionQballReconstructionImageFilter.h" #include <boost/lexical_cast.hpp> #include "mitkCommandLineParser.h" #include <mitkIOUtil.h> #include <itksys/SystemTools.hxx> #include <mitkDiffusionPropertyHelper.h> #include <mitkITKImageImport.h> #include <mitkImageCast.h> #include <mitkProperties.h> #include <mitkIOUtil.h> +#include <mitkPreferenceListReaderOptionsFunctor.h> using namespace mitk; using namespace std; /** * Perform Q-ball reconstruction using a spherical harmonics basis */ int main(int argc, char* argv[]) { - mitkCommandLineParser parser; - parser.setArgumentPrefix("--", "-"); - parser.addArgument("input", "i", mitkCommandLineParser::InputFile, "Input file", "input raw dwi (.dwi or .fsl/.fslgz)", us::Any(), false); - parser.addArgument("outFile", "o", mitkCommandLineParser::OutputFile, "Output file", "output file", us::Any(), false); - parser.addArgument("shOrder", "sh", mitkCommandLineParser::Int, "Spherical harmonics order", "spherical harmonics order", 4, true); - parser.addArgument("b0Threshold", "t", mitkCommandLineParser::Int, "b0 threshold", "baseline image intensity threshold", 0, true); - parser.addArgument("lambda", "r", mitkCommandLineParser::Float, "Lambda", "ragularization factor lambda", 0.006, true); - parser.addArgument("csa", "csa", mitkCommandLineParser::Bool, "Constant solid angle consideration", "use constant solid angle consideration"); - parser.addArgument("outputCoeffs", "shc", mitkCommandLineParser::Bool, "Output coefficients", "output file containing the SH coefficients"); - parser.addArgument("mrtrix", "mb", mitkCommandLineParser::Bool, "MRtrix", "use MRtrix compatible spherical harmonics definition"); + mitkCommandLineParser parser; + parser.setArgumentPrefix("--", "-"); + parser.addArgument("input", "i", mitkCommandLineParser::InputFile, "Input file", "input raw dwi (.dwi or .nii/.nii.gz)", us::Any(), false); + parser.addArgument("outFile", "o", mitkCommandLineParser::OutputFile, "Output file", "output file", us::Any(), false); + parser.addArgument("shOrder", "sh", mitkCommandLineParser::Int, "Spherical harmonics order", "spherical harmonics order", 4, true); + parser.addArgument("b0Threshold", "t", mitkCommandLineParser::Int, "b0 threshold", "baseline image intensity threshold", 0, true); + parser.addArgument("lambda", "r", mitkCommandLineParser::Float, "Lambda", "ragularization factor lambda", 0.006, true); + parser.addArgument("csa", "csa", mitkCommandLineParser::Bool, "Constant solid angle consideration", "use constant solid angle consideration"); + parser.addArgument("outputCoeffs", "shc", mitkCommandLineParser::Bool, "Output coefficients", "output file containing the SH coefficients"); + parser.addArgument("mrtrix", "mb", mitkCommandLineParser::Bool, "MRtrix", "use MRtrix compatible spherical harmonics definition"); - parser.setCategory("Signal Modelling"); - parser.setTitle("Qball Reconstruction"); - parser.setDescription(""); - parser.setContributor("MIC"); + parser.setCategory("Signal Modelling"); + parser.setTitle("Qball Reconstruction"); + parser.setDescription(""); + parser.setContributor("MIC"); - map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv); - if (parsedArgs.size()==0) - return EXIT_FAILURE; + map<string, us::Any> parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size()==0) + return EXIT_FAILURE; - std::string inFileName = us::any_cast<string>(parsedArgs["input"]); - std::string outfilename = us::any_cast<string>(parsedArgs["outFile"]); - outfilename = itksys::SystemTools::GetFilenamePath(outfilename)+"/"+itksys::SystemTools::GetFilenameWithoutExtension(outfilename); + std::string inFileName = us::any_cast<string>(parsedArgs["input"]); + std::string outfilename = us::any_cast<string>(parsedArgs["outFile"]); + outfilename = itksys::SystemTools::GetFilenamePath(outfilename)+"/"+itksys::SystemTools::GetFilenameWithoutExtension(outfilename); - int threshold = 0; - if (parsedArgs.count("b0Threshold")) - threshold = us::any_cast<int>(parsedArgs["b0Threshold"]); + int threshold = 0; + if (parsedArgs.count("b0Threshold")) + threshold = us::any_cast<int>(parsedArgs["b0Threshold"]); - int shOrder = 4; - if (parsedArgs.count("shOrder")) - shOrder = us::any_cast<int>(parsedArgs["shOrder"]); + int shOrder = 4; + if (parsedArgs.count("shOrder")) + shOrder = us::any_cast<int>(parsedArgs["shOrder"]); - float lambda = 0.006; - if (parsedArgs.count("lambda")) - lambda = us::any_cast<float>(parsedArgs["lambda"]); + float lambda = 0.006; + if (parsedArgs.count("lambda")) + lambda = us::any_cast<float>(parsedArgs["lambda"]); - int normalization = 0; - if (parsedArgs.count("csa") && us::any_cast<bool>(parsedArgs["csa"])) - normalization = 6; + int normalization = 0; + if (parsedArgs.count("csa") && us::any_cast<bool>(parsedArgs["csa"])) + normalization = 6; - bool outCoeffs = false; - if (parsedArgs.count("outputCoeffs")) - outCoeffs = us::any_cast<bool>(parsedArgs["outputCoeffs"]); + bool outCoeffs = false; + if (parsedArgs.count("outputCoeffs")) + outCoeffs = us::any_cast<bool>(parsedArgs["outputCoeffs"]); - bool mrTrix = false; - if (parsedArgs.count("mrtrix")) - mrTrix = us::any_cast<bool>(parsedArgs["mrtrix"]); + bool mrTrix = false; + if (parsedArgs.count("mrtrix")) + mrTrix = us::any_cast<bool>(parsedArgs["mrtrix"]); - try - { - std::vector<BaseData::Pointer> infile = mitk::IOUtil::Load(inFileName); - Image::Pointer dwi = dynamic_cast<Image*>(infile.at(0).GetPointer()); - mitk::DiffusionPropertyHelper propertyHelper(dwi); - propertyHelper.AverageRedundantGradients(0.001); - propertyHelper.InitializeImage(); - - mitk::OdfImage::Pointer image = mitk::OdfImage::New(); - mitk::Image::Pointer coeffsImage = mitk::Image::New(); - - std::cout << "SH order: " << shOrder; - std::cout << "lambda: " << lambda; - std::cout << "B0 threshold: " << threshold; - switch ( shOrder ) - { - case 4: - { - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,4,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); - - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - filter->SetUseMrtrixBasis(mrTrix); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - break; - } - case 6: - { - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,6,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); - - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - filter->SetUseMrtrixBasis(mrTrix); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - break; - } - case 8: - { - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,8,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); + try + { + mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Diffusion Weighted Images"}, {}); + std::vector< mitk::BaseData::Pointer > infile = mitk::IOUtil::Load(inFileName, &functor); - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - filter->SetUseMrtrixBasis(mrTrix); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - break; - } - case 10: - { - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,10,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); + Image::Pointer dwi = dynamic_cast<Image*>(infile.at(0).GetPointer()); + mitk::DiffusionPropertyHelper propertyHelper(dwi); + propertyHelper.AverageRedundantGradients(0.001); + propertyHelper.InitializeImage(); - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - filter->SetUseMrtrixBasis(mrTrix); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - break; - } - case 12: - { - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,12,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); + mitk::OdfImage::Pointer image = mitk::OdfImage::New(); + mitk::Image::Pointer coeffsImage = mitk::Image::New(); - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - break; - } - default: - { - std::cout << "Supplied SH order not supported. Using default order of 4."; - typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,4,ODF_SAMPLING_SIZE> FilterType; - mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); - mitk::CastToItkImage(dwi, itkVectorImagePointer); - - FilterType::Pointer filter = FilterType::New(); - filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); - filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); - filter->SetThreshold( threshold ); - filter->SetLambda(lambda); - filter->SetUseMrtrixBasis(mrTrix); - if (normalization==0) - filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); - else - filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); - filter->Update(); - image->InitializeByItk( filter->GetOutput() ); - image->SetVolume( filter->GetOutput()->GetBufferPointer() ); - coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); - coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); - } - } + std::cout << "SH order: " << shOrder; + std::cout << "lambda: " << lambda; + std::cout << "B0 threshold: " << threshold; + switch ( shOrder ) + { + case 4: + { + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,4,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); - std::string coeffout = outfilename; - coeffout += "_shcoeffs.nrrd"; + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + filter->SetUseMrtrixBasis(mrTrix); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); + break; + } + case 6: + { + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,6,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); - outfilename += ".odf"; - mitk::IOUtil::Save(image, outfilename); + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + filter->SetUseMrtrixBasis(mrTrix); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); + break; + } + case 8: + { + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,8,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); - if (outCoeffs) - mitk::IOUtil::Save(coeffsImage, coeffout); + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + filter->SetUseMrtrixBasis(mrTrix); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); + break; } - catch ( itk::ExceptionObject &err) + case 10: { - std::cout << "Exception: " << err; + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,10,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); + + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + filter->SetUseMrtrixBasis(mrTrix); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); + break; } - catch ( std::exception err) + case 12: { - std::cout << "Exception: " << err.what(); + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,12,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); + + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); + break; } - catch ( ... ) + default: { - std::cout << "Exception!"; + std::cout << "Supplied SH order not supported. Using default order of 4."; + typedef itk::AnalyticalDiffusionQballReconstructionImageFilter<short,short,float,4,ODF_SAMPLING_SIZE> FilterType; + mitk::DiffusionPropertyHelper::ImageType::Pointer itkVectorImagePointer = mitk::DiffusionPropertyHelper::ImageType::New(); + mitk::CastToItkImage(dwi, itkVectorImagePointer); + + FilterType::Pointer filter = FilterType::New(); + filter->SetBValue(mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi)); + filter->SetGradientImage( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi), itkVectorImagePointer ); + filter->SetThreshold( threshold ); + filter->SetLambda(lambda); + filter->SetUseMrtrixBasis(mrTrix); + if (normalization==0) + filter->SetNormalizationMethod(FilterType::QBAR_STANDARD); + else + filter->SetNormalizationMethod(FilterType::QBAR_SOLID_ANGLE); + filter->Update(); + image->InitializeByItk( filter->GetOutput() ); + image->SetVolume( filter->GetOutput()->GetBufferPointer() ); + coeffsImage->InitializeByItk( filter->GetCoefficientImage().GetPointer() ); + coeffsImage->SetVolume( filter->GetCoefficientImage()->GetBufferPointer() ); } - return EXIT_SUCCESS; + } + + std::string coeffout = outfilename; + coeffout += "_shcoeffs.nrrd"; + + outfilename += ".odf"; + mitk::IOUtil::Save(image, outfilename); + + if (outCoeffs) + mitk::IOUtil::Save(coeffsImage, coeffout); + } + catch ( itk::ExceptionObject &err) + { + std::cout << "Exception: " << err; + } + catch ( std::exception err) + { + std::cout << "Exception: " << err.what(); + } + catch ( ... ) + { + std::cout << "Exception!"; + } + return EXIT_SUCCESS; }