diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp index 061b06623f..a313be04f9 100644 --- a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkFitFibersToImageFilter.cpp @@ -1,955 +1,971 @@ #include "itkFitFibersToImageFilter.h" #include namespace itk{ FitFibersToImageFilter::FitFibersToImageFilter() : m_PeakImage(nullptr) , m_DiffImage(nullptr) , m_ScalarImage(nullptr) , m_MaskImage(nullptr) , m_FitIndividualFibers(true) , m_GradientTolerance(1e-5) , m_Lambda(0.1) , m_MaxIterations(20) , m_FiberSampling(10) , m_Coverage(0) , m_Overshoot(0) , m_RMSE(0.0) , m_FilterOutliers(false) , m_MeanWeight(1.0) , m_MedianWeight(1.0) , m_MinWeight(1.0) , m_MaxWeight(1.0) , m_Verbose(true) , m_DeepCopy(true) , m_ResampleFibers(true) , m_NumUnknowns(0) , m_NumResiduals(0) , m_NumCoveredDirections(0) , m_SignalModel(nullptr) , sz_x(0) , sz_y(0) , sz_z(0) , m_MeanTractDensity(0) , m_MeanSignal(0) , fiber_count(0) , m_Regularization(VnlCostFunction::REGU::VOXEL_VARIANCE) { this->SetNumberOfRequiredOutputs(3); } FitFibersToImageFilter::~FitFibersToImageFilter() { } void FitFibersToImageFilter::CreateDiffSystem() { sz_x = m_DiffImage->GetLargestPossibleRegion().GetSize(0); sz_y = m_DiffImage->GetLargestPossibleRegion().GetSize(1); sz_z = m_DiffImage->GetLargestPossibleRegion().GetSize(2); dim_four_size = m_DiffImage->GetVectorLength(); int num_voxels = sz_x*sz_y*sz_z; float minSpacing = 1; if(m_DiffImage->GetSpacing()[0]GetSpacing()[1] && m_DiffImage->GetSpacing()[0]GetSpacing()[2]) minSpacing = m_DiffImage->GetSpacing()[0]; else if (m_DiffImage->GetSpacing()[1] < m_DiffImage->GetSpacing()[2]) minSpacing = m_DiffImage->GetSpacing()[1]; else minSpacing = m_DiffImage->GetSpacing()[2]; if (m_ResampleFibers) for (unsigned int bundle=0; bundleGetDeepCopy(); m_Tractograms.at(bundle)->ResampleLinear(minSpacing/m_FiberSampling); std::cout.rdbuf (old); } m_NumResiduals = num_voxels * dim_four_size; MITK_INFO << "Num. unknowns: " << m_NumUnknowns; MITK_INFO << "Num. residuals: " << m_NumResiduals; MITK_INFO << "Creating system ..."; A.set_size(m_NumResiduals, m_NumUnknowns); b.set_size(m_NumResiduals); b.fill(0.0); m_MeanTractDensity = 0; m_MeanSignal = 0; m_NumCoveredDirections = 0; fiber_count = 0; vnl_vector voxel_indicator; voxel_indicator.set_size(sz_x*sz_y*sz_z); voxel_indicator.fill(0); m_GroupSizes.clear(); for (unsigned int bundle=0; bundle polydata = m_Tractograms.at(bundle)->GetFiberPolyData(); m_GroupSizes.push_back(m_Tractograms.at(bundle)->GetNumFibers()); for (int i=0; iGetNumFibers(); ++i) { vtkCell* cell = polydata->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); if (numPoints<2) MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; for (int j=0; jGetPoint(j); PointType3 p; p[0]=p1[0]; p[1]=p1[1]; p[2]=p1[2]; itk::Index<3> idx3; m_DiffImage->TransformPhysicalPointToIndex(p, idx3); if (!m_DiffImage->GetLargestPossibleRegion().IsInside(idx3) || (m_MaskImage.IsNotNull() && m_MaskImage->GetPixel(idx3)==0)) continue; double* p2 = points->GetPoint(j+1); mitk::DiffusionSignalModel<>::GradientType fiber_dir; fiber_dir[0] = p[0]-p2[0]; fiber_dir[1] = p[1]-p2[1]; fiber_dir[2] = p[2]-p2[2]; fiber_dir.Normalize(); int x = idx3[0]; int y = idx3[1]; int z = idx3[2]; mitk::DiffusionSignalModel<>::PixelType simulated_pixel = m_SignalModel->SimulateMeasurement(fiber_dir); VectorImgType::PixelType measured_pixel = m_DiffImage->GetPixel(idx3); double simulated_mean = 0; double measured_mean = 0; int num_nonzero_g = 0; for (int g=0; gGetGradientDirection(g).GetNorm()GetLargestPossibleRegion().GetSize(0); sz_y = m_PeakImage->GetLargestPossibleRegion().GetSize(1); sz_z = m_PeakImage->GetLargestPossibleRegion().GetSize(2); dim_four_size = m_PeakImage->GetLargestPossibleRegion().GetSize(3)/3 + 1; // +1 for zero - peak int num_voxels = sz_x*sz_y*sz_z; float minSpacing = 1; if(m_PeakImage->GetSpacing()[0]GetSpacing()[1] && m_PeakImage->GetSpacing()[0]GetSpacing()[2]) minSpacing = m_PeakImage->GetSpacing()[0]; else if (m_PeakImage->GetSpacing()[1] < m_PeakImage->GetSpacing()[2]) minSpacing = m_PeakImage->GetSpacing()[1]; else minSpacing = m_PeakImage->GetSpacing()[2]; if (m_ResampleFibers) for (unsigned int bundle=0; bundleGetDeepCopy(); m_Tractograms.at(bundle)->ResampleLinear(minSpacing/m_FiberSampling); std::cout.rdbuf (old); } m_NumResiduals = num_voxels * dim_four_size; MITK_INFO << "Num. unknowns: " << m_NumUnknowns; MITK_INFO << "Num. residuals: " << m_NumResiduals; MITK_INFO << "Creating system ..."; A.set_size(m_NumResiduals, m_NumUnknowns); b.set_size(m_NumResiduals); b.fill(0.0); m_MeanTractDensity = 0; m_MeanSignal = 0; m_NumCoveredDirections = 0; fiber_count = 0; m_GroupSizes.clear(); for (unsigned int bundle=0; bundle polydata = m_Tractograms.at(bundle)->GetFiberPolyData(); m_GroupSizes.push_back(m_Tractograms.at(bundle)->GetNumFibers()); for (int i=0; iGetNumFibers(); ++i) { vtkCell* cell = polydata->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); if (numPoints<2) MITK_INFO << "FIBER WITH ONLY ONE POINT ENCOUNTERED!"; for (int j=0; jGetPoint(j); PointType4 p; p[0]=p1[0]; p[1]=p1[1]; p[2]=p1[2]; p[3]=0; itk::Index<4> idx4; m_PeakImage->TransformPhysicalPointToIndex(p, idx4); itk::Index<3> idx3; idx3[0] = idx4[0]; idx3[1] = idx4[1]; idx3[2] = idx4[2]; if (!m_PeakImage->GetLargestPossibleRegion().IsInside(idx4) || (m_MaskImage.IsNotNull() && m_MaskImage->GetPixel(idx3)==0)) continue; double* p2 = points->GetPoint(j+1); vnl_vector_fixed fiber_dir; fiber_dir[0] = p[0]-p2[0]; fiber_dir[1] = p[1]-p2[1]; fiber_dir[2] = p[2]-p2[2]; fiber_dir.normalize(); double w = 1; int peak_id = dim_four_size-1; double peak_mag = 0; GetClosestPeak(idx4, m_PeakImage, fiber_dir, peak_id, w, peak_mag); int x = idx4[0]; int y = idx4[1]; int z = idx4[2]; unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z + sz_x*sz_y*sz_z*peak_id; if (b[linear_index] == 0 && peak_idGetLargestPossibleRegion().GetSize(0); sz_y = m_ScalarImage->GetLargestPossibleRegion().GetSize(1); sz_z = m_ScalarImage->GetLargestPossibleRegion().GetSize(2); int num_voxels = sz_x*sz_y*sz_z; float minSpacing = 1; if(m_ScalarImage->GetSpacing()[0]GetSpacing()[1] && m_ScalarImage->GetSpacing()[0]GetSpacing()[2]) minSpacing = m_ScalarImage->GetSpacing()[0]; else if (m_ScalarImage->GetSpacing()[1] < m_ScalarImage->GetSpacing()[2]) minSpacing = m_ScalarImage->GetSpacing()[1]; else minSpacing = m_ScalarImage->GetSpacing()[2]; if (m_ResampleFibers) for (unsigned int bundle=0; bundleGetDeepCopy(); m_Tractograms.at(bundle)->ResampleLinear(minSpacing/m_FiberSampling); std::cout.rdbuf (old); } m_NumResiduals = num_voxels; MITK_INFO << "Num. unknowns: " << m_NumUnknowns; MITK_INFO << "Num. residuals: " << m_NumResiduals; MITK_INFO << "Creating system ..."; A.set_size(m_NumResiduals, m_NumUnknowns); b.set_size(m_NumResiduals); b.fill(0.0); m_MeanTractDensity = 0; m_MeanSignal = 0; int numCoveredVoxels = 0; fiber_count = 0; m_GroupSizes.clear(); for (unsigned int bundle=0; bundle polydata = m_Tractograms.at(bundle)->GetFiberPolyData(); m_GroupSizes.push_back(m_Tractograms.at(bundle)->GetNumFibers()); for (int i=0; iGetNumFibers(); ++i) { vtkCell* cell = polydata->GetCell(i); int numPoints = cell->GetNumberOfPoints(); vtkPoints* points = cell->GetPoints(); for (int j=0; jGetPoint(j); PointType3 p; p[0]=p1[0]; p[1]=p1[1]; p[2]=p1[2]; itk::Index<3> idx3; m_ScalarImage->TransformPhysicalPointToIndex(p, idx3); if (!m_ScalarImage->GetLargestPossibleRegion().IsInside(idx3) || (m_MaskImage.IsNotNull() && m_MaskImage->GetPixel(idx3)==0)) continue; float image_value = m_ScalarImage->GetPixel(idx3); int x = idx3[0]; int y = idx3[1]; int z = idx3[2]; unsigned int linear_index = x + sz_x*y + sz_x*sz_y*z; if (b[linear_index] == 0) { numCoveredVoxels++; m_MeanSignal += image_value; } m_MeanTractDensity += 1; if (m_FitIndividualFibers) { b[linear_index] = image_value; A.put(linear_index, fiber_count, A.get(linear_index, fiber_count) + 1); } else { b[linear_index] = image_value; A.put(linear_index, bundle, A.get(linear_index, bundle) + 1); } } ++fiber_count; } } m_MeanTractDensity /= (numCoveredVoxels*fiber_count); m_MeanSignal /= numCoveredVoxels; A /= m_MeanTractDensity; b *= 100.0/m_MeanSignal; // times 100 because we want to avoid too small values for computational reasons // NEW FIT // m_MeanTractDensity /= numCoveredVoxels; // m_MeanSignal /= numCoveredVoxels; // b /= m_MeanSignal; // b *= m_MeanTractDensity; } void FitFibersToImageFilter::GenerateData() { m_NumUnknowns = m_Tractograms.size(); if (m_FitIndividualFibers) { m_NumUnknowns = 0; for (unsigned int bundle=0; bundleGetNumFibers(); } else m_FilterOutliers = false; if (m_NumUnknowns<1) { MITK_INFO << "No fibers in tractogram."; return; } fiber_count = 0; sz_x = 0; sz_y = 0; sz_z = 0; m_MeanTractDensity = 0; m_MeanSignal = 0; if (m_PeakImage.IsNotNull()) CreatePeakSystem(); else if (m_DiffImage.IsNotNull()) CreateDiffSystem(); else if (m_ScalarImage.IsNotNull()) CreateScalarSystem(); else mitkThrow() << "No input image set!"; double init_lambda = fiber_count; // initialization for lambda estimation itk::TimeProbe clock; clock.Start(); cost = VnlCostFunction(m_NumUnknowns); cost.SetProblem(A, b, init_lambda, m_Regularization); cost.SetGroupSizes(m_GroupSizes); m_Weights.set_size(m_NumUnknowns); m_Weights.fill( 1.0/m_NumUnknowns ); vnl_lbfgsb minimizer(cost); vnl_vector l; l.set_size(m_NumUnknowns); l.fill(0); vnl_vector bound_selection; bound_selection.set_size(m_NumUnknowns); bound_selection.fill(1); minimizer.set_bound_selection(bound_selection); minimizer.set_lower_bound(l); minimizer.set_projected_gradient_tolerance(m_GradientTolerance); - MITK_INFO << "Regularization type: " << m_Regularization; + if (m_Regularization==VnlCostFunction::REGU::MSM) + MITK_INFO << "Regularization type: MSM"; + else if (m_Regularization==VnlCostFunction::REGU::VARIANCE) + MITK_INFO << "Regularization type: VARIANCE"; + else if (m_Regularization==VnlCostFunction::REGU::LASSO) + MITK_INFO << "Regularization type: LASSO"; + else if (m_Regularization==VnlCostFunction::REGU::VOXEL_VARIANCE) + MITK_INFO << "Regularization type: VOXEL_VARIANCE"; + else if (m_Regularization==VnlCostFunction::REGU::GROUP_LASSO) + MITK_INFO << "Regularization type: GROUP_LASSO"; + else if (m_Regularization==VnlCostFunction::REGU::GROUP_VARIANCE) + MITK_INFO << "Regularization type: GROUP_VARIANCE"; + else if (m_Regularization==VnlCostFunction::REGU::NONE) + MITK_INFO << "Regularization type: NONE"; + if (m_Regularization!=VnlCostFunction::REGU::NONE) // REMOVE FOR NEW FIT AND SET cost.m_Lambda = m_Lambda { MITK_INFO << "Estimating regularization"; minimizer.set_trace(false); minimizer.set_max_function_evals(2); minimizer.minimize(m_Weights); vnl_vector dx; dx.set_size(m_NumUnknowns); dx.fill(0.0); cost.calc_regularization_gradient(m_Weights, dx); if (m_Weights.magnitude()==0) { MITK_INFO << "Regularization estimation failed. Using default value."; cost.m_Lambda = fiber_count*m_Lambda; } else { double r = dx.magnitude()/m_Weights.magnitude(); // wtf??? cost.m_Lambda *= m_Lambda*55.0/r; MITK_INFO << r << " - " << m_Lambda*55.0/r; if (cost.m_Lambda>10e7) { MITK_INFO << "Regularization estimation failed. Using default value."; cost.m_Lambda = fiber_count*m_Lambda; } } } + else + cost.m_Lambda = 0; MITK_INFO << "Using regularization factor of " << cost.m_Lambda << " (λ: " << m_Lambda << ")"; MITK_INFO << "Fitting fibers"; minimizer.set_trace(m_Verbose); minimizer.set_max_function_evals(m_MaxIterations); minimizer.minimize(m_Weights); std::vector< double > weights; if (m_FilterOutliers) { for (auto w : m_Weights) weights.push_back(w); std::sort(weights.begin(), weights.end()); MITK_INFO << "Setting upper weight bound to " << weights.at(m_NumUnknowns*0.99); vnl_vector u; u.set_size(m_NumUnknowns); u.fill(weights.at(m_NumUnknowns*0.99)); minimizer.set_upper_bound(u); bound_selection.fill(2); minimizer.set_bound_selection(bound_selection); minimizer.minimize(m_Weights); weights.clear(); } for (auto w : m_Weights) weights.push_back(w); std::sort(weights.begin(), weights.end()); m_MeanWeight = m_Weights.mean(); m_MedianWeight = weights.at(m_NumUnknowns*0.5); m_MinWeight = weights.at(0); m_MaxWeight = weights.at(m_NumUnknowns-1); MITK_INFO << "*************************"; MITK_INFO << "Weight statistics"; MITK_INFO << "Sum: " << m_Weights.sum(); MITK_INFO << "Mean: " << m_MeanWeight; MITK_INFO << "1% quantile: " << weights.at(m_NumUnknowns*0.01); MITK_INFO << "5% quantile: " << weights.at(m_NumUnknowns*0.05); MITK_INFO << "25% quantile: " << weights.at(m_NumUnknowns*0.25); MITK_INFO << "Median: " << m_MedianWeight; MITK_INFO << "75% quantile: " << weights.at(m_NumUnknowns*0.75); MITK_INFO << "95% quantile: " << weights.at(m_NumUnknowns*0.95); MITK_INFO << "99% quantile: " << weights.at(m_NumUnknowns*0.99); MITK_INFO << "Min: " << m_MinWeight; MITK_INFO << "Max: " << m_MaxWeight; MITK_INFO << "*************************"; MITK_INFO << "NumEvals: " << minimizer.get_num_evaluations(); MITK_INFO << "NumIterations: " << minimizer.get_num_iterations(); MITK_INFO << "Residual cost: " << minimizer.get_end_error(); m_RMSE = cost.S->get_rms_error(m_Weights); MITK_INFO << "Final RMS: " << m_RMSE; clock.Stop(); int h = clock.GetTotal()/3600; int m = ((int)clock.GetTotal()%3600)/60; int s = (int)clock.GetTotal()%60; MITK_INFO << "Optimization took " << h << "h, " << m << "m and " << s << "s"; MITK_INFO << "Weighting fibers"; m_RmsDiffPerBundle.set_size(m_Tractograms.size()); std::streambuf *old = cout.rdbuf(); // <-- save std::stringstream ss; std::cout.rdbuf (ss.rdbuf()); if (m_FitIndividualFibers) { unsigned int fiber_count = 0; for (unsigned int bundle=0; bundle temp_weights; temp_weights.set_size(m_Weights.size()); temp_weights.copy_in(m_Weights.data_block()); for (int i=0; iGetNumFibers(); i++) { m_Tractograms.at(bundle)->SetFiberWeight(i, m_Weights[fiber_count]); temp_weights[fiber_count] = 0; ++fiber_count; } double d_rms = cost.S->get_rms_error(temp_weights) - m_RMSE; m_RmsDiffPerBundle[bundle] = d_rms; m_Tractograms.at(bundle)->Compress(0.1); m_Tractograms.at(bundle)->ColorFibersByFiberWeights(false, true); } } else { for (unsigned int i=0; i temp_weights; temp_weights.set_size(m_Weights.size()); temp_weights.copy_in(m_Weights.data_block()); temp_weights[i] = 0; double d_rms = cost.S->get_rms_error(temp_weights) - m_RMSE; m_RmsDiffPerBundle[i] = d_rms; m_Tractograms.at(i)->SetFiberWeights(m_Weights[i]); m_Tractograms.at(i)->Compress(0.1); m_Tractograms.at(i)->ColorFibersByFiberWeights(false, true); } } std::cout.rdbuf (old); // transform back A *= m_MeanSignal/100.0; b *= m_MeanSignal/100.0; MITK_INFO << "Generating output images ..."; if (m_PeakImage.IsNotNull()) GenerateOutputPeakImages(); else if (m_DiffImage.IsNotNull()) GenerateOutputDiffImages(); else if (m_ScalarImage.IsNotNull()) GenerateOutputScalarImages(); m_Coverage = m_Coverage/m_MeanSignal; m_Overshoot = m_Overshoot/m_MeanSignal; MITK_INFO << std::fixed << "Coverage: " << setprecision(2) << 100.0*m_Coverage << "%"; MITK_INFO << std::fixed << "Overshoot: " << setprecision(2) << 100.0*m_Overshoot << "%"; } void FitFibersToImageFilter::GenerateOutputDiffImages() { VectorImgType::PixelType pix; pix.SetSize(m_DiffImage->GetVectorLength()); pix.Fill(0); itk::ImageDuplicator< VectorImgType >::Pointer duplicator = itk::ImageDuplicator< VectorImgType >::New(); duplicator->SetInputImage(m_DiffImage); duplicator->Update(); m_UnderexplainedImageDiff = duplicator->GetOutput(); m_UnderexplainedImageDiff->FillBuffer(pix); duplicator->SetInputImage(m_UnderexplainedImageDiff); duplicator->Update(); m_OverexplainedImageDiff = duplicator->GetOutput(); m_OverexplainedImageDiff->FillBuffer(pix); duplicator->SetInputImage(m_OverexplainedImageDiff); duplicator->Update(); m_ResidualImageDiff = duplicator->GetOutput(); m_ResidualImageDiff->FillBuffer(pix); duplicator->SetInputImage(m_ResidualImageDiff); duplicator->Update(); m_FittedImageDiff = duplicator->GetOutput(); m_FittedImageDiff->FillBuffer(pix); vnl_vector fitted_b; fitted_b.set_size(b.size()); cost.S->multiply(m_Weights, fitted_b); itk::ImageRegionIterator it1 = itk::ImageRegionIterator(m_DiffImage, m_DiffImage->GetLargestPossibleRegion()); itk::ImageRegionIterator it2 = itk::ImageRegionIterator(m_FittedImageDiff, m_FittedImageDiff->GetLargestPossibleRegion()); itk::ImageRegionIterator it3 = itk::ImageRegionIterator(m_ResidualImageDiff, m_ResidualImageDiff->GetLargestPossibleRegion()); itk::ImageRegionIterator it4 = itk::ImageRegionIterator(m_UnderexplainedImageDiff, m_UnderexplainedImageDiff->GetLargestPossibleRegion()); itk::ImageRegionIterator it5 = itk::ImageRegionIterator(m_OverexplainedImageDiff, m_OverexplainedImageDiff->GetLargestPossibleRegion()); m_MeanSignal = 0; m_Coverage = 0; m_Overshoot = 0; while( !it2.IsAtEnd() ) { itk::Index<3> idx3 = it2.GetIndex(); VectorImgType::PixelType original_pix =it1.Get(); VectorImgType::PixelType fitted_pix =it2.Get(); VectorImgType::PixelType residual_pix =it3.Get(); VectorImgType::PixelType underexplained_pix =it4.Get(); VectorImgType::PixelType overexplained_pix =it5.Get(); int num_nonzero_g = 0; double original_mean = 0; for (int g=0; gGetGradientDirection(g).GetNorm()>=mitk::eps ) { original_mean += original_pix[g]; ++num_nonzero_g; } } original_mean /= num_nonzero_g; for (int g=0; g=0) { underexplained_pix[g] = residual_pix[g]; m_Coverage += fitted_b[linear_index] + original_mean; } m_MeanSignal += b[linear_index] + original_mean; } it2.Set(fitted_pix); it3.Set(residual_pix); it4.Set(underexplained_pix); it5.Set(overexplained_pix); ++it1; ++it2; ++it3; ++it4; ++it5; } } void FitFibersToImageFilter::GenerateOutputScalarImages() { itk::ImageDuplicator< DoubleImgType >::Pointer duplicator = itk::ImageDuplicator< DoubleImgType >::New(); duplicator->SetInputImage(m_ScalarImage); duplicator->Update(); m_UnderexplainedImageScalar = duplicator->GetOutput(); m_UnderexplainedImageScalar->FillBuffer(0); duplicator->SetInputImage(m_UnderexplainedImageScalar); duplicator->Update(); m_OverexplainedImageScalar = duplicator->GetOutput(); m_OverexplainedImageScalar->FillBuffer(0); duplicator->SetInputImage(m_OverexplainedImageScalar); duplicator->Update(); m_ResidualImageScalar = duplicator->GetOutput(); m_ResidualImageScalar->FillBuffer(0); duplicator->SetInputImage(m_ResidualImageScalar); duplicator->Update(); m_FittedImageScalar = duplicator->GetOutput(); m_FittedImageScalar->FillBuffer(0); vnl_vector fitted_b; fitted_b.set_size(b.size()); cost.S->multiply(m_Weights, fitted_b); itk::ImageRegionIterator it1 = itk::ImageRegionIterator(m_ScalarImage, m_ScalarImage->GetLargestPossibleRegion()); itk::ImageRegionIterator it2 = itk::ImageRegionIterator(m_FittedImageScalar, m_FittedImageScalar->GetLargestPossibleRegion()); itk::ImageRegionIterator it3 = itk::ImageRegionIterator(m_ResidualImageScalar, m_ResidualImageScalar->GetLargestPossibleRegion()); itk::ImageRegionIterator it4 = itk::ImageRegionIterator(m_UnderexplainedImageScalar, m_UnderexplainedImageScalar->GetLargestPossibleRegion()); itk::ImageRegionIterator it5 = itk::ImageRegionIterator(m_OverexplainedImageScalar, m_OverexplainedImageScalar->GetLargestPossibleRegion()); m_MeanSignal = 0; m_Coverage = 0; m_Overshoot = 0; while( !it2.IsAtEnd() ) { itk::Index<3> idx3 = it2.GetIndex(); DoubleImgType::PixelType original_pix =it1.Get(); DoubleImgType::PixelType fitted_pix =it2.Get(); DoubleImgType::PixelType residual_pix =it3.Get(); DoubleImgType::PixelType underexplained_pix =it4.Get(); DoubleImgType::PixelType overexplained_pix =it5.Get(); unsigned int linear_index = idx3[0] + sz_x*idx3[1] + sz_x*sz_y*idx3[2]; fitted_pix = fitted_b[linear_index]; residual_pix = original_pix - fitted_pix; if (residual_pix<0) { overexplained_pix = residual_pix; m_Coverage += b[linear_index]; m_Overshoot -= residual_pix; } else if (residual_pix>=0) { underexplained_pix = residual_pix; m_Coverage += fitted_b[linear_index]; } m_MeanSignal += b[linear_index]; it2.Set(fitted_pix); it3.Set(residual_pix); it4.Set(underexplained_pix); it5.Set(overexplained_pix); ++it1; ++it2; ++it3; ++it4; ++it5; } } VnlCostFunction::REGU FitFibersToImageFilter::GetRegularization() const { return m_Regularization; } void FitFibersToImageFilter::SetRegularization(const VnlCostFunction::REGU &Regularization) { m_Regularization = Regularization; } void FitFibersToImageFilter::GenerateOutputPeakImages() { itk::ImageDuplicator< PeakImgType >::Pointer duplicator = itk::ImageDuplicator< PeakImgType >::New(); duplicator->SetInputImage(m_PeakImage); duplicator->Update(); m_UnderexplainedImage = duplicator->GetOutput(); m_UnderexplainedImage->FillBuffer(0.0); duplicator->SetInputImage(m_UnderexplainedImage); duplicator->Update(); m_OverexplainedImage = duplicator->GetOutput(); m_OverexplainedImage->FillBuffer(0.0); duplicator->SetInputImage(m_OverexplainedImage); duplicator->Update(); m_ResidualImage = duplicator->GetOutput(); m_ResidualImage->FillBuffer(0.0); duplicator->SetInputImage(m_ResidualImage); duplicator->Update(); m_FittedImage = duplicator->GetOutput(); m_FittedImage->FillBuffer(0.0); vnl_vector fitted_b; fitted_b.set_size(b.size()); cost.S->multiply(m_Weights, fitted_b); for (unsigned int r=0; r idx4; unsigned int linear_index = r; idx4[0] = linear_index % sz_x; linear_index /= sz_x; idx4[1] = linear_index % sz_y; linear_index /= sz_y; idx4[2] = linear_index % sz_z; linear_index /= sz_z; int peak_id = linear_index % dim_four_size; if (peak_id peak_dir; idx4[3] = peak_id*3; peak_dir[0] = m_PeakImage->GetPixel(idx4); idx4[3] += 1; peak_dir[1] = m_PeakImage->GetPixel(idx4); idx4[3] += 1; peak_dir[2] = m_PeakImage->GetPixel(idx4); peak_dir.normalize(); peak_dir *= fitted_b[r]; idx4[3] = peak_id*3; m_FittedImage->SetPixel(idx4, peak_dir[0]); idx4[3] += 1; m_FittedImage->SetPixel(idx4, peak_dir[1]); idx4[3] += 1; m_FittedImage->SetPixel(idx4, peak_dir[2]); } } m_MeanSignal = 0; m_Coverage = 0; m_Overshoot = 0; itk::Index<4> idx4; for (idx4[0]=0; idx4[0] idx3; idx3[0] = idx4[0]; idx3[1] = idx4[1]; idx3[2] = idx4[2]; if (m_MaskImage.IsNotNull() && m_MaskImage->GetPixel(idx3)==0) continue; vnl_vector_fixed peak_dir; vnl_vector_fixed fitted_dir; vnl_vector_fixed overshoot_dir; for (idx4[3]=0; idx4[3]<(itk::IndexValueType)m_PeakImage->GetLargestPossibleRegion().GetSize(3); ++idx4[3]) { peak_dir[idx4[3]%3] = m_PeakImage->GetPixel(idx4); fitted_dir[idx4[3]%3] = m_FittedImage->GetPixel(idx4); m_ResidualImage->SetPixel(idx4, m_PeakImage->GetPixel(idx4) - m_FittedImage->GetPixel(idx4)); if (idx4[3]%3==2) { m_MeanSignal += peak_dir.magnitude(); itk::Index<4> tidx= idx4; if (peak_dir.magnitude()>fitted_dir.magnitude()) { m_Coverage += fitted_dir.magnitude(); m_UnderexplainedImage->SetPixel(tidx, peak_dir[2]-fitted_dir[2]); tidx[3]--; m_UnderexplainedImage->SetPixel(tidx, peak_dir[1]-fitted_dir[1]); tidx[3]--; m_UnderexplainedImage->SetPixel(tidx, peak_dir[0]-fitted_dir[0]); } else { overshoot_dir[0] = fitted_dir[0]-peak_dir[0]; overshoot_dir[1] = fitted_dir[1]-peak_dir[1]; overshoot_dir[2] = fitted_dir[2]-peak_dir[2]; m_Coverage += peak_dir.magnitude(); m_Overshoot += overshoot_dir.magnitude(); m_OverexplainedImage->SetPixel(tidx, overshoot_dir[2]); tidx[3]--; m_OverexplainedImage->SetPixel(tidx, overshoot_dir[1]); tidx[3]--; m_OverexplainedImage->SetPixel(tidx, overshoot_dir[0]); } } } } } void FitFibersToImageFilter::GetClosestPeak(itk::Index<4> idx, PeakImgType::Pointer peak_image , vnl_vector_fixed fiber_dir, int& id, double& w, double& peak_mag ) { int m_NumDirs = peak_image->GetLargestPossibleRegion().GetSize()[3]/3; vnl_vector_fixed out_dir; out_dir.fill(0); float angle = 0.9; for (int i=0; i dir; idx[3] = i*3; dir[0] = peak_image->GetPixel(idx); idx[3] += 1; dir[1] = peak_image->GetPixel(idx); idx[3] += 1; dir[2] = peak_image->GetPixel(idx); float mag = dir.magnitude(); if (magangle) { angle = a; w = angle; peak_mag = mag; id = i; } } } std::vector FitFibersToImageFilter::GetTractograms() const { return m_Tractograms; } void FitFibersToImageFilter::SetTractograms(const std::vector &tractograms) { m_Tractograms = tractograms; } void FitFibersToImageFilter::SetSignalModel(mitk::DiffusionSignalModel<> *SignalModel) { m_SignalModel = SignalModel; } } diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.cpp b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.cpp new file mode 100644 index 0000000000..8c179673b2 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.cpp @@ -0,0 +1,156 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#ifndef __itkTdiToVolumeFractionFilter_txx +#define __itkTdiToVolumeFractionFilter_txx + +#include "itkTdiToVolumeFractionFilter.h" +#include +#include +#include +#include + +namespace itk { + +template< class TPixelType > +TdiToVolumeFractionFilter< TPixelType >::TdiToVolumeFractionFilter() + : m_TdiThreshold(0.3) + , m_Sqrt(1.0) +{ + this->SetNumberOfRequiredInputs(5); + this->SetNumberOfRequiredOutputs(4); +} + +template< class TPixelType > +void TdiToVolumeFractionFilter< TPixelType >::BeforeThreadedGenerateData() +{ + typename InputImageType::Pointer input_tdi = static_cast< InputImageType * >( this->ProcessObject::GetInput(0) ); + + for (int i=0; i<4; ++i) + { + typename OutputImageType::Pointer outputImage = OutputImageType::New(); + outputImage->SetOrigin( input_tdi->GetOrigin() ); + outputImage->SetRegions( input_tdi->GetLargestPossibleRegion() ); + outputImage->SetSpacing( input_tdi->GetSpacing() ); + outputImage->SetDirection( input_tdi->GetDirection() ); + outputImage->Allocate(); + outputImage->FillBuffer(0.0); + this->SetNthOutput(i, outputImage); + } +} + +template< class TPixelType > +void TdiToVolumeFractionFilter< TPixelType >::ThreadedGenerateData(const OutputImageRegionType& outputRegionForThread, ThreadIdType) +{ + typename OutputImageType::Pointer o_intra_ax = static_cast< OutputImageType * >(this->ProcessObject::GetOutput(0)); + typename OutputImageType::Pointer o_inter_ax = static_cast< OutputImageType * >(this->ProcessObject::GetOutput(1)); + typename OutputImageType::Pointer o_gray_matter = static_cast< OutputImageType * >(this->ProcessObject::GetOutput(2)); + typename OutputImageType::Pointer o_csf = static_cast< OutputImageType * >(this->ProcessObject::GetOutput(3)); + + ImageRegionIterator< OutputImageType > oit0(o_intra_ax, outputRegionForThread); + ImageRegionIterator< OutputImageType > oit1(o_inter_ax, outputRegionForThread); + ImageRegionIterator< OutputImageType > oit2(o_gray_matter, outputRegionForThread); + ImageRegionIterator< OutputImageType > oit3(o_csf, outputRegionForThread); + + typename InputImageType::Pointer i_wm = static_cast< InputImageType * >(this->ProcessObject::GetInput(1)); + typename InputImageType::Pointer i_gm = static_cast< InputImageType * >(this->ProcessObject::GetInput(2)); + typename InputImageType::Pointer i_dgm = static_cast< InputImageType * >(this->ProcessObject::GetInput(3)); + typename InputImageType::Pointer i_csf = static_cast< InputImageType * >(this->ProcessObject::GetInput(4)); + + ImageRegionIterator< OutputImageType > iit0(i_wm, outputRegionForThread); + ImageRegionIterator< OutputImageType > iit1(i_gm, outputRegionForThread); + ImageRegionIterator< OutputImageType > iit2(i_dgm, outputRegionForThread); + ImageRegionIterator< OutputImageType > iit3(i_csf, outputRegionForThread); + + typename InputImageType::Pointer input_tdi = static_cast< InputImageType * >( this->ProcessObject::GetInput(0) ); + ImageRegionConstIterator< InputImageType > tdi_it(input_tdi, outputRegionForThread); + + while( !tdi_it.IsAtEnd() ) + { + TPixelType intra_ax_val = tdi_it.Get(); + intra_ax_val = std::pow(intra_ax_val, 1.0/m_Sqrt); + if (intra_ax_val>m_TdiThreshold) + intra_ax_val = m_TdiThreshold; + intra_ax_val /= m_TdiThreshold; + if (intra_ax_val<0) + intra_ax_val = 0; + + TPixelType inter_ax_val = 0.0; + TPixelType wm_val = iit0.Get(); + TPixelType gm_val = iit1.Get(); + TPixelType dgm_val = iit2.Get(); + TPixelType csf_val = iit3.Get(); + + if (dgm_val>0) + { + TPixelType ex = 1.0 - intra_ax_val; + gm_val += dgm_val * ex; + } + + if (intra_ax_val < 0.0001 && wm_val > 0.0001) + { + if (csf_val+gm_val == 0 ) + gm_val += wm_val; + else + { + if (csf_val>gm_val) + csf_val += wm_val; + else + gm_val += wm_val; + } + } + else if ( intra_ax_val>0.0001 ) + { + auto comp_sum = csf_val+gm_val+intra_ax_val; + if (comp_sum<1.0) + inter_ax_val = 1.0 - comp_sum; + else if (comp_sum>1) + intra_ax_val = 1.0 - (csf_val+gm_val); + if (intra_ax_val<-0.01) + MITK_INFO << "Corrupted volume fraction. intra_ax_val=" << intra_ax_val; + if (intra_ax_val<0) + intra_ax_val = 0; + } + auto comp_sum = csf_val+gm_val+intra_ax_val+inter_ax_val; + if (std::fabs(comp_sum-1.0)>0.00001 && comp_sum > 0.0001) + { + csf_val /= comp_sum; + gm_val /= comp_sum; + intra_ax_val /= comp_sum; + inter_ax_val /= comp_sum; + } + + oit0.Set(intra_ax_val); + oit1.Set(inter_ax_val); + oit2.Set(gm_val); + oit3.Set(csf_val); + + ++tdi_it; + + ++iit0; + ++iit1; + ++iit2; + ++iit3; + + ++oit0; + ++oit1; + ++oit2; + ++oit3; + } +} + +} +#endif diff --git a/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.h b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.h new file mode 100644 index 0000000000..b2c1caf944 --- /dev/null +++ b/Modules/DiffusionImaging/FiberTracking/Algorithms/itkTdiToVolumeFractionFilter.h @@ -0,0 +1,78 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +/*=================================================================== + +This file is based heavily on a corresponding ITK filter. + +===================================================================*/ +#ifndef __itkTdiToVolumeFractionFilter_h_ +#define __itkTdiToVolumeFractionFilter_h_ + +#include +#include + +namespace itk{ + +/** +* \brief */ + +template< class TPixelType > +class TdiToVolumeFractionFilter : public ImageToImageFilter< Image< TPixelType, 3 >, Image< TPixelType, 3 > > +{ + +public: + + typedef TdiToVolumeFractionFilter Self; + typedef SmartPointer Pointer; + typedef SmartPointer ConstPointer; + typedef ImageToImageFilter< Image< TPixelType, 3 >, Image< TPixelType, 3 > > Superclass; + + /** Method for creation through the object factory. */ + itkFactorylessNewMacro(Self) + itkCloneMacro(Self) + + /** Runtime information support. */ + itkTypeMacro(TdiToVolumeFractionFilter, ImageToImageFilter) + + typedef typename Superclass::InputImageType InputImageType; + typedef typename Superclass::OutputImageType OutputImageType; + typedef typename Superclass::OutputImageRegionType OutputImageRegionType; + + itkSetMacro( TdiThreshold, TPixelType ) + itkSetMacro( Sqrt, TPixelType ) + +protected: + TdiToVolumeFractionFilter(); + ~TdiToVolumeFractionFilter() override {} + + void BeforeThreadedGenerateData() override; + void ThreadedGenerateData( const OutputImageRegionType &outputRegionForThread, ThreadIdType threadId) override; + +private: + + TPixelType m_TdiThreshold; + TPixelType m_Sqrt; +}; + +} + +#ifndef ITK_MANUAL_INSTANTIATION +#include "itkTdiToVolumeFractionFilter.cpp" +#endif + +#endif //__itkTdiToVolumeFractionFilter_h_ + diff --git a/Modules/DiffusionImaging/FiberTracking/cmdapps/Fiberfox/FiberfoxOptimization.cpp b/Modules/DiffusionImaging/FiberTracking/cmdapps/Fiberfox/FiberfoxOptimization.cpp index f5d4e8cb82..52d2f8be9b 100755 --- a/Modules/DiffusionImaging/FiberTracking/cmdapps/Fiberfox/FiberfoxOptimization.cpp +++ b/Modules/DiffusionImaging/FiberTracking/cmdapps/Fiberfox/FiberfoxOptimization.cpp @@ -1,616 +1,831 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include #include #include #include #include #include #include +#include +#include using namespace mitk; -double CalcErrorSignal(itk::VectorImage< short, 3 >* reference, itk::VectorImage< short, 3 >* simulation, itk::Image< unsigned char,3 >::Pointer mask) +double CalcErrorSignal(const std::vector& histo_mod, itk::VectorImage< short, 3 >* reference, itk::VectorImage< short, 3 >* simulation, itk::Image< unsigned char,3 >::Pointer mask, itk::Image< double,3 >::Pointer fa) { + typedef itk::Image< double, 3 > DoubleImageType; typedef itk::VectorImage< short, 3 > DwiImageType; - try + if (fa.IsNotNull()) + { + itk::ImageRegionIterator< DwiImageType > it1(reference, reference->GetLargestPossibleRegion()); + itk::ImageRegionIterator< DwiImageType > it2(simulation, simulation->GetLargestPossibleRegion()); + itk::ImageRegionConstIterator< DoubleImageType > it3(fa, fa->GetLargestPossibleRegion()); + + unsigned int count = 0; + double error = 0; + while(!it1.IsAtEnd()) + { + if (mask.IsNull() || (mask.IsNotNull() && mask->GetLargestPossibleRegion().IsInside(it1.GetIndex()) && mask->GetPixel(it1.GetIndex())>0) ) + { + double fa = it3.Get(); + if (fa>0) + { + double mod = 1.0; + for (int i=histo_mod.size()-1; i>=0; --i) + if (fa >= (double)i/histo_mod.size()) + { + mod = histo_mod.at(i); + break; + } + + for (unsigned int i=0; iGetVectorLength(); ++i) + { + if (it1.Get()[i]>0) + { + double diff = (double)it2.Get()[i]/it1.Get()[i] - 1.0; + error += std::pow(mod, 4) * fabs(diff); + count++; + } + } + } + } + ++it1; + ++it2; + ++it3; + } + return error/count; + } + else { itk::ImageRegionIterator< DwiImageType > it1(reference, reference->GetLargestPossibleRegion()); itk::ImageRegionIterator< DwiImageType > it2(simulation, simulation->GetLargestPossibleRegion()); unsigned int count = 0; double error = 0; while(!it1.IsAtEnd()) { if (mask.IsNull() || (mask.IsNotNull() && mask->GetLargestPossibleRegion().IsInside(it1.GetIndex()) && mask->GetPixel(it1.GetIndex())>0) ) { for (unsigned int i=0; iGetVectorLength(); ++i) { if (it1.Get()[i]>0) { double diff = (double)it2.Get()[i]/it1.Get()[i] - 1.0; error += fabs(diff); count++; } } } ++it1; ++it2; } return error/count; } - catch(...) - { - return -1; - } return -1; } -double CalcErrorFA(const std::vector& histo_mod, mitk::Image::Pointer dwi1, itk::VectorImage< short, 3 >* dwi2, itk::Image< unsigned char,3 >::Pointer mask, itk::Image< double,3 >::Pointer fa1, itk::Image< double,3 >::Pointer md1) +double CalcErrorFA(const std::vector& histo_mod, mitk::Image::Pointer dwi1, itk::VectorImage< short, 3 >* dwi2, itk::Image< unsigned char,3 >::Pointer mask, itk::Image< double,3 >::Pointer fa1, itk::Image< double,3 >::Pointer md1, bool b0_contrast) { typedef itk::TensorDerivedMeasurementsFilter MeasurementsType; typedef itk::Image< double, 3 > DoubleImageType; + typedef itk::VectorImage< short, 3 > DwiType; + DwiType::Pointer dwi1_itk = mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi1); typedef itk::DiffusionTensor3DReconstructionImageFilter TensorReconstructionImageFilterType; DoubleImageType::Pointer fa2; { mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::Pointer gradientContainerCopy = mitk::DiffusionPropertyHelper::GradientDirectionsContainerType::New(); for(auto it = mitk::DiffusionPropertyHelper::GetGradientContainer(dwi1)->Begin(); it != mitk::DiffusionPropertyHelper::GetGradientContainer(dwi1)->End(); it++) gradientContainerCopy->push_back(it.Value()); TensorReconstructionImageFilterType::Pointer tensorReconstructionFilter = TensorReconstructionImageFilterType::New(); tensorReconstructionFilter->SetBValue( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi1) ); tensorReconstructionFilter->SetGradientImage(gradientContainerCopy, dwi2 ); tensorReconstructionFilter->Update(); MeasurementsType::Pointer measurementsCalculator = MeasurementsType::New(); measurementsCalculator->SetInput( tensorReconstructionFilter->GetOutput() ); measurementsCalculator->SetMeasure(MeasurementsType::FA); measurementsCalculator->Update(); fa2 = measurementsCalculator->GetOutput(); } DoubleImageType::Pointer md2; if (md1.IsNotNull()) { typedef itk::AdcImageFilter< short, double > AdcFilterType; AdcFilterType::Pointer filter = AdcFilterType::New(); filter->SetInput( dwi2 ); filter->SetGradientDirections( mitk::DiffusionPropertyHelper::GetGradientContainer(dwi1) ); filter->SetB_value( mitk::DiffusionPropertyHelper::GetReferenceBValue(dwi1) ); filter->SetFitSignal(false); filter->Update(); md2 = filter->GetOutput(); } itk::ImageRegionConstIterator< DoubleImageType > it1(fa1, fa1->GetLargestPossibleRegion()); itk::ImageRegionConstIterator< DoubleImageType > it2(fa2, fa2->GetLargestPossibleRegion()); + itk::ImageRegionConstIterator< itk::VectorImage< short, 3 > > it_diff1(dwi1_itk, dwi1_itk->GetLargestPossibleRegion()); + itk::ImageRegionConstIterator< itk::VectorImage< short, 3 > > it_diff2(dwi2, dwi2->GetLargestPossibleRegion()); + unsigned int count = 0; double error = 0; if (md1.IsNotNull() && md2.IsNotNull()) { itk::ImageRegionConstIterator< DoubleImageType > it3(md1, md1->GetLargestPossibleRegion()); itk::ImageRegionConstIterator< DoubleImageType > it4(md2, md2->GetLargestPossibleRegion()); while(!it1.IsAtEnd()) { if (mask.IsNull() || (mask.IsNotNull() && mask->GetLargestPossibleRegion().IsInside(it1.GetIndex()) && mask->GetPixel(it1.GetIndex())>0) ) { double fa = it1.Get(); if (fa>0 && it3.Get()>0) { double mod = 1.0; for (int i=histo_mod.size()-1; i>=0; --i) if (fa >= (double)i/histo_mod.size()) { mod = histo_mod.at(i); break; } - double fa_diff = fabs(it2.Get()/fa - 1.0); - double md_diff = fabs(it4.Get()/it3.Get() - 1.0); - error += mod*mod*mod*mod * (fa_diff + md_diff); + double fa_diff = std::fabs(it2.Get()/fa - 1.0); + double md_diff = std::fabs(it4.Get()/it3.Get() - 1.0); + error += std::pow(mod, 4) * (fa_diff + md_diff); count += 2; + + if (b0_contrast && it_diff1.Get()[0]>0) + { + double b0_diff = (double)it_diff2.Get()[0]/it_diff1.Get()[0] - 1.0; + error += std::pow(mod, 4) * std::fabs(b0_diff); + ++count; + } } } ++it1; ++it2; ++it3; ++it4; + ++it_diff1; + ++it_diff2; } } else { unsigned int count = 0; double error = 0; while(!it1.IsAtEnd()) { if (mask.IsNull() || (mask.IsNotNull() && mask->GetLargestPossibleRegion().IsInside(it1.GetIndex()) && mask->GetPixel(it1.GetIndex())>0) ) { double fa = it1.Get(); if (fa>0) { double mod = 1.0; for (int i=histo_mod.size()-1; i>=0; --i) if (fa >= (double)i/histo_mod.size()) { mod = histo_mod.at(i); break; } double fa_diff = fabs(it2.Get()/fa - 1.0); - error += mod * fa_diff; + error += std::pow(mod, 4) * fa_diff; ++count; + + if (b0_contrast && it_diff1.Get()[0]>0) + { + double b0_diff = (double)it_diff2.Get()[0]/it_diff1.Get()[0] - 1.0; + error += std::pow(mod, 4) * std::fabs(b0_diff); + ++count; + } } } ++it1; ++it2; + ++it_diff1; + ++it_diff2; } } return error/count; } -FiberfoxParameters MakeProposalRelaxation(FiberfoxParameters old_params, double temperature) +FiberfoxParameters MakeProposalScale(FiberfoxParameters old_params, double temperature) { + FiberfoxParameters new_params(old_params); std::random_device r; std::default_random_engine randgen(r()); - std::uniform_int_distribution uint1(0, 4); + std::normal_distribution normal_dist(0, new_params.m_SignalGen.m_SignalScale*0.1*temperature); + double add = 0; + while (add == 0) + add = normal_dist(randgen); + new_params.m_SignalGen.m_SignalScale += add; + MITK_INFO << "Proposal Signal Scale: " << new_params.m_SignalGen.m_SignalScale << " (" << add << ")"; + return new_params; +} +FiberfoxParameters MakeProposalRelaxation(FiberfoxParameters old_params, double temperature) +{ FiberfoxParameters new_params(old_params); + std::random_device r; + std::default_random_engine randgen(r()); + std::uniform_int_distribution uint1(0, 3); int prop = uint1(randgen); switch(prop) { case 0: - { - std::normal_distribution normal_dist(0, new_params.m_SignalGen.m_SignalScale*0.1*temperature); - double add = 0; - while (add == 0) - add = normal_dist(randgen); - new_params.m_SignalGen.m_SignalScale += add; - MITK_INFO << "Proposal Signal Scale: " << new_params.m_SignalGen.m_SignalScale << " (" << add << ")"; - break; - } - case 1: { int model_index = rand()%new_params.m_NonFiberModelList.size(); double t2 = new_params.m_NonFiberModelList[model_index]->GetT2(); std::normal_distribution normal_dist(0, t2*0.1*temperature); double add = 0; while (add == 0) add = normal_dist(randgen); if ( (t2+add)*1.5 > new_params.m_NonFiberModelList[model_index]->GetT1() ) add = -add; t2 += add; new_params.m_NonFiberModelList[model_index]->SetT2(t2); MITK_INFO << "Proposal T2 (Non-Fiber " << model_index << "): " << t2 << " (" << add << ")"; break; } - case 2: + case 1: { int model_index = rand()%new_params.m_FiberModelList.size(); double t2 = new_params.m_FiberModelList[model_index]->GetT2(); std::normal_distribution normal_dist(0, t2*0.1*temperature); double add = 0; while (add == 0) add = normal_dist(randgen); if ( (t2+add)*1.5 > new_params.m_FiberModelList[model_index]->GetT1() ) add = -add; t2 += add; new_params.m_FiberModelList[model_index]->SetT2(t2); MITK_INFO << "Proposal T2 (Fiber " << model_index << "): " << t2 << " (" << add << ")"; break; } - case 3: + case 2: { int model_index = rand()%new_params.m_NonFiberModelList.size(); double t1 = new_params.m_NonFiberModelList[model_index]->GetT1(); std::normal_distribution normal_dist(0, t1*0.1*temperature); double add = 0; while (add == 0) add = normal_dist(randgen); if ( t1+add < new_params.m_NonFiberModelList[model_index]->GetT2() * 1.5 ) add = -add; t1 += add; new_params.m_NonFiberModelList[model_index]->SetT1(t1); MITK_INFO << "Proposal T1 (Non-Fiber " << model_index << "): " << t1 << " (" << add << ")"; break; } - case 4: + case 3: { int model_index = rand()%new_params.m_FiberModelList.size(); double t1 = new_params.m_FiberModelList[model_index]->GetT1(); std::normal_distribution normal_dist(0, t1*0.1*temperature); double add = 0; while (add == 0) add = normal_dist(randgen); if ( t1+add < new_params.m_FiberModelList[model_index]->GetT2() * 1.5 ) add = -add; t1 += add; new_params.m_FiberModelList[model_index]->SetT1(t1); MITK_INFO << "Proposal T1 (Fiber " << model_index << "): " << t1 << " (" << add << ")"; break; } } return new_params; } double UpdateDiffusivity(double d, double temperature) { std::random_device r; std::default_random_engine randgen(r()); std::normal_distribution normal_dist(0, d*0.1*temperature); double add = 0; while (add == 0) add = normal_dist(randgen); if (d+add > 0.0025) d -= add; else if ( d+add < 0.0 ) d -= add; else d += add; return d; } void ProposeDiffusivities(mitk::DiffusionSignalModel<>* signalModel, double temperature) { if (dynamic_cast*>(signalModel)) { mitk::StickModel<>* m = dynamic_cast*>(signalModel); double new_d = UpdateDiffusivity(m->GetDiffusivity(), temperature); MITK_INFO << "d: " << new_d << " (" << new_d-m->GetDiffusivity() << ")"; m->SetDiffusivity(new_d); } else if (dynamic_cast*>(signalModel)) { mitk::TensorModel<>* m = dynamic_cast*>(signalModel); double new_d1 = UpdateDiffusivity(m->GetDiffusivity1(), temperature); double new_d2 = UpdateDiffusivity(m->GetDiffusivity2(), temperature); while (new_d1GetDiffusivity2(), temperature); MITK_INFO << "d1: " << new_d1 << " (" << new_d1-m->GetDiffusivity1() << ")"; MITK_INFO << "d2: " << new_d2 << " (" << new_d2-m->GetDiffusivity2() << ")"; m->SetDiffusivity1(new_d1); m->SetDiffusivity2(new_d2); m->SetDiffusivity3(new_d2); } else if (dynamic_cast*>(signalModel)) { mitk::BallModel<>* m = dynamic_cast*>(signalModel); double new_d = UpdateDiffusivity(m->GetDiffusivity(), temperature); MITK_INFO << "d: " << new_d << " (" << new_d-m->GetDiffusivity() << ")"; m->SetDiffusivity(new_d); } else if (dynamic_cast*>(signalModel)) { mitk::AstroStickModel<>* m = dynamic_cast*>(signalModel); double new_d = UpdateDiffusivity(m->GetDiffusivity(), temperature); MITK_INFO << "d: " << new_d << " (" << new_d-m->GetDiffusivity() << ")"; m->SetDiffusivity(new_d); } } FiberfoxParameters MakeProposalDiff(FiberfoxParameters old_params, double temperature) { FiberfoxParameters new_params(old_params); std::random_device r; std::default_random_engine randgen(r()); std::uniform_int_distribution uint1(0, new_params.m_NonFiberModelList.size() + new_params.m_FiberModelList.size() - 1); unsigned int prop = uint1(randgen); if (prop::Pointer > frac, double temperature) +{ + FiberfoxParameters new_params(old_params); + + MITK_INFO << "Proposal Volume"; + std::random_device r; + std::default_random_engine randgen(r()); + + { + std::normal_distribution normal_dist(0, old_tdi_thr*0.1*temperature); + new_tdi_thr = old_tdi_thr + normal_dist(randgen); + while (new_tdi_thr<=0.01) + new_tdi_thr = old_tdi_thr + normal_dist(randgen); + } + { + std::normal_distribution normal_dist(0, old_sqrt*0.1*temperature); + new_sqrt = old_sqrt + normal_dist(randgen); + while (new_sqrt<=0.01) + new_sqrt = old_sqrt + normal_dist(randgen); + } + + itk::TdiToVolumeFractionFilter< double >::Pointer fraction_generator = itk::TdiToVolumeFractionFilter< double >::New(); + fraction_generator->SetTdiThreshold(new_tdi_thr); + fraction_generator->SetSqrt(new_sqrt); + fraction_generator->SetInput(0, frac.at(0)); + fraction_generator->SetInput(1, frac.at(1)); + fraction_generator->SetInput(2, frac.at(2)); + fraction_generator->SetInput(3, frac.at(3)); + fraction_generator->SetInput(4, frac.at(4)); + fraction_generator->Update(); + + new_params.m_FiberModelList[0]->SetVolumeFractionImage(fraction_generator->GetOutput(0)); + new_params.m_FiberModelList[1]->SetVolumeFractionImage(fraction_generator->GetOutput(1)); + new_params.m_NonFiberModelList[0]->SetVolumeFractionImage(fraction_generator->GetOutput(2)); + new_params.m_NonFiberModelList[1]->SetVolumeFractionImage(fraction_generator->GetOutput(3)); + + MITK_INFO << "TDI Threshold: " << new_tdi_thr << " (" << new_tdi_thr-old_tdi_thr << ")"; + MITK_INFO << "SQRT: " << new_sqrt << " (" << new_sqrt-old_sqrt << ")"; + + return new_params; +} + /*! * \brief Command line interface to optimize Fiberfox parameters. */ int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setTitle("FiberfoxOptimization"); parser.setCategory("Optimize Fiberfox Parameters"); parser.setContributor("MIC"); parser.setArgumentPrefix("--", "-"); - parser.addArgument("parameters", "p", mitkCommandLineParser::InputFile, "Parameter file:", "fiberfox parameter file (.ffp)", us::Any(), false); - parser.addArgument("input", "i", mitkCommandLineParser::String, "Input:", "Input tractogram or diffusion-weighted image.", us::Any(), false); - parser.addArgument("target", "t", mitkCommandLineParser::String, "Target image:", "Approximate target dMRI.", us::Any(), false); - parser.addArgument("mask", "", mitkCommandLineParser::InputFile, "Mask image:", "Error is only calculated inside the mask image", false); - parser.addArgument("fa_image", "", mitkCommandLineParser::InputFile, "FA image", "Optimize FA instead of raw signal"); - parser.addArgument("md_image", "", mitkCommandLineParser::InputFile, "MD image", "Optimize MD in conjunction with FA (recommended when optimizing FA)"); + parser.beginGroup("1. Mandatory Input:"); + parser.addArgument("parameters", "p", mitkCommandLineParser::InputFile, "Parameter File:", "fiberfox parameter file (.ffp)", us::Any(), false); + parser.addArgument("tracts", "t", mitkCommandLineParser::String, "Input Tractogram:", "Input tractogram.", us::Any(), false); + parser.addArgument("out_folder", "o", mitkCommandLineParser::String, "Output Folder:", "", us::Any(), false); + parser.addArgument("dmri", "d", mitkCommandLineParser::String, "Target image:", "Target dMRI to approximate.", us::Any(), false); + parser.addArgument("mask", "", mitkCommandLineParser::InputFile, "Mask image:", "Error is only calculated inside the mask image", false); + parser.endGroup(); + parser.beginGroup("2. Parameters to optimize:"); parser.addArgument("no_diff", "", mitkCommandLineParser::Bool, "Don't optimize diffusivities:", "Don't optimize diffusivities"); - parser.addArgument("no_relax", "", mitkCommandLineParser::Bool, "Don't optimize relaxation times and signal scale:", "Don't optimize relaxation times and signal scale"); - + parser.addArgument("no_relax", "", mitkCommandLineParser::Bool, "Don't optimize relaxation times:", "Don't optimize relaxation times"); + parser.addArgument("no_scale", "", mitkCommandLineParser::Bool, "Don't optimize signal scale:", "Don't optimize global signal scale"); + parser.endGroup(); + + parser.beginGroup("3. Error measure:"); + parser.addArgument("fa_error", "", mitkCommandLineParser::Bool, "Optimize FA", "Optimize FA instead of raw signal. Requires FA image."); + parser.addArgument("fa_image", "", mitkCommandLineParser::InputFile, "FA image:", "Weight error by FA histogram. Always necessary with option fa_error!"); + parser.addArgument("md_image", "", mitkCommandLineParser::InputFile, "MD image:", "Optimize MD in conjunction with FA (recommended when optimizing FA)."); + parser.endGroup(); + + parser.beginGroup("4. Optimization of volume fraction maps:"); + parser.addArgument("tdi", "", mitkCommandLineParser::InputFile, "TDI:", "tract density image"); + parser.addArgument("wm", "", mitkCommandLineParser::InputFile, "WM:", "white matter volume fraction image"); + parser.addArgument("gm", "", mitkCommandLineParser::InputFile, "GM:", "gray matter volume fraction image"); + parser.addArgument("dgm", "", mitkCommandLineParser::InputFile, "DGM:", "subcortical gray matter volume fraction image"); + parser.addArgument("csf", "", mitkCommandLineParser::InputFile, "CSF:", "CSF volume fraction image"); + parser.addArgument("tdi_threshold", "", mitkCommandLineParser::Float, "", "", 0.75); + parser.addArgument("sqrt", "", mitkCommandLineParser::Float, "", "", 1.0); + parser.endGroup(); + + parser.beginGroup("5. General parameters:"); parser.addArgument("iterations", "", mitkCommandLineParser::Int, "Iterations:", "Number of optimizations steps", 1000); - parser.addArgument("start_temp", "", mitkCommandLineParser::Float, "Start temperature:", "", 1.0); - parser.addArgument("end_temp", "", mitkCommandLineParser::Float, "End temperature:", "", 0.1); + parser.addArgument("start_temp", "", mitkCommandLineParser::Float, "Start temperature:", "Higher temperature means larger parameter change proposals", 1.0); + parser.addArgument("end_temp", "", mitkCommandLineParser::Float, "End temperature:", "Higher temperature means larger parameter change proposals", 0.1); + parser.endGroup(); std::map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) return EXIT_FAILURE; std::string paramName = us::any_cast(parsedArgs["parameters"]); + std::string out_folder = us::any_cast(parsedArgs["out_folder"]); + std::string tract_file = us::any_cast(parsedArgs["tracts"]); + + MITK_INFO << "Loading target dMRI and parameters"; + FiberfoxParameters parameters; + parameters.LoadParameters(paramName); + typedef itk::VectorImage< short, 3 > ItkDwiType; + mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Diffusion Weighted Images", "Fiberbundles"}, {}); + mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::Load(us::any_cast(parsedArgs["dmri"]), &functor)[0].GetPointer()); + ItkDwiType::Pointer reference = mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi); + parameters.m_SignalGen.m_ImageRegion = reference->GetLargestPossibleRegion(); + parameters.m_SignalGen.m_ImageSpacing = reference->GetSpacing(); + parameters.m_SignalGen.m_ImageOrigin = reference->GetOrigin(); + parameters.m_SignalGen.m_ImageDirection = reference->GetDirection(); + parameters.SetBvalue(static_cast(dwi->GetProperty(mitk::DiffusionPropertyHelper::REFERENCEBVALUEPROPERTYNAME.c_str()).GetPointer() )->GetValue()); + parameters.SetGradienDirections(static_cast( dwi->GetProperty(mitk::DiffusionPropertyHelper::GRADIENTCONTAINERPROPERTYNAME.c_str()).GetPointer() )->GetGradientDirectionsContainer()); - std::string input = us::any_cast(parsedArgs["input"]); + mitk::FiberBundle::Pointer tracts = dynamic_cast(mitk::IOUtil::Load(tract_file, &functor)[0].GetPointer()); int iterations=1000; if (parsedArgs.count("iterations")) iterations = us::any_cast(parsedArgs["iterations"]); float start_temp=1.0; if (parsedArgs.count("start_temp")) start_temp = us::any_cast(parsedArgs["start_temp"]); float end_temp=0.1; if (parsedArgs.count("end_temp")) end_temp = us::any_cast(parsedArgs["end_temp"]); - bool no_diff=false; - if (parsedArgs.count("no_diff")) - no_diff = true; + float tdi_threshold=0.75; + if (parsedArgs.count("tdi_threshold")) + tdi_threshold = us::any_cast(parsedArgs["tdi_threshold"]); + + float sqrt=1.0; + if (parsedArgs.count("sqrt")) + sqrt = us::any_cast(parsedArgs["sqrt"]); - bool no_relax=false; - if (parsedArgs.count("no_relax")) - no_relax = true; + bool fa_error=false; + if (parsedArgs.count("fa_error")) + fa_error = true; std::string fa_file = ""; if (parsedArgs.count("fa_image")) fa_file = us::any_cast(parsedArgs["fa_image"]); std::string md_file = ""; if (parsedArgs.count("md_image")) md_file = us::any_cast(parsedArgs["md_image"]); - std::string mask_file = ""; - if (parsedArgs.count("mask")) - mask_file = us::any_cast(parsedArgs["mask"]); - - if (no_relax && no_diff) + std::vector< int > possible_proposals; + if (!parsedArgs.count("no_diff")) + { + MITK_INFO << "Optimizing diffusivities"; + possible_proposals.push_back(0); + } + if (!parsedArgs.count("no_relax")) + { + MITK_INFO << "Optimizing relaxation constants"; + possible_proposals.push_back(1); + } + if (!parsedArgs.count("no_scale")) + { + MITK_INFO << "Optimizing global signal scale"; + possible_proposals.push_back(2); + } + if (possible_proposals.empty()) { MITK_INFO << "Incompatible options. Nothing to optimize."; return EXIT_FAILURE; } - itk::Image< unsigned char,3 >::Pointer mask = nullptr; - if (mask_file.compare("")!=0) + itk::ImageFileReader< itk::Image< unsigned char, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< unsigned char, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["mask"]) ); + reader->Update(); + itk::Image< unsigned char,3 >::Pointer mask = reader->GetOutput(); + + std::vector< itk::Image< double, 3 >::Pointer > fracs; + if ( parsedArgs.count("tdi")>0 && parsedArgs.count("wm")>0 && parsedArgs.count("gm")>0 && parsedArgs.count("dgm")>0 && parsedArgs.count("csf")>0 ) { - mitk::Image::Pointer mitk_mask = dynamic_cast(mitk::IOUtil::Load(mask_file)[0].GetPointer()); - mitk::CastToItkImage(mitk_mask, mask); + MITK_INFO << "Optimizing volume fractions"; + { + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["tdi"]) ); + reader->Update(); + fracs.push_back(reader->GetOutput()); + } + { + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["wm"]) ); + reader->Update(); + fracs.push_back(reader->GetOutput()); + } + { + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["gm"]) ); + reader->Update(); + fracs.push_back(reader->GetOutput()); + } + { + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["dgm"]) ); + reader->Update(); + fracs.push_back(reader->GetOutput()); + } + { + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( us::any_cast(parsedArgs["csf"]) ); + reader->Update(); + fracs.push_back(reader->GetOutput()); + } + MITK_INFO << "Initial sqrt: " << sqrt; + MITK_INFO << "Initial TDI threshold: " << tdi_threshold; + + possible_proposals.push_back(3); } std::vector< double > histogram_modifiers; itk::Image< double,3 >::Pointer fa_image = nullptr; if (fa_file.compare("")!=0) { - mitk::Image::Pointer mitk_img = dynamic_cast(mitk::IOUtil::Load(fa_file)[0].GetPointer()); - mitk::CastToItkImage(mitk_img, fa_image); + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( fa_file ); + reader->Update(); + fa_image = reader->GetOutput(); int binsPerDimension = 20; using ImageToHistogramFilterType = itk::Statistics::MaskedImageToHistogramFilter< itk::Image< double,3 >, itk::Image< unsigned char,3 > >; ImageToHistogramFilterType::HistogramType::MeasurementVectorType lowerBound(binsPerDimension); lowerBound.Fill(0.0); ImageToHistogramFilterType::HistogramType::MeasurementVectorType upperBound(binsPerDimension); upperBound.Fill(1.0); ImageToHistogramFilterType::HistogramType::SizeType size(1); size.Fill(binsPerDimension); ImageToHistogramFilterType::Pointer imageToHistogramFilter = ImageToHistogramFilterType::New(); imageToHistogramFilter->SetInput( fa_image ); imageToHistogramFilter->SetHistogramBinMinimum( lowerBound ); imageToHistogramFilter->SetHistogramBinMaximum( upperBound ); imageToHistogramFilter->SetHistogramSize( size ); imageToHistogramFilter->SetMaskImage(mask); imageToHistogramFilter->SetMaskValue(1); imageToHistogramFilter->Update(); ImageToHistogramFilterType::HistogramType* histogram = imageToHistogramFilter->GetOutput(); unsigned int max = 0; for(unsigned int i = 0; i < histogram->GetSize()[0]; ++i) { if (histogram->GetFrequency(i)>max) max = histogram->GetFrequency(i); } + MITK_INFO << "FA histogram modifiers:"; for(unsigned int i = 0; i < histogram->GetSize()[0]; ++i) { histogram_modifiers.push_back((double)max/histogram->GetFrequency(i)); - MITK_INFO << histogram_modifiers.back(); + MITK_INFO << std::pow(histogram_modifiers.back(), 4); } + if (fa_error) + MITK_INFO << "Using FA error measure."; } - itk::Image< double,3 >::Pointer md_image = nullptr; if (md_file.compare("")!=0) { - mitk::Image::Pointer mitk_img = dynamic_cast(mitk::IOUtil::Load(md_file)[0].GetPointer()); - mitk::CastToItkImage(mitk_img, md_image); + itk::ImageFileReader< itk::Image< double, 3 > >::Pointer reader = itk::ImageFileReader< itk::Image< double, 3 > >::New(); + reader->SetFileName( md_file ); + reader->Update(); + md_image = reader->GetOutput(); + if (fa_error) + MITK_INFO << "Using MD error measure."; + } + if (fa_error && fa_image.IsNull()) + { + MITK_INFO << "Incompatible options. Need FA image to calculate FA error."; + return EXIT_FAILURE; } - - FiberfoxParameters parameters; - parameters.LoadParameters(paramName); - - MITK_INFO << "Loading target image"; - typedef itk::VectorImage< short, 3 > ItkDwiType; - mitk::PreferenceListReaderOptionsFunctor functor = mitk::PreferenceListReaderOptionsFunctor({"Diffusion Weighted Images", "Fiberbundles"}, {}); - mitk::Image::Pointer dwi = dynamic_cast(mitk::IOUtil::Load(us::any_cast(parsedArgs["target"]), &functor)[0].GetPointer()); - ItkDwiType::Pointer reference = mitk::DiffusionPropertyHelper::GetItkVectorImage(dwi); - parameters.m_SignalGen.m_ImageRegion = reference->GetLargestPossibleRegion(); - parameters.m_SignalGen.m_ImageSpacing = reference->GetSpacing(); - parameters.m_SignalGen.m_ImageOrigin = reference->GetOrigin(); - parameters.m_SignalGen.m_ImageDirection = reference->GetDirection(); - parameters.SetBvalue(static_cast(dwi->GetProperty(mitk::DiffusionPropertyHelper::REFERENCEBVALUEPROPERTYNAME.c_str()).GetPointer() )->GetValue()); - parameters.SetGradienDirections(static_cast( dwi->GetProperty(mitk::DiffusionPropertyHelper::GRADIENTCONTAINERPROPERTYNAME.c_str()).GetPointer() )->GetGradientDirectionsContainer()); - - mitk::BaseData::Pointer inputData = mitk::IOUtil::Load(input, &functor)[0]; itk::TractsToDWIImageFilter< short >::Pointer tractsToDwiFilter = itk::TractsToDWIImageFilter< short >::New(); - tractsToDwiFilter->SetFiberBundle(dynamic_cast(inputData.GetPointer())); + tractsToDwiFilter->SetFiberBundle(tracts); tractsToDwiFilter->SetParameters(parameters); tractsToDwiFilter->Update(); ItkDwiType::Pointer sim = tractsToDwiFilter->GetOutput(); - { mitk::Image::Pointer image = mitk::GrabItkImageMemory( tractsToDwiFilter->GetOutput() ); image->GetPropertyList()->ReplaceProperty( mitk::DiffusionPropertyHelper::GRADIENTCONTAINERPROPERTYNAME.c_str(), mitk::GradientDirectionsProperty::New( parameters.m_SignalGen.GetGradientDirections() ) ); image->GetPropertyList()->ReplaceProperty( mitk::DiffusionPropertyHelper::REFERENCEBVALUEPROPERTYNAME.c_str(), mitk::FloatProperty::New( parameters.m_SignalGen.GetBvalue() ) ); mitk::DiffusionPropertyHelper propertyHelper( image ); propertyHelper.InitializeImage(); - mitk::IOUtil::Save(image, "initial.dwi"); + mitk::IOUtil::Save(image, out_folder + "/initial.dwi"); } + + double old_tdi_thr = tdi_threshold; + double old_sqrt = sqrt; + double new_tdi_thr; + double new_sqrt; MITK_INFO << "\n\n"; MITK_INFO << "Iterations: " << iterations; MITK_INFO << "start_temp: " << start_temp; MITK_INFO << "end_temp: " << end_temp; double alpha = log(end_temp/start_temp); int accepted = 0; double last_error = 9999999; - if (fa_image.IsNotNull()) + if (fa_error) { MITK_INFO << "Calculating FA error"; - last_error = CalcErrorFA(histogram_modifiers, dwi, sim, mask, fa_image, md_image); + last_error = CalcErrorFA(histogram_modifiers, dwi, sim, mask, fa_image, md_image, true); } else { MITK_INFO << "Calculating raw-image error"; - last_error = CalcErrorSignal(reference, sim, mask); + last_error = CalcErrorSignal(histogram_modifiers, reference, sim, mask, fa_image); } MITK_INFO << "Initial E = " << last_error; MITK_INFO << "\n\n**************************************************************************************"; + std::random_device r; + std::default_random_engine randgen(r()); + std::uniform_int_distribution uint1(0, possible_proposals.size()-1); for (int i=0; i uint1(0, 1); - FiberfoxParameters proposal(parameters); int select = uint1(randgen); - if (no_relax) - select = 0; - else if (no_diff) - select = 1; - - if (select==0) + switch (possible_proposals.at(select)) + { + case 0: + { proposal = MakeProposalDiff(proposal, temperature); - else + break; + } + case 1: + { proposal = MakeProposalRelaxation(proposal, temperature); + break; + } + case 2: + { + proposal = MakeProposalScale(proposal, temperature); + break; + } + case 3: + { + proposal = MakeProposalVolume(old_tdi_thr, old_sqrt, new_tdi_thr, new_sqrt, proposal, fracs, temperature); + break; + } + } std::streambuf *old = cout.rdbuf(); // <-- save std::stringstream ss; std::cout.rdbuf (ss.rdbuf()); itk::TractsToDWIImageFilter< short >::Pointer tractsToDwiFilter = itk::TractsToDWIImageFilter< short >::New(); - tractsToDwiFilter->SetFiberBundle(dynamic_cast(inputData.GetPointer())); + tractsToDwiFilter->SetFiberBundle(dynamic_cast(tracts.GetPointer())); tractsToDwiFilter->SetParameters(proposal); tractsToDwiFilter->Update(); ItkDwiType::Pointer sim = tractsToDwiFilter->GetOutput(); std::cout.rdbuf (old); // <-- restore double new_error = 9999999; - if (fa_image.IsNotNull()) - new_error = CalcErrorFA(histogram_modifiers, dwi, sim, mask, fa_image, md_image); + if (fa_error && fa_image.IsNotNull()) + new_error = CalcErrorFA(histogram_modifiers, dwi, sim, mask, fa_image, md_image, true); else - new_error = CalcErrorSignal(reference, sim, mask); + new_error = CalcErrorSignal(histogram_modifiers, reference, sim, mask, fa_image); MITK_INFO << "E = " << new_error << "(" << new_error-last_error << ")"; if (last_errorGetOutput() ); image->GetPropertyList()->ReplaceProperty( mitk::DiffusionPropertyHelper::GRADIENTCONTAINERPROPERTYNAME.c_str(), mitk::GradientDirectionsProperty::New( parameters.m_SignalGen.GetGradientDirections() ) ); image->GetPropertyList()->ReplaceProperty( mitk::DiffusionPropertyHelper::REFERENCEBVALUEPROPERTYNAME.c_str(), mitk::FloatProperty::New( parameters.m_SignalGen.GetBvalue() ) ); mitk::DiffusionPropertyHelper propertyHelper( image ); propertyHelper.InitializeImage(); - mitk::IOUtil::Save(image, "optimized.dwi"); + mitk::IOUtil::Save(image, out_folder + "/optimized.dwi"); - proposal.SaveParameters("optimized.ffp"); + proposal.SaveParameters(out_folder + "/optimized.ffp"); std::cout.rdbuf (old); // <-- restore accepted++; + old_tdi_thr = new_tdi_thr; + old_sqrt = new_sqrt; + MITK_INFO << "Accepted (acc. rate " << (float)accepted/(i+1) << ")"; parameters = FiberfoxParameters(proposal); last_error = new_error; } MITK_INFO << "\n\n\n"; } return EXIT_SUCCESS; } diff --git a/Modules/DiffusionImaging/FiberTracking/files.cmake b/Modules/DiffusionImaging/FiberTracking/files.cmake index 8c374b4b53..9d63278f86 100644 --- a/Modules/DiffusionImaging/FiberTracking/files.cmake +++ b/Modules/DiffusionImaging/FiberTracking/files.cmake @@ -1,105 +1,106 @@ set(CPP_FILES mitkFiberTrackingModuleActivator.cpp ## IO datastructures IODataStructures/FiberBundle/mitkFiberBundle.cpp IODataStructures/FiberBundle/mitkTrackvis.cpp IODataStructures/PlanarFigureComposite/mitkPlanarFigureComposite.cpp IODataStructures/mitkTractographyForest.cpp IODataStructures/mitkFiberfoxParameters.cpp # Interactions # Tractography Algorithms/GibbsTracking/mitkParticleGrid.cpp Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.cpp Algorithms/GibbsTracking/mitkEnergyComputer.cpp Algorithms/GibbsTracking/mitkGibbsEnergyComputer.cpp Algorithms/GibbsTracking/mitkFiberBuilder.cpp Algorithms/GibbsTracking/mitkSphereInterpolator.cpp Algorithms/itkStreamlineTrackingFilter.cpp Algorithms/TrackingHandlers/mitkTrackingDataHandler.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerTensor.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerPeaks.cpp Algorithms/TrackingHandlers/mitkTrackingHandlerOdf.cpp ) set(H_FILES # DataStructures -> FiberBundle IODataStructures/FiberBundle/mitkFiberBundle.h IODataStructures/FiberBundle/mitkTrackvis.h IODataStructures/mitkFiberfoxParameters.h IODataStructures/mitkTractographyForest.h # Algorithms Algorithms/itkTractDensityImageFilter.h Algorithms/itkTractsToFiberEndingsImageFilter.h Algorithms/itkTractsToRgbaImageFilter.h Algorithms/itkTractsToVectorImageFilter.h Algorithms/itkEvaluateDirectionImagesFilter.h Algorithms/itkEvaluateTractogramDirectionsFilter.h Algorithms/itkFiberCurvatureFilter.h Algorithms/itkFitFibersToImageFilter.h Algorithms/itkTractClusteringFilter.h Algorithms/itkFiberExtractionFilter.h + Algorithms/itkTdiToVolumeFractionFilter.h # Tractography Algorithms/TrackingHandlers/mitkTrackingDataHandler.h Algorithms/TrackingHandlers/mitkTrackingHandlerRandomForest.h Algorithms/TrackingHandlers/mitkTrackingHandlerTensor.h Algorithms/TrackingHandlers/mitkTrackingHandlerPeaks.h Algorithms/TrackingHandlers/mitkTrackingHandlerOdf.h Algorithms/itkGibbsTrackingFilter.h Algorithms/itkStochasticTractographyFilter.h Algorithms/GibbsTracking/mitkParticle.h Algorithms/GibbsTracking/mitkParticleGrid.h Algorithms/GibbsTracking/mitkMetropolisHastingsSampler.h Algorithms/GibbsTracking/mitkSimpSamp.h Algorithms/GibbsTracking/mitkEnergyComputer.h Algorithms/GibbsTracking/mitkGibbsEnergyComputer.h Algorithms/GibbsTracking/mitkSphereInterpolator.h Algorithms/GibbsTracking/mitkFiberBuilder.h Algorithms/itkStreamlineTrackingFilter.h # Clustering Algorithms/ClusteringMetrics/mitkClusteringMetric.h Algorithms/ClusteringMetrics/mitkClusteringMetricEuclideanMean.h Algorithms/ClusteringMetrics/mitkClusteringMetricEuclideanMax.h Algorithms/ClusteringMetrics/mitkClusteringMetricEuclideanStd.h Algorithms/ClusteringMetrics/mitkClusteringMetricAnatomic.h Algorithms/ClusteringMetrics/mitkClusteringMetricScalarMap.h Algorithms/ClusteringMetrics/mitkClusteringMetricInnerAngles.h Algorithms/ClusteringMetrics/mitkClusteringMetricLength.h # Fiberfox Fiberfox/itkFibersFromPlanarFiguresFilter.h Fiberfox/itkTractsToDWIImageFilter.h Fiberfox/itkKspaceImageFilter.h Fiberfox/itkDftImageFilter.h Fiberfox/itkFieldmapGeneratorFilter.h Fiberfox/SignalModels/mitkDiffusionSignalModel.h Fiberfox/SignalModels/mitkTensorModel.h Fiberfox/SignalModels/mitkBallModel.h Fiberfox/SignalModels/mitkDotModel.h Fiberfox/SignalModels/mitkAstroStickModel.h Fiberfox/SignalModels/mitkStickModel.h Fiberfox/SignalModels/mitkRawShModel.h Fiberfox/SignalModels/mitkDiffusionNoiseModel.h Fiberfox/SignalModels/mitkRicianNoiseModel.h Fiberfox/SignalModels/mitkChiSquareNoiseModel.h Fiberfox/Sequences/mitkAcquisitionType.h Fiberfox/Sequences/mitkSingleShotEpi.h Fiberfox/Sequences/mitkCartesianReadout.h ) set(RESOURCE_FILES # Binary directory resources FiberTrackingLUTBaryCoords.bin FiberTrackingLUTIndices.bin )