diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp index 6176198c36..b0e835ebf5 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp @@ -1,459 +1,460 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkEnergyComputer.h" #include +#include using namespace mitk; EnergyComputer::EnergyComputer(ItkQBallImgType* qballImage, ItkFloatImageType* mask, ParticleGrid* particleGrid, SphereInterpolator* interpolator, ItkRandGenType* randGen) : m_UseTrilinearInterpolation(true) { m_ParticleGrid = particleGrid; m_RandGen = randGen; m_Image = qballImage; m_SphereInterpolator = interpolator; m_Mask = mask; m_ParticleLength = m_ParticleGrid->m_ParticleLength; m_SquaredParticleLength = m_ParticleLength*m_ParticleLength; m_Size[0] = m_Image->GetLargestPossibleRegion().GetSize()[0]; m_Size[1] = m_Image->GetLargestPossibleRegion().GetSize()[1]; m_Size[2] = m_Image->GetLargestPossibleRegion().GetSize()[2]; if (m_Size[0]<3 || m_Size[1]<3 || m_Size[2]<3) m_UseTrilinearInterpolation = false; m_Spacing[0] = m_Image->GetSpacing()[0]; m_Spacing[1] = m_Image->GetSpacing()[1]; m_Spacing[2] = m_Image->GetSpacing()[2]; // calculate rotation matrix vnl_matrix temp = m_Image->GetDirection().GetVnlMatrix(); vnl_matrix directionMatrix; directionMatrix.set_size(3,3); vnl_copy(temp, directionMatrix); vnl_vector_fixed d0 = directionMatrix.get_column(0); d0.normalize(); vnl_vector_fixed d1 = directionMatrix.get_column(1); d1.normalize(); vnl_vector_fixed d2 = directionMatrix.get_column(2); d2.normalize(); directionMatrix.set_column(0, d0); directionMatrix.set_column(1, d1); directionMatrix.set_column(2, d2); vnl_matrix_fixed I = directionMatrix*directionMatrix.transpose(); if(!I.is_identity(mitk::eps)) fprintf(stderr,"itkGibbsTrackingFilter: image direction is not a rotation matrix. Tracking not possible!\n"); m_RotationMatrix = directionMatrix; if (QBALL_ODFSIZE != m_SphereInterpolator->nverts) fprintf(stderr,"EnergyComputer: error during init: data does not match with interpolation scheme\n"); int totsz = m_Size[0]*m_Size[1]*m_Size[2]; m_CumulatedSpatialProbability.resize(totsz, 0.0); // +1? m_ActiveIndices.resize(totsz, 0); // calculate active voxels and cumulate probabilities m_NumActiveVoxels = 0; m_CumulatedSpatialProbability[0] = 0; for (int x = 0; x < m_Size[0];x++) for (int y = 0; y < m_Size[1];y++) for (int z = 0; z < m_Size[2];z++) { int idx = x+(y+z*m_Size[1])*m_Size[0]; ItkFloatImageType::IndexType index; index[0] = x; index[1] = y; index[2] = z; if (m_Mask->GetPixel(index) > 0.5) { m_CumulatedSpatialProbability[m_NumActiveVoxels+1] = m_CumulatedSpatialProbability[m_NumActiveVoxels] + m_Mask->GetPixel(index); m_ActiveIndices[m_NumActiveVoxels] = idx; m_NumActiveVoxels++; } } for (int k = 0; k < m_NumActiveVoxels; k++) m_CumulatedSpatialProbability[k] /= m_CumulatedSpatialProbability[m_NumActiveVoxels]; std::cout << "EnergyComputer: " << m_NumActiveVoxels << " active voxels found" << std::endl; } void EnergyComputer::SetParameters(float particleWeight, float particleWidth, float connectionPotential, float curvThres, float inexBalance, float particlePotential) { m_ParticleChemicalPotential = particlePotential; m_ConnectionPotential = connectionPotential; m_ParticleWeight = particleWeight; float bal = 1/(1+exp(-inexBalance)); m_ExtStrength = 2*bal; m_IntStrength = 2*(1-bal)/m_SquaredParticleLength; m_CurvatureThreshold = curvThres; float sigma_s = particleWidth; gamma_s = 1/(sigma_s*sigma_s); gamma_reg_s =1/(m_SquaredParticleLength/4); } // draw random position from active voxels void EnergyComputer::DrawRandomPosition(vnl_vector_fixed& R) { float r = m_RandGen->GetVariate();//m_RandGen->frand(); int j; int rl = 1; int rh = m_NumActiveVoxels; while(rh != rl) { j = rl + (rh-rl)/2; if (r < m_CumulatedSpatialProbability[j]) { rh = j; continue; } if (r > m_CumulatedSpatialProbability[j]) { rl = j+1; continue; } break; } R[0] = m_Spacing[0]*((float)(m_ActiveIndices[rh-1] % m_Size[0]) + m_RandGen->GetVariate()); R[1] = m_Spacing[1]*((float)((m_ActiveIndices[rh-1]/m_Size[0]) % m_Size[1]) + m_RandGen->GetVariate()); R[2] = m_Spacing[2]*((float)(m_ActiveIndices[rh-1]/(m_Size[0]*m_Size[1])) + m_RandGen->GetVariate()); } // return spatial probability of position float EnergyComputer::SpatProb(vnl_vector_fixed pos) { ItkFloatImageType::IndexType index; index[0] = floor(pos[0]/m_Spacing[0]); index[1] = floor(pos[1]/m_Spacing[1]); index[2] = floor(pos[2]/m_Spacing[2]); if (m_Mask->GetLargestPossibleRegion().IsInside(index)) // is inside image? return m_Mask->GetPixel(index); else return 0; } float EnergyComputer::EvaluateOdf(vnl_vector_fixed& pos, vnl_vector_fixed dir) { const int sampleSteps = 10; // evaluate ODF at 2*sampleSteps+1 positions along dir vnl_vector_fixed samplePos; // current position to evaluate float result = 0; // average of sampled ODF values int xint, yint, zint; // voxel containing samplePos // rotate particle direction according to image rotation dir = m_RotationMatrix*dir; // get interpolation for rotated direction m_SphereInterpolator->getInterpolation(dir); // sample ODF values along particle direction for (int i=-sampleSteps; i <= sampleSteps;i++) { samplePos = pos + (dir * m_ParticleLength) * ((float)i/sampleSteps); if (!m_UseTrilinearInterpolation) // image has not enough slices to use trilinear interpolation { ItkQBallImgType::IndexType index; index[0] = floor(pos[0]/m_Spacing[0]); index[1] = floor(pos[1]/m_Spacing[1]); index[2] = floor(pos[2]/m_Spacing[2]); if (m_Image->GetLargestPossibleRegion().IsInside(index)) { result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2]); } } else // use trilinear interpolation { float Rx = samplePos[0]/m_Spacing[0]-0.5; float Ry = samplePos[1]/m_Spacing[1]-0.5; float Rz = samplePos[2]/m_Spacing[2]-0.5; xint = floor(Rx); yint = floor(Ry); zint = floor(Rz); if (xint >= 0 && xint < m_Size[0]-1 && yint >= 0 && yint < m_Size[1]-1 && zint >= 0 && zint < m_Size[2]-1) { float xfrac = Rx-xint; float yfrac = Ry-yint; float zfrac = Rz-zint; ItkQBallImgType::IndexType index; float weight; weight = (1-xfrac)*(1-yfrac)*(1-zfrac); index[0] = xint; index[1] = yint; index[2] = zint; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(1-yfrac)*(1-zfrac); index[0] = xint+1; index[1] = yint; index[2] = zint; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(yfrac)*(1-zfrac); index[0] = xint; index[1] = yint+1; index[2] = zint; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(1-yfrac)*(zfrac); index[0] = xint; index[1] = yint; index[2] = zint+1; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(yfrac)*(1-zfrac); index[0] = xint+1; index[1] = yint+1; index[2] = zint; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(yfrac)*(zfrac); index[0] = xint; index[1] = yint+1; index[2] = zint+1; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(1-yfrac)*(zfrac); index[0] = xint+1; index[1] = yint; index[2] = zint+1; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(yfrac)*(zfrac); index[0] = xint+1; index[1] = yint+1; index[2] = zint+1; result += (m_Image->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + m_Image->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; } } } result /= (2*sampleSteps+1); // average result over taken samples return result; } float EnergyComputer::ComputeExternalEnergy(vnl_vector_fixed &R, vnl_vector_fixed &N, Particle *dp) { if (SpatProb(R) == 0) // check if position is inside mask - return -INFINITY; + return itk::NumericTraits::NonpositiveMin(); float odfVal = EvaluateOdf(R, N); // evaluate ODF in given direction float modelVal = 0; m_ParticleGrid->ComputeNeighbors(R); // retrieve neighbouring particles from particle grid Particle* neighbour = m_ParticleGrid->GetNextNeighbor(); while (neighbour!=NULL) // iterate over nieghbouring particles { if (dp != neighbour) // don't evaluate against itself { // see Reisert et al. "Global Reconstruction of Neuronal Fibers", MICCAI 2009 float dot = fabs(dot_product(N,neighbour->dir)); float bw = mbesseli0(dot); float dpos = (neighbour->pos-R).squared_magnitude(); float w = mexp(dpos*gamma_s); modelVal += w*(bw+m_ParticleChemicalPotential); w = mexp(dpos*gamma_reg_s); } neighbour = m_ParticleGrid->GetNextNeighbor(); } float energy = 2*(odfVal/m_ParticleWeight-modelVal) - (mbesseli0(1.0)+m_ParticleChemicalPotential); return energy*m_ExtStrength; } float EnergyComputer::ComputeInternalEnergy(Particle *dp) { float energy = 0; if (dp->pID != -1) // has predecessor energy += ComputeInternalEnergyConnection(dp,+1); if (dp->mID != -1) // has successor energy += ComputeInternalEnergyConnection(dp,-1); return energy; } float EnergyComputer::ComputeInternalEnergyConnection(Particle *p1,int ep1) { Particle *p2 = 0; int ep2; if (ep1 == 1) p2 = m_ParticleGrid->GetParticle(p1->pID); // get predecessor else p2 = m_ParticleGrid->GetParticle(p1->mID); // get successor // check in which direction the connected particle is pointing if (p2->mID == p1->ID) ep2 = -1; else if (p2->pID == p1->ID) ep2 = 1; else std::cout << "EnergyComputer: Connections are inconsistent!" << std::endl; return ComputeInternalEnergyConnection(p1,ep1,p2,ep2); } float EnergyComputer::ComputeInternalEnergyConnection(Particle *p1,int ep1, Particle *p2, int ep2) { // see Reisert et al. "Global Reconstruction of Neuronal Fibers", MICCAI 2009 if ((dot_product(p1->dir,p2->dir))*ep1*ep2 > -m_CurvatureThreshold) // angle between particles is too sharp - return -INFINITY; + return itk::NumericTraits::NonpositiveMin(); // calculate the endpoints of the two particles vnl_vector_fixed endPoint1 = p1->pos + (p1->dir * (m_ParticleLength * ep1)); vnl_vector_fixed endPoint2 = p2->pos + (p2->dir * (m_ParticleLength * ep2)); // check if endpoints are too far apart to connect if ((endPoint1-endPoint2).squared_magnitude() > m_SquaredParticleLength) - return -INFINITY; + return itk::NumericTraits::NonpositiveMin(); // calculate center point of the two particles vnl_vector_fixed R = (p2->pos + p1->pos); R *= 0.5; // they are not allowed to connect if the mask image does not allow it if (SpatProb(R) == 0) - return -INFINITY; + return itk::NumericTraits::NonpositiveMin(); // get distances of endpoints to center point float norm1 = (endPoint1-R).squared_magnitude(); float norm2 = (endPoint2-R).squared_magnitude(); // calculate actual internal energy float energy = (m_ConnectionPotential-norm1-norm2)*m_IntStrength; return energy; } float EnergyComputer::mbesseli0(float x) { // BESSEL_APPROXCOEFF[0] = -0.1714; // BESSEL_APPROXCOEFF[1] = 0.5332; // BESSEL_APPROXCOEFF[2] = -1.4889; // BESSEL_APPROXCOEFF[3] = 2.0389; float y = x*x; float erg = -0.1714; erg += y*0.5332; erg += y*y*-1.4889; erg += y*y*y*2.0389; return erg; } float EnergyComputer::mexp(float x) { return((x>=7.0) ? 0 : ((x>=5.0) ? (-0.0029*x+0.0213) : ((x>=3.0) ? (-0.0215*x+0.1144) : ((x>=2.0) ? (-0.0855*x+0.3064) : ((x>=1.0) ? (-0.2325*x+0.6004) : ((x>=0.5) ? (-0.4773*x+0.8452) : ((x>=0.0) ? (-0.7869*x+1.0000) : 1 ))))))); // return exp(-x); } //ComputeFiberCorrelation() //{ // float bD = 15; // vnl_matrix_fixed bDir = // *itk::PointShell >::DistributePointShell(); // const int N = QBALL_ODFSIZE; // vnl_matrix_fixed temp = bDir.transpose(); // vnl_matrix_fixed C = temp*bDir; // vnl_matrix_fixed Q = C; // vnl_vector_fixed mean; // for(int i=0; i repMean; // for (int i=0; i P = Q*Q; // std::vector pointer; // pointer.reserve(N*N); // double * start = C.data_block(); // double * end = start + N*N; // for (double * iter = start; iter != end; ++iter) // { // pointer.push_back(iter); // } // std::sort(pointer.begin(), pointer.end(), LessDereference()); // vnl_vector_fixed alpha; // vnl_vector_fixed beta; // for (int i=0; im_Meanval_sq = (sum*sum)/N; // vnl_vector_fixed alpha_0; // vnl_vector_fixed alpha_2; // vnl_vector_fixed alpha_4; // vnl_vector_fixed alpha_6; // for(int i=0; i T; // T.set_column(0,alpha_0); // T.set_column(1,alpha_2); // T.set_column(2,alpha_4); // T.set_column(3,alpha_6); // vnl_vector_fixed coeff = vnl_matrix_inverse(T).pinverse()*beta; // MITK_INFO << "itkGibbsTrackingFilter: Bessel oefficients: " << coeff; // BESSEL_APPROXCOEFF = new float[4]; // BESSEL_APPROXCOEFF[0] = coeff(0); // BESSEL_APPROXCOEFF[1] = coeff(1); // BESSEL_APPROXCOEFF[2] = coeff(2); // BESSEL_APPROXCOEFF[3] = coeff(3); // BESSEL_APPROXCOEFF[0] = -0.1714; // BESSEL_APPROXCOEFF[1] = 0.5332; // BESSEL_APPROXCOEFF[2] = -1.4889; // BESSEL_APPROXCOEFF[3] = 2.0389; //} diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h index 21f72bfd26..d847131fb0 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h @@ -1,84 +1,84 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _ENCOMP #define _ENCOMP #include #include #include #include #include using namespace mitk; class MitkDiffusionImaging_EXPORT EnergyComputer { public: typedef itk::Vector OdfVectorType; typedef itk::Image ItkQBallImgType; typedef itk::Image ItkFloatImageType; typedef itk::Statistics::MersenneTwisterRandomVariateGenerator ItkRandGenType; EnergyComputer(ItkQBallImgType* qballImage, ItkFloatImageType* mask, ParticleGrid* particleGrid, SphereInterpolator* interpolator, ItkRandGenType* randGen); void SetParameters(float particleWeight, float particleWidth, float connectionPotential, float curvThres, float inexBalance, float particlePotential); // get random position inside mask void DrawRandomPosition(vnl_vector_fixed& R); // external energy calculation float ComputeExternalEnergy(vnl_vector_fixed& R, vnl_vector_fixed& N, Particle* dp); // internal energy calculation float ComputeInternalEnergyConnection(Particle *p1,int ep1); float ComputeInternalEnergyConnection(Particle *p1,int ep1, Particle *p2, int ep2); float ComputeInternalEnergy(Particle *dp); protected: vnl_matrix_fixed m_RotationMatrix; SphereInterpolator* m_SphereInterpolator; ParticleGrid* m_ParticleGrid; ItkRandGenType* m_RandGen; ItkQBallImgType* m_Image; ItkFloatImageType* m_Mask; vnl_vector_fixed m_Size; vnl_vector_fixed m_Spacing; std::vector< float > m_CumulatedSpatialProbability; std::vector< int > m_ActiveIndices; // indices inside mask bool m_UseTrilinearInterpolation; // is deactivated if less than 3 image slices are available int m_NumActiveVoxels; // voxels inside mask float m_ConnectionPotential; // larger value results in larger energy value -> higher proposal acceptance probability float m_ParticleChemicalPotential; // larger value results in larger energy value -> higher proposal acceptance probability float gamma_s; float gamma_reg_s; float m_ParticleWeight; // defines how much one particle contributes to the artificial signal - float m_ExtStrength; - float m_IntStrength; - float m_ParticleLength; - float m_SquaredParticleLength; - float m_CurvatureThreshold; + float m_ExtStrength; // weighting factor for external energy + float m_IntStrength; // weighting factor for internal energy + float m_ParticleLength; // particle length + float m_SquaredParticleLength; // squared particle length + float m_CurvatureThreshold; // maximum angle accepted between two connected particles float SpatProb(vnl_vector_fixed pos); float EvaluateOdf(vnl_vector_fixed &pos, vnl_vector_fixed dir); float mbesseli0(float x); float mexp(float x); }; #endif diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp index c8c7c62d23..0007ba3368 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp @@ -1,478 +1,478 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkMetropolisHastingsSampler.h" using namespace mitk; MetropolisHastingsSampler::MetropolisHastingsSampler(ParticleGrid* grid, EnergyComputer* enComp, ItkRandGenType* randGen, float curvThres) : m_AcceptedProposals(0) , m_ExTemp(0.01) , m_BirthProb(0.25) , m_DeathProb(0.05) , m_ShiftProb(0.15) , m_OptShiftProb(0.1) , m_ConnectionProb(0.45) , m_TractProb(0.5) , m_DelProb(0.1) , m_ChempotParticle(0.0) { m_RandGen = randGen; m_ParticleGrid = grid; m_EnergyComputer = enComp; m_ParticleLength = m_ParticleGrid->m_ParticleLength; m_DistanceThreshold = m_ParticleLength*m_ParticleLength; m_Sigma = m_ParticleLength/8.0; m_Gamma = 1/(m_Sigma*m_Sigma*2); m_Z = pow(2*M_PI*m_Sigma,3.0/2.0)*(M_PI*m_Sigma/m_ParticleLength); m_CurvatureThreshold = curvThres; m_StopProb = exp(-1/m_TractProb); } void MetropolisHastingsSampler::SetProbabilities(float birth, float death, float shift, float optShift, float connect) { m_BirthProb = birth; m_DeathProb = death; m_ShiftProb = shift; m_OptShiftProb = optShift; m_ConnectionProb = connect; float sum = m_BirthProb+m_DeathProb+m_ShiftProb+m_OptShiftProb+m_ConnectionProb; if (sum!=1 && sum>mitk::eps) { m_BirthProb /= sum; m_DeathProb /= sum; m_ShiftProb /= sum; m_OptShiftProb /= sum; m_ConnectionProb /= sum; } std::cout << "Update proposal probabilities:" << std::endl; std::cout << "Birth: " << m_BirthProb << std::endl; std::cout << "Death: " << m_DeathProb << std::endl; std::cout << "Shift: " << m_ShiftProb << std::endl; std::cout << "Optimal shift: " << m_OptShiftProb << std::endl; std::cout << "Connection: " << m_ConnectionProb << std::endl; } // update temperature of simulated annealing process void MetropolisHastingsSampler::SetTemperature(float val) { m_InTemp = val; m_Density = exp(-m_ChempotParticle/m_InTemp); } // add small random number drawn from gaussian to each vector element -vnl_vector_fixed MetropolisHastingsSampler::DistortVector(float sigma, vnl_vector_fixed& vec) +void MetropolisHastingsSampler::DistortVector(float sigma, vnl_vector_fixed& vec) { vec[0] += m_RandGen->GetNormalVariate(0.0, sigma); vec[1] += m_RandGen->GetNormalVariate(0.0, sigma); vec[2] += m_RandGen->GetNormalVariate(0.0, sigma); } // generate normalized random vector vnl_vector_fixed MetropolisHastingsSampler::GetRandomDirection() { vnl_vector_fixed vec; vec[0] = m_RandGen->GetNormalVariate(); vec[1] = m_RandGen->GetNormalVariate(); vec[2] = m_RandGen->GetNormalVariate(); vec.normalize(); return vec; } // generate actual proposal (birth, death, shift and connection of particle) void MetropolisHastingsSampler::MakeProposal() { float randnum = m_RandGen->GetVariate(); // Birth Proposal if (randnum < m_BirthProb) { vnl_vector_fixed R; m_EnergyComputer->DrawRandomPosition(R); vnl_vector_fixed N = GetRandomDirection(); Particle prop; prop.pos = R; prop.dir = N; float prob = m_Density * m_DeathProb /((m_BirthProb)*(m_ParticleGrid->m_NumParticles+1)); float ex_energy = m_EnergyComputer->ComputeExternalEnergy(R,N,0); float in_energy = m_EnergyComputer->ComputeInternalEnergy(&prop); prob *= exp((in_energy/m_InTemp+ex_energy/m_ExTemp)) ; if (prob > 1 || m_RandGen->GetVariate() < prob) { Particle *p = m_ParticleGrid->NewParticle(R); if (p!=0) { p->pos = R; p->dir = N; m_AcceptedProposals++; } } } // Death Proposal else if (randnum < m_BirthProb+m_DeathProb) { if (m_ParticleGrid->m_NumParticles > 0) { int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles; Particle *dp = m_ParticleGrid->GetParticle(pnum); if (dp->pID == -1 && dp->mID == -1) { float ex_energy = m_EnergyComputer->ComputeExternalEnergy(dp->pos,dp->dir,dp); float in_energy = m_EnergyComputer->ComputeInternalEnergy(dp); float prob = m_ParticleGrid->m_NumParticles * (m_BirthProb) /(m_Density*m_DeathProb); //*SpatProb(dp->R); prob *= exp(-(in_energy/m_InTemp+ex_energy/m_ExTemp)) ; if (prob > 1 || m_RandGen->GetVariate() < prob) { m_ParticleGrid->RemoveParticle(pnum); m_AcceptedProposals++; } } } } // Shift Proposal else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb) { if (m_ParticleGrid->m_NumParticles > 0) { int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); Particle prop_p = *p; DistortVector(m_Sigma, prop_p.pos); DistortVector(m_Sigma/(2*m_ParticleLength), prop_p.dir); prop_p.dir.normalize(); float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.pos,prop_p.dir,p) - m_EnergyComputer->ComputeExternalEnergy(p->pos,p->dir,p); float in_energy = m_EnergyComputer->ComputeInternalEnergy(&prop_p) - m_EnergyComputer->ComputeInternalEnergy(p); float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp); if (m_RandGen->GetVariate() < prob) { vnl_vector_fixed Rtmp = p->pos; vnl_vector_fixed Ntmp = p->dir; p->pos = prop_p.pos; p->dir = prop_p.dir; if (!m_ParticleGrid->TryUpdateGrid(pnum)) { p->pos = Rtmp; p->dir = Ntmp; } m_AcceptedProposals++; } } } // Optimal Shift Proposal else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb+m_OptShiftProb) { if (m_ParticleGrid->m_NumParticles > 0) { int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); bool no_proposal = false; Particle prop_p = *p; if (p->pID != -1 && p->mID != -1) { Particle *plus = m_ParticleGrid->GetParticle(p->pID); int ep_plus = (plus->pID == p->ID)? 1 : -1; Particle *minus = m_ParticleGrid->GetParticle(p->mID); int ep_minus = (minus->pID == p->ID)? 1 : -1; prop_p.pos = (plus->pos + plus->dir * (m_ParticleLength * ep_plus) + minus->pos + minus->dir * (m_ParticleLength * ep_minus)); prop_p.pos *= 0.5; prop_p.dir = plus->pos - minus->pos; prop_p.dir.normalize(); } else if (p->pID != -1) { Particle *plus = m_ParticleGrid->GetParticle(p->pID); int ep_plus = (plus->pID == p->ID)? 1 : -1; prop_p.pos = plus->pos + plus->dir * (m_ParticleLength * ep_plus * 2); prop_p.dir = plus->dir; } else if (p->mID != -1) { Particle *minus = m_ParticleGrid->GetParticle(p->mID); int ep_minus = (minus->pID == p->ID)? 1 : -1; prop_p.pos = minus->pos + minus->dir * (m_ParticleLength * ep_minus * 2); prop_p.dir = minus->dir; } else no_proposal = true; if (!no_proposal) { float cos = dot_product(prop_p.dir, p->dir); float p_rev = exp(-((prop_p.pos-p->pos).squared_magnitude() + (1-cos*cos))*m_Gamma)/m_Z; float ex_energy = m_EnergyComputer->ComputeExternalEnergy(prop_p.pos,prop_p.dir,p) - m_EnergyComputer->ComputeExternalEnergy(p->pos,p->dir,p); float in_energy = m_EnergyComputer->ComputeInternalEnergy(&prop_p) - m_EnergyComputer->ComputeInternalEnergy(p); float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp)*m_ShiftProb*p_rev/(m_OptShiftProb+m_ShiftProb*p_rev); if (m_RandGen->GetVariate() < prob) { vnl_vector_fixed Rtmp = p->pos; vnl_vector_fixed Ntmp = p->dir; p->pos = prop_p.pos; p->dir = prop_p.dir; if (!m_ParticleGrid->TryUpdateGrid(pnum)) { p->pos = Rtmp; p->dir = Ntmp; } m_AcceptedProposals++; } } } } else { if (m_ParticleGrid->m_NumParticles > 0) { int pnum = m_RandGen->GetIntegerVariate()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); EndPoint P; P.p = p; P.ep = (m_RandGen->GetVariate() > 0.5)? 1 : -1; RemoveAndSaveTrack(P); if (m_BackupTrack.m_Probability != 0) { MakeTrackProposal(P); float prob = (m_ProposalTrack.m_Energy-m_BackupTrack.m_Energy)/m_InTemp ; prob = exp(prob)*(m_BackupTrack.m_Probability * pow(m_DelProb,m_ProposalTrack.m_Length)) /(m_ProposalTrack.m_Probability * pow(m_DelProb,m_BackupTrack.m_Length)); if (m_RandGen->GetVariate() < prob) { ImplementTrack(m_ProposalTrack); m_AcceptedProposals++; } else { ImplementTrack(m_BackupTrack); } } else ImplementTrack(m_BackupTrack); } } } // establish connections between particles stored in input Track void MetropolisHastingsSampler::ImplementTrack(Track &T) { for (int k = 1; k < T.m_Length;k++) m_ParticleGrid->CreateConnection(T.track[k-1].p,T.track[k-1].ep,T.track[k].p,-T.track[k].ep); } // remove pending track from random particle, save it in m_BackupTrack and calculate its probability void MetropolisHastingsSampler::RemoveAndSaveTrack(EndPoint P) { EndPoint Current = P; int cnt = 0; float energy = 0; float AccumProb = 1.0; m_BackupTrack.track[cnt] = Current; EndPoint Next; for (;;) { Next.p = 0; if (Current.ep == 1) { if (Current.p->pID != -1) { Next.p = m_ParticleGrid->GetParticle(Current.p->pID); Current.p->pID = -1; m_ParticleGrid->m_NumConnections--; } } else if (Current.ep == -1) { if (Current.p->mID != -1) { Next.p = m_ParticleGrid->GetParticle(Current.p->mID); Current.p->mID = -1; m_ParticleGrid->m_NumConnections--; } } else { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 3\n"); break; } if (Next.p == 0) // no successor { Next.ep = 0; // mark as empty successor break; } else { if (Next.p->pID == Current.p->ID) { Next.p->pID = -1; Next.ep = 1; } else if (Next.p->mID == Current.p->ID) { Next.p->mID = -1; Next.ep = -1; } else { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 4\n"); break; } } ComputeEndPointProposalDistribution(Current); AccumProb *= (m_SimpSamp.probFor(Next)); if (Next.p == 0) // no successor -> break break; energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); Current = Next; Current.ep *= -1; cnt++; m_BackupTrack.track[cnt] = Current; if (m_RandGen->GetVariate() > m_DelProb) break; } m_BackupTrack.m_Energy = energy; m_BackupTrack.m_Probability = AccumProb; m_BackupTrack.m_Length = cnt+1; } // generate new track using kind of a local tracking starting from P in the given direction, store it in m_ProposalTrack and calculate its probability void MetropolisHastingsSampler::MakeTrackProposal(EndPoint P) { EndPoint Current = P; int cnt = 0; float energy = 0; float AccumProb = 1.0; m_ProposalTrack.track[cnt++] = Current; Current.p->label = 1; for (;;) { // next candidate is already connected if ((Current.ep == 1 && Current.p->pID != -1) || (Current.ep == -1 && Current.p->mID != -1)) break; // track too long // if (cnt > 250) // break; ComputeEndPointProposalDistribution(Current); int k = m_SimpSamp.draw(m_RandGen->GetVariate()); // stop tracking proposed if (k==0) break; EndPoint Next = m_SimpSamp.objs[k]; float probability = m_SimpSamp.probFor(k); // accumulate energy and proposal distribution energy += m_EnergyComputer->ComputeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); AccumProb *= probability; // track to next endpoint Current = Next; Current.ep *= -1; Current.p->label = 1; // put label to avoid loops m_ProposalTrack.track[cnt++] = Current; } m_ProposalTrack.m_Energy = energy; m_ProposalTrack.m_Probability = AccumProb; m_ProposalTrack.m_Length = cnt; // clear labels for (int j = 0; j < m_ProposalTrack.m_Length;j++) m_ProposalTrack.track[j].p->label = 0; } // get neigbouring particles of P and calculate the according connection probabilities void MetropolisHastingsSampler::ComputeEndPointProposalDistribution(EndPoint P) { Particle *p = P.p; int ep = P.ep; float dist,dot; vnl_vector_fixed R = p->pos + (p->dir * (ep*m_ParticleLength) ); m_ParticleGrid->ComputeNeighbors(R); m_SimpSamp.clear(); m_SimpSamp.add(m_StopProb,EndPoint(0,0)); for (;;) { Particle *p2 = m_ParticleGrid->GetNextNeighbor(); if (p2 == 0) break; if (p!=p2 && p2->label == 0) { if (p2->mID == -1) { dist = (p2->pos - p2->dir * m_ParticleLength - R).squared_magnitude(); if (dist < m_DistanceThreshold) { dot = dot_product(p2->dir,p->dir) * ep; if (dot > m_CurvatureThreshold) { float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,-1); m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,-1)); } } } if (p2->pID == -1) { dist = (p2->pos + p2->dir * m_ParticleLength - R).squared_magnitude(); if (dist < m_DistanceThreshold) { dot = dot_product(p2->dir,p->dir) * (-ep); if (dot > m_CurvatureThreshold) { float en = m_EnergyComputer->ComputeInternalEnergyConnection(p,ep,p2,+1); m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,+1)); } } } } } } // return number of accepted proposals int MetropolisHastingsSampler::GetNumAcceptedProposals() { return m_AcceptedProposals; } diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h index 852fc88399..f9396f3f04 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h @@ -1,98 +1,98 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _SAMPLER #define _SAMPLER // MITK #include #include #include #include // ITK #include #include // MISC #include namespace mitk { class MitkDiffusionImaging_EXPORT MetropolisHastingsSampler { public: typedef itk::Image< float, 3 > ItkFloatImageType; typedef itk::Statistics::MersenneTwisterRandomVariateGenerator ItkRandGenType; MetropolisHastingsSampler(ParticleGrid* grid, EnergyComputer* enComp, ItkRandGenType* randGen, float curvThres); void SetTemperature(float val); void MakeProposal(); int GetNumAcceptedProposals(); void SetProbabilities(float birth, float death, float shift, float optShift, float connect); protected: // connection proposal related methods void ImplementTrack(Track& T); void RemoveAndSaveTrack(EndPoint P); void MakeTrackProposal(EndPoint P); void ComputeEndPointProposalDistribution(EndPoint P); // generate random vectors - vnl_vector_fixed DistortVector(float sigma, vnl_vector_fixed& vec); + void DistortVector(float sigma, vnl_vector_fixed& vec); vnl_vector_fixed GetRandomDirection(); ItkRandGenType* m_RandGen; // random generator Track m_ProposalTrack; // stores proposal track Track m_BackupTrack; // stores track removed for new proposal traCK SimpSamp m_SimpSamp; // neighbouring particles and their probabilities for the local tracking float m_InTemp; // simulated annealing temperature float m_ExTemp; // simulated annealing temperature float m_Density; float m_BirthProb; // probability for particle birth float m_DeathProb; // probability for particle death float m_ShiftProb; // probability for particle shift float m_OptShiftProb; // probability for optimal particle shift float m_ConnectionProb; // probability for particle connection proposal float m_Sigma; float m_Gamma; float m_Z; float m_DistanceThreshold; // threshold for maximum distance between connected particles float m_CurvatureThreshold; // threshold for maximum angle between connected particles float m_TractProb; float m_StopProb; float m_DelProb; float m_ParticleLength; float m_ChempotParticle; ParticleGrid* m_ParticleGrid; EnergyComputer* m_EnergyComputer; int m_AcceptedProposals; }; } #endif