diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h index 4cc633bfbe..d784f3b2b7 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.h @@ -1,253 +1,252 @@ #ifndef __itkFitFibersToImageFilter_h__ #define __itkFitFibersToImageFilter_h__ // MITK #include #include #include #include #include #include #include #include #include namespace itk{ /** -* \brief */ +* \brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). */ class FitFibersToImageFilter : public ImageSource< mitk::PeakImage::ItkPeakImageType > { public: typedef FitFibersToImageFilter Self; typedef ProcessObject Superclass; typedef SmartPointer< Self > Pointer; typedef SmartPointer< const Self > ConstPointer; typedef itk::Point PointType4; typedef mitk::PeakImage::ItkPeakImageType PeakImgType; -// typedef std::vector< OutputImageType::Pointer > OutputImageContainerType; itkFactorylessNewMacro(Self) itkCloneMacro(Self) itkTypeMacro( FitFibersToImageFilter, ImageSource ) itkSetMacro( PeakImage, PeakImgType::Pointer) itkGetMacro( PeakImage, PeakImgType::Pointer) itkSetMacro( FitIndividualFibers, bool) itkGetMacro( FitIndividualFibers, bool) itkSetMacro( GradientTolerance, double) itkGetMacro( GradientTolerance, double) itkSetMacro( Lambda, double) itkGetMacro( Lambda, double) itkSetMacro( MaxIterations, int) itkGetMacro( MaxIterations, int) itkSetMacro( FiberSampling, float) itkGetMacro( FiberSampling, float) itkSetMacro( FilterOutliers, bool) itkGetMacro( FilterOutliers, bool) itkSetMacro( Verbose, bool) itkGetMacro( Verbose, bool) itkSetMacro( DeepCopy, bool) itkGetMacro( DeepCopy, bool) itkGetMacro( Weights, vnl_vector) itkGetMacro( FittedImage, PeakImgType::Pointer) itkGetMacro( ResidualImage, PeakImgType::Pointer) itkGetMacro( OverexplainedImage, PeakImgType::Pointer) itkGetMacro( UnderexplainedImage, PeakImgType::Pointer) itkGetMacro( Coverage, double) itkGetMacro( Overshoot, double) itkGetMacro( MeanWeight, double) itkGetMacro( MedianWeight, double) itkGetMacro( MinWeight, double) itkGetMacro( MaxWeight, double) void SetTractograms(const std::vector &tractograms); void GenerateData() override; std::vector GetTractograms() const; protected: FitFibersToImageFilter(); virtual ~FitFibersToImageFilter(); vnl_vector_fixed GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer m_PeakImage , vnl_vector_fixed fiber_dir, int& id, double& w ); std::vector< mitk::FiberBundle::Pointer > m_Tractograms; PeakImgType::Pointer m_PeakImage; bool m_FitIndividualFibers; double m_GradientTolerance; double m_Lambda; int m_MaxIterations; float m_FiberSampling; double m_Coverage; double m_Overshoot; bool m_FilterOutliers; double m_MeanWeight; double m_MedianWeight; double m_MinWeight; double m_MaxWeight; bool m_Verbose; bool m_DeepCopy; // output vnl_vector m_Weights; PeakImgType::Pointer m_UnderexplainedImage; PeakImgType::Pointer m_OverexplainedImage; PeakImgType::Pointer m_ResidualImage; PeakImgType::Pointer m_FittedImage; }; } class VnlCostFunction : public vnl_cost_function { public: vnl_sparse_matrix_linear_system< double >* S; vnl_sparse_matrix< double > m_A; vnl_sparse_matrix< double > m_A_Ones; // matrix indicating active weights with 1 vnl_vector< double > m_b; double m_Lambda; // regularization factor vnl_vector row_sums; // number of active weights per row vnl_vector local_weight_means; // mean weight of each row void SetProblem(vnl_sparse_matrix< double >& A, vnl_vector& b, double lambda) { S = new vnl_sparse_matrix_linear_system(A, b); m_A = A; m_b = b; m_Lambda = lambda; m_A_Ones.set_size(m_A.rows(), m_A.cols()); m_A.reset(); while (m_A.next()) m_A_Ones.put(m_A.getrow(), m_A.getcolumn(), 1); unsigned int N = m_b.size(); vnl_vector ones; ones.set_size(dim); ones.fill(1.0); row_sums.set_size(N); m_A_Ones.mult(ones, row_sums); local_weight_means.set_size(N); } VnlCostFunction(const int NumVars) : vnl_cost_function(NumVars) { } void regu_MSE(vnl_vector const &x, double& cost) { double mean = x.mean(); vnl_vector tx = x-mean; cost += m_Lambda*1e8*tx.squared_magnitude()/x.size(); } void regu_MSM(vnl_vector const &x, double& cost) { cost += m_Lambda*1e8*x.squared_magnitude()/x.size(); } void regu_localMSE(vnl_vector const &x, double& cost) { m_A_Ones.mult(x, local_weight_means); local_weight_means = element_quotient(local_weight_means, row_sums); m_A_Ones.reset(); double regu = 0; while (m_A_Ones.next()) { double d = 0; if (x[m_A_Ones.getcolumn()]>local_weight_means[m_A_Ones.getrow()]) d = std::exp(x[m_A_Ones.getcolumn()]) - std::exp(local_weight_means[m_A_Ones.getrow()]); else d = x[m_A_Ones.getcolumn()] - local_weight_means[m_A_Ones.getrow()]; regu += d*d; } cost += m_Lambda*regu/dim; } void grad_regu_MSE(vnl_vector const &x, vnl_vector &dx) { double mean = x.mean(); vnl_vector tx = x-mean; vnl_vector tx2(dim, 0.0); vnl_vector h(dim, 1.0); for (int c=0; c const &x, vnl_vector &dx) { dx += m_Lambda*1e8*2.0*x/dim; } void grad_regu_localMSE(vnl_vector const &x, vnl_vector &dx) { m_A_Ones.mult(x, local_weight_means); local_weight_means = element_quotient(local_weight_means, row_sums); vnl_vector exp_x = x.apply(std::exp); vnl_vector exp_means = local_weight_means.apply(std::exp); vnl_vector tdx(dim, 0); m_A_Ones.reset(); while (m_A_Ones.next()) { int c = m_A_Ones.getcolumn(); int r = m_A_Ones.getrow(); if (x[c]>local_weight_means[r]) tdx[c] += exp_x[c] * ( exp_x[c] - exp_means[r] ); else tdx[c] += x[c] - local_weight_means[r]; } dx += tdx*2.0*m_Lambda/dim; } double f(vnl_vector const &x) { double cost = S->get_rms_error(x); cost *= cost; regu_localMSE(x, cost); return cost; } void gradf(vnl_vector const &x, vnl_vector &dx) { dx.fill(0.0); unsigned int N = m_b.size(); vnl_vector d; d.set_size(N); S->multiply(x,d); d -= m_b; S->transpose_multiply(d, dx); dx *= 2.0/N; grad_regu_localMSE(x,dx); } }; #ifndef ITK_MANUAL_INSTANTIATION #include "itkFitFibersToImageFilter.cpp" #endif #endif // __itkFitFibersToImageFilter_h__ diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp index 800ca9a878..b1eecb53d6 100755 --- a/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/TractographyEvaluation/TractPlausibilityFit.cpp @@ -1,263 +1,271 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using namespace std; typedef itksys::SystemTools ist; typedef itk::Point PointType4; typedef itk::Image< float, 4 > PeakImgType; std::vector< string > get_file_list(const std::string& path) { std::vector< string > file_list; itk::Directory::Pointer dir = itk::Directory::New(); if (dir->Load(path.c_str())) { int n = dir->GetNumberOfFiles(); for (int r = 0; r < n; r++) { const char *filename = dir->GetFile(r); std::string ext = ist::GetFilenameExtension(filename); if (ext==".fib" || ext==".trk") file_list.push_back(path + '/' + filename); } } return file_list; } /*! \brief Fits the tractogram to the input peak image by assigning a weight to each fiber (similar to https://doi.org/10.1016/j.neuroimage.2015.06.092). */ int main(int argc, char* argv[]) { mitkCommandLineParser parser; - parser.setTitle("Fit Fibers To Image"); - parser.setCategory("Fiber Tracking and Processing Methods"); - parser.setDescription("Assigns a weight to each fiber in order to optimally explain the input peak image"); + parser.setTitle(""); + parser.setCategory("Fiber Tracking Evaluation"); + parser.setDescription(""); parser.setContributor("MIC"); parser.setArgumentPrefix("--", "-"); parser.addArgument("", "i1", mitkCommandLineParser::InputFile, "Input tractogram:", "input tractogram (.fib, vtk ascii file format)", us::Any(), false); parser.addArgument("", "i2", mitkCommandLineParser::InputFile, "Input peaks:", "input peak image", us::Any(), false); parser.addArgument("", "i3", mitkCommandLineParser::InputFile, "", "", us::Any(), false); parser.addArgument("min_gain", "", mitkCommandLineParser::Float, "Min. gain:", "process stops if remaining bundles don't contribute enough", 0.05); parser.addArgument("", "o", mitkCommandLineParser::OutputDirectory, "Output:", "output root", us::Any(), false); parser.addArgument("max_iter", "", mitkCommandLineParser::Int, "Max. iterations:", "maximum number of optimizer iterations", 20); parser.addArgument("bundle_based", "", mitkCommandLineParser::Bool, "Bundle based fit:", "fit one weight per input tractogram/bundle, not for each fiber", false); parser.addArgument("min_g", "", mitkCommandLineParser::Float, "Min. gradient:", "lower termination threshold for gradient magnitude", 1e-5); parser.addArgument("lambda", "", mitkCommandLineParser::Float, "Lambda:", "modifier for regularization", 0.1); parser.addArgument("filter_outliers", "", mitkCommandLineParser::Bool, "Filter outliers:", "perform second optimization run with an upper weight bound based on the first weight estimation (95% quantile)", true); map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; string fib_file = us::any_cast(parsedArgs["i1"]); string peak_file_name = us::any_cast(parsedArgs["i2"]); string candidate_folder = us::any_cast(parsedArgs["i3"]); string outRoot = us::any_cast(parsedArgs["o"]); bool single_fib = true; if (parsedArgs.count("bundle_based")) single_fib = !us::any_cast(parsedArgs["bundle_based"]); int max_iter = 20; if (parsedArgs.count("max_iter")) max_iter = us::any_cast(parsedArgs["max_iter"]); float g_tol = 1e-5; if (parsedArgs.count("min_g")) g_tol = us::any_cast(parsedArgs["min_g"]); float min_gain = 0.05; if (parsedArgs.count("min_gain")) min_gain = us::any_cast(parsedArgs["min_gain"]); float lambda = 0.1; if (parsedArgs.count("lambda")) lambda = us::any_cast(parsedArgs["lambda"]); bool filter_outliers = true; if (parsedArgs.count("filter_outliers")) filter_outliers = us::any_cast(parsedArgs["filter_outliers"]); try { mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Peak Image", "Fiberbundles"}, {}); mitk::Image::Pointer inputImage = dynamic_cast(mitk::IOUtil::Load(peak_file_name, &functor)[0].GetPointer()); typedef mitk::ImageToItk< PeakImgType > CasterType; CasterType::Pointer caster = CasterType::New(); caster->SetInput(inputImage); caster->Update(); PeakImgType::Pointer peak_image = caster->GetOutput(); std::vector< mitk::FiberBundle::Pointer > input_reference; mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(fib_file)[0].GetPointer()); if (fib.IsNull()) return EXIT_FAILURE; input_reference.push_back(fib); std::vector< mitk::FiberBundle::Pointer > input_candidates; std::vector< string > candidate_tract_files = get_file_list(candidate_folder); for (string f : candidate_tract_files) { mitk::FiberBundle::Pointer fib = dynamic_cast(mitk::IOUtil::Load(f)[0].GetPointer()); if (fib.IsNull()) continue; input_candidates.push_back(fib); } int iteration = 0; std::string name = ist::GetFilenameWithoutExtension(fib_file); itk::FitFibersToImageFilter::Pointer fitter = itk::FitFibersToImageFilter::New(); fitter->SetTractograms(input_reference); fitter->SetFitIndividualFibers(single_fib); fitter->SetMaxIterations(max_iter); fitter->SetGradientTolerance(g_tol); fitter->SetLambda(lambda); fitter->SetFilterOutliers(filter_outliers); fitter->SetPeakImage(peak_image); fitter->SetVerbose(false); fitter->SetDeepCopy(false); fitter->Update(); + + + fitter->GetTractograms().at(0)->SetFiberWeights(fitter->GetCoverage()); + fitter->GetTractograms().at(0)->ColorFibersByFiberWeights(false, false); + mitk::IOUtil::Save(fitter->GetTractograms().at(0), outRoot + "0_" + name + ".fib"); peak_image = fitter->GetUnderexplainedImage(); itk::ImageFileWriter< PeakImgType >::Pointer writer = itk::ImageFileWriter< PeakImgType >::New(); writer->SetInput(peak_image); writer->SetFileName(outRoot + boost::lexical_cast(iteration) + "_" + name + ".nrrd"); writer->Update(); double coverage = fitter->GetCoverage(); MITK_INFO << "Iteration: " << iteration; MITK_INFO << std::fixed << "Coverage: " << setprecision(1) << 100.0*coverage << "%"; // fitter->SetPeakImage(peak_image); while (!input_candidates.empty()) { streambuf *old = cout.rdbuf(); // <-- save stringstream ss; std::cout.rdbuf (ss.rdbuf()); // <-- redirect double next_coverage = 0; mitk::FiberBundle::Pointer best_candidate = nullptr; for (auto fib : input_candidates) { // WHY NECESSARY AGAIN?? itk::FitFibersToImageFilter::Pointer fitter = itk::FitFibersToImageFilter::New(); fitter->SetFitIndividualFibers(single_fib); fitter->SetMaxIterations(max_iter); fitter->SetGradientTolerance(g_tol); fitter->SetLambda(lambda); fitter->SetFilterOutliers(filter_outliers); fitter->SetVerbose(false); fitter->SetPeakImage(peak_image); fitter->SetDeepCopy(false); // ****************************** fitter->SetTractograms({fib}); fitter->Update(); double candidate_coverage = fitter->GetCoverage(); if (candidate_coverage>next_coverage) { next_coverage = candidate_coverage; if ((1.0-coverage) * next_coverage >= min_gain) { best_candidate = fitter->GetTractograms().at(0); peak_image = fitter->GetUnderexplainedImage(); } } } if (best_candidate.IsNull()) { std::cout.rdbuf (old); // <-- restore break; } // fitter->SetPeakImage(peak_image); + best_candidate->SetFiberWeights((1.0-coverage) * next_coverage); + best_candidate->ColorFibersByFiberWeights(false, false); + coverage += (1.0-coverage) * next_coverage; int i=0; std::vector< mitk::FiberBundle::Pointer > remaining_candidates; std::vector< string > remaining_candidate_files; for (auto fib : input_candidates) { if (fib!=best_candidate) { remaining_candidates.push_back(fib); remaining_candidate_files.push_back(candidate_tract_files.at(i)); } else name = ist::GetFilenameWithoutExtension(candidate_tract_files.at(i)); ++i; } input_candidates = remaining_candidates; candidate_tract_files = remaining_candidate_files; iteration++; mitk::IOUtil::Save(best_candidate, outRoot + boost::lexical_cast(iteration) + "_" + name + ".fib"); writer->SetInput(peak_image); writer->SetFileName(outRoot + boost::lexical_cast(iteration) + "_" + name + ".nrrd"); writer->Update(); std::cout.rdbuf (old); // <-- restore MITK_INFO << "Iteration: " << iteration; MITK_INFO << std::fixed << "Coverage: " << setprecision(1) << 100.0*coverage << "% (+" << 100*(1.0-coverage) * next_coverage << "%)"; } MITK_INFO << "DONE"; } catch (itk::ExceptionObject e) { std::cout << e; return EXIT_FAILURE; } catch (std::exception e) { std::cout << e.what(); return EXIT_FAILURE; } catch (...) { std::cout << "ERROR!?!"; return EXIT_FAILURE; } return EXIT_SUCCESS; }