diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp index dbd159e376..459633b559 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.cpp @@ -1,580 +1,453 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkMetropolisHastingsSampler.h" using namespace mitk; -MetropolisHastingsSampler::MetropolisHastingsSampler(ParticleGrid* grid, MTRand* randGen) +MetropolisHastingsSampler::MetropolisHastingsSampler(ParticleGrid* grid, EnergyComputer* enComp, MTRand* randGen, float curvThres) : m_AcceptedProposals(0) + , m_ExTemp(0.01) + , m_BirthProb(0.25) + , m_DeathProb(0.05) + , m_ShiftProb(0.15) + , m_OptShiftProb(0.1) + , m_ConnectionProb(0.45) + , m_TractProb(0.5) + , m_DelProb(0.1) + , m_ChempotParticle(0.0) { - mtrand = randGen; + m_RandGen = randGen; m_ParticleGrid = grid; - m_Iterations = 0; - m_AcceptedProposals = 0; - externalEnergy = 0; - internalEnergy = 0; -} - - -void MetropolisHastingsSampler::SetEnergyComputer(EnergyComputer *e) -{ - enc = e; -} - -void MetropolisHastingsSampler::Iterate(float* acceptance, unsigned long* numCon, unsigned long* numPart, bool *abort) -{ - m_AcceptedProposals = 0; - for (int it = 0; it < m_Iterations;it++) - { - if (*abort) - break; + m_EnergyComputer = enComp; - IterateOneStep(); + m_ParticleLength = m_ParticleGrid->m_ParticleLength; + m_DistanceThreshold = m_ParticleLength*m_ParticleLength; + m_Sigma = m_ParticleLength/8.0; + m_Gamma = 1/(m_Sigma*m_Sigma*2); + m_Z = pow(2*M_PI*m_Sigma,3.0/2.0)*(M_PI*m_Sigma/m_ParticleLength); - *numCon = m_ParticleGrid->m_NumConnections; - *numPart = m_ParticleGrid->m_NumParticles; - } - *acceptance = (float)m_AcceptedProposals/m_Iterations; -} + m_CurvatureThreshold = curvThres; -void MetropolisHastingsSampler::SetParameters(float Temp, int numit, float plen, float curv_hardthres, float chempot_particle) -{ - m_Iterations = numit; - - m_BirthProb = 0.25; - m_DeathProb = 0.05; - m_ShiftProb = 0.15; - m_OptShiftProb = 0.1; - m_ConnectionProb = 0.45; - - m_ChempotParticle = chempot_particle; - - float sum = m_BirthProb+m_DeathProb+m_ShiftProb+m_OptShiftProb+m_ConnectionProb; - m_BirthProb /= sum; m_DeathProb /= sum; m_ShiftProb /= sum; m_OptShiftProb /= sum; - - m_InTemp = Temp; - m_ExTemp = 0.01; - m_Density = exp(-chempot_particle/m_InTemp); - - len_def = plen; - len_sig = 0.0; - cap_def = 1.0; - cap_sig = 0.0; - - // shift proposal - sigma_g = len_def/8.0; - gamma_g = 1/(sigma_g*sigma_g*2); - Z_g = pow(2*M_PI*sigma_g,3.0/2.0)*(M_PI*sigma_g/len_def); - - // conn proposal - dthres = len_def; - nthres = curv_hardthres; - T_prop = 0.5; - dthres *= dthres; - stopprobability = exp(-1/T_prop); - del_prob = 0.1; + // ??? + m_StopProb = exp(-1/m_TractProb); } -void MetropolisHastingsSampler::SetTemperature(float temp) +void MetropolisHastingsSampler::SetTemperature(float val) { - m_InTemp = temp; + m_InTemp = val; m_Density = exp(-m_ChempotParticle/m_InTemp); } -vnl_vector_fixed MetropolisHastingsSampler::distortn(float sigma, vnl_vector_fixed& vec) +vnl_vector_fixed MetropolisHastingsSampler::DistortVector(float sigma, vnl_vector_fixed& vec) { - vec[0] += sigma*mtrand->frandn(); - vec[1] += sigma*mtrand->frandn(); - vec[2] += sigma*mtrand->frandn(); + vec[0] += sigma*m_RandGen->frandn(); + vec[1] += sigma*m_RandGen->frandn(); + vec[2] += sigma*m_RandGen->frandn(); } -vnl_vector_fixed MetropolisHastingsSampler::rand_sphere() +vnl_vector_fixed MetropolisHastingsSampler::GetRandomDirection() { vnl_vector_fixed vec; - vec[0] += mtrand->frandn(); - vec[1] += mtrand->frandn(); - vec[2] += mtrand->frandn(); + vec[0] += m_RandGen->frandn(); + vec[1] += m_RandGen->frandn(); + vec[2] += m_RandGen->frandn(); vec.normalize(); return vec; } -void MetropolisHastingsSampler::IterateOneStep() +void MetropolisHastingsSampler::MakeProposal() { - float randnum = mtrand->frand(); - //randnum = 0; + float randnum = m_RandGen->frand(); - /////////////////////////////////////////////////////////////// - //////// Birth Proposal - /////////////////////////////////////////////////////////////// + // Birth Proposal if (randnum < m_BirthProb) { - -#ifdef TIMING - tic(&birthproposal_time); - birthstats.propose(); -#endif - vnl_vector_fixed R; - enc->drawSpatPosition(R); - - //fprintf(stderr,"drawn: %f, %f, %f\n",R[0],R[1],R[2]); - //R.setXYZ(20.5*3, 35.5*3, 1.5*3); - - vnl_vector_fixed N = rand_sphere(); - //N.setXYZ(1,0,0); - float len = len_def;// + len_sig*(mtrand->frand()-0.5); + m_EnergyComputer->drawSpatPosition(R); + vnl_vector_fixed N = GetRandomDirection(); + float len = m_ParticleLength; Particle prop; prop.R = R; prop.N = N; prop.len = len; - float prob = m_Density * m_DeathProb /((m_BirthProb)*(m_ParticleGrid->m_NumParticles+1)); - float ex_energy = enc->computeExternalEnergy(R,N,len,0); - float in_energy = enc->computeInternalEnergy(&prop); - + float ex_energy = m_EnergyComputer->computeExternalEnergy(R,N,len,0); + float in_energy = m_EnergyComputer->computeInternalEnergy(&prop); prob *= exp((in_energy/m_InTemp+ex_energy/m_ExTemp)) ; - if (prob > 1 || mtrand->frand() < prob) + if (prob > 1 || m_RandGen->frand() < prob) { Particle *p = m_ParticleGrid->NewParticle(R); if (p!=0) { p->R = R; p->N = N; p->len = len; -#ifdef TIMING - birthstats.accepted(); -#endif m_AcceptedProposals++; } } - -#ifdef TIMING - toc(&birthproposal_time); -#endif } - /////////////////////////////////////////////////////////////// - //////// Death Proposal - /////////////////////////////////////////////////////////////// + // Death Proposal else if (randnum < m_BirthProb+m_DeathProb) { if (m_ParticleGrid->m_NumParticles > 0) { -#ifdef TIMING - tic(&deathproposal_time); - deathstats.propose(); -#endif - int pnum = rand()%m_ParticleGrid->m_NumParticles; Particle *dp = m_ParticleGrid->GetParticle(pnum); if (dp->pID == -1 && dp->mID == -1) { - - float ex_energy = enc->computeExternalEnergy(dp->R,dp->N,dp->len,dp); - float in_energy = enc->computeInternalEnergy(dp); + float ex_energy = m_EnergyComputer->computeExternalEnergy(dp->R,dp->N,dp->len,dp); + float in_energy = m_EnergyComputer->computeInternalEnergy(dp); float prob = m_ParticleGrid->m_NumParticles * (m_BirthProb) /(m_Density*m_DeathProb); //*SpatProb(dp->R); prob *= exp(-(in_energy/m_InTemp+ex_energy/m_ExTemp)) ; - if (prob > 1 || mtrand->frand() < prob) + if (prob > 1 || m_RandGen->frand() < prob) { m_ParticleGrid->RemoveParticle(pnum); -#ifdef TIMING - deathstats.accepted(); -#endif m_AcceptedProposals++; } } -#ifdef TIMING - toc(&deathproposal_time); -#endif } } - /////////////////////////////////////////////////////////////// - //////// Shift Proposal - /////////////////////////////////////////////////////////////// + // Shift Proposal else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb) { - float energy = 0; if (m_ParticleGrid->m_NumParticles > 0) { -#ifdef TIMING - tic(&shiftproposal_time); - shiftstats.propose(); -#endif - int pnum = rand()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); Particle prop_p = *p; - distortn(sigma_g, prop_p.R); - distortn(sigma_g/(2*p->len), prop_p.N); + DistortVector(m_Sigma, prop_p.R); + DistortVector(m_Sigma/(2*p->len), prop_p.N); prop_p.N.normalize(); - float ex_energy = enc->computeExternalEnergy(prop_p.R,prop_p.N,p->len,p) - - enc->computeExternalEnergy(p->R,p->N,p->len,p); - float in_energy = enc->computeInternalEnergy(&prop_p) - enc->computeInternalEnergy(p); + float ex_energy = m_EnergyComputer->computeExternalEnergy(prop_p.R,prop_p.N,p->len,p) + - m_EnergyComputer->computeExternalEnergy(p->R,p->N,p->len,p); + float in_energy = m_EnergyComputer->computeInternalEnergy(&prop_p) - m_EnergyComputer->computeInternalEnergy(p); float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp); - // * SpatProb(p->R) / SpatProb(prop_p.R); - if (mtrand->frand() < prob) + if (m_RandGen->frand() < prob) { vnl_vector_fixed Rtmp = p->R; vnl_vector_fixed Ntmp = p->N; p->R = prop_p.R; p->N = prop_p.N; if (!m_ParticleGrid->TryUpdateGrid(pnum)) { p->R = Rtmp; p->N = Ntmp; } -#ifdef TIMING - shiftstats.accepted(); -#endif m_AcceptedProposals++; } - -#ifdef TIMING - toc(&shiftproposal_time); -#endif - } - } + // Optimal Shift Proposal else if (randnum < m_BirthProb+m_DeathProb+m_ShiftProb+m_OptShiftProb) { - float energy = 0; if (m_ParticleGrid->m_NumParticles > 0) { int pnum = rand()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); bool no_proposal = false; Particle prop_p = *p; if (p->pID != -1 && p->mID != -1) { Particle *plus = m_ParticleGrid->GetParticle(p->pID); int ep_plus = (plus->pID == p->ID)? 1 : -1; Particle *minus = m_ParticleGrid->GetParticle(p->mID); int ep_minus = (minus->pID == p->ID)? 1 : -1; prop_p.R = (plus->R + plus->N * (plus->len * ep_plus) + minus->R + minus->N * (minus->len * ep_minus)); prop_p.R *= 0.5; prop_p.N = plus->R - minus->R; prop_p.N.normalize(); } else if (p->pID != -1) { Particle *plus = m_ParticleGrid->GetParticle(p->pID); int ep_plus = (plus->pID == p->ID)? 1 : -1; prop_p.R = plus->R + plus->N * (plus->len * ep_plus * 2); prop_p.N = plus->N; } else if (p->mID != -1) { Particle *minus = m_ParticleGrid->GetParticle(p->mID); int ep_minus = (minus->pID == p->ID)? 1 : -1; prop_p.R = minus->R + minus->N * (minus->len * ep_minus * 2); prop_p.N = minus->N; } else no_proposal = true; if (!no_proposal) { float cos = dot_product(prop_p.N, p->N); - float p_rev = exp(-((prop_p.R-p->R).squared_magnitude() + (1-cos*cos))*gamma_g)/Z_g; + float p_rev = exp(-((prop_p.R-p->R).squared_magnitude() + (1-cos*cos))*m_Gamma)/m_Z; - float ex_energy = enc->computeExternalEnergy(prop_p.R,prop_p.N,p->len,p) - - enc->computeExternalEnergy(p->R,p->N,p->len,p); - float in_energy = enc->computeInternalEnergy(&prop_p) - enc->computeInternalEnergy(p); + float ex_energy = m_EnergyComputer->computeExternalEnergy(prop_p.R,prop_p.N,p->len,p) + - m_EnergyComputer->computeExternalEnergy(p->R,p->N,p->len,p); + float in_energy = m_EnergyComputer->computeInternalEnergy(&prop_p) - m_EnergyComputer->computeInternalEnergy(p); float prob = exp(ex_energy/m_ExTemp+in_energy/m_InTemp)*m_ShiftProb*p_rev/(m_OptShiftProb+m_ShiftProb*p_rev); - //* SpatProb(p->R) / SpatProb(prop_p.R); - if (mtrand->frand() < prob) + if (m_RandGen->frand() < prob) { vnl_vector_fixed Rtmp = p->R; vnl_vector_fixed Ntmp = p->N; p->R = prop_p.R; p->N = prop_p.N; if (!m_ParticleGrid->TryUpdateGrid(pnum)) { p->R = Rtmp; p->N = Ntmp; } m_AcceptedProposals++; } } } - } else { if (m_ParticleGrid->m_NumParticles > 0) { - -#ifdef TIMING - tic(&connproposal_time); - connstats.propose(); -#endif - int pnum = rand()%m_ParticleGrid->m_NumParticles; Particle *p = m_ParticleGrid->GetParticle(pnum); EndPoint P; P.p = p; - P.ep = (mtrand->frand() > 0.5)? 1 : -1; + P.ep = (m_RandGen->frand() > 0.5)? 1 : -1; RemoveAndSaveTrack(P); - if (TrackBackup.m_Probability != 0) + if (m_BackupTrack.m_Probability != 0) { MakeTrackProposal(P); - float prob = (TrackProposal.m_Energy-TrackBackup.m_Energy)/m_InTemp ; + float prob = (m_ProposalTrack.m_Energy-m_BackupTrack.m_Energy)/m_InTemp ; - // prob = exp(prob)*(TrackBackup.proposal_probability) - // /(TrackProposal.proposal_probability); - prob = exp(prob)*(TrackBackup.m_Probability * pow(del_prob,TrackProposal.m_Length)) - /(TrackProposal.m_Probability * pow(del_prob,TrackBackup.m_Length)); - if (mtrand->frand() < prob) + prob = exp(prob)*(m_BackupTrack.m_Probability * pow(m_DelProb,m_ProposalTrack.m_Length)) + /(m_ProposalTrack.m_Probability * pow(m_DelProb,m_BackupTrack.m_Length)); + if (m_RandGen->frand() < prob) { - ImplementTrack(TrackProposal); -#ifdef TIMING - connstats.accepted(); -#endif + ImplementTrack(m_ProposalTrack); m_AcceptedProposals++; } else { - ImplementTrack(TrackBackup); + ImplementTrack(m_BackupTrack); } } else - ImplementTrack(TrackBackup); - -#ifdef TIMING - toc(&connproposal_time); -#endif + ImplementTrack(m_BackupTrack); } } } void MetropolisHastingsSampler::ImplementTrack(Track &T) { for (int k = 1; k < T.m_Length;k++) - { m_ParticleGrid->CreateConnection(T.track[k-1].p,T.track[k-1].ep,T.track[k].p,-T.track[k].ep); - } } void MetropolisHastingsSampler::RemoveAndSaveTrack(EndPoint P) { EndPoint Current = P; - int cnt = 0; float energy = 0; float AccumProb = 1.0; - TrackBackup.track[cnt] = Current; - + m_BackupTrack.track[cnt] = Current; EndPoint Next; - - for (;;) { Next.p = 0; if (Current.ep == 1) { if (Current.p->pID != -1) { Next.p = m_ParticleGrid->GetParticle(Current.p->pID); Current.p->pID = -1; m_ParticleGrid->m_NumConnections--; } } else if (Current.ep == -1) { if (Current.p->mID != -1) { Next.p = m_ParticleGrid->GetParticle(Current.p->mID); Current.p->mID = -1; m_ParticleGrid->m_NumConnections--; } } else { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 3\n"); break; } if (Next.p == 0) // no successor { Next.ep = 0; // mark as empty successor break; } else { if (Next.p->pID == Current.p->ID) { Next.p->pID = -1; Next.ep = 1; } else if (Next.p->mID == Current.p->ID) { Next.p->mID = -1; Next.ep = -1; } else { fprintf(stderr,"MetropolisHastingsSampler_randshift: Connection inconsistent 4\n"); break; } } - ComputeEndPointProposalDistribution(Current); - - AccumProb *= (simpsamp.probFor(Next)); + AccumProb *= (m_SimpSamp.probFor(Next)); if (Next.p == 0) // no successor -> break break; - energy += enc->computeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); + energy += m_EnergyComputer->computeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); Current = Next; Current.ep *= -1; cnt++; - TrackBackup.track[cnt] = Current; - + m_BackupTrack.track[cnt] = Current; - if (mtrand->rand() > del_prob) - { + if (m_RandGen->rand() > m_DelProb) break; - } - } - TrackBackup.m_Energy = energy; - TrackBackup.m_Probability = AccumProb; - TrackBackup.m_Length = cnt+1; + m_BackupTrack.m_Energy = energy; + m_BackupTrack.m_Probability = AccumProb; + m_BackupTrack.m_Length = cnt+1; } void MetropolisHastingsSampler::MakeTrackProposal(EndPoint P) { EndPoint Current = P; int cnt = 0; float energy = 0; float AccumProb = 1.0; - TrackProposal.track[cnt++] = Current; + m_ProposalTrack.track[cnt++] = Current; Current.p->label = 1; for (;;) { - // next candidate is already connected if ((Current.ep == 1 && Current.p->pID != -1) || (Current.ep == -1 && Current.p->mID != -1)) break; // track too long - if (cnt > 250) - break; +// if (cnt > 250) +// break; ComputeEndPointProposalDistribution(Current); - // // no candidates anymore - // if (simpsamp.isempty()) - // break; - - int k = simpsamp.draw(mtrand->frand()); + int k = m_SimpSamp.draw(m_RandGen->frand()); // stop tracking proposed if (k==0) break; - EndPoint Next = simpsamp.objs[k]; - float probability = simpsamp.probFor(k); + EndPoint Next = m_SimpSamp.objs[k]; + float probability = m_SimpSamp.probFor(k); // accumulate energy and proposal distribution - energy += enc->computeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); + energy += m_EnergyComputer->computeInternalEnergyConnection(Current.p,Current.ep,Next.p,Next.ep); AccumProb *= probability; // track to next endpoint Current = Next; Current.ep *= -1; Current.p->label = 1; // put label to avoid loops - TrackProposal.track[cnt++] = Current; + m_ProposalTrack.track[cnt++] = Current; } - TrackProposal.m_Energy = energy; - TrackProposal.m_Probability = AccumProb; - TrackProposal.m_Length = cnt; + m_ProposalTrack.m_Energy = energy; + m_ProposalTrack.m_Probability = AccumProb; + m_ProposalTrack.m_Length = cnt; // clear labels - for (int j = 0; j < TrackProposal.m_Length;j++) - TrackProposal.track[j].p->label = 0; + for (int j = 0; j < m_ProposalTrack.m_Length;j++) + m_ProposalTrack.track[j].p->label = 0; } void MetropolisHastingsSampler::ComputeEndPointProposalDistribution(EndPoint P) { Particle *p = P.p; int ep = P.ep; float dist,dot; vnl_vector_fixed R = p->R + (p->N * (ep*p->len) ); m_ParticleGrid->ComputeNeighbors(R); - simpsamp.clear(); + m_SimpSamp.clear(); - simpsamp.add(stopprobability,EndPoint(0,0)); + m_SimpSamp.add(m_StopProb,EndPoint(0,0)); for (;;) { Particle *p2 = m_ParticleGrid->GetNextNeighbor(); if (p2 == 0) break; if (p!=p2 && p2->label == 0) { if (p2->mID == -1) { dist = (p2->R - p2->N * p2->len - R).squared_magnitude(); - if (dist < dthres) + if (dist < m_DistanceThreshold) { dot = dot_product(p2->N,p->N) * ep; - if (dot > nthres) + if (dot > m_CurvatureThreshold) { - float en = enc->computeInternalEnergyConnection(p,ep,p2,-1); - simpsamp.add(exp(en/T_prop),EndPoint(p2,-1)); + float en = m_EnergyComputer->computeInternalEnergyConnection(p,ep,p2,-1); + m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,-1)); } } } if (p2->pID == -1) { dist = (p2->R + p2->N * p2->len - R).squared_magnitude(); - if (dist < dthres) + if (dist < m_DistanceThreshold) { dot = dot_product(p2->N,p->N) * (-ep); - if (dot > nthres) + if (dot > m_CurvatureThreshold) { - float en = enc->computeInternalEnergyConnection(p,ep,p2,+1); - simpsamp.add(exp(en/T_prop),EndPoint(p2,+1)); + float en = m_EnergyComputer->computeInternalEnergyConnection(p,ep,p2,+1); + m_SimpSamp.add(exp(en/m_TractProb),EndPoint(p2,+1)); } } } } } } +int MetropolisHastingsSampler::GetNumAcceptedProposals() +{ + return m_AcceptedProposals; +} + diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h index 409f15ca5a..75c2210434 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkMetropolisHastingsSampler.h @@ -1,108 +1,100 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _SAMPLER #define _SAMPLER // MITK #include #include #include #include #include // ITK #include // MISC #include namespace mitk { class MitkDiffusionImaging_EXPORT MetropolisHastingsSampler { public: typedef itk::Image< float, 3 > ItkFloatImageType; - ParticleGrid* m_ParticleGrid; - const int* datasz; - EnergyComputer* enc; - int m_Iterations; - float width; - float height; - float depth; - int m_AcceptedProposals; + MetropolisHastingsSampler(ParticleGrid* grid, EnergyComputer* enComp, MTRand* randGen, float curvThres); + void SetTemperature(float val); + + void MakeProposal(); + int GetNumAcceptedProposals(); + +protected: + + // connection proposal related methods + void ImplementTrack(Track& T); + void RemoveAndSaveTrack(EndPoint P); + void MakeTrackProposal(EndPoint P); + void ComputeEndPointProposalDistribution(EndPoint P); + + // generate random vectors + vnl_vector_fixed DistortVector(float sigma, vnl_vector_fixed& vec); + vnl_vector_fixed GetRandomDirection(); + + MTRand* m_RandGen; + Track m_ProposalTrack; + Track m_BackupTrack; + SimpSamp m_SimpSamp; float m_InTemp; float m_ExTemp; float m_Density; float m_BirthProb; float m_DeathProb; float m_ShiftProb; float m_OptShiftProb; float m_ConnectionProb; - float sigma_g; - float gamma_g; - float Z_g; - - float dthres; - float nthres; - float T_prop; - float stopprobability; - float del_prob; - - float len_def; - float len_sig; + float m_Sigma; + float m_Gamma; + float m_Z; - float cap_def; - float cap_sig; - - float externalEnergy; - float internalEnergy; + float m_DistanceThreshold; + float m_CurvatureThreshold; + float m_TractProb; + float m_StopProb; + float m_DelProb; + float m_ParticleLength; float m_ChempotParticle; - MTRand* mtrand; - Track TrackProposal, TrackBackup; - SimpSamp simpsamp; - - MetropolisHastingsSampler(ParticleGrid* grid, MTRand* randGen); - - void SetEnergyComputer(EnergyComputer *e); - void SetParameters(float Temp, int numit, float plen, float curv_hardthres, float chempot_particle); - void SetTemperature(float temp); - - void Iterate(float* acceptance, unsigned long* numCon, unsigned long* numPart, bool *abort); - void IterateOneStep(); - - void ImplementTrack(Track& T); - void RemoveAndSaveTrack(EndPoint P); - void MakeTrackProposal(EndPoint P); - void ComputeEndPointProposalDistribution(EndPoint P); - - vnl_vector_fixed distortn(float sigma, vnl_vector_fixed& vec); - vnl_vector_fixed rand_sphere(); -// vnl_vector_fixed rand(float w, float h, float d); + ParticleGrid* m_ParticleGrid; + const int* datasz; + EnergyComputer* m_EnergyComputer; + float width; + float height; + float depth; + int m_AcceptedProposals; }; } #endif diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp index ad1514abaf..ebd714fcc0 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp @@ -1,385 +1,387 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkParticleGrid.h" #include #include using namespace mitk; -ParticleGrid::ParticleGrid(ItkFloatImageType* image, float cellSize) +ParticleGrid::ParticleGrid(ItkFloatImageType* image, float particleLength) { // initialize counters m_NumParticles = 0; m_NumConnections = 0; m_NumCellOverflows = 0; + m_ParticleLength = particleLength; - // define isotropic grid from voxel spacing and particle length (cellSize) + // define isotropic grid from voxel spacing and particle length + float cellSize = 2*m_ParticleLength; m_GridSize[0] = image->GetLargestPossibleRegion().GetSize()[0]*image->GetSpacing()[0]/cellSize +1; m_GridSize[1] = image->GetLargestPossibleRegion().GetSize()[1]*image->GetSpacing()[1]/cellSize +1; m_GridSize[2] = image->GetLargestPossibleRegion().GetSize()[2]*image->GetSpacing()[2]/cellSize +1; m_GridScale[0] = 1/cellSize; m_GridScale[1] = 1/cellSize; m_GridScale[2] = 1/cellSize; m_CellCapacity = 1024; // maximum number of particles per grid cell m_ContainerCapacity = 100000; // initial particle container capacity int numCells = m_GridSize[0]*m_GridSize[1]*m_GridSize[2]; // number of grid cells m_Particles.resize(m_ContainerCapacity); // allocate and initialize particles m_Grid.resize(numCells*m_CellCapacity, NULL); // allocate and initialize particle grid m_OccupationCount.resize(numCells, 0); // allocate and initialize occupation counter array m_NeighbourTracker.cellidx.resize(8, 0); // allocate and initialize neighbour tracker m_NeighbourTracker.cellidx_c.resize(8, 0); for (int i = 0;i < m_ContainerCapacity;i++) // initialize particle IDs m_Particles[i].ID = i; std::cout << "ParticleGrid: allocated " << (sizeof(Particle)*m_ContainerCapacity + sizeof(Particle*)*m_GridSize[0]*m_GridSize[1]*m_GridSize[2])/1048576 << "mb for " << m_ContainerCapacity/1000 << "k particles." << std::endl; } ParticleGrid::~ParticleGrid() { } bool ParticleGrid::ReallocateGrid() { std::cout << "ParticleGrid: reallocating ..." << std::endl; int new_capacity = m_ContainerCapacity + 100000; // increase container capacity by 100k particles try { m_Particles.resize(new_capacity); // reallocate particles for (int i = 0; i R) { if (m_NumParticles >= m_ContainerCapacity) { if (!ReallocateGrid()) return NULL; } int xint = int(R[0]*m_GridScale[0]); if (xint < 0) return NULL; if (xint >= m_GridSize[0]) return NULL; int yint = int(R[1]*m_GridScale[1]); if (yint < 0) return NULL; if (yint >= m_GridSize[1]) return NULL; int zint = int(R[2]*m_GridScale[2]); if (zint < 0) return NULL; if (zint >= m_GridSize[2]) return NULL; int idx = xint + m_GridSize[0]*(yint + m_GridSize[1]*zint); if (m_OccupationCount[idx] < m_CellCapacity) { Particle *p = &(m_Particles[m_NumParticles]); p->R = R; p->mID = -1; p->pID = -1; m_NumParticles++; p->gridindex = m_CellCapacity*idx + m_OccupationCount[idx]; m_Grid[p->gridindex] = p; m_OccupationCount[idx]++; return p; } else { m_NumCellOverflows++; return NULL; } } bool ParticleGrid::TryUpdateGrid(int k) { Particle* p = &(m_Particles[k]); int xint = int(p->R[0]*m_GridScale[0]); if (xint < 0) return false; if (xint >= m_GridSize[0]) return false; int yint = int(p->R[1]*m_GridScale[1]); if (yint < 0) return false; if (yint >= m_GridSize[1]) return false; int zint = int(p->R[2]*m_GridScale[2]); if (zint < 0) return false; if (zint >= m_GridSize[2]) return false; int idx = xint + m_GridSize[0]*(yint+ zint*m_GridSize[1]); int cellidx = p->gridindex/m_CellCapacity; if (idx != cellidx) // cell has changed { if (m_OccupationCount[idx] < m_CellCapacity) { // remove from old position in grid; int grdindex = p->gridindex; m_Grid[grdindex] = m_Grid[cellidx*m_CellCapacity + m_OccupationCount[cellidx]-1]; m_Grid[grdindex]->gridindex = grdindex; m_OccupationCount[cellidx]--; // insert at new position in grid p->gridindex = idx*m_CellCapacity + m_OccupationCount[idx]; m_Grid[p->gridindex] = p; m_OccupationCount[idx]++; return true; } else { m_NumCellOverflows++; return false; } } return true; } void ParticleGrid::RemoveParticle(int k) { Particle* p = &(m_Particles[k]); int gridIndex = p->gridindex; int cellIdx = gridIndex/m_CellCapacity; int idx = gridIndex%m_CellCapacity; // remove pending connections if (p->mID != -1) DestroyConnection(p,-1); if (p->pID != -1) DestroyConnection(p,+1); // remove from grid if (idx < m_OccupationCount[cellIdx]-1) { m_Grid[gridIndex] = m_Grid[cellIdx*m_CellCapacity+m_OccupationCount[cellIdx]-1]; m_Grid[cellIdx*m_CellCapacity+m_OccupationCount[cellIdx]-1] = NULL; m_Grid[gridIndex]->gridindex = gridIndex; } m_OccupationCount[cellIdx]--; // remove from container if (k < m_NumParticles-1) { Particle* last = &m_Particles[m_NumParticles-1]; // last particle // update connections of last particle because its index is changing if (last->mID!=-1) { if ( m_Particles[last->mID].mID == m_NumParticles-1 ) m_Particles[last->mID].mID = k; else if ( m_Particles[last->mID].pID == m_NumParticles-1 ) m_Particles[last->mID].pID = k; } if (last->pID!=-1) { if ( m_Particles[last->pID].mID == m_NumParticles-1 ) m_Particles[last->pID].mID = k; else if ( m_Particles[last->pID].pID == m_NumParticles-1 ) m_Particles[last->pID].pID = k; } m_Particles[k] = m_Particles[m_NumParticles-1]; // move very last particle to empty slot m_Particles[m_NumParticles-1].ID = m_NumParticles-1; // update ID of removed particle to match the index m_Particles[k].ID = k; // update ID of moved particle m_Grid[m_Particles[k].gridindex] = &m_Particles[k]; // update address of moved particle } m_NumParticles--; } void ParticleGrid::ComputeNeighbors(vnl_vector_fixed &R) { float xfrac = R[0]*m_GridScale[0]; float yfrac = R[1]*m_GridScale[1]; float zfrac = R[2]*m_GridScale[2]; int xint = int(xfrac); int yint = int(yfrac); int zint = int(zfrac); int dx = -1; if (xfrac-xint > 0.5) dx = 1; if (xint <= 0) { xint = 0; dx = 1; } if (xint >= m_GridSize[0]-1) { xint = m_GridSize[0]-1; dx = -1; } int dy = -1; if (yfrac-yint > 0.5) dy = 1; if (yint <= 0) {yint = 0; dy = 1; } if (yint >= m_GridSize[1]-1) {yint = m_GridSize[1]-1; dy = -1;} int dz = -1; if (zfrac-zint > 0.5) dz = 1; if (zint <= 0) {zint = 0; dz = 1; } if (zint >= m_GridSize[2]-1) {zint = m_GridSize[2]-1; dz = -1;} m_NeighbourTracker.cellidx[0] = xint + m_GridSize[0]*(yint+zint*m_GridSize[1]); m_NeighbourTracker.cellidx[1] = m_NeighbourTracker.cellidx[0] + dx; m_NeighbourTracker.cellidx[2] = m_NeighbourTracker.cellidx[1] + dy*m_GridSize[0]; m_NeighbourTracker.cellidx[3] = m_NeighbourTracker.cellidx[2] - dx; m_NeighbourTracker.cellidx[4] = m_NeighbourTracker.cellidx[0] + dz*m_GridSize[0]*m_GridSize[1]; m_NeighbourTracker.cellidx[5] = m_NeighbourTracker.cellidx[4] + dx; m_NeighbourTracker.cellidx[6] = m_NeighbourTracker.cellidx[5] + dy*m_GridSize[0]; m_NeighbourTracker.cellidx[7] = m_NeighbourTracker.cellidx[6] - dx; m_NeighbourTracker.cellidx_c[0] = m_CellCapacity*m_NeighbourTracker.cellidx[0]; m_NeighbourTracker.cellidx_c[1] = m_CellCapacity*m_NeighbourTracker.cellidx[1]; m_NeighbourTracker.cellidx_c[2] = m_CellCapacity*m_NeighbourTracker.cellidx[2]; m_NeighbourTracker.cellidx_c[3] = m_CellCapacity*m_NeighbourTracker.cellidx[3]; m_NeighbourTracker.cellidx_c[4] = m_CellCapacity*m_NeighbourTracker.cellidx[4]; m_NeighbourTracker.cellidx_c[5] = m_CellCapacity*m_NeighbourTracker.cellidx[5]; m_NeighbourTracker.cellidx_c[6] = m_CellCapacity*m_NeighbourTracker.cellidx[6]; m_NeighbourTracker.cellidx_c[7] = m_CellCapacity*m_NeighbourTracker.cellidx[7]; m_NeighbourTracker.cellcnt = 0; m_NeighbourTracker.pcnt = 0; } Particle* ParticleGrid::GetNextNeighbor() { if (m_NeighbourTracker.pcnt < m_OccupationCount[m_NeighbourTracker.cellidx[m_NeighbourTracker.cellcnt]]) { return m_Grid[m_NeighbourTracker.cellidx_c[m_NeighbourTracker.cellcnt] + (m_NeighbourTracker.pcnt++)]; } else { for(;;) { m_NeighbourTracker.cellcnt++; if (m_NeighbourTracker.cellcnt >= 8) return 0; if (m_OccupationCount[m_NeighbourTracker.cellidx[m_NeighbourTracker.cellcnt]] > 0) break; } m_NeighbourTracker.pcnt = 1; return m_Grid[m_NeighbourTracker.cellidx_c[m_NeighbourTracker.cellcnt]]; } } void ParticleGrid::CreateConnection(Particle *P1,int ep1, Particle *P2, int ep2) { if (ep1 == -1) P1->mID = P2->ID; else P1->pID = P2->ID; if (ep2 == -1) P2->mID = P1->ID; else P2->pID = P1->ID; m_NumConnections++; } void ParticleGrid::DestroyConnection(Particle *P1,int ep1, Particle *P2, int ep2) { if (ep1 == -1) P1->mID = -1; else P1->pID = -1; if (ep2 == -1) P2->mID = -1; else P2->pID = -1; m_NumConnections--; } void ParticleGrid::DestroyConnection(Particle *P1,int ep1) { Particle *P2 = 0; if (ep1 == 1) { P2 = &m_Particles[P1->pID]; P1->pID = -1; } else { P2 = &m_Particles[P1->mID]; P1->mID = -1; } if (P2->mID == P1->ID) P2->mID = -1; else P2->pID = -1; m_NumConnections--; } bool ParticleGrid::CheckConsistency() { for (int i=0; iID != i) { std::cout << "Particle ID error!" << std::endl; return false; } if (p->mID!=-1) { Particle* p2 = &m_Particles[p->mID]; if (p2->mID!=p->ID && p2->pID!=p->ID) { std::cout << "Connection inconsistent!" << std::endl; return false; } } if (p->pID!=-1) { Particle* p2 = &m_Particles[p->pID]; if (p2->mID!=p->ID && p2->pID!=p->ID) { std::cout << "Connection inconsistent!" << std::endl; return false; } } } return true; } diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h index 281e526bb1..8d8d73e5a8 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h @@ -1,119 +1,120 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _PARTICLEGRID #define _PARTICLEGRID // MITK #include "MitkDiffusionImagingExports.h" #include // ITK #include namespace mitk { class MitkDiffusionImaging_EXPORT ParticleGrid { public: typedef itk::Image< float, 3 > ItkFloatImageType; int m_NumParticles; // number of particles int m_NumConnections; // number of connections int m_NumCellOverflows; // number of cell overflows + float m_ParticleLength; - ParticleGrid(ItkFloatImageType* image, float cellSize); + ParticleGrid(ItkFloatImageType* image, float particleLength); ~ParticleGrid(); Particle* GetParticle(int ID); Particle* NewParticle(vnl_vector_fixed R); bool TryUpdateGrid(int k); void RemoveParticle(int k); void ComputeNeighbors(vnl_vector_fixed &R); Particle* GetNextNeighbor(); void CreateConnection(Particle *P1,int ep1, Particle *P2, int ep2); void DestroyConnection(Particle *P1,int ep1, Particle *P2, int ep2); void DestroyConnection(Particle *P1,int ep1); bool CheckConsistency(); protected: bool ReallocateGrid(); std::vector< Particle* > m_Grid; // the grid std::vector< Particle > m_Particles; // particle container std::vector< int > m_OccupationCount; // number of particles per grid cell int m_ContainerCapacity; // maximal number of particles vnl_vector_fixed< int, 3 > m_GridSize; // grid dimensions vnl_vector_fixed< float, 3 > m_GridScale; // scaling factor for grid int m_CellCapacity; // particle capacity of single cell in grid struct NeighborTracker // to run over the neighbors { std::vector< int > cellidx; std::vector< int > cellidx_c; int cellcnt; int pcnt; } m_NeighbourTracker; }; class MitkDiffusionImaging_EXPORT Track { public: std::vector< EndPoint > track; float m_Energy; float m_Probability; int m_Length; Track() { track.resize(1000); } ~Track(){} void clear() { m_Length = 0; m_Energy = 0; m_Probability = 1; } bool isequal(Track& t) { for (int i = 0; i < m_Length;i++) { if (track[i].p != t.track[i].p || track[i].ep != t.track[i].ep) return false; } return true; } }; } #endif diff --git a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp index 57806429a3..ee16a181e1 100644 --- a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp +++ b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp @@ -1,392 +1,397 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "itkGibbsTrackingFilter.h" // MITK #include #include #include #include #include #include #include #include // ITK #include #pragma GCC visibility push(default) #include #pragma GCC visibility pop // MISC #include #include #include namespace itk{ template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::GibbsTrackingFilter(): m_TempStart(0.1), m_TempEnd(0.001), m_NumIt(500000), m_ParticleWeight(0), m_ParticleWidth(0), m_ParticleLength(0), m_ChempotConnection(10), - m_ChempotParticle(0), m_InexBalance(0), m_Chempot2(0.2), m_FiberLength(10), m_AbortTracking(false), m_NumConnections(0), m_NumParticles(0), m_NumAcceptedFibers(0), m_CurrentStep(0), m_BuildFibers(false), m_Steps(10), m_Memory(0), m_ProposalAcceptance(0), m_CurvatureHardThreshold(0.7), m_Meanval_sq(0.0), m_DuplicateImage(true) { } template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::~GibbsTrackingFilter() { } // fill output fiber bundle datastructure template< class ItkQBallImageType > typename GibbsTrackingFilter< ItkQBallImageType >::FiberPolyDataType GibbsTrackingFilter< ItkQBallImageType >::GetFiberBundle() { if (!m_AbortTracking) { m_BuildFibers = true; while (m_BuildFibers){} } return m_FiberPolyData; } template< class ItkQBallImageType > bool GibbsTrackingFilter< ItkQBallImageType > ::EstimateParticleWeight() { MITK_INFO << "itkGibbsTrackingFilter: estimating particle weight"; typedef itk::DiffusionQballGeneralizedFaImageFilter GfaFilterType; GfaFilterType::Pointer gfaFilter = GfaFilterType::New(); gfaFilter->SetInput(m_QBallImage); gfaFilter->SetComputationMethod(GfaFilterType::GFA_STANDARD); gfaFilter->Update(); ItkFloatImageType::Pointer gfaImage = gfaFilter->GetOutput(); float samplingStart = 1.0; float samplingStop = 0.66; // GFA iterator typedef ImageRegionIterator< ItkFloatImageType > GfaIteratorType; GfaIteratorType gfaIt(gfaImage, gfaImage->GetLargestPossibleRegion() ); // Mask iterator typedef ImageRegionConstIterator< ItkFloatImageType > MaskIteratorType; MaskIteratorType mit(m_MaskImage, m_MaskImage->GetLargestPossibleRegion() ); // Input iterator typedef ImageRegionConstIterator< ItkQBallImageType > InputIteratorType; InputIteratorType it(m_QBallImage, m_QBallImage->GetLargestPossibleRegion() ); float upper = 0; int count = 0; for(float thr=samplingStart; thr>samplingStop; thr-=0.01) { it.GoToBegin(); mit.GoToBegin(); gfaIt.GoToBegin(); while( !gfaIt.IsAtEnd() ) { if(gfaIt.Get()>thr && mit.Get()>0) { itk::OrientationDistributionFunction odf(it.Get().GetDataPointer()); upper += odf.GetMaxValue()-odf.GetMeanValue(); ++count; } ++it; ++mit; ++gfaIt; } } if (count>0) upper /= count; else return false; m_ParticleWeight = upper/6; return true; } // perform global tracking template< class ItkQBallImageType > void GibbsTrackingFilter< ItkQBallImageType >::GenerateData() { if (m_QBallImage.IsNull() && m_TensorImage.IsNotNull()) { TensorImageToQBallImageFilter::Pointer filter = TensorImageToQBallImageFilter::New(); filter->SetInput( m_TensorImage ); filter->Update(); m_QBallImage = filter->GetOutput(); } // image sizes and spacing int imgSize[4] = { QBALL_ODFSIZE, m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(0), m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(1), m_QBallImage->GetLargestPossibleRegion().GetSize().GetElement(2)}; double spacing[3] = {m_QBallImage->GetSpacing().GetElement(0),m_QBallImage->GetSpacing().GetElement(1),m_QBallImage->GetSpacing().GetElement(2)}; // make sure image has enough slices if (imgSize[1]<3 || imgSize[2]<3 || imgSize[3]<3) { MITK_INFO << "itkGibbsTrackingFilter: image size < 3 not supported"; m_AbortTracking = true; } // calculate rotation matrix vnl_matrix temp = m_QBallImage->GetDirection().GetVnlMatrix(); vnl_matrix directionMatrix; directionMatrix.set_size(3,3); vnl_copy(temp, directionMatrix); vnl_vector_fixed d0 = directionMatrix.get_column(0); d0.normalize(); vnl_vector_fixed d1 = directionMatrix.get_column(1); d1.normalize(); vnl_vector_fixed d2 = directionMatrix.get_column(2); d2.normalize(); directionMatrix.set_column(0, d0); directionMatrix.set_column(1, d1); directionMatrix.set_column(2, d2); vnl_matrix_fixed I = directionMatrix*directionMatrix.transpose(); if(!I.is_identity(mitk::eps)){ MITK_INFO << "itkGibbsTrackingFilter: image direction is not a rotation matrix. Tracking not possible!"; m_AbortTracking = true; } // generate local working copy of QBall image typename ItkQBallImageType::Pointer qballImage; if (m_DuplicateImage) { typedef itk::ImageDuplicator< ItkQBallImageType > DuplicateFilterType; typename DuplicateFilterType::Pointer duplicator = DuplicateFilterType::New(); duplicator->SetInputImage( m_QBallImage ); duplicator->Update(); qballImage = duplicator->GetOutput(); } else { qballImage = m_QBallImage; } // perform mean subtraction on odfs typedef ImageRegionIterator< ItkQBallImageType > InputIteratorType; InputIteratorType it(qballImage, qballImage->GetLargestPossibleRegion() ); it.GoToBegin(); while (!it.IsAtEnd()) { itk::OrientationDistributionFunction odf(it.Get().GetDataPointer()); float mean = odf.GetMeanValue(); odf -= mean; it.Set(odf.GetDataPointer()); ++it; } // mask image int maskImageSize[3]; float *mask; if(m_MaskImage.IsNotNull()) { mask = (float*) m_MaskImage->GetBufferPointer(); maskImageSize[0] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(0); maskImageSize[1] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(1); maskImageSize[2] = m_MaskImage->GetLargestPossibleRegion().GetSize().GetElement(2); } else { mask = 0; maskImageSize[0] = imgSize[1]; maskImageSize[1] = imgSize[2]; maskImageSize[2] = imgSize[3]; } int mask_oversamp_mult = maskImageSize[0]/imgSize[1]; // get paramters float minSpacing; if(spacing[0]m_NumIt) { MITK_INFO << "itkGibbsTrackingFilter: not enough iterations!"; m_AbortTracking = true; } if (m_CurvatureHardThreshold < mitk::eps) m_CurvatureHardThreshold = 0; unsigned long singleIts = (unsigned long)((1.0*m_NumIt) / (1.0*m_Steps)); MTRand randGen(1); srand(1); SphereInterpolator* interpolator = LoadSphereInterpolator(); MITK_INFO << "itkGibbsTrackingFilter: setting up MH-sampler"; - ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, 2*m_ParticleLength); - MetropolisHastingsSampler* sampler = new MetropolisHastingsSampler(particleGrid, &randGen); + ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength); EnergyComputer encomp(&randGen, qballImage, imgSize,spacing,interpolator,particleGrid,mask,mask_oversamp_mult, directionMatrix); encomp.setParameters(m_ParticleWeight,m_ParticleWidth,m_ChempotConnection*m_ParticleLength*m_ParticleLength,m_ParticleLength,m_CurvatureHardThreshold,m_InexBalance,m_Chempot2, m_Meanval_sq); - sampler->SetEnergyComputer(&encomp); - sampler->SetParameters(m_TempStart,singleIts,m_ParticleLength,m_CurvatureHardThreshold,m_ChempotParticle); + MetropolisHastingsSampler* sampler = new MetropolisHastingsSampler(particleGrid, &encomp, &randGen, m_CurvatureHardThreshold); // main loop m_NumAcceptedFibers = 0; - for( int step = 0; step < m_Steps; step++ ) + for( int m_CurrentStep = 1; m_CurrentStep <= m_Steps; m_CurrentStep++ ) { - if (m_AbortTracking) - break; - - m_CurrentStep = step+1; - float temperature = m_TempStart * exp(alpha*(((1.0)*step)/((1.0)*m_Steps))); - - sampler->SetTemperature(temperature); - sampler->Iterate(&m_ProposalAcceptance, &m_NumConnections, &m_NumParticles, &m_AbortTracking); + m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/m_NumIt; + m_NumParticles = particleGrid->m_NumParticles; + m_NumConnections = particleGrid->m_NumConnections; MITK_INFO << "itkGibbsTrackingFilter: proposal acceptance: " << 100*m_ProposalAcceptance << "%"; MITK_INFO << "itkGibbsTrackingFilter: particles: " << m_NumParticles; MITK_INFO << "itkGibbsTrackingFilter: connections: " << m_NumConnections; - MITK_INFO << "itkGibbsTrackingFilter: progress: " << 100*(float)step/m_Steps << "%"; + MITK_INFO << "itkGibbsTrackingFilter: progress: " << 100*(float)m_CurrentStep/m_Steps << "%"; - if (m_BuildFibers) + float temperature = m_TempStart * exp(alpha*(((1.0)*m_CurrentStep)/((1.0)*m_Steps))); + sampler->SetTemperature(temperature); + + for (unsigned long i=0; iGetNumberOfLines(); - m_BuildFibers = false; - } - } + if (m_AbortTracking) + break; + sampler->MakeProposal(); + if (m_BuildFibers) + { + m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/m_NumIt; + m_NumParticles = particleGrid->m_NumParticles; + m_NumConnections = particleGrid->m_NumConnections; + + FiberBuilder fiberBuilder(particleGrid, m_MaskImage); + m_FiberPolyData = fiberBuilder.iterate(m_FiberLength); + m_NumAcceptedFibers = m_FiberPolyData->GetNumberOfLines(); + m_BuildFibers = false; + } + } + } FiberBuilder fiberBuilder(particleGrid, m_MaskImage); m_FiberPolyData = fiberBuilder.iterate(m_FiberLength); m_NumAcceptedFibers = m_FiberPolyData->GetNumberOfLines(); delete interpolator; delete sampler; delete particleGrid; m_AbortTracking = true; m_BuildFibers = false; MITK_INFO << "itkGibbsTrackingFilter: done generate data"; } template< class ItkQBallImageType > SphereInterpolator* GibbsTrackingFilter< ItkQBallImageType >::LoadSphereInterpolator() { QString applicationDir = QCoreApplication::applicationDirPath(); applicationDir.append("/"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); applicationDir.append("../"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); applicationDir.append("../../"); mitk::StandardFileLocations::GetInstance()->AddDirectoryForSearch( applicationDir.toStdString().c_str(), false ); std::string lutPath = mitk::StandardFileLocations::GetInstance()->FindFile("FiberTrackingLUTBaryCoords.bin"); ifstream BaryCoords; BaryCoords.open(lutPath.c_str(), ios::in | ios::binary); float* coords; if (BaryCoords.is_open()) { float tmp; coords = new float [1630818]; BaryCoords.seekg (0, ios::beg); for (int i=0; i<1630818; i++) { BaryCoords.read((char *)&tmp, sizeof(tmp)); coords[i] = tmp; } BaryCoords.close(); } else { MITK_INFO << "itkGibbsTrackingFilter: unable to open barycoords file"; m_AbortTracking = true; } ifstream Indices; lutPath = mitk::StandardFileLocations::GetInstance()->FindFile("FiberTrackingLUTIndices.bin"); Indices.open(lutPath.c_str(), ios::in | ios::binary); int* ind; if (Indices.is_open()) { int tmp; ind = new int [1630818]; Indices.seekg (0, ios::beg); for (int i=0; i<1630818; i++) { Indices.read((char *)&tmp, 4); ind[i] = tmp; } Indices.close(); } else { MITK_INFO << "itkGibbsTrackingFilter: unable to open indices file"; m_AbortTracking = true; } // initialize sphere interpolator with lookuptables return new SphereInterpolator(coords, ind, QBALL_ODFSIZE, 301, 0.5); } } diff --git a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h index 2e97bb4a63..d198de102b 100644 --- a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h +++ b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h @@ -1,161 +1,157 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef itkGibbsTrackingFilter_h #define itkGibbsTrackingFilter_h // MITK #include // ITK #include #include #include // VTK #include #include #include #include #include namespace itk{ template< class ItkQBallImageType > class GibbsTrackingFilter : public ProcessObject { public: typedef GibbsTrackingFilter Self; typedef ProcessObject Superclass; typedef SmartPointer< Self > Pointer; typedef SmartPointer< const Self > ConstPointer; itkNewMacro(Self) itkTypeMacro( GibbsTrackingFilter, ProcessObject ) typedef Image< DiffusionTensor3D, 3 > ItkTensorImage; typedef typename ItkQBallImageType::Pointer ItkQBallImageTypePointer; typedef Image< float, 3 > ItkFloatImageType; typedef vtkSmartPointer< vtkPolyData > FiberPolyDataType; itkSetMacro( TempStart, float ) itkGetMacro( TempStart, float ) itkSetMacro( TempEnd, float ) itkGetMacro( TempEnd, float ) itkSetMacro( NumIt, unsigned long ) itkGetMacro( NumIt, unsigned long ) itkSetMacro( ParticleWeight, float ) itkGetMacro( ParticleWeight, float ) /** width of particle sigma (std-dev of gaussian around center) **/ itkSetMacro( ParticleWidth, float ) itkGetMacro( ParticleWidth, float ) /** length of particle from midpoint to ends **/ itkSetMacro( ParticleLength, float ) itkGetMacro( ParticleLength, float ) itkSetMacro( ChempotConnection, float ) itkGetMacro( ChempotConnection, float ) - itkSetMacro( ChempotParticle, float ) - itkGetMacro( ChempotParticle, float ) - itkSetMacro( InexBalance, float ) itkGetMacro( InexBalance, float ) itkSetMacro( Chempot2, float ) itkGetMacro( Chempot2, float ) itkSetMacro( FiberLength, int ) itkGetMacro( FiberLength, int ) itkSetMacro( AbortTracking, bool ) itkGetMacro( AbortTracking, bool ) itkSetMacro( CurrentStep, unsigned long ) itkGetMacro( CurrentStep, unsigned long ) itkSetMacro( CurvatureHardThreshold, float) itkGetMacro( CurvatureHardThreshold, float) itkGetMacro(NumParticles, unsigned long) itkGetMacro(NumConnections, unsigned long) itkGetMacro(NumAcceptedFibers, int) itkGetMacro(ProposalAcceptance, float) itkGetMacro(Steps, unsigned int) // input data itkSetMacro(QBallImage, typename ItkQBallImageType::Pointer) itkSetMacro(MaskImage, ItkFloatImageType::Pointer) itkSetMacro(TensorImage, ItkTensorImage::Pointer) void GenerateData(); virtual void Update(){ this->GenerateData(); } FiberPolyDataType GetFiberBundle(); protected: GibbsTrackingFilter(); virtual ~GibbsTrackingFilter(); bool EstimateParticleWeight(); SphereInterpolator* LoadSphereInterpolator(); // Input Images typename ItkQBallImageType::Pointer m_QBallImage; typename ItkFloatImageType::Pointer m_MaskImage; typename ItkTensorImage::Pointer m_TensorImage; // Tracking parameters float m_TempStart; // Start temperature float m_TempEnd; // End temperature unsigned long m_NumIt; // Total number of iterations unsigned long m_CurrentStep; // current tracking step float m_ParticleWeight; // w (unitless) float m_ParticleWidth; //sigma (mm) float m_ParticleLength; // ell (mm) float m_ChempotConnection; // gross L (chemisches potential) - float m_ChempotParticle; // unbenutzt (immer null, wenn groesser dann insgesamt weniger teilchen) float m_InexBalance; // gewichtung zwischen den lambdas; -5 ... 5 -> nur intern ... nur extern,default 0 float m_Chempot2; // typischerweise 0 int m_FiberLength; bool m_AbortTracking; int m_NumAcceptedFibers; volatile bool m_BuildFibers; unsigned int m_Steps; float m_Memory; float m_ProposalAcceptance; float m_CurvatureHardThreshold; float m_Meanval_sq; bool m_DuplicateImage; FiberPolyDataType m_FiberPolyData; unsigned long m_NumParticles; unsigned long m_NumConnections; }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkGibbsTrackingFilter.cpp" #endif #endif