diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp index 24917a8991..6c9dfe8ba8 100755 --- a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/FitFibersToImage.cpp @@ -1,554 +1,752 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include using namespace std; typedef itksys::SystemTools ist; -typedef itk::Image ItkUcharImgType; -typedef std::tuple< ItkUcharImgType::Pointer, std::string > MaskType; -typedef mitk::DiffusionPropertyHelper DPH; -typedef itk::Point PointType; typedef itk::Point PointType4; -typedef mitk::StickModel<> ModelType; -typedef mitk::BallModel<> BallModelType; typedef itk::Image< float, 4 > PeakImgType; +const float UPSCALE = 1000.0; -vnl_vector_fixed GetCLosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed fiber_dir, bool flip_x, bool flip_y, bool flip_z ) +vnl_vector_fixed GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed fiber_dir, int& id ) { int m_NumDirs = peak_image->GetLargestPossibleRegion().GetSize()[3]/3; vnl_vector_fixed out_dir; out_dir.fill(0); float angle = 0.8; for (int i=0; i dir; idx[3] = i*3; dir[0] = peak_image->GetPixel(idx); idx[3] += 1; dir[1] = peak_image->GetPixel(idx); idx[3] += 1; dir[2] = peak_image->GetPixel(idx); float mag = dir.magnitude(); if (magangle) { angle = fabs(a); if (a<0) out_dir = -dir; else out_dir = dir; out_dir *= mag; - out_dir *= 1000; + out_dir *= UPSCALE; // for the fit it's better if the magnitude is larger since many fibers pass each voxel and otherwise the individual contributions would be very small + id = i; } } return out_dir; } -std::vector SolveLinear( std::string , std::vector< mitk::FiberBundle::Pointer > inputTractogram, mitk::Image::Pointer inputImage, bool flip_x=false, bool flip_y=false, bool flip_z=false ) +class VnlCostFunction : public vnl_cost_function +{ +public: + + vnl_sparse_matrix_linear_system< double >* S; + vnl_sparse_matrix< double > m_A; + vnl_vector< double > m_b; + itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen; + + void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector& b) + { + S = new vnl_sparse_matrix_linear_system(A, b); + m_A = A; + m_b = b; + randGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); + } + + VnlCostFunction(const int NumVars) : vnl_cost_function(NumVars) + { + } + + double f_itk(vnl_vector const &x) const + { + double min = x.min_value(); + if( min<0 ) + return 10000 * -min; + + double cost = S->get_rms_error(x); + double regu = x.squared_magnitude()/x.size(); + cost += regu; + return cost; + } + + double f(vnl_vector const &x) + { + double cost = S->get_rms_error(x); + + double regu = x.squared_magnitude()/x.size(); + +// unsigned int norm = 0; +// for (unsigned int i=0; i const &x, vnl_vector &dx) +// { +// fdgradf(x, dx); +// } + + void gradf(vnl_vector const &x, vnl_vector &dx) + { + dx.fill(0.0); + double mag = x.magnitude(); + unsigned int N = m_b.size(); + + vnl_vector d; d.set_size(N); + S->multiply(x,d); + d -= m_b; + + vnl_vector numerator; numerator.set_size(x.size()); + S->transpose_multiply(d, dx); + dx *= 2.0/N; + + if (mag>0) + dx += x/mag; + } + +}; + +class ItkCostFunction : public itk::SingleValuedCostFunction +{ +public: + /** Standard class typedefs. */ + typedef ItkCostFunction Self; + typedef SingleValuedCostFunction Superclass; + typedef itk::SmartPointer Pointer; + typedef itk::SmartPointer ConstPointer; + + /** Method for creation through the object factory. */ + itkNewMacro(Self) + + /** Run-time type information (and related methods). */ + itkTypeMacro(ItkCostFunction, SingleValuedCostfunction) + + void SetVnlCostFunction(VnlCostFunction& cf) { m_VnlCostFunction = cf; } + unsigned int GetNumberOfParameters(void) const { return m_VnlCostFunction.get_number_of_unknowns(); } // itk::CostFunction + + MeasureType GetValue(const ParametersType & parameters) const + { + return m_VnlCostFunction.f_itk(parameters); + } + + void GetDerivative(const ParametersType &, DerivativeType & ) const { + throw itk::ExceptionObject( __FILE__, __LINE__, "No derivative is available for this cost function."); + } + +protected: + ItkCostFunction(){} + ~ItkCostFunction(){} + + VnlCostFunction m_VnlCostFunction = VnlCostFunction(1); + +private: + ItkCostFunction(const Self &); //purposely not implemented + void operator = (const Self &); //purposely not implemented +}; + +void OptimizeItk(VnlCostFunction& cf, vnl_vector& x, int iter, double lb, double ub) +{ + ItkCostFunction::ParametersType p; p.SetData(x.data_block(), x.size()); + + ItkCostFunction::Pointer itk_cf = ItkCostFunction::New(); + itk_cf->SetVnlCostFunction(cf); + + std::pair< double, double > bounds; bounds.first = lb; bounds.second = ub; + MITK_INFO << bounds; + +// itk::Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen = itk::Statistics::MersenneTwisterRandomVariateGenerator::New(); +// itk::OnePlusOneEvolutionaryOptimizer::Pointer opt = itk::OnePlusOneEvolutionaryOptimizer::New(); +// opt->SetCostFunction(itk_cf); +// opt->MinimizeOn(); +// opt->SetInitialPosition(p); +// opt->SetNormalVariateGenerator(randGen); +// opt->SetMaximumIteration(iter); +// opt->StartOptimization(); + +// itk::ParticleSwarmOptimizer::Pointer opt = itk::ParticleSwarmOptimizer::New(); +// opt->SetCostFunction(itk_cf); +// opt->SetInitialPosition(p); +// opt->SetParameterBounds(bounds, x.size()); +// opt->SetMaximalNumberOfIterations(iter); +// opt->SetNumberOfParticles(100); +// opt->SetParametersConvergenceTolerance(0.01, x.size()); +// opt->SetNumberOfGenerationsWithMinimalImprovement(3); +// opt->StartOptimization(); + +// itk::GradientDescentOptimizer::Pointer opt = itk::GradientDescentOptimizer::New(); +// opt->SetCostFunction(itk_cf); +// opt->SetInitialPosition(p); +// opt->SetMinimize(true); +// opt->SetNumberOfIterations(iter); +// opt->StartOptimization(); + + itk::SPSAOptimizer::Pointer opt = itk::SPSAOptimizer::New(); + opt->SetCostFunction(itk_cf); + opt->SetInitialPosition(p); + opt->SetMinimize(true); + opt->SetNumberOfPerturbations(iter); + opt->StartOptimization(); + + x.copy_in(opt->GetCurrentPosition().data_block()); + for (unsigned int i=0; iGetCurrentPosition()[i]; +// MITK_INFO << "Cost: " << opt->GetCurrentCost(); + +} + +std::vector FitFibers( std::string , std::vector< mitk::FiberBundle::Pointer > input_tracts, mitk::Image::Pointer inputImage, bool single_fiber_fit, int max_iter, float g_tol, bool lb ) { typedef mitk::ImageToItk< PeakImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(inputImage); caster->Update(); PeakImgType::Pointer itkImage = caster->GetOutput(); unsigned int* image_size = inputImage->GetDimensions(); int sz_x = image_size[0]; int sz_y = image_size[1]; int sz_z = image_size[2]; + int sz_peaks = image_size[3]; int num_voxels = sz_x*sz_y*sz_z; - unsigned int num_unknowns = inputTractogram.size(); //inputTractogram->GetNumFibers(); - unsigned int number_of_residuals = num_voxels * 3; + unsigned int num_unknowns = input_tracts.size(); //inputTractogram->GetNumFibers(); + if (single_fiber_fit) + { + num_unknowns = 0; + for (unsigned int bundle=0; bundleGetNumFibers(); + } + + unsigned int number_of_residuals = num_voxels * sz_peaks; // create linear system MITK_INFO << "Num. unknowns: " << num_unknowns; MITK_INFO << "Num. residuals: " << number_of_residuals; - vigra::MultiArray<2, double> test_m(vigra::Shape2(number_of_residuals, 1000000), 0.0); - MITK_INFO << test_m.height(); + MITK_INFO << "Creating system ..."; + vnl_sparse_matrix< double > A; A.set_size(number_of_residuals, num_unknowns); + vnl_vector b; b.set_size(number_of_residuals); b.fill(0.0); - MITK_INFO << "Creating matrices ..."; - vigra::MultiArray<2, double> A(vigra::Shape2(number_of_residuals, num_unknowns), 0.0); - vigra::MultiArray<2, double> b(vigra::Shape2(number_of_residuals, 1), 0.0); - vigra::MultiArray<2, double> x(vigra::Shape2(num_unknowns, 1), 1.0); + float max_peak_mag = 0; + int max_peak_idx = -1; - MITK_INFO << "Filling matrices ..."; - for (unsigned int bundle=0; bundle polydata = inputTractogram.at(bundle)->GetFiberPolyData(); + vtkSmartPointer polydata = input_tracts.at(bundle)->GetFiberPolyData(); - for (int i=0; iGetNumFibers(); ++i) + for (int i=0; iGetNumFibers(); ++i) { vtkCell* cell = polydata->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); if (numPoints<2) MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; for (int j=0; jGetPoint(j); PointType4 p; p[0]=p1[0]; p[1]=p1[1]; p[2]=p1[2]; p[3]=0; itk::Index<4> idx4; itkImage->TransformPhysicalPointToIndex(p, idx4); if (!itkImage->GetLargestPossibleRegion().IsInside(idx4)) continue; double* p2 = points->GetPoint(j+1); vnl_vector_fixed fiber_dir; fiber_dir[0] = p[0]-p2[0]; fiber_dir[1] = p[1]-p2[1]; fiber_dir[2] = p[2]-p2[2]; fiber_dir.normalize(); - vnl_vector_fixed odf_peak = GetCLosestPeak(idx4, itkImage, fiber_dir, flip_x, flip_y, flip_z); - if (odf_peak.magnitude()<0.001) + int peak_id = -1; + vnl_vector_fixed odf_peak = GetClosestPeak(idx4, itkImage, fiber_dir, peak_id); + float peak_mag = odf_peak.magnitude(); + if (peak_id<0) continue; int x = idx4[0]; int y = idx4[1]; int z = idx4[2]; - unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z; + unsigned int linear_index = sz_peaks*(x + sz_x*y + sz_x*sz_y*z); - for (unsigned int k=0; k<3; ++k) + if (peak_mag>max_peak_mag) { - b(3*linear_index + k, 0) = (double)odf_peak[k]; - A(3*linear_index + k, bundle) = A(3*linear_index + k, bundle) + (double)fiber_dir[k]; + max_peak_mag = peak_mag; + max_peak_idx = linear_index + 3*peak_id; } - } - } - } - - MITK_INFO << "Solving linear system"; - vigra::linalg::nonnegativeLeastSquares(A, b, x); - - std::vector weights; - for (unsigned int i=0; i SolveEvo(mitk::FiberBundle* inputTractogram, mitk::Image::Pointer inputImage, ModelType signalModel, BallModelType ballModel, float start_weight, int num_iterations=1000) -{ - std::vector out_weights; - DPH::ImageType::Pointer itkImage = DPH::GetItkVectorImage(inputImage); - itk::VectorImage< double, 3>::Pointer simulatedImage = itk::VectorImage< double, 3>::New(); - simulatedImage->SetSpacing(itkImage->GetSpacing()); - simulatedImage->SetOrigin(itkImage->GetOrigin()); - simulatedImage->SetDirection(itkImage->GetDirection()); - simulatedImage->SetRegions(itkImage->GetLargestPossibleRegion()); - simulatedImage->SetVectorLength(itkImage->GetVectorLength()); - simulatedImage->Allocate(); - DPH::ImageType::PixelType zero_signal; - zero_signal.SetSize(itkImage->GetVectorLength()); - zero_signal.Fill(0); - simulatedImage->FillBuffer(zero_signal); - - MITK_INFO << "start_weight: " << start_weight; - double step = start_weight/10; - - vtkSmartPointer polydata = inputTractogram->GetFiberPolyData(); - - unsigned int* image_size = inputImage->GetDimensions(); - int sz_x = image_size[0]; - int sz_y = image_size[1]; - - MITK_INFO << "INITIALIZING"; - std::vector< std::vector< ModelType::PixelType > > fiber_model_signals; - std::vector< std::vector< DPH::ImageType::IndexType > > fiber_image_indices; - std::map< unsigned int, std::vector > image_index_to_fiber_indices; - - std::vector< int > fiber_indices; - int f = 0; - for (int i=0; iGetNumFibers(); ++i) - { - vtkCell* cell = polydata->GetCell(i); - int numPoints = cell->GetNumberOfPoints(); - vtkPoints* points = cell->GetPoints(); - -// if (numPoints<2) -// continue; - - std::vector< ModelType::PixelType > model_signals; - std::vector< DPH::ImageType::IndexType > image_indices; - - for (int j=0; jGetPoint(j); - PointType p; - p[0]=p1[0]; - p[1]=p1[1]; - p[2]=p1[2]; - - DPH::ImageType::IndexType idx; - itkImage->TransformPhysicalPointToIndex(p, idx); - if (!itkImage->GetLargestPossibleRegion().IsInside(idx)) - continue; - - double* p2 = points->GetPoint(j+1); - ModelType::GradientType d; - d[0] = p[0]-p2[0]; - d[1] = p[1]-p2[1]; - d[2] = p[2]-p2[2]; - signalModel.SetFiberDirection(d); - - ModelType::PixelType sig = signalModel.SimulateMeasurement(); - simulatedImage->SetPixel(idx, simulatedImage->GetPixel(idx) + start_weight * sig); - - model_signals.push_back(step*sig); - image_indices.push_back(idx); - unsigned int linear_index = idx[0] + sz_x*idx[1] + sz_x*sz_y*idx[2]; - - if (image_index_to_fiber_indices.count(linear_index)==0) - { - image_index_to_fiber_indices[linear_index] = {i}; - } - else - { - std::vector< int > index_fiber_indices = image_index_to_fiber_indices[linear_index]; - if(std::find(index_fiber_indices.begin(), index_fiber_indices.end(), i) == index_fiber_indices.end()) + if (peak_mag model_signals = fiber_model_signals.at(f); - std::vector< DPH::ImageType::IndexType > image_indices = fiber_image_indices.at(f); - - double E_ext_old = 0; - double E_ext_new = 0; - double E_int_old = 0; - double E_int_new = 0; - - int add = std::rand()%2; -// add = 1; - int use_ball = std::rand()%2; - use_ball = 0; - - if (add==0 && use_ball==0 && out_weights[f]GetPixel(idx); - ModelType::PixelType simulated_val = simulatedImage->GetPixel(idx); - ModelType::PixelType model_val = model_signals.at(c); // value of fiber model at currnet fiber position - if (use_ball==1) - model_val = ballSignal; - - // EXTERNAL ENERGY - for (unsigned int g=0; g index_fiber_indices = image_index_to_fiber_indices[linear_index]; - - if (index_fiber_indices.size()>1) - { - float mean_weight = 0; - for (auto neighbor_index : index_fiber_indices) - { - if (neighbor_index!=f) - mean_weight += out_weights[neighbor_index]; - } - mean_weight /= (index_fiber_indices.size()-1); - E_int_old += fabs(mean_weight-old_weight); - E_int_new += fabs(mean_weight-new_weight); - } } - E_ext_old /= image_indices.size(); - E_ext_new /= image_indices.size(); - E_int_old /= image_indices.size(); - E_int_new /= image_indices.size(); - -// MITK_INFO << "EXT: " << E_ext_old << " --> " << E_ext_new; -// MITK_INFO << "INT: " << E_int_old-E_int_new; - - double R = exp( (E_ext_old-E_ext_new)/T + 0*(E_int_old-E_int_new)/T ); -// MITK_INFO << R; - - float p = static_cast (rand()) / static_cast (RAND_MAX); - - p = E_ext_new; - R = E_ext_old; - - MITK_INFO << add << " - " << i; - if (pGetPixel(idx); - ModelType::PixelType dVal = model_signals.at(c); - if (use_ball==1) - dVal = ballSignal; + ++fiber_count; + } + } - if (add==1) - { - sVal += dVal; - } - else - { - sVal -= dVal; - } + vnl_vector_fixed max_corr_fiber_dir; max_corr_fiber_dir.fill(0.0); + vnl_vector_fixed min_corr_fiber_dir; min_corr_fiber_dir.fill(0.0); + for (unsigned int i=0; iSetPixel(idx, sVal); - ++c; - } + min_corr_fiber_dir[0] += A.get(min_peak_idx, i); + min_corr_fiber_dir[1] += A.get(min_peak_idx+1, i); + min_corr_fiber_dir[2] += A.get(min_peak_idx+2, i); + } - out_weights[f] = new_weight; - } + float upper_bound = max_peak_mag/max_corr_fiber_dir.magnitude(); + float lower_bound = min_peak_mag/min_corr_fiber_dir.magnitude(); + + if (!lb || lower_bound>=upper_bound) + lower_bound = 0; + + MITK_INFO << "Lower bound: " << lower_bound; + MITK_INFO << "Upper bound: " << upper_bound; + + itk::TimeProbe clock; + clock.Start(); + + MITK_INFO << "Fitting fibers"; + VnlCostFunction cost(num_unknowns); + cost.SetProblem(A, b); + + MITK_INFO << g_tol << " " << max_iter; + vnl_vector x; x.set_size(num_unknowns); x.fill( (upper_bound-lower_bound)/2 ); +// OptimizeItk(cost, x, max_iter, lower_bound, upper_bound); + + vnl_lbfgsb minimizer(cost); + vnl_vector l; l.set_size(num_unknowns); l.fill(lower_bound); + vnl_vector u; u.set_size(num_unknowns); u.fill(upper_bound); + vnl_vector bound_selection; bound_selection.set_size(num_unknowns); bound_selection.fill(2); + minimizer.set_bound_selection(bound_selection); + minimizer.set_lower_bound(l); + minimizer.set_upper_bound(u); + minimizer.set_trace(true); + minimizer.set_projected_gradient_tolerance(g_tol); + if (max_iter>0) + minimizer.set_max_function_evals(max_iter); + minimizer.minimize(x); + MITK_INFO << "Residual error: " << minimizer.get_end_error(); + MITK_INFO << "NumEvals: " << minimizer.get_num_evaluations(); + MITK_INFO << "NumIterations: " << minimizer.get_num_iterations(); + +// vnl_sparse_matrix_linear_system S(A, b); +// vnl_lsqr linear_solver( S ); +// linear_solver.set_max_iterations(max_iter); +// linear_solver.minimize(x); + + clock.Stop(); + int h = clock.GetTotal()/3600; + int m = ((int)clock.GetTotal()%3600)/60; + int s = (int)clock.GetTotal()%60; + MITK_INFO << "Optimization took " << h << "h, " << m << "m and " << s << "s"; - } -// MITK_INFO << "Accepted: " << (float)accepted/fiber_indices.size(); + std::vector weights; + float max_w = 0; + for (unsigned int i=0; imax_w) + max_w = x[i]; + weights.push_back(x[i]); } + MITK_INFO << "Max w: " << max_w; - typedef itk::ImageFileWriter< itk::VectorImage< double, 3> > WriterType; - WriterType::Pointer writer = WriterType::New(); - writer->SetFileName("/home/neher/Projects/TractPlausibility/model_signal.nrrd"); - writer->SetInput(simulatedImage); - writer->Update(); - - return out_weights; + return weights; } - std::vector< string > get_file_list(const std::string& path) { std::vector< string > file_list; - itk::Directory::Pointer dir = itk::Directory::New(); if (dir->Load(path.c_str())) { int n = dir->GetNumberOfFiles(); for (int r = 0; r < n; r++) { const char *filename = dir->GetFile(r); std::string ext = ist::GetFilenameExtension(filename); if (ext==".fib" || ext==".trk") file_list.push_back(path + '/' + filename); } } - return file_list; } int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("Fit Fibers To Image"); parser.setCategory("Fiber Tracking Evaluation"); parser.setDescription(""); parser.setContributor("MIC"); parser.setArgumentPrefix("--", "-"); - parser.addArgument("input_tractograms", "i1", mitkCommandLineParser::StringList, "Input tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); - parser.addArgument("input_peaks", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); - parser.addArgument("out", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); + parser.addArgument("", "i1", mitkCommandLineParser::StringList, "Input tractograms:", "input tractograms (.fib, vtk ascii file format)", us::Any(), false); + parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); + parser.addArgument("", "it", mitkCommandLineParser::Int, "", ""); + parser.addArgument("", "s", mitkCommandLineParser::Bool, "", ""); + parser.addArgument("", "lb", mitkCommandLineParser::Bool, "", ""); + parser.addArgument("", "g", mitkCommandLineParser::Float, "", ""); + parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; - mitkCommandLineParser::StringContainerType fib_files = us::any_cast(parsedArgs["input_tractograms"]); - string dwiFile = us::any_cast(parsedArgs["input_peaks"]); - string outRoot = us::any_cast(parsedArgs["out"]); + mitkCommandLineParser::StringContainerType fib_files = us::any_cast(parsedArgs["i1"]); + string dwiFile = us::any_cast(parsedArgs["i2"]); + string outRoot = us::any_cast(parsedArgs["o"]); + + bool single_fib = false; + if (parsedArgs.count("s")) + single_fib = us::any_cast(parsedArgs["s"]); + + int max_iter = 0; + if (parsedArgs.count("it")) + max_iter = us::any_cast(parsedArgs["it"]); + + float g_tol = 1e-5; + if (parsedArgs.count("g")) + g_tol = us::any_cast(parsedArgs["g"]); + + bool lb = false; + if (parsedArgs.count("lb")) + lb = us::any_cast(parsedArgs["lb"]); try { - std::vector< mitk::FiberBundle::Pointer > bundles; + std::vector< mitk::FiberBundle::Pointer > input_tracts; - mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Diffusion Weighted Images", "Fiberbundles"}, {}); - mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(dwiFile, &functor)[0].GetPointer()); + mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); + mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(dwiFile, &functor)[0].GetPointer()); float minSpacing = 1; if(inputImage->GetGeometry()->GetSpacing()[0]GetGeometry()->GetSpacing()[1] && inputImage->GetGeometry()->GetSpacing()[0]GetGeometry()->GetSpacing()[2]) minSpacing = inputImage->GetGeometry()->GetSpacing()[0]; else if (inputImage->GetGeometry()->GetSpacing()[1] < inputImage->GetGeometry()->GetSpacing()[2]) minSpacing = inputImage->GetGeometry()->GetSpacing()[1]; else minSpacing = inputImage->GetGeometry()->GetSpacing()[2]; std::vector< std::string > fib_names; for (auto item : fib_files) { if ( ist::FileIsDirectory(item) ) { for ( auto fibFile : get_file_list(item) ) { mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(fibFile)[0].GetPointer()); if (inputTractogram.IsNull()) continue; - inputTractogram->ResampleLinear(minSpacing/4); - bundles.push_back(inputTractogram); + inputTractogram->ResampleLinear(minSpacing/10); + input_tracts.push_back(inputTractogram); fib_names.push_back(fibFile); } } else { mitk::FiberBundle::Pointer inputTractogram = dynamic_cast(mitk::IOUtil::Load(item)[0].GetPointer()); if (inputTractogram.IsNull()) continue; - inputTractogram->ResampleLinear(minSpacing/4); - bundles.push_back(inputTractogram); + inputTractogram->ResampleLinear(minSpacing/10); + input_tracts.push_back(inputTractogram); fib_names.push_back(item); } } - std::vector weights = SolveLinear(outRoot, bundles, inputImage); + std::vector weights = FitFibers(outRoot, input_tracts, inputImage, single_fib, max_iter, g_tol, lb); - for (unsigned int i=0; iSetFiberWeight(i, weights.at(i)); - bundle->SetFiberWeights(weights.at(i)); - mitk::IOUtil::Save(bundle, outRoot + name + "_fitted.fib"); + unsigned int fiber_count = 0; + + for (unsigned int bundle=0; bundleGetNumFibers(); i++) + { + fib->SetFiberWeight(i, weights.at(fiber_count)); + ++fiber_count; + } + + std::string name = fib_names.at(bundle); + name = ist::GetFilenameWithoutExtension(name); + mitk::IOUtil::Save(fib, outRoot + name + "_fitted.fib"); + } + } + else + { + for (unsigned int i=0; iSetFiberWeights(weights.at(i)); + mitk::IOUtil::Save(bundle, outRoot + name + "_fitted.fib"); + } + } + + + // OUTPUT IMAGES + MITK_INFO << "Generating output images ..."; + typedef mitk::ImageToItk< PeakImgType > CasterType; + CasterType::Pointer caster = CasterType::New(); + caster->SetInput(inputImage); + caster->Update(); + PeakImgType::Pointer itkImage = caster->GetOutput(); + + itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); + duplicator->SetInputImage(itkImage); + duplicator->Update(); + PeakImgType::Pointer unexplained_image = duplicator->GetOutput(); + + duplicator->SetInputImage(unexplained_image); + duplicator->Update(); + PeakImgType::Pointer residual_image = duplicator->GetOutput(); + + duplicator->SetInputImage(residual_image); + duplicator->Update(); + PeakImgType::Pointer explained_image = duplicator->GetOutput(); + explained_image->FillBuffer(0.0); + + for (unsigned int bundle=0; bundle polydata = input_tracts.at(bundle)->GetFiberPolyData(); + + for (int i=0; iGetNumFibers(); ++i) + { + vtkCell* cell = polydata->GetCell(i); + int numPoints = cell->GetNumberOfPoints(); + vtkPoints* points = cell->GetPoints(); + + if (numPoints<2) + MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; + + float w = input_tracts.at(bundle)->GetFiberWeight(i)/UPSCALE; + + for (int j=0; jGetPoint(j); + PointType4 p; + p[0]=p1[0]; + p[1]=p1[1]; + p[2]=p1[2]; + p[3]=0; + + itk::Index<4> idx4; + itkImage->TransformPhysicalPointToIndex(p, idx4); + if (!itkImage->GetLargestPossibleRegion().IsInside(idx4)) + continue; + + double* p2 = points->GetPoint(j+1); + vnl_vector_fixed fiber_dir; + fiber_dir[0] = p[0]-p2[0]; + fiber_dir[1] = p[1]-p2[1]; + fiber_dir[2] = p[2]-p2[2]; + fiber_dir.normalize(); + + int peak_id = -1; + GetClosestPeak(idx4, itkImage, fiber_dir, peak_id); + if (peak_id<0) + continue; + + vnl_vector_fixed unexplained_dir; + vnl_vector_fixed explained_dir; + vnl_vector_fixed res_dir; + vnl_vector_fixed peak_dir; + + idx4[3] = peak_id*3; + unexplained_dir[0] = unexplained_image->GetPixel(idx4); + explained_dir[0] = explained_image->GetPixel(idx4); + peak_dir[0] = itkImage->GetPixel(idx4); + res_dir[0] = residual_image->GetPixel(idx4); + + idx4[3] += 1; + unexplained_dir[1] = unexplained_image->GetPixel(idx4); + explained_dir[1] = explained_image->GetPixel(idx4); + peak_dir[1] = itkImage->GetPixel(idx4); + res_dir[1] = residual_image->GetPixel(idx4); + + idx4[3] += 1; + unexplained_dir[2] = unexplained_image->GetPixel(idx4); + explained_dir[2] = explained_image->GetPixel(idx4); + peak_dir[2] = itkImage->GetPixel(idx4); + res_dir[2] = residual_image->GetPixel(idx4); + + if (dot_product(peak_dir, fiber_dir)<0) + fiber_dir *= -1; + fiber_dir *= w; + + idx4[3] = peak_id*3; + residual_image->SetPixel(idx4, res_dir[0] - fiber_dir[0]); + + idx4[3] += 1; + residual_image->SetPixel(idx4, res_dir[1] - fiber_dir[1]); + + idx4[3] += 1; + residual_image->SetPixel(idx4, res_dir[2] - fiber_dir[2]); + + + if ( fabs(unexplained_dir[0]) - fabs(fiber_dir[0]) < 0 ) // did we "overexplain" stuff? + fiber_dir = unexplained_dir; + + idx4[3] = peak_id*3; + unexplained_image->SetPixel(idx4, unexplained_dir[0] - fiber_dir[0]); + explained_image->SetPixel(idx4, explained_dir[0] + fiber_dir[0]); + + idx4[3] += 1; + unexplained_image->SetPixel(idx4, unexplained_dir[1] - fiber_dir[1]); + explained_image->SetPixel(idx4, explained_dir[1] + fiber_dir[1]); + + idx4[3] += 1; + unexplained_image->SetPixel(idx4, unexplained_dir[2] - fiber_dir[2]); + explained_image->SetPixel(idx4, explained_dir[2] + fiber_dir[2]); + } + + } } + + itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); + writer->SetInput(unexplained_image); + writer->SetFileName(outRoot + "unexplained_image.nrrd"); + writer->Update(); + + writer->SetInput(explained_image); + writer->SetFileName(outRoot + "explained_image.nrrd"); + writer->Update(); + + writer->SetInput(residual_image); + writer->SetFileName(outRoot + "residual_image.nrrd"); + writer->Update(); } catch (itk::ExceptionObject e) { std::cout << e; return EXIT_FAILURE; } catch (std::exception e) { std::cout << e.what(); return EXIT_FAILURE; } catch (...) { std::cout << "ERROR!?!"; return EXIT_FAILURE; } return EXIT_SUCCESS; }