diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp index 7c8cbed799..6412ef569a 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.cpp @@ -1,452 +1,423 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkEnergyComputer.h" using namespace mitk; -EnergyComputer::EnergyComputer(MTRand* rgen, ItkQBallImgType* data, const int *dsz, double *cellsize, SphereInterpolator *sp, ParticleGrid *pcon, float *spimg, int spmult, vnl_matrix_fixed rotMatrix) +EnergyComputer::EnergyComputer(MTRand* rgen, ItkQBallImgType* qballImage, SphereInterpolator *sp, ParticleGrid *particleGrid, float *mask, vnl_matrix_fixed rotMatrix) { - mtrand = rgen; + m_ParticleGrid = particleGrid; + m_RandGen = rgen; m_RotationMatrix = rotMatrix; - m_QBallImageData = data; - m_QBallImageSize = dsz; + m_ImageData = qballImage; m_SphereInterpolator = sp; + m_MaskImageData = mask; - m_MaskImageData = spimg; + m_Size[0] = m_ImageData->GetLargestPossibleRegion().GetSize()[0]; + m_Size[1] = m_ImageData->GetLargestPossibleRegion().GetSize()[1]; + m_Size[2] = m_ImageData->GetLargestPossibleRegion().GetSize()[2]; - nip = m_QBallImageSize[0]; + m_Spacing[0] = m_ImageData->GetSpacing()[0]; + m_Spacing[1] = m_ImageData->GetSpacing()[1]; + m_Spacing[2] = m_ImageData->GetSpacing()[2]; - - w = m_QBallImageSize[1]; - h = m_QBallImageSize[2]; - d = m_QBallImageSize[3]; - - voxsize_w = cellsize[0]; - voxsize_h = cellsize[1]; - voxsize_d = cellsize[2]; - - - w_sp = m_QBallImageSize[1]*spmult; - h_sp = m_QBallImageSize[2]*spmult; - d_sp = m_QBallImageSize[3]*spmult; - - voxsize_sp_w = cellsize[0]/spmult; - voxsize_sp_h = cellsize[1]/spmult; - voxsize_sp_d = cellsize[2]/spmult; - - - fprintf(stderr,"Data size (voxels) : %i x %i x %i\n",w,h,d); - fprintf(stderr,"voxel size: %f x %f x %f\n",voxsize_w,voxsize_h,voxsize_d); - fprintf(stderr,"mask_oversamp_mult: %i\n",spmult); + nip = QBALL_ODFSIZE; if (nip != sp->nverts) - { fprintf(stderr,"EnergyComputer: error during init: data does not match with interpolation scheme\n"); - } - - m_ParticleGrid = pcon; - - - int totsz = w_sp*h_sp*d_sp; - cumulspatprob = (float*) malloc(sizeof(float) * totsz); - spatidx = (int*) malloc(sizeof(int) * totsz); - if (cumulspatprob == 0 || spatidx == 0) - { - fprintf(stderr,"EnergyCOmputerBase: out of memory!\n"); - return ; - } + int totsz = m_Size[0]*m_Size[1]*m_Size[2]; + cumulspatprob.resize(totsz, 0.0); + spatidx.resize(totsz, 0); - scnt = 0; + m_NumActiveVoxels = 0; cumulspatprob[0] = 0; - for (int x = 1; x < w_sp;x++) - for (int y = 1; y < h_sp;y++) - for (int z = 1; z < d_sp;z++) + for (int x = 1; x < m_Size[0];x++) + for (int y = 1; y < m_Size[1];y++) + for (int z = 1; z < m_Size[2];z++) { - int idx = x+(y+z*h_sp)*w_sp; + int idx = x+(y+z*m_Size[1])*m_Size[0]; if (m_MaskImageData[idx] > 0.5) { - cumulspatprob[scnt+1] = cumulspatprob[scnt] + m_MaskImageData[idx]; - spatidx[scnt] = idx; - scnt++; + cumulspatprob[m_NumActiveVoxels+1] = cumulspatprob[m_NumActiveVoxels] + m_MaskImageData[idx]; + spatidx[m_NumActiveVoxels] = idx; + m_NumActiveVoxels++; } } - for (int k = 0; k < scnt; k++) - { - cumulspatprob[k] /= cumulspatprob[scnt]; - } + for (int k = 0; k < m_NumActiveVoxels; k++) + cumulspatprob[k] /= cumulspatprob[m_NumActiveVoxels]; - fprintf(stderr,"#active voxels: %i (in mask units) \n",scnt); + fprintf(stderr,"EnergyComputer: %i active voxels found\n",m_NumActiveVoxels); } void EnergyComputer::setParameters(float pwei,float pwid,float chempot_connection, float length,float curv_hardthres, float inex_balance, float chempot2, float meanv) { this->chempot2 = chempot2; meanval_sq = meanv; eigencon_energy = chempot_connection; eigen_energy = 0; particle_weight = pwei; float bal = 1/(1+exp(-inex_balance)); ex_strength = 2*bal; // cleanup (todo) in_strength = 2*(1-bal)/length/length; // cleanup (todo) // in_strength = 0.64/length/length; // cleanup (todo) particle_length_sq = length*length; curv_hard = curv_hardthres; float sigma_s = pwid; gamma_s = 1/(sigma_s*sigma_s); gamma_reg_s =1/(length*length/4); } void EnergyComputer::drawSpatPosition(vnl_vector_fixed& R) { - float r = mtrand->frand(); + float r = m_RandGen->frand(); int j; int rl = 1; - int rh = scnt; + int rh = m_NumActiveVoxels; while(rh != rl) { j = rl + (rh-rl)/2; if (r < cumulspatprob[j]) { rh = j; continue; } if (r > cumulspatprob[j]) { rl = j+1; continue; } break; } - R[0] = voxsize_sp_w*((float)(spatidx[rh-1] % w_sp) + mtrand->frand()); - R[1] = voxsize_sp_h*((float)((spatidx[rh-1]/w_sp) % h_sp) + mtrand->frand()); - R[2] = voxsize_sp_d*((float)(spatidx[rh-1]/(w_sp*h_sp)) + mtrand->frand()); + R[0] = m_Spacing[0]*((float)(spatidx[rh-1] % m_Size[0]) + m_RandGen->frand()); + R[1] = m_Spacing[1]*((float)((spatidx[rh-1]/m_Size[0]) % m_Size[1]) + m_RandGen->frand()); + R[2] = m_Spacing[2]*((float)(spatidx[rh-1]/(m_Size[0]*m_Size[1])) + m_RandGen->frand()); } float EnergyComputer::SpatProb(vnl_vector_fixed R) { - int rx = int(R[0]/voxsize_sp_w); - int ry = int(R[1]/voxsize_sp_h); - int rz = int(R[2]/voxsize_sp_d); - if (rx >= 0 && rx < w_sp && ry >= 0 && ry < h_sp && rz >= 0 && rz < d_sp){ - return m_MaskImageData[rx + w_sp* (ry + h_sp*rz)]; + int rx = int(R[0]/m_Spacing[0]); + int ry = int(R[1]/m_Spacing[1]); + int rz = int(R[2]/m_Spacing[2]); + if (rx >= 0 && rx < m_Size[0] && ry >= 0 && ry < m_Size[1] && rz >= 0 && rz < m_Size[2]){ + return m_MaskImageData[rx + m_Size[0]* (ry + m_Size[1]*rz)]; } else return 0; } float EnergyComputer::evaluateODF(vnl_vector_fixed &R, vnl_vector_fixed &N, float &len) { const int CU = 10; vnl_vector_fixed Rs; float Dn = 0; - int xint,yint,zint,spatindex; + int xint,yint,zint; vnl_vector_fixed n; n[0] = N[0]; n[1] = N[1]; n[2] = N[2]; n = m_RotationMatrix*n; m_SphereInterpolator->getInterpolation(n); for (int i=-CU; i <= CU;i++) { Rs = R + (N * len) * ((float)i/CU); - float Rx = Rs[0]/voxsize_w-0.5; - float Ry = Rs[1]/voxsize_h-0.5; - float Rz = Rs[2]/voxsize_d-0.5; + float Rx = Rs[0]/m_Spacing[0]-0.5; + float Ry = Rs[1]/m_Spacing[1]-0.5; + float Rz = Rs[2]/m_Spacing[2]-0.5; xint = int(floor(Rx)); yint = int(floor(Ry)); zint = int(floor(Rz)); - if (xint >= 0 && xint < w-1 && yint >= 0 && yint < h-1 && zint >= 0 && zint < d-1) + if (xint >= 0 && xint < m_Size[0]-1 && yint >= 0 && yint < m_Size[1]-1 && zint >= 0 && zint < m_Size[2]-1) { float xfrac = Rx-xint; float yfrac = Ry-yint; float zfrac = Rz-zint; ItkQBallImgType::IndexType index; float weight; weight = (1-xfrac)*(1-yfrac)*(1-zfrac); index[0] = xint; index[1] = yint; index[2] = zint; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(1-yfrac)*(1-zfrac); index[0] = xint+1; index[1] = yint; index[2] = zint; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(yfrac)*(1-zfrac); index[0] = xint; index[1] = yint+1; index[2] = zint; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(1-yfrac)*(zfrac); index[0] = xint; index[1] = yint; index[2] = zint+1; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(yfrac)*(1-zfrac); index[0] = xint+1; index[1] = yint+1; index[2] = zint; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (1-xfrac)*(yfrac)*(zfrac); index[0] = xint; index[1] = yint+1; index[2] = zint+1; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(1-yfrac)*(zfrac); index[0] = xint+1; index[1] = yint; index[2] = zint+1; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; weight = (xfrac)*(yfrac)*(zfrac); index[0] = xint+1; index[1] = yint+1; index[2] = zint+1; - Dn += (m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + - m_QBallImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; + Dn += (m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[0]-1]*m_SphereInterpolator->interpw[0] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[1]-1]*m_SphereInterpolator->interpw[1] + + m_ImageData->GetPixel(index)[m_SphereInterpolator->idx[2]-1]* m_SphereInterpolator->interpw[2])*weight; } } Dn *= 1.0/(2*CU+1); return Dn; } float EnergyComputer::computeExternalEnergy(vnl_vector_fixed &R, vnl_vector_fixed &N, float &len, Particle *dp) { float m = SpatProb(R); if (m == 0) return -INFINITY; float Dn = evaluateODF(R,N,len); float Sn = 0; float Pn = 0; m_ParticleGrid->ComputeNeighbors(R); for (;;) { Particle *p = m_ParticleGrid->GetNextNeighbor(); if (p == 0) break; if (dp != p) { float dot = fabs(dot_product(N,p->N)); float bw = mbesseli0(dot); float dpos = (p->R-R).squared_magnitude(); float w = mexp(dpos*gamma_s); Sn += w*(bw+chempot2); w = mexp(dpos*gamma_reg_s); Pn += w*bw; } } float energy = 0; energy += 2*(Dn/particle_weight-Sn) - (mbesseli0(1.0)+chempot2); return energy*ex_strength; } float EnergyComputer::computeInternalEnergy(Particle *dp) { float energy = eigen_energy; if (dp->pID != -1) energy += computeInternalEnergyConnection(dp,+1); if (dp->mID != -1) energy += computeInternalEnergyConnection(dp,-1); return energy; } float EnergyComputer::computeInternalEnergyConnection(Particle *p1,int ep1) { Particle *p2 = 0; int ep2; if (ep1 == 1) p2 = m_ParticleGrid->GetParticle(p1->pID); else p2 = m_ParticleGrid->GetParticle(p1->mID); if (p2->mID == p1->ID) ep2 = -1; else if (p2->pID == p1->ID) ep2 = 1; else fprintf(stderr,"EnergyComputer_connec: Connections are inconsistent!\n"); if (p2 == 0) fprintf(stderr,"bug2"); return computeInternalEnergyConnection(p1,ep1,p2,ep2); } float EnergyComputer::computeInternalEnergyConnection(Particle *p1,int ep1, Particle *p2, int ep2) { if ((dot_product(p1->N,p2->N))*ep1*ep2 > -curv_hard) return -INFINITY; vnl_vector_fixed R1 = p1->R + (p1->N * (p1->len * ep1)); vnl_vector_fixed R2 = p2->R + (p2->N * (p2->len * ep2)); if ((R1-R2).squared_magnitude() > particle_length_sq) return -INFINITY; vnl_vector_fixed R = (p2->R + p1->R); R *= 0.5; if (SpatProb(R) == 0) return -INFINITY; float norm1 = (R1-R).squared_magnitude(); float norm2 = (R2-R).squared_magnitude(); float energy = (eigencon_energy-norm1-norm2)*in_strength; return energy; } float EnergyComputer::mbesseli0(float x) { // BESSEL_APPROXCOEFF[0] = -0.1714; // BESSEL_APPROXCOEFF[1] = 0.5332; // BESSEL_APPROXCOEFF[2] = -1.4889; // BESSEL_APPROXCOEFF[3] = 2.0389; float y = x*x; float erg = -0.1714; erg += y*0.5332; erg += y*y*-1.4889; erg += y*y*y*2.0389; return erg; } float EnergyComputer::mexp(float x) { return((x>=7.0) ? 0 : ((x>=5.0) ? (-0.0029*x+0.0213) : ((x>=3.0) ? (-0.0215*x+0.1144) : ((x>=2.0) ? (-0.0855*x+0.3064) : ((x>=1.0) ? (-0.2325*x+0.6004) : ((x>=0.5) ? (-0.4773*x+0.8452) : ((x>=0.0) ? (-0.7869*x+1.0000) : 1 ))))))); // return exp(-x); } //ComputeFiberCorrelation() //{ // float bD = 15; // vnl_matrix_fixed bDir = // *itk::PointShell >::DistributePointShell(); // const int N = QBALL_ODFSIZE; // vnl_matrix_fixed temp = bDir.transpose(); // vnl_matrix_fixed C = temp*bDir; // vnl_matrix_fixed Q = C; // vnl_vector_fixed mean; // for(int i=0; i repMean; // for (int i=0; i P = Q*Q; // std::vector pointer; // pointer.reserve(N*N); // double * start = C.data_block(); // double * end = start + N*N; // for (double * iter = start; iter != end; ++iter) // { // pointer.push_back(iter); // } // std::sort(pointer.begin(), pointer.end(), LessDereference()); // vnl_vector_fixed alpha; // vnl_vector_fixed beta; // for (int i=0; im_Meanval_sq = (sum*sum)/N; // vnl_vector_fixed alpha_0; // vnl_vector_fixed alpha_2; // vnl_vector_fixed alpha_4; // vnl_vector_fixed alpha_6; // for(int i=0; i T; // T.set_column(0,alpha_0); // T.set_column(1,alpha_2); // T.set_column(2,alpha_4); // T.set_column(3,alpha_6); // vnl_vector_fixed coeff = vnl_matrix_inverse(T).pinverse()*beta; // MITK_INFO << "itkGibbsTrackingFilter: Bessel oefficients: " << coeff; // BESSEL_APPROXCOEFF = new float[4]; // BESSEL_APPROXCOEFF[0] = coeff(0); // BESSEL_APPROXCOEFF[1] = coeff(1); // BESSEL_APPROXCOEFF[2] = coeff(2); // BESSEL_APPROXCOEFF[3] = coeff(3); // BESSEL_APPROXCOEFF[0] = -0.1714; // BESSEL_APPROXCOEFF[1] = 0.5332; // BESSEL_APPROXCOEFF[2] = -1.4889; // BESSEL_APPROXCOEFF[3] = 2.0389; //} diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h index 8ef31620c9..fc143f8974 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkEnergyComputer.h @@ -1,100 +1,104 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _ENCOMP #define _ENCOMP #include #include #include #include #include using namespace mitk; class MitkDiffusionImaging_EXPORT EnergyComputer { public: - typedef itk::Vector OdfVectorType; - typedef itk::Image ItkQBallImgType; + typedef itk::Vector OdfVectorType; + typedef itk::Image ItkQBallImgType; + typedef itk::Image ItkFloatImageType; - float eigen_energy; - vnl_matrix_fixed m_RotationMatrix; - ItkQBallImgType* m_QBallImageData; - const int *m_QBallImageSize; - SphereInterpolator *m_SphereInterpolator; - ParticleGrid *m_ParticleGrid; + vnl_matrix_fixed m_RotationMatrix; + ItkQBallImgType* m_ImageData; + vnl_vector_fixed m_Size; + vnl_vector_fixed m_Spacing; + SphereInterpolator* m_SphereInterpolator; + ParticleGrid* m_ParticleGrid; + + std::vector< float > cumulspatprob; + std::vector< int > spatidx; + + float *m_MaskImageData; + int m_NumActiveVoxels; int w,h,d; float voxsize_w; float voxsize_h; float voxsize_d; int w_sp,h_sp,d_sp; float voxsize_sp_w; float voxsize_sp_h; float voxsize_sp_d; - MTRand* mtrand; + MTRand* m_RandGen; + float eigen_energy; int nip; // number of data vertices on sphere - float *m_MaskImageData; - float *cumulspatprob; - int *spatidx; - int scnt; float eigencon_energy; float chempot2; float meanval_sq; float gamma_s; float gamma_reg_s; float particle_weight; float ex_strength; float in_strength; float particle_length_sq; float curv_hard; - EnergyComputer(MTRand* rgen, ItkQBallImgType* data, const int *dsz, double *cellsize, SphereInterpolator *sp, ParticleGrid *pcon, float *spimg, int spmult, vnl_matrix_fixed rotMatrix); + EnergyComputer(MTRand* rgen, ItkQBallImgType* qballImage, SphereInterpolator *sp, ParticleGrid *pcon, float *mask, vnl_matrix_fixed rotMatrix); void setParameters(float pwei,float pwid,float chempot_connection, float length,float curv_hardthres, float inex_balance, float chempot2, float meanv); void drawSpatPosition(vnl_vector_fixed& R); float SpatProb(vnl_vector_fixed R); float evaluateODF(vnl_vector_fixed &R, vnl_vector_fixed &N, float &len); float computeExternalEnergy(vnl_vector_fixed &R, vnl_vector_fixed &N, float &len, Particle *dp); float computeInternalEnergy(Particle *dp); float computeInternalEnergyConnection(Particle *p1,int ep1); float computeInternalEnergyConnection(Particle *p1,int ep1, Particle *p2, int ep2); float mbesseli0(float x); float mexp(float x); }; #endif diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h index 75c2210434..86af2a78d7 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h @@ -1,100 +1,96 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _SAMPLER #define _SAMPLER // MITK #include #include #include #include #include // ITK #include // MISC #include namespace mitk { class MitkDiffusionImaging_EXPORT MetropolisHastingsSampler { public: typedef itk::Image< float, 3 > ItkFloatImageType; MetropolisHastingsSampler(ParticleGrid* grid, EnergyComputer* enComp, MTRand* randGen, float curvThres); void SetTemperature(float val); void MakeProposal(); int GetNumAcceptedProposals(); protected: // connection proposal related methods void ImplementTrack(Track& T); void RemoveAndSaveTrack(EndPoint P); void MakeTrackProposal(EndPoint P); void ComputeEndPointProposalDistribution(EndPoint P); // generate random vectors vnl_vector_fixed DistortVector(float sigma, vnl_vector_fixed& vec); vnl_vector_fixed GetRandomDirection(); MTRand* m_RandGen; Track m_ProposalTrack; Track m_BackupTrack; SimpSamp m_SimpSamp; float m_InTemp; float m_ExTemp; float m_Density; float m_BirthProb; float m_DeathProb; float m_ShiftProb; float m_OptShiftProb; float m_ConnectionProb; float m_Sigma; float m_Gamma; float m_Z; float m_DistanceThreshold; float m_CurvatureThreshold; float m_TractProb; float m_StopProb; float m_DelProb; float m_ParticleLength; float m_ChempotParticle; ParticleGrid* m_ParticleGrid; - const int* datasz; EnergyComputer* m_EnergyComputer; - float width; - float height; - float depth; int m_AcceptedProposals; }; } #endif diff --git a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp index ee16a181e1..be493149bb 100644 --- a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp +++ b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp @@ -1,397 +1,397 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "itkGibbsTrackingFilter.h" // MITK #include #include #include #include #include #include #include #include // ITK #include #pragma GCC visibility push(default) #include #pragma GCC visibility pop // MISC #include #include #include namespace itk{ template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::GibbsTrackingFilter(): m_TempStart(0.1), m_TempEnd(0.001), m_NumIt(500000), m_ParticleWeight(0), m_ParticleWidth(0), m_ParticleLength(0), m_ChempotConnection(10), m_InexBalance(0), m_Chempot2(0.2), m_FiberLength(10), m_AbortTracking(false), m_NumConnections(0), m_NumParticles(0), m_NumAcceptedFibers(0), m_CurrentStep(0), m_BuildFibers(false), m_Steps(10), m_Memory(0), m_ProposalAcceptance(0), m_CurvatureHardThreshold(0.7), m_Meanval_sq(0.0), m_DuplicateImage(true) { } template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::~GibbsTrackingFilter() { } // fill output fiber bundle datastructure template< class ItkQBallImageType > typename GibbsTrackingFilter< ItkQBallImageType >::FiberPolyDataType GibbsTrackingFilter< ItkQBallImageType >::GetFiberBundle() { if (!m_AbortTracking) { m_BuildFibers = true; while (m_BuildFibers){} } return m_FiberPolyData; } template< class ItkQBallImageType > bool GibbsTrackingFilter< ItkQBallImageType > ::EstimateParticleWeight() { MITK_INFO << "itkGibbsTrackingFilter: estimating particle weight"; typedef itk::DiffusionQballGeneralizedFaImageFilter GfaFilterType; GfaFilterType::Pointer gfaFilter = GfaFilterType::New(); gfaFilter->SetInput(m_QBallImage); gfaFilter->SetComputationMethod(GfaFilterType::GFA_STANDARD); gfaFilter->Update(); ItkFloatImageType::Pointer gfaImage = gfaFilter->GetOutput(); float samplingStart = 1.0; float samplingStop = 0.66; // GFA iterator typedef ImageRegionIterator< ItkFloatImageType > GfaIteratorType; GfaIteratorType gfaIt(gfaImage, gfaImage->GetLargestPossibleRegion() ); // Mask iterator typedef ImageRegionConstIterator< ItkFloatImageType > MaskIteratorType; MaskIteratorType mit(m_MaskImage, m_MaskImage->GetLargestPossibleRegion() ); // Input iterator typedef ImageRegionConstIterator< ItkQBallImageType > InputIteratorType; InputIteratorType it(m_QBallImage, m_QBallImage->GetLargestPossibleRegion() ); float upper = 0; int count = 0; for(float thr=samplingStart; thr>samplingStop; thr-=0.01) { it.GoToBegin(); mit.GoToBegin(); gfaIt.GoToBegin(); while( !gfaIt.IsAtEnd() ) { if(gfaIt.Get()>thr && mit.Get()>0) { itk::OrientationDistributionFunction odf(it.Get().GetDataPointer()); upper += odf.GetMaxValue()-odf.GetMeanValue(); ++count; } ++it; ++mit; ++gfaIt; } } if (count>0) upper /= count; else return false; m_ParticleWeight = upper/6; return true; } // perform global tracking template< class ItkQBallImageType > void GibbsTrackingFilter< ItkQBallImageType >::GenerateData() { if (m_QBallImage.IsNull() && m_TensorImage.IsNotNull()) { TensorImageToQBallImageFilter::Pointer filter = TensorImageToQBallImageFilter::New(); filter->SetInput( m_TensorImage ); filter->Update(); m_QBallImage = filter->GetOutput(); } // image sizes and spacing int imgSize[4] = { QBALL_ODFSIZE, m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(0), m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(1), m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(2)}; double spacing[3] = {m_QBallImage->GetSpacing().GetElement(0),m_QBallImage->GetSpacing().GetElement(1),m_QBallImage->GetSpacing().GetElement(2)}; // make sure image has enough slices if (imgSize[1]<3 || imgSize[2]<3 || imgSize[3]<3) { MITK_INFO << "itkGibbsTrackingFilter: image size < 3 not supported"; m_AbortTracking = true; } // calculate rotation matrix vnl_matrix temp = m_QBallImage->GetDirection().GetVnlMatrix(); vnl_matrix directionMatrix; directionMatrix.set_size(3,3); vnl_copy(temp, directionMatrix); vnl_vector_fixed d0 = directionMatrix.get_column(0); d0.normalize(); vnl_vector_fixed d1 = directionMatrix.get_column(1); d1.normalize(); vnl_vector_fixed d2 = directionMatrix.get_column(2); d2.normalize(); directionMatrix.set_column(0, d0); directionMatrix.set_column(1, d1); directionMatrix.set_column(2, d2); vnl_matrix_fixed I = directionMatrix*directionMatrix.transpose(); if(!I.is_identity(mitk::eps)){ MITK_INFO << "itkGibbsTrackingFilter: image direction is not a rotation matrix. Tracking not possible!"; m_AbortTracking = true; } // generate local working copy of QBall image typename ItkQBallImageType::Pointer qballImage; if (m_DuplicateImage) { typedef itk::ImageDuplicator< ItkQBallImageType > DuplicateFilterType; typename DuplicateFilterType::Pointer duplicator = DuplicateFilterType::New(); duplicator->SetInputImage( m_QBallImage ); duplicator->Update(); qballImage = duplicator->GetOutput(); } else { qballImage = m_QBallImage; } // perform mean subtraction on odfs typedef ImageRegionIterator< ItkQBallImageType > InputIteratorType; InputIteratorType it(qballImage, qballImage->GetLargestPossibleRegion() ); it.GoToBegin(); while (!it.IsAtEnd()) { itk::OrientationDistributionFunction odf(it.Get().GetDataPointer()); float mean = odf.GetMeanValue(); odf -= mean; it.Set(odf.GetDataPointer()); ++it; } // mask image int maskImageSize[3]; float *mask; if(m_MaskImage.IsNotNull()) { mask = (float*) m_MaskImage->GetBufferPointer(); maskImageSize[0] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(0); maskImageSize[1] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(1); maskImageSize[2] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(2); } else { mask = 0; maskImageSize[0] = imgSize[1]; maskImageSize[1] = imgSize[2]; maskImageSize[2] = imgSize[3]; } int mask_oversamp_mult = maskImageSize[0]/imgSize[1]; // get paramters float minSpacing; if(spacing[0]m_NumIt) { MITK_INFO << "itkGibbsTrackingFilter: not enough iterations!"; m_AbortTracking = true; } if (m_CurvatureHardThreshold < mitk::eps) m_CurvatureHardThreshold = 0; unsigned long singleIts = (unsigned long)((1.0*m_NumIt) / (1.0*m_Steps)); MTRand randGen(1); srand(1); SphereInterpolator* interpolator = LoadSphereInterpolator(); MITK_INFO << "itkGibbsTrackingFilter: setting up MH-sampler"; ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength); - EnergyComputer encomp(&randGen, qballImage, imgSize,spacing,interpolator,particleGrid,mask,mask_oversamp_mult, directionMatrix); + EnergyComputer encomp(&randGen, qballImage, interpolator, particleGrid, mask, directionMatrix); encomp.setParameters(m_ParticleWeight,m_ParticleWidth,m_ChempotConnection*m_ParticleLength*m_ParticleLength,m_ParticleLength,m_CurvatureHardThreshold,m_InexBalance,m_Chempot2, m_Meanval_sq); MetropolisHastingsSampler* sampler = new MetropolisHastingsSampler(particleGrid, &encomp, &randGen, m_CurvatureHardThreshold); // main loop m_NumAcceptedFibers = 0; for( int m_CurrentStep = 1; m_CurrentStep <= m_Steps; m_CurrentStep++ ) { m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/m_NumIt; m_NumParticles = particleGrid->m_NumParticles; m_NumConnections = particleGrid->m_NumConnections; MITK_INFO << "itkGibbsTrackingFilter: proposal acceptance: " << 100*m_ProposalAcceptance << "%"; MITK_INFO << "itkGibbsTrackingFilter: particles: " << m_NumParticles; MITK_INFO << "itkGibbsTrackingFilter: connections: " << m_NumConnections; MITK_INFO << "itkGibbsTrackingFilter: progress: " << 100*(float)m_CurrentStep/m_Steps << "%"; float temperature = m_TempStart * exp(alpha*(((1.0)*m_CurrentStep)/((1.0)*m_Steps))); sampler->SetTemperature(temperature); for (unsigned long i=0; iMakeProposal(); if (m_BuildFibers) { m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/m_NumIt; m_NumParticles = particleGrid->m_NumParticles; m_NumConnections = particleGrid->m_NumConnections; FiberBuilder fiberBuilder(particleGrid, m_MaskImage); m_FiberPolyData = fiberBuilder.iterate(m_FiberLength); m_NumAcceptedFibers = m_FiberPolyData->GetNumberOfLines(); m_BuildFibers = false; } } } FiberBuilder fiberBuilder(particleGrid, m_MaskImage); m_FiberPolyData = fiberBuilder.iterate(m_FiberLength); m_NumAcceptedFibers = m_FiberPolyData->GetNumberOfLines(); delete interpolator; delete sampler; delete particleGrid; m_AbortTracking = true; m_BuildFibers = false; MITK_INFO << "itkGibbsTrackingFilter: done generate data"; } template< class ItkQBallImageType > SphereInterpolator* GibbsTrackingFilter< ItkQBallImageType >::LoadSphereInterpolator() { QString applicationDir = QCoreApplication::applicationDirPath(); applicationDir.append("/"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); applicationDir.append("../"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); applicationDir.append("../../"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); std::string lutPath = mitk::StandardFileLocations::GetInstance()->FindFile("FiberTrackingLUTBaryCoords.bin"); ifstream BaryCoords; BaryCoords.open(lutPath.c_str(), ios::in | ios::binary); float* coords; if (BaryCoords.is_open()) { float tmp; coords = new float [1630818]; BaryCoords.seekg (0, ios::beg); for (int i=0; i<1630818; i++) { BaryCoords.read((char *)&tmp, sizeof(tmp)); coords[i] = tmp; } BaryCoords.close(); } else { MITK_INFO << "itkGibbsTrackingFilter: unable to open barycoords file"; m_AbortTracking = true; } ifstream Indices; lutPath = mitk::StandardFileLocations::GetInstance()->FindFile("FiberTrackingLUTIndices.bin"); Indices.open(lutPath.c_str(), ios::in | ios::binary); int* ind; if (Indices.is_open()) { int tmp; ind = new int [1630818]; Indices.seekg (0, ios::beg); for (int i=0; i<1630818; i++) { Indices.read((char *)&tmp, 4); ind[i] = tmp; } Indices.close(); } else { MITK_INFO << "itkGibbsTrackingFilter: unable to open indices file"; m_AbortTracking = true; } // initialize sphere interpolator with lookuptables return new SphereInterpolator(coords, ind, QBALL_ODFSIZE, 301, 0.5); } }