diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp index fc29808d57..3c7279f0fc 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.cpp @@ -1,415 +1,415 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkParticleGrid.h" #include #include using namespace mitk; -ParticleGrid::ParticleGrid(ItkFloatImageType* image, float particleLength) +ParticleGrid::ParticleGrid(ItkFloatImageType* image, float particleLength, int cellCapacity) { // initialize counters m_NumParticles = 0; m_NumConnections = 0; m_NumCellOverflows = 0; m_ParticleLength = particleLength; // define isotropic grid from voxel spacing and particle length float cellSize = 2*m_ParticleLength; m_GridSize[0] = image->GetLargestPossibleRegion().GetSize()[0]*image->GetSpacing()[0]/cellSize +1; m_GridSize[1] = image->GetLargestPossibleRegion().GetSize()[1]*image->GetSpacing()[1]/cellSize +1; m_GridSize[2] = image->GetLargestPossibleRegion().GetSize()[2]*image->GetSpacing()[2]/cellSize +1; m_GridScale[0] = 1/cellSize; m_GridScale[1] = 1/cellSize; m_GridScale[2] = 1/cellSize; - m_CellCapacity = 1024; // maximum number of particles per grid cell - m_ContainerCapacity = 100000; // initial particle container capacity + m_CellCapacity = cellCapacity; // maximum number of particles per grid cell + m_ContainerCapacity = 100000; // initial particle container capacity int numCells = m_GridSize[0]*m_GridSize[1]*m_GridSize[2]; // number of grid cells m_Particles.resize(m_ContainerCapacity); // allocate and initialize particles m_Grid.resize(numCells*m_CellCapacity, NULL); // allocate and initialize particle grid m_OccupationCount.resize(numCells, 0); // allocate and initialize occupation counter array m_NeighbourTracker.cellidx.resize(8, 0); // allocate and initialize neighbour tracker m_NeighbourTracker.cellidx_c.resize(8, 0); for (int i = 0;i < m_ContainerCapacity;i++) // initialize particle IDs m_Particles[i].ID = i; std::cout << "ParticleGrid: allocated " << (sizeof(Particle)*m_ContainerCapacity + sizeof(Particle*)*m_GridSize[0]*m_GridSize[1]*m_GridSize[2])/1048576 << "mb for " << m_ContainerCapacity/1000 << "k particles." << std::endl; } ParticleGrid::~ParticleGrid() { } // remove all particles void ParticleGrid::ResetGrid() { // initialize counters m_NumParticles = 0; m_NumConnections = 0; m_NumCellOverflows = 0; m_Particles.clear(); m_Grid.clear(); m_OccupationCount.clear(); m_NeighbourTracker.cellidx.clear(); m_NeighbourTracker.cellidx_c.clear(); int numCells = m_GridSize[0]*m_GridSize[1]*m_GridSize[2]; // number of grid cells m_Particles.resize(m_ContainerCapacity); // allocate and initialize particles m_Grid.resize(numCells*m_CellCapacity, NULL); // allocate and initialize particle grid m_OccupationCount.resize(numCells, 0); // allocate and initialize occupation counter array m_NeighbourTracker.cellidx.resize(8, 0); // allocate and initialize neighbour tracker m_NeighbourTracker.cellidx_c.resize(8, 0); for (int i = 0;i < m_ContainerCapacity;i++) // initialize particle IDs m_Particles[i].ID = i; } bool ParticleGrid::ReallocateGrid() { std::cout << "ParticleGrid: reallocating ..." << std::endl; int new_capacity = m_ContainerCapacity + 100000; // increase container capacity by 100k particles try { m_Particles.resize(new_capacity); // reallocate particles for (int i = 0; i R) { if (m_NumParticles >= m_ContainerCapacity) { if (!ReallocateGrid()) return NULL; } int xint = int(R[0]*m_GridScale[0]); if (xint < 0) return NULL; if (xint >= m_GridSize[0]) return NULL; int yint = int(R[1]*m_GridScale[1]); if (yint < 0) return NULL; if (yint >= m_GridSize[1]) return NULL; int zint = int(R[2]*m_GridScale[2]); if (zint < 0) return NULL; if (zint >= m_GridSize[2]) return NULL; int idx = xint + m_GridSize[0]*(yint + m_GridSize[1]*zint); if (m_OccupationCount[idx] < m_CellCapacity) { Particle *p = &(m_Particles[m_NumParticles]); p->pos = R; p->mID = -1; p->pID = -1; m_NumParticles++; p->gridindex = m_CellCapacity*idx + m_OccupationCount[idx]; m_Grid[p->gridindex] = p; m_OccupationCount[idx]++; return p; } else { m_NumCellOverflows++; return NULL; } } bool ParticleGrid::TryUpdateGrid(int k) { Particle* p = &(m_Particles[k]); int xint = int(p->pos[0]*m_GridScale[0]); if (xint < 0) return false; if (xint >= m_GridSize[0]) return false; int yint = int(p->pos[1]*m_GridScale[1]); if (yint < 0) return false; if (yint >= m_GridSize[1]) return false; int zint = int(p->pos[2]*m_GridScale[2]); if (zint < 0) return false; if (zint >= m_GridSize[2]) return false; int idx = xint + m_GridSize[0]*(yint+ zint*m_GridSize[1]); int cellidx = p->gridindex/m_CellCapacity; if (idx != cellidx) // cell has changed { if (m_OccupationCount[idx] < m_CellCapacity) { // remove from old position in grid; int grdindex = p->gridindex; m_Grid[grdindex] = m_Grid[cellidx*m_CellCapacity + m_OccupationCount[cellidx]-1]; m_Grid[grdindex]->gridindex = grdindex; m_OccupationCount[cellidx]--; // insert at new position in grid p->gridindex = idx*m_CellCapacity + m_OccupationCount[idx]; m_Grid[p->gridindex] = p; m_OccupationCount[idx]++; return true; } else { m_NumCellOverflows++; return false; } } return true; } void ParticleGrid::RemoveParticle(int k) { Particle* p = &(m_Particles[k]); int gridIndex = p->gridindex; int cellIdx = gridIndex/m_CellCapacity; int idx = gridIndex%m_CellCapacity; // remove pending connections if (p->mID != -1) DestroyConnection(p,-1); if (p->pID != -1) DestroyConnection(p,+1); // remove from grid if (idx < m_OccupationCount[cellIdx]-1) { m_Grid[gridIndex] = m_Grid[cellIdx*m_CellCapacity+m_OccupationCount[cellIdx]-1]; m_Grid[cellIdx*m_CellCapacity+m_OccupationCount[cellIdx]-1] = NULL; m_Grid[gridIndex]->gridindex = gridIndex; } m_OccupationCount[cellIdx]--; // remove from container if (k < m_NumParticles-1) { Particle* last = &m_Particles[m_NumParticles-1]; // last particle // update connections of last particle because its index is changing if (last->mID!=-1) { if ( m_Particles[last->mID].mID == m_NumParticles-1 ) m_Particles[last->mID].mID = k; else if ( m_Particles[last->mID].pID == m_NumParticles-1 ) m_Particles[last->mID].pID = k; } if (last->pID!=-1) { if ( m_Particles[last->pID].mID == m_NumParticles-1 ) m_Particles[last->pID].mID = k; else if ( m_Particles[last->pID].pID == m_NumParticles-1 ) m_Particles[last->pID].pID = k; } m_Particles[k] = m_Particles[m_NumParticles-1]; // move very last particle to empty slot m_Particles[m_NumParticles-1].ID = m_NumParticles-1; // update ID of removed particle to match the index m_Particles[k].ID = k; // update ID of moved particle m_Grid[m_Particles[k].gridindex] = &m_Particles[k]; // update address of moved particle } m_NumParticles--; } void ParticleGrid::ComputeNeighbors(vnl_vector_fixed &R) { float xfrac = R[0]*m_GridScale[0]; float yfrac = R[1]*m_GridScale[1]; float zfrac = R[2]*m_GridScale[2]; int xint = int(xfrac); int yint = int(yfrac); int zint = int(zfrac); int dx = -1; if (xfrac-xint > 0.5) dx = 1; if (xint <= 0) { xint = 0; dx = 1; } if (xint >= m_GridSize[0]-1) { xint = m_GridSize[0]-1; dx = -1; } int dy = -1; if (yfrac-yint > 0.5) dy = 1; if (yint <= 0) {yint = 0; dy = 1; } if (yint >= m_GridSize[1]-1) {yint = m_GridSize[1]-1; dy = -1;} int dz = -1; if (zfrac-zint > 0.5) dz = 1; if (zint <= 0) {zint = 0; dz = 1; } if (zint >= m_GridSize[2]-1) {zint = m_GridSize[2]-1; dz = -1;} m_NeighbourTracker.cellidx[0] = xint + m_GridSize[0]*(yint+zint*m_GridSize[1]); m_NeighbourTracker.cellidx[1] = m_NeighbourTracker.cellidx[0] + dx; m_NeighbourTracker.cellidx[2] = m_NeighbourTracker.cellidx[1] + dy*m_GridSize[0]; m_NeighbourTracker.cellidx[3] = m_NeighbourTracker.cellidx[2] - dx; m_NeighbourTracker.cellidx[4] = m_NeighbourTracker.cellidx[0] + dz*m_GridSize[0]*m_GridSize[1]; m_NeighbourTracker.cellidx[5] = m_NeighbourTracker.cellidx[4] + dx; m_NeighbourTracker.cellidx[6] = m_NeighbourTracker.cellidx[5] + dy*m_GridSize[0]; m_NeighbourTracker.cellidx[7] = m_NeighbourTracker.cellidx[6] - dx; m_NeighbourTracker.cellidx_c[0] = m_CellCapacity*m_NeighbourTracker.cellidx[0]; m_NeighbourTracker.cellidx_c[1] = m_CellCapacity*m_NeighbourTracker.cellidx[1]; m_NeighbourTracker.cellidx_c[2] = m_CellCapacity*m_NeighbourTracker.cellidx[2]; m_NeighbourTracker.cellidx_c[3] = m_CellCapacity*m_NeighbourTracker.cellidx[3]; m_NeighbourTracker.cellidx_c[4] = m_CellCapacity*m_NeighbourTracker.cellidx[4]; m_NeighbourTracker.cellidx_c[5] = m_CellCapacity*m_NeighbourTracker.cellidx[5]; m_NeighbourTracker.cellidx_c[6] = m_CellCapacity*m_NeighbourTracker.cellidx[6]; m_NeighbourTracker.cellidx_c[7] = m_CellCapacity*m_NeighbourTracker.cellidx[7]; m_NeighbourTracker.cellcnt = 0; m_NeighbourTracker.pcnt = 0; } Particle* ParticleGrid::GetNextNeighbor() { if (m_NeighbourTracker.pcnt < m_OccupationCount[m_NeighbourTracker.cellidx[m_NeighbourTracker.cellcnt]]) { return m_Grid[m_NeighbourTracker.cellidx_c[m_NeighbourTracker.cellcnt] + (m_NeighbourTracker.pcnt++)]; } else { for(;;) { m_NeighbourTracker.cellcnt++; if (m_NeighbourTracker.cellcnt >= 8) return 0; if (m_OccupationCount[m_NeighbourTracker.cellidx[m_NeighbourTracker.cellcnt]] > 0) break; } m_NeighbourTracker.pcnt = 1; return m_Grid[m_NeighbourTracker.cellidx_c[m_NeighbourTracker.cellcnt]]; } } void ParticleGrid::CreateConnection(Particle *P1,int ep1, Particle *P2, int ep2) { if (ep1 == -1) P1->mID = P2->ID; else P1->pID = P2->ID; if (ep2 == -1) P2->mID = P1->ID; else P2->pID = P1->ID; m_NumConnections++; } void ParticleGrid::DestroyConnection(Particle *P1,int ep1, Particle *P2, int ep2) { if (ep1 == -1) P1->mID = -1; else P1->pID = -1; if (ep2 == -1) P2->mID = -1; else P2->pID = -1; m_NumConnections--; } void ParticleGrid::DestroyConnection(Particle *P1,int ep1) { Particle *P2 = 0; if (ep1 == 1) { P2 = &m_Particles[P1->pID]; P1->pID = -1; } else { P2 = &m_Particles[P1->mID]; P1->mID = -1; } if (P2->mID == P1->ID) P2->mID = -1; else P2->pID = -1; m_NumConnections--; } bool ParticleGrid::CheckConsistency() { for (int i=0; iID != i) { std::cout << "Particle ID error!" << std::endl; return false; } if (p->mID!=-1) { Particle* p2 = &m_Particles[p->mID]; if (p2->mID!=p->ID && p2->pID!=p->ID) { std::cout << "Connection inconsistent!" << std::endl; return false; } } if (p->pID!=-1) { Particle* p2 = &m_Particles[p->pID]; if (p2->mID!=p->ID && p2->pID!=p->ID) { std::cout << "Connection inconsistent!" << std::endl; return false; } } } return true; } diff --git a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h index 64d1ff83a5..4eb2e596f5 100644 --- a/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h +++ b/Modules/DiffusionImaging/Tractography/GibbsTracking/mitkParticleGrid.h @@ -1,121 +1,121 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef _PARTICLEGRID #define _PARTICLEGRID // MITK #include "MitkDiffusionImagingExports.h" #include // ITK #include namespace mitk { class MitkDiffusionImaging_EXPORT ParticleGrid { public: typedef itk::Image< float, 3 > ItkFloatImageType; int m_NumParticles; // number of particles int m_NumConnections; // number of connections int m_NumCellOverflows; // number of cell overflows float m_ParticleLength; - ParticleGrid(ItkFloatImageType* image, float particleLength); + ParticleGrid(ItkFloatImageType* image, float particleLength, int cellCapacity); ~ParticleGrid(); Particle* GetParticle(int ID); Particle* NewParticle(vnl_vector_fixed R); bool TryUpdateGrid(int k); void RemoveParticle(int k); void ComputeNeighbors(vnl_vector_fixed &R); Particle* GetNextNeighbor(); void CreateConnection(Particle *P1,int ep1, Particle *P2, int ep2); void DestroyConnection(Particle *P1,int ep1, Particle *P2, int ep2); void DestroyConnection(Particle *P1,int ep1); bool CheckConsistency(); void ResetGrid(); protected: bool ReallocateGrid(); std::vector< Particle* > m_Grid; // the grid std::vector< Particle > m_Particles; // particle container std::vector< int > m_OccupationCount; // number of particles per grid cell int m_ContainerCapacity; // maximal number of particles vnl_vector_fixed< int, 3 > m_GridSize; // grid dimensions vnl_vector_fixed< float, 3 > m_GridScale; // scaling factor for grid int m_CellCapacity; // particle capacity of single cell in grid struct NeighborTracker // to run over the neighbors { std::vector< int > cellidx; std::vector< int > cellidx_c; int cellcnt; int pcnt; } m_NeighbourTracker; }; class MitkDiffusionImaging_EXPORT Track { public: std::vector< EndPoint > track; float m_Energy; float m_Probability; int m_Length; Track() { track.resize(1000); } ~Track(){} void clear() { m_Length = 0; m_Energy = 0; m_Probability = 1; } bool isequal(Track& t) { for (int i = 0; i < m_Length;i++) { if (track[i].p != t.track[i].p || track[i].ep != t.track[i].ep) return false; } return true; } }; } #endif diff --git a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp index 5597642336..a12951d4e6 100644 --- a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp +++ b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.cpp @@ -1,477 +1,477 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "itkGibbsTrackingFilter.h" // MITK #include #include #include #include #include //#include #include #include // ITK #include #include #include // MISC #include #include #include #include namespace itk{ template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::GibbsTrackingFilter(): m_StartTemperature(0.1), m_EndTemperature(0.001), m_Iterations(500000), m_ParticleWeight(0), m_ParticleWidth(0), m_ParticleLength(0), m_ConnectionPotential(10), m_InexBalance(0), m_ParticlePotential(0.2), m_MinFiberLength(10), m_AbortTracking(false), m_NumConnections(0), m_NumParticles(0), m_NumAcceptedFibers(0), m_CurrentStep(0), m_BuildFibers(false), m_Steps(10), m_ProposalAcceptance(0), m_CurvatureThreshold(0.7), m_DuplicateImage(true), m_RandomSeed(-1), m_LoadParameterFile(""), m_LutPath("") { } template< class ItkQBallImageType > GibbsTrackingFilter< ItkQBallImageType >::~GibbsTrackingFilter() { } // fill output fiber bundle datastructure template< class ItkQBallImageType > typename GibbsTrackingFilter< ItkQBallImageType >::FiberPolyDataType GibbsTrackingFilter< ItkQBallImageType >::GetFiberBundle() { if (!m_AbortTracking) { m_BuildFibers = true; while (m_BuildFibers){} } return m_FiberPolyData; } template< class ItkQBallImageType > void GibbsTrackingFilter< ItkQBallImageType > ::EstimateParticleWeight() { MITK_INFO << "GibbsTrackingFilter: estimating particle weight"; float minSpacing; if(m_QBallImage->GetSpacing()[0]GetSpacing()[1] && m_QBallImage->GetSpacing()[0]GetSpacing()[2]) minSpacing = m_QBallImage->GetSpacing()[0]; else if (m_QBallImage->GetSpacing()[1] < m_QBallImage->GetSpacing()[2]) minSpacing = m_QBallImage->GetSpacing()[1]; else minSpacing = m_QBallImage->GetSpacing()[2]; float m_ParticleLength = 1.5*minSpacing; float m_ParticleWidth = 0.5*minSpacing; // seed random generators Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen = Statistics::MersenneTwisterRandomVariateGenerator::New(); if (m_RandomSeed>-1) randGen->SetSeed(m_RandomSeed); else randGen->SetSeed(); // instantiate all necessary components SphereInterpolator* interpolator = new SphereInterpolator(m_LutPath); - ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength); + ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength, m_ParticleGridCellCapacity); GibbsEnergyComputer* encomp = new GibbsEnergyComputer(m_QBallImage, m_MaskImage, particleGrid, interpolator, randGen); // EnergyComputer* encomp = new EnergyComputer(m_QBallImage, m_MaskImage, particleGrid, interpolator, randGen); MetropolisHastingsSampler* sampler = new MetropolisHastingsSampler(particleGrid, encomp, randGen, m_CurvatureThreshold); float alpha = log(m_EndTemperature/m_StartTemperature); m_ParticleWeight = 0.01; int ppv = 0; // main loop int neededParts = 3000; while (ppvSetParameters(m_ParticleWeight,m_ParticleWidth,m_ConnectionPotential*m_ParticleLength*m_ParticleLength,m_CurvatureThreshold,m_InexBalance,m_ParticlePotential); for( int step = 0; step < 10; step++ ) { // update temperatur for simulated annealing process float temperature = m_StartTemperature * exp(alpha*(((1.0)*step)/((1.0)*10))); sampler->SetTemperature(temperature); for (unsigned long i=0; i<10000; i++) sampler->MakeProposal(); } ppv = particleGrid->m_NumParticles; particleGrid->ResetGrid(); } delete sampler; delete encomp; delete particleGrid; delete interpolator; MITK_INFO << "GibbsTrackingFilter: finished estimating particle weight"; } // perform global tracking template< class ItkQBallImageType > void GibbsTrackingFilter< ItkQBallImageType >::GenerateData() { TimeProbe preClock; preClock.Start(); // check if input is qball or tensor image and generate qball if necessary if (m_QBallImage.IsNull() && m_TensorImage.IsNotNull()) { TensorImageToQBallImageFilter::Pointer filter = TensorImageToQBallImageFilter::New(); filter->SetInput( m_TensorImage ); filter->Update(); m_QBallImage = filter->GetOutput(); } else if (m_DuplicateImage) // generate local working copy of QBall image (if not disabled) { typedef itk::ImageDuplicator< ItkQBallImageType > DuplicateFilterType; typename DuplicateFilterType::Pointer duplicator = DuplicateFilterType::New(); duplicator->SetInputImage( m_QBallImage ); duplicator->Update(); m_QBallImage = duplicator->GetOutput(); } // perform mean subtraction on odfs typedef ImageRegionIterator< ItkQBallImageType > InputIteratorType; InputIteratorType it(m_QBallImage, m_QBallImage->GetLargestPossibleRegion() ); it.GoToBegin(); while (!it.IsAtEnd()) { itk::OrientationDistributionFunction odf(it.Get().GetDataPointer()); float mean = odf.GetMeanValue(); odf -= mean; it.Set(odf.GetDataPointer()); ++it; } // check if mask image is given if it needs resampling PrepareMaskImage(); // load parameter file LoadParameters(); // prepare parameters float minSpacing; if(m_QBallImage->GetSpacing()[0]GetSpacing()[1] && m_QBallImage->GetSpacing()[0]GetSpacing()[2]) minSpacing = m_QBallImage->GetSpacing()[0]; else if (m_QBallImage->GetSpacing()[1] < m_QBallImage->GetSpacing()[2]) minSpacing = m_QBallImage->GetSpacing()[1]; else minSpacing = m_QBallImage->GetSpacing()[2]; if(m_ParticleLength == 0) m_ParticleLength = 1.5*minSpacing; if(m_ParticleWidth == 0) m_ParticleWidth = 0.5*minSpacing; if(m_ParticleWeight == 0) EstimateParticleWeight(); float alpha = log(m_EndTemperature/m_StartTemperature); m_Steps = m_Iterations/10000; if (m_Steps<10) m_Steps = 10; if (m_Steps>m_Iterations) { MITK_INFO << "GibbsTrackingFilter: not enough iterations!"; m_AbortTracking = true; } if (m_CurvatureThreshold < mitk::eps) m_CurvatureThreshold = 0; unsigned long singleIts = (unsigned long)((1.0*m_Iterations) / (1.0*m_Steps)); // seed random generators Statistics::MersenneTwisterRandomVariateGenerator::Pointer randGen = Statistics::MersenneTwisterRandomVariateGenerator::New(); if (m_RandomSeed>-1) randGen->SetSeed(m_RandomSeed); else randGen->SetSeed(); // load sphere interpolator to evaluate the ODFs SphereInterpolator* interpolator = new SphereInterpolator(m_LutPath); // initialize the actual tracking components (ParticleGrid, Metropolis Hastings Sampler and Energy Computer) - ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength); + ParticleGrid* particleGrid = new ParticleGrid(m_MaskImage, m_ParticleLength, m_ParticleGridCellCapacity); GibbsEnergyComputer* encomp = new GibbsEnergyComputer(m_QBallImage, m_MaskImage, particleGrid, interpolator, randGen); encomp->SetParameters(m_ParticleWeight,m_ParticleWidth,m_ConnectionPotential*m_ParticleLength*m_ParticleLength,m_CurvatureThreshold,m_InexBalance,m_ParticlePotential); MetropolisHastingsSampler* sampler = new MetropolisHastingsSampler(particleGrid, encomp, randGen, m_CurvatureThreshold); MITK_INFO << "----------------------------------------"; MITK_INFO << "Iterations: " << m_Iterations; MITK_INFO << "Steps: " << m_Steps; MITK_INFO << "Particle length: " << m_ParticleLength; MITK_INFO << "Particle width: " << m_ParticleWidth; MITK_INFO << "Particle weight: " << m_ParticleWeight; MITK_INFO << "Start temperature: " << m_StartTemperature; MITK_INFO << "End temperature: " << m_EndTemperature; MITK_INFO << "In/Ex balance: " << m_InexBalance; MITK_INFO << "Min. fiber length: " << m_MinFiberLength; MITK_INFO << "Curvature threshold: " << m_CurvatureThreshold; MITK_INFO << "Random seed: " << m_RandomSeed; MITK_INFO << "----------------------------------------"; // main loop preClock.Stop(); TimeProbe clock; clock.Start(); m_NumAcceptedFibers = 0; unsigned long counter = 1; for( m_CurrentStep = 1; m_CurrentStep <= m_Steps; m_CurrentStep++ ) { // update temperatur for simulated annealing process float temperature = m_StartTemperature * exp(alpha*(((1.0)*m_CurrentStep)/((1.0)*m_Steps))); sampler->SetTemperature(temperature); for (unsigned long i=0; iMakeProposal(); if (m_BuildFibers || (i==singleIts-1 && m_CurrentStep==m_Steps)) { m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/counter; m_NumParticles = particleGrid->m_NumParticles; m_NumConnections = particleGrid->m_NumConnections; FiberBuilder fiberBuilder(particleGrid, m_MaskImage); m_FiberPolyData = fiberBuilder.iterate(m_MinFiberLength); m_NumAcceptedFibers = m_FiberPolyData->GetNumberOfLines(); m_BuildFibers = false; } counter++; } m_ProposalAcceptance = (float)sampler->GetNumAcceptedProposals()/counter; m_NumParticles = particleGrid->m_NumParticles; m_NumConnections = particleGrid->m_NumConnections; MITK_INFO << "GibbsTrackingFilter: proposal acceptance: " << 100*m_ProposalAcceptance << "%"; MITK_INFO << "GibbsTrackingFilter: particles: " << m_NumParticles; MITK_INFO << "GibbsTrackingFilter: connections: " << m_NumConnections; MITK_INFO << "GibbsTrackingFilter: progress: " << 100*(float)m_CurrentStep/m_Steps << "%"; MITK_INFO << "GibbsTrackingFilter: cell overflows: " << particleGrid->m_NumCellOverflows; MITK_INFO << "----------------------------------------"; if (m_AbortTracking) break; } clock.Stop(); delete sampler; delete encomp; delete interpolator; delete particleGrid; m_AbortTracking = true; m_BuildFibers = false; int h = clock.GetTotal()/3600; int m = ((int)clock.GetTotal()%3600)/60; int s = (int)clock.GetTotal()%60; MITK_INFO << "GibbsTrackingFilter: finished gibbs tracking in " << h << "h, " << m << "m and " << s << "s"; m = (int)preClock.GetTotal()/60; s = (int)preClock.GetTotal()%60; MITK_INFO << "GibbsTrackingFilter: preparation of the data took " << m << "m and " << s << "s"; MITK_INFO << "GibbsTrackingFilter: " << m_NumAcceptedFibers << " fibers accepted"; SaveParameters(); } template< class ItkQBallImageType > void GibbsTrackingFilter< ItkQBallImageType >::PrepareMaskImage() { if(m_MaskImage.IsNull()) { MITK_INFO << "GibbsTrackingFilter: generating default mask image"; m_MaskImage = ItkFloatImageType::New(); m_MaskImage->SetSpacing( m_QBallImage->GetSpacing() ); m_MaskImage->SetOrigin( m_QBallImage->GetOrigin() ); m_MaskImage->SetDirection( m_QBallImage->GetDirection() ); m_MaskImage->SetRegions( m_QBallImage->GetLargestPossibleRegion() ); m_MaskImage->Allocate(); m_MaskImage->FillBuffer(1.0); } else if ( m_MaskImage->GetLargestPossibleRegion().GetSize()[0]!=m_QBallImage->GetLargestPossibleRegion().GetSize()[0] || m_MaskImage->GetLargestPossibleRegion().GetSize()[1]!=m_QBallImage->GetLargestPossibleRegion().GetSize()[1] || m_MaskImage->GetLargestPossibleRegion().GetSize()[2]!=m_QBallImage->GetLargestPossibleRegion().GetSize()[2] || m_MaskImage->GetSpacing()[0]!=m_QBallImage->GetSpacing()[0] || m_MaskImage->GetSpacing()[1]!=m_QBallImage->GetSpacing()[1] || m_MaskImage->GetSpacing()[2]!=m_QBallImage->GetSpacing()[2] ) { MITK_INFO << "GibbsTrackingFilter: resampling mask image"; typedef itk::ResampleImageFilter< ItkFloatImageType, ItkFloatImageType, float > ResamplerType; ResamplerType::Pointer resampler = ResamplerType::New(); resampler->SetOutputSpacing( m_QBallImage->GetSpacing() ); resampler->SetOutputOrigin( m_QBallImage->GetOrigin() ); resampler->SetOutputDirection( m_QBallImage->GetDirection() ); resampler->SetSize( m_QBallImage->GetLargestPossibleRegion().GetSize() ); resampler->SetInput( m_MaskImage ); resampler->SetDefaultPixelValue(1.0); resampler->Update(); m_MaskImage = resampler->GetOutput(); MITK_INFO << "GibbsTrackingFilter: resampling finished"; } } // load tracking paramters from xml file (.gtp) template< class ItkQBallImageType > bool GibbsTrackingFilter< ItkQBallImageType >::LoadParameters() { m_AbortTracking = true; try { if( m_LoadParameterFile.length()==0 ) { m_AbortTracking = false; return true; } MITK_INFO << "GibbsTrackingFilter: loading parameter file " << m_LoadParameterFile; TiXmlDocument doc( m_LoadParameterFile ); doc.LoadFile(); TiXmlHandle hDoc(&doc); TiXmlElement* pElem; TiXmlHandle hRoot(0); pElem = hDoc.FirstChildElement().Element(); hRoot = TiXmlHandle(pElem); pElem = hRoot.FirstChildElement("parameter_set").Element(); QString iterations(pElem->Attribute("iterations")); m_Iterations = iterations.toULong(); QString particleLength(pElem->Attribute("particle_length")); m_ParticleLength = particleLength.toFloat(); QString particleWidth(pElem->Attribute("particle_width")); m_ParticleWidth = particleWidth.toFloat(); QString partWeight(pElem->Attribute("particle_weight")); m_ParticleWeight = partWeight.toFloat(); QString startTemp(pElem->Attribute("temp_start")); m_StartTemperature = startTemp.toFloat(); QString endTemp(pElem->Attribute("temp_end")); m_EndTemperature = endTemp.toFloat(); QString inExBalance(pElem->Attribute("inexbalance")); m_InexBalance = inExBalance.toFloat(); QString fiberLength(pElem->Attribute("fiber_length")); m_MinFiberLength = fiberLength.toFloat(); QString curvThres(pElem->Attribute("curvature_threshold")); m_CurvatureThreshold = cos(curvThres.toFloat()*M_PI/180); m_AbortTracking = false; MITK_INFO << "GibbsTrackingFilter: parameter file loaded successfully"; return true; } catch(...) { MITK_INFO << "GibbsTrackingFilter: could not load parameter file"; return false; } } // save current tracking paramters to xml file (.gtp) template< class ItkQBallImageType > bool GibbsTrackingFilter< ItkQBallImageType >::SaveParameters() { try { if( m_SaveParameterFile.length()==0 ) { MITK_INFO << "GibbsTrackingFilter: no filename specified to save parameters"; return true; } MITK_INFO << "GibbsTrackingFilter: saving parameter file " << m_SaveParameterFile; TiXmlDocument documentXML; TiXmlDeclaration* declXML = new TiXmlDeclaration( "1.0", "", "" ); documentXML.LinkEndChild( declXML ); TiXmlElement* mainXML = new TiXmlElement("global_tracking_parameter_file"); mainXML->SetAttribute("file_version", "0.1"); documentXML.LinkEndChild(mainXML); TiXmlElement* paramXML = new TiXmlElement("parameter_set"); paramXML->SetAttribute("iterations", QString::number(m_Iterations).toStdString()); paramXML->SetAttribute("particle_length", QString::number(m_ParticleLength).toStdString()); paramXML->SetAttribute("particle_width", QString::number(m_ParticleWidth).toStdString()); paramXML->SetAttribute("particle_weight", QString::number(m_ParticleWeight).toStdString()); paramXML->SetAttribute("temp_start", QString::number(m_StartTemperature).toStdString()); paramXML->SetAttribute("temp_end", QString::number(m_EndTemperature).toStdString()); paramXML->SetAttribute("inexbalance", QString::number(m_InexBalance).toStdString()); paramXML->SetAttribute("fiber_length", QString::number(m_MinFiberLength).toStdString()); paramXML->SetAttribute("curvature_threshold", QString::number(m_CurvatureThreshold).toStdString()); mainXML->LinkEndChild(paramXML); QString filename(m_SaveParameterFile.c_str()); if(!filename.endsWith(".gtp")) filename += ".gtp"; documentXML.SaveFile( filename.toStdString() ); MITK_INFO << "GibbsTrackingFilter: parameter file saved successfully"; return true; } catch(...) { MITK_INFO << "GibbsTrackingFilter: could not save parameter file"; return false; } } } diff --git a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h index 658d1650d5..70e833477d 100644 --- a/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h +++ b/Modules/DiffusionImaging/Tractography/itkGibbsTrackingFilter.h @@ -1,145 +1,148 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef itkGibbsTrackingFilter_h #define itkGibbsTrackingFilter_h // MITK #include // ITK #include #include #include #include // VTK #include #include #include #include #include namespace itk{ template< class ItkQBallImageType > class GibbsTrackingFilter : public ProcessObject { public: typedef GibbsTrackingFilter Self; typedef ProcessObject Superclass; typedef SmartPointer< Self > Pointer; typedef SmartPointer< const Self > ConstPointer; itkNewMacro(Self) itkTypeMacro( GibbsTrackingFilter, ProcessObject ) typedef Image< DiffusionTensor3D, 3 > ItkTensorImage; typedef typename ItkQBallImageType::Pointer ItkQBallImageTypePointer; typedef Image< float, 3 > ItkFloatImageType; typedef vtkSmartPointer< vtkPolyData > FiberPolyDataType; // parameter setter itkSetMacro( StartTemperature, float ) itkSetMacro( EndTemperature, float ) itkSetMacro( Iterations, unsigned long ) itkSetMacro( ParticleWeight, float ) itkSetMacro( ParticleWidth, float ) itkSetMacro( ParticleLength, float ) itkSetMacro( ConnectionPotential, float ) itkSetMacro( InexBalance, float ) itkSetMacro( ParticlePotential, float ) itkSetMacro( MinFiberLength, int ) itkSetMacro( AbortTracking, bool ) itkSetMacro( CurvatureThreshold, float) itkSetMacro( DuplicateImage, bool ) itkSetMacro( RandomSeed, int ) itkSetMacro( LoadParameterFile, std::string ) itkSetMacro( SaveParameterFile, std::string ) itkSetMacro( LutPath, std::string ) // getter itkGetMacro( ParticleWeight, float ) itkGetMacro( ParticleWidth, float ) itkGetMacro( ParticleLength, float ) itkGetMacro( CurrentStep, unsigned long ) itkGetMacro( NumParticles, int ) itkGetMacro( NumConnections, int ) itkGetMacro( NumAcceptedFibers, int ) itkGetMacro( ProposalAcceptance, float ) itkGetMacro( Steps, unsigned int) // input data itkSetMacro(QBallImage, typename ItkQBallImageType::Pointer) itkSetMacro(MaskImage, ItkFloatImageType::Pointer) itkSetMacro(TensorImage, ItkTensorImage::Pointer) void GenerateData(); virtual void Update(){ this->GenerateData(); } FiberPolyDataType GetFiberBundle(); protected: GibbsTrackingFilter(); virtual ~GibbsTrackingFilter(); void EstimateParticleWeight(); void PrepareMaskImage(); bool LoadParameters(); bool SaveParameters(); // Input Images typename ItkQBallImageType::Pointer m_QBallImage; typename ItkFloatImageType::Pointer m_MaskImage; typename ItkTensorImage::Pointer m_TensorImage; // Tracking parameters float m_StartTemperature; // Start temperature float m_EndTemperature; // End temperature unsigned long m_Iterations; // Total number of iterations unsigned long m_CurrentStep; // current tracking step float m_ParticleWeight; // w (unitless) float m_ParticleWidth; // sigma (mm) float m_ParticleLength; // l (mm) float m_ConnectionPotential; // gross L (chemisches potential, default 10) float m_InexBalance; // gewichtung zwischen den lambdas; -5 ... 5 -> nur intern ... nur extern,default 0 float m_ParticlePotential; // default 0.2 int m_MinFiberLength; // discard all fibers shortan than the specified length in mm bool m_AbortTracking; // set flag to abort tracking int m_NumAcceptedFibers; // number of reconstructed fibers generated by the FiberBuilder volatile bool m_BuildFibers; // set flag to generate fibers from particle grid unsigned int m_Steps; // number of temperature decrease steps float m_ProposalAcceptance; // proposal acceptance rate (0-1) float m_CurvatureThreshold; // curvature threshold in radians (1 -> no curvature is accepted, -1 all curvature angles are accepted) bool m_DuplicateImage; // generates a working copy of the qball image so that the original image won't be changed by the mean subtraction int m_NumParticles; // current number of particles in grid int m_NumConnections; // current number of connections between particles in grid int m_RandomSeed; // seed value for random generator (-1 for standard seeding) std::string m_LoadParameterFile; // filename of parameter file (reader) std::string m_SaveParameterFile; // filename of parameter file (writer) std::string m_LutPath; // path to lookuptables used by the sphere interpolator FiberPolyDataType m_FiberPolyData; // container for reconstructed fibers + + //Constant values + static const int m_ParticleGridCellCapacity = 1024; }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkGibbsTrackingFilter.cpp" #endif #endif