diff --git a/Modules/Classification/CLCore/include/mitkConfigFileReader.h b/Modules/Classification/CLCore/include/mitkConfigFileReader.h index b033a5bf29..16685e479d 100644 --- a/Modules/Classification/CLCore/include/mitkConfigFileReader.h +++ b/Modules/Classification/CLCore/include/mitkConfigFileReader.h @@ -1,206 +1,214 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef ConfigFileReader_h #define ConfigFileReader_h #include #include #include #include #include +#include class ConfigFileReader { protected: typedef std::map ContentType; typedef std::map > ListContentType; ContentType m_ConfigContent; ListContentType m_ListContent; std::map m_ListSize; std::string Trim(std::string const& source, char const * delim = " \t\r\n") { std::string result(source); std::string::size_type index = result.find_last_not_of(delim); if(index != std::string::npos) result.erase(++index); index = result.find_first_not_of(delim); if(index != std::string::npos) result.erase(0, index); else result.erase(); return result; } std::string RemoveComment(std::string const& source, char const * delim = "#;") { std::string result(source); std::string::size_type index = result.find_first_of(delim); if(index != std::string::npos) result.erase(++index); return Trim(result); } std::string ListIndex(std::string const& section, unsigned int index) const { std::stringstream stream; stream << section << "/" << index; std::string result = stream.str(); std::transform(result.begin(), result.end(), result.begin(), ::tolower); return result; } std::string ContentIndex(std::string const& section, std::string const& entry) const { std::string result = section + '/' + entry; std::transform(result.begin(), result.end(), result.begin(), ::tolower); return result; } std::string ListSizeIndex(std::string const& section) const { std::string result = section; std::transform(result.begin(), result.end(), result.begin(), ::tolower); return result; } public: ConfigFileReader(std::string const& configFile) { ReadFile (configFile); } void ReadFile(std::string const& filePath) { std::ifstream file(filePath.c_str()); ReadStream(file); file.close(); } void ReadStream (std::istream& stream) { std::string line; std::string name; std::string value; std::string inSection; bool inConfigSection = true; int posEqual; while (std::getline(stream,line)) { line = RemoveComment(line, "#"); if (! line.length()) continue; if (line[0] == '[') { inConfigSection = true; inSection = Trim(line.substr(1,line.find(']')-1)); continue; } if(line[0] == '{') { std::string address = Trim(line.substr(1,line.find('}')-1)); inSection = ListIndex(address, ListSize(address,0)); m_ListSize[ListSizeIndex(address)]++; inConfigSection = false; continue; } if (inConfigSection) { posEqual=line.find('='); name = Trim(line.substr(0,posEqual)); value = Trim(line.substr(posEqual+1)); m_ConfigContent[ContentIndex(inSection, name)]=value; } else { m_ListContent[inSection].push_back(Trim(line)); } } } std::string Value(std::string const& section, std::string const& entry) const { std::string index = ContentIndex(section,entry); if (m_ConfigContent.find(index) == m_ConfigContent.end()) - throw "Entry doesn't exist " + section + entry; + throw std::string("Entry doesn't exist " + section +"::"+ entry); + std::cout << section << "::" << entry << m_ConfigContent.find(index)->second << std::endl; return m_ConfigContent.find(index)->second; } std::string Value(const std::string & section, const std::string & entry, const std::string& standard) { try { return Value(section, entry); } - catch (const char *) { + catch (const std::string) { m_ConfigContent[ContentIndex(section, entry)] = standard; + std::cout << section << "::" << entry << standard << " (default)" << std::endl; return standard; } } int IntValue(const std::string & section, const std::string & entry) const { int result; std::stringstream stream (Value(section, entry)); stream >> result; return result; } int IntValue(const std::string & section, const std::string & entry, int standard) { try { return IntValue(section, entry); } - catch (const char *) { + catch (const std::string) { std::stringstream stream; stream << standard; m_ConfigContent[ContentIndex(section, entry)] = stream.str(); + std::cout << section << "::" << entry << stream.str() << "(default)" << std::endl; return standard; } } std::vector Vector(std::string const& section, unsigned int index) const { if (m_ListContent.find(ListIndex(section, index)) == m_ListContent.end()) - throw "Entry doesn't exist " + section; + { + throw std::string("Entry doesn't exist " + section); + } return m_ListContent.find(ListIndex(section,index))->second; } unsigned int ListSize(std::string const& section) const { if (m_ListSize.find(ListSizeIndex(section)) == m_ListSize.end()) - throw "Entry doesn't exist " + section; + { + throw std::string("Entry doesn't exist " + section); + } return m_ListSize.find(ListSizeIndex(section))->second; } unsigned int ListSize(std::string const& section, unsigned int standard) { try { return ListSize(ListSizeIndex(section)); } catch (...) { m_ListSize[ListSizeIndex(section)] = standard; return standard; } } }; #endif \ No newline at end of file diff --git a/Modules/Classification/CLImportanceWeighting/src/mitkGeneralizedLinearModel.cpp b/Modules/Classification/CLImportanceWeighting/src/mitkGeneralizedLinearModel.cpp index 2e5f0c7476..d8f64013ad 100644 --- a/Modules/Classification/CLImportanceWeighting/src/mitkGeneralizedLinearModel.cpp +++ b/Modules/Classification/CLImportanceWeighting/src/mitkGeneralizedLinearModel.cpp @@ -1,283 +1,283 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include #include #include #include #include #include #include #include static void _UpdateXMatrix(const vnl_matrix &xData, bool addConstant, v3p_netlib_doublereal *x); static void _UpdatePermXMatrix(const vnl_matrix &xData, bool addConstant, const vnl_vector &permutation, vnl_matrix &x); static void _InitMuEta(mitk::DistSimpleBinominal *dist, mitk::LogItLinking *link, const vnl_vector &yData, vnl_vector &mu, vnl_vector &eta); static void _FinalizeBVector(vnl_vector &b, vnl_vector &perm, int cols); double mitk::GeneralizedLinearModel::Predict( const vnl_vector &c) { LogItLinking link; double mu = 0; int cols = m_B.size(); for (int i = 0; i < cols; ++i) { if (!m_AddConstantColumn) mu += c(i)* m_B(i); else if ( i == 0) mu += 1* m_B(i); else mu += c(i-1)*m_B(i); } return link.InverseLink(mu); } vnl_vector mitk::GeneralizedLinearModel::Predict(const vnl_matrix &x) { LogItLinking link; vnl_vector mu(x.rows()); int cols = m_B.size(); for (unsigned int r = 0 ; r < mu.size(); ++r) { mu(r) = 0; for (int c = 0; c < cols; ++c) { if (!m_AddConstantColumn) mu(r) += x(r,c)*m_B(c); else if ( c == 0) mu(r) += m_B(c); else mu(r) += x(r,c-1)*m_B(c); } mu(r) = link.InverseLink(mu(r)); } return mu; } vnl_vector mitk::GeneralizedLinearModel::B() { return m_B; } vnl_vector mitk::GeneralizedLinearModel::ExpMu(const vnl_matrix &x) { LogItLinking link; vnl_vector mu(x.rows()); int cols = m_B.size(); for (unsigned int r = 0 ; r < mu.size(); ++r) { mu(r) = 0; for (int c = 0; c < cols; ++c) { if (!m_AddConstantColumn) mu(r) += x(r,c)*m_B(c); else if ( c == 0) mu(r) += m_B(c); else mu(r) += x(r,c-1)*m_B(c); } mu(r) = exp(-mu(r)); } return mu; } mitk::GeneralizedLinearModel::GeneralizedLinearModel(const vnl_matrix &xData, const vnl_vector &yData, bool addConstantColumn) : m_AddConstantColumn(addConstantColumn) { EstimatePermutation(xData); DistSimpleBinominal dist; LogItLinking link; vnl_matrix x; int rows = xData.rows(); int cols = m_Permutation.size(); vnl_vector mu(rows); vnl_vector eta(rows); vnl_vector weightedY(rows); vnl_matrix weightedX(rows, cols); vnl_vector oldB(cols); _UpdatePermXMatrix(xData, m_AddConstantColumn, m_Permutation, x); _InitMuEta(&dist, &link, yData, mu, eta); int iter = 0; int iterLimit = 100; double sqrtEps = sqrt(std::numeric_limits::epsilon()); double convertCriterion =1e-6; m_B.set_size(m_Permutation.size()); m_B.fill(0); while (iter <= iterLimit) { ++iter; oldB = m_B; - // Do Row-wise operation. No Vector oepration at this point. + // Do Row-wise operation. No Vector operation at this point. for (int r = 0; r qr(weightedX); m_B = qr.solve(weightedY); eta = x * m_B; for (int r = 0; r < rows; ++r) { mu(r) = link.InverseLink(eta(r)); } bool stayInLoop = false; for(int c= 0; c < cols; ++c) { stayInLoop |= std::abs( m_B(c) - oldB(c)) > convertCriterion * std::max(sqrtEps, std::abs(oldB(c))); } if (!stayInLoop) break; } _FinalizeBVector(m_B, m_Permutation, xData.cols()); } void mitk::GeneralizedLinearModel::EstimatePermutation(const vnl_matrix &xData) { v3p_netlib_integer rows = xData.rows(); v3p_netlib_integer cols = xData.cols(); if (m_AddConstantColumn) ++cols; v3p_netlib_doublereal *x = new v3p_netlib_doublereal[rows* cols]; _UpdateXMatrix(xData, m_AddConstantColumn, x); v3p_netlib_doublereal *qraux = new v3p_netlib_doublereal[cols]; v3p_netlib_integer *jpvt = new v3p_netlib_integer[cols]; std::fill_n(jpvt,cols,0); v3p_netlib_doublereal *work = new v3p_netlib_doublereal[cols]; std::fill_n(work,cols,0); v3p_netlib_integer job = 16; // Make a call to Lapack-DQRDC which does QR with permutation // Permutation is saved in JPVT. v3p_netlib_dqrdc_(x, &rows, &rows, &cols, qraux, jpvt, work, &job); double limit = std::abs(x[0]) * std::max(cols, rows) * std::numeric_limits::epsilon(); // Calculate the rank of the matrix int m_Rank = 0; for (int i = 0; i limit) ? 1 : 0; } // Create a permutation vector m_Permutation.set_size(m_Rank); for (int i = 0; i < m_Rank; ++i) { m_Permutation(i) = jpvt[i]-1; } delete[] x; delete[] qraux; delete[] jpvt; delete[] work; } // Copy a vnl-matrix to an c-array with row-wise representation. // Adds a constant column if required. static void _UpdateXMatrix(const vnl_matrix &xData, bool addConstant, v3p_netlib_doublereal *x) { v3p_netlib_integer rows = xData.rows(); v3p_netlib_integer cols = xData.cols(); if (addConstant) ++cols; for (int r=0; r < rows; ++r) { for (int c=0; c &xData, bool addConstant, const vnl_vector &permutation, vnl_matrix &x) { int rows = xData.rows(); int cols = permutation.size(); x.set_size(rows, cols); for (int r=0; r < rows; ++r) { for (int c=0; c &yData, vnl_vector &mu, vnl_vector &eta) { int rows = yData.size(); mu.set_size(rows); eta.set_size(rows); for (int r = 0; r < rows; ++r) { mu(r) = dist->Init(yData(r)); eta(r) = link->Link(mu(r)); } } // Inverts the permutation on a given b-vector. // Necessary to get a b-vector that match the original data static void _FinalizeBVector(vnl_vector &b, vnl_vector &perm, int cols) { vnl_vector tempB(cols+1); tempB.fill(0); for (unsigned int c = 0; c < perm.size(); ++c) { tempB(perm(c)) = b(c); } b = tempB; } diff --git a/Modules/Classification/CLMRUtilities/src/MRNormalization/mitkMRNormLinearStatisticBasedFilter.cpp b/Modules/Classification/CLMRUtilities/src/MRNormalization/mitkMRNormLinearStatisticBasedFilter.cpp index f6894b5700..7d4e7051da 100644 --- a/Modules/Classification/CLMRUtilities/src/MRNormalization/mitkMRNormLinearStatisticBasedFilter.cpp +++ b/Modules/Classification/CLMRUtilities/src/MRNormalization/mitkMRNormLinearStatisticBasedFilter.cpp @@ -1,147 +1,147 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include "mitkMRNormLinearStatisticBasedFilter.h" #include "mitkImageToItk.h" #include "mitkImageAccessByItk.h" #include "itkImageRegionIterator.h" // MITK #include #include #include // ITK #include #include mitk::MRNormLinearStatisticBasedFilter::MRNormLinearStatisticBasedFilter() : m_CenterMode(MRNormLinearStatisticBasedFilter::MEDIAN) { - this->SetNumberOfIndexedInputs(3); - this->SetNumberOfRequiredInputs(3); + this->SetNumberOfIndexedInputs(2); + this->SetNumberOfRequiredInputs(1); } mitk::MRNormLinearStatisticBasedFilter::~MRNormLinearStatisticBasedFilter() { } void mitk::MRNormLinearStatisticBasedFilter::SetMask( const mitk::Image* mask ) { // Process object is not const-correct so the const_cast is required here Image* nonconstMask = const_cast< mitk::Image * >( mask ); this->SetNthInput(1, nonconstMask ); } const mitk::Image* mitk::MRNormLinearStatisticBasedFilter::GetMask() const { return this->GetInput(1); } void mitk::MRNormLinearStatisticBasedFilter::GenerateInputRequestedRegion() { Superclass::GenerateInputRequestedRegion(); mitk::Image* input = const_cast< mitk::Image * > ( this->GetInput() ); input->SetRequestedRegionToLargestPossibleRegion(); } void mitk::MRNormLinearStatisticBasedFilter::GenerateOutputInformation() { mitk::Image::ConstPointer input = this->GetInput(); mitk::Image::Pointer output = this->GetOutput(); - itkDebugMacro(<<"GenerateOutputInformation()"); + itkDebugMacro(<< "GenerateOutputInformation()"); output->Initialize(input->GetPixelType(), *input->GetTimeGeometry()); output->SetPropertyList(input->GetPropertyList()->Clone()); } template < typename TPixel, unsigned int VImageDimension > void mitk::MRNormLinearStatisticBasedFilter::InternalComputeMask(itk::Image* itkImage) { // Define all necessary Types typedef itk::Image ImageType; typedef itk::Image MaskType; typedef itk::LabelStatisticsImageFilter FilterType; typedef itk::MinimumMaximumImageCalculator MinMaxComputerType; typename MaskType::Pointer itkMask0 = MaskType::New(); mitk::CastToItkImage(this->GetMask(), itkMask0); typename ImageType::Pointer outImage = ImageType::New(); mitk::CastToItkImage(this->GetOutput(0), outImage); typename MinMaxComputerType::Pointer minMaxComputer = MinMaxComputerType::New(); minMaxComputer->SetImage(itkImage); minMaxComputer->Compute(); typename FilterType::Pointer labelStatisticsImageFilter = FilterType::New(); labelStatisticsImageFilter->SetUseHistograms(true); labelStatisticsImageFilter->SetHistogramParameters(256, minMaxComputer->GetMinimum(),minMaxComputer->GetMaximum()); labelStatisticsImageFilter->SetInput( itkImage ); labelStatisticsImageFilter->SetLabelInput(itkMask0); labelStatisticsImageFilter->Update(); double median0 = labelStatisticsImageFilter->GetMedian(1); double mean0 = labelStatisticsImageFilter->GetMean(1); double stddev = labelStatisticsImageFilter->GetSigma(1); double modulo0=0; { auto histo = labelStatisticsImageFilter->GetHistogram(1); double maxFrequency=0; for (auto hIter=histo->Begin();hIter!=histo->End();++hIter) { if (maxFrequency < hIter.GetFrequency()) { maxFrequency = hIter.GetFrequency(); modulo0 = (histo->GetBinMin(0,hIter.GetInstanceIdentifier()) + histo->GetBinMax(0,hIter.GetInstanceIdentifier())) / 2.0; } } } double value0=0; switch (m_CenterMode) { case MRNormLinearStatisticBasedFilter::MEAN: value0=mean0; break; case MRNormLinearStatisticBasedFilter::MEDIAN: value0=median0; break; case MRNormLinearStatisticBasedFilter::MODE: value0=modulo0; break; } double offset = value0; double scaling = stddev; if (scaling < 0.0001) return; itk::ImageRegionIterator inIter(itkImage, itkImage->GetLargestPossibleRegion()); itk::ImageRegionIterator outIter(outImage, outImage->GetLargestPossibleRegion()); while (! inIter.IsAtEnd()) { TPixel value = inIter.Value(); outIter.Set((value - offset) / scaling); ++inIter; ++outIter; } } void mitk::MRNormLinearStatisticBasedFilter::GenerateData() { AccessByItk(GetInput(0),InternalComputeMask); } \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLGlobalImageFeatures.cpp b/Modules/Classification/CLMiniApps/CLGlobalImageFeatures.cpp index 4ea88341b6..b3ff9ccce2 100644 --- a/Modules/Classification/CLMiniApps/CLGlobalImageFeatures.cpp +++ b/Modules/Classification/CLMiniApps/CLGlobalImageFeatures.cpp @@ -1,214 +1,307 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkCLPolyToNrrd_cpp #define mitkCLPolyToNrrd_cpp #include "time.h" #include #include #include #include "mitkCommandLineParser.h" #include #include #include #include +#include +#include +#include + + +#include "itkNearestNeighborInterpolateImageFunction.h" +#include "itkResampleImageFilter.h" + + typedef itk::Image< double, 3 > FloatImageType; typedef itk::Image< unsigned char, 3 > MaskImageType; static std::vector splitDouble(std::string str, char delimiter) { std::vector internal; std::stringstream ss(str); // Turn the string into a stream. std::string tok; double val; while (std::getline(ss, tok, delimiter)) { std::stringstream s2(tok); s2 >> val; internal.push_back(val); } return internal; } +template +void +ResampleImage(itk::Image* itkImage, float resolution, mitk::Image::Pointer& newImage) +{ + typedef itk::Image ImageType; + typedef itk::ResampleImageFilter ResampleFilterType; + + typename ResampleFilterType::Pointer resampler = ResampleFilterType::New(); + auto spacing = itkImage->GetSpacing(); + auto size = itkImage->GetLargestPossibleRegion().GetSize(); + + for (int i = 0; i < VImageDimension; ++i) + { + size[i] = size[i] / (1.0*resolution)*(1.0*spacing[i])+1.0; + } + spacing.Fill(resolution); + + resampler->SetInput(itkImage); + resampler->SetSize(size); + resampler->SetOutputSpacing(spacing); + resampler->SetOutputOrigin(itkImage->GetOrigin()); + resampler->SetOutputDirection(itkImage->GetDirection()); + resampler->Update(); + + newImage->InitializeByItk(resampler->GetOutput()); + mitk::GrabItkImageMemory(resampler->GetOutput(), newImage); +} + + +static void +ResampleMask(mitk::Image::Pointer mask, mitk::Image::Pointer ref, mitk::Image::Pointer& newMask) +{ + typedef itk::NearestNeighborInterpolateImageFunction< MaskImageType> NearestNeighborInterpolateImageFunctionType; + typedef itk::ResampleImageFilter ResampleFilterType; + + NearestNeighborInterpolateImageFunctionType::Pointer nn_interpolator = NearestNeighborInterpolateImageFunctionType::New(); + MaskImageType::Pointer itkMoving = MaskImageType::New(); + MaskImageType::Pointer itkRef = MaskImageType::New(); + mitk::CastToItkImage(mask, itkMoving); + mitk::CastToItkImage(ref, itkRef); + + + ResampleFilterType::Pointer resampler = ResampleFilterType::New(); + resampler->SetInput(itkMoving); + resampler->SetReferenceImage(itkRef); + resampler->UseReferenceImageOn(); + resampler->SetInterpolator(nn_interpolator); + resampler->Update(); + + newMask->InitializeByItk(resampler->GetOutput()); + mitk::GrabItkImageMemory(resampler->GetOutput(), newMask); +} + int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setArgumentPrefix("--", "-"); // required params parser.addArgument("image", "i", mitkCommandLineParser::InputImage, "Input Image", "Path to the input VTK polydata", us::Any(), false); parser.addArgument("mask", "m", mitkCommandLineParser::InputImage, "Input Mask", "Mask Image that specifies the area over for the statistic, (Values = 1)", us::Any(), false); parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output text file", "Target file. The output statistic is appended to this file.", us::Any(), false); parser.addArgument("cooccurence","cooc",mitkCommandLineParser::String, "Use Co-occurence matrix", "calculates Co-occurence based features",us::Any()); parser.addArgument("volume","vol",mitkCommandLineParser::String, "Use Volume-Statistic", "calculates volume based features",us::Any()); parser.addArgument("run-length","rl",mitkCommandLineParser::String, "Use Co-occurence matrix", "calculates Co-occurence based features",us::Any()); parser.addArgument("first-order","fo",mitkCommandLineParser::String, "Use First Order Features", "calculates First order based features",us::Any()); parser.addArgument("header","head",mitkCommandLineParser::String,"Add Header (Labels) to output","",us::Any()); parser.addArgument("description","d",mitkCommandLineParser::String,"Text","Description that is added to the output",us::Any()); parser.addArgument("same-space", "sp", mitkCommandLineParser::String, "Bool", "Set the spacing of all images to equal. Otherwise an error will be thrown. ", us::Any()); + parser.addArgument("resample-mask", "rm", mitkCommandLineParser::Bool, "Bool", "Resamples the mask to the resolution of the input image ", us::Any()); + parser.addArgument("save-resample-mask", "srm", mitkCommandLineParser::String, "String", "If specified the resampled mask is saved to this path (if -rm is 1)", us::Any()); + parser.addArgument("fixed-isotropic", "fi", mitkCommandLineParser::Float, "Float", "Input image resampled to fixed isotropic resolution given in mm. Should be used with resample-mask ", us::Any()); parser.addArgument("direction", "dir", mitkCommandLineParser::String, "Int", "Allows to specify the direction for Cooc and RL. 0: All directions, 1: Only single direction (Test purpose), 2,3,4... Without dimension 0,1,2... ", us::Any()); // Miniapp Infos parser.setCategory("Classification Tools"); parser.setTitle("Global Image Feature calculator"); parser.setDescription("Calculates different global statistics for a given segmentation / image combination"); parser.setContributor("MBI"); std::map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) { return EXIT_FAILURE; } if ( parsedArgs.count("help") || parsedArgs.count("h")) { return EXIT_SUCCESS; } - MITK_INFO << "Version: "<< 1.3; + MITK_INFO << "Version: "<< 1.7; - bool useCooc = parsedArgs.count("cooccurence"); + //bool useCooc = parsedArgs.count("cooccurence"); + + bool resampleMask = false; + if (parsedArgs.count("resample-mask")) + { + resampleMask = us::any_cast(parsedArgs["resample-mask"]); + } mitk::Image::Pointer image = mitk::IOUtil::LoadImage(parsedArgs["image"].ToString()); mitk::Image::Pointer mask = mitk::IOUtil::LoadImage(parsedArgs["mask"].ToString()); + if (parsedArgs.count("fixed-isotropic")) + { + mitk::Image::Pointer newImage = mitk::Image::New(); + float resolution = us::any_cast(parsedArgs["fixed-isotropic"]); + AccessByItk_2(image, ResampleImage, resolution, newImage); + image = newImage; + } + + if (resampleMask) + { + mitk::Image::Pointer newMaskImage = mitk::Image::New(); + ResampleMask(mask, image, newMaskImage); + mask = newMaskImage; + if (parsedArgs.count("save-resample-mask")) + { + mitk::IOUtil::SaveImage(mask, parsedArgs["save-resample-mask"].ToString()); + } + } + + bool fixDifferentSpaces = parsedArgs.count("same-space"); if ( ! mitk::Equal(mask->GetGeometry(0)->GetOrigin(), image->GetGeometry(0)->GetOrigin())) { MITK_INFO << "Not equal Origins"; if (fixDifferentSpaces) { image->GetGeometry(0)->SetOrigin(mask->GetGeometry(0)->GetOrigin()); } else { return -1; } } if ( ! mitk::Equal(mask->GetGeometry(0)->GetSpacing(), image->GetGeometry(0)->GetSpacing())) { MITK_INFO << "Not equal Sapcings"; if (fixDifferentSpaces) { image->GetGeometry(0)->SetSpacing(mask->GetGeometry(0)->GetSpacing()); } else { return -1; } } int direction = 0; if (parsedArgs.count("direction")) { direction = splitDouble(parsedArgs["direction"].ToString(), ';')[0]; } mitk::AbstractGlobalImageFeature::FeatureListType stats; //////////////////////////////////////////////////////////////// - // CAlculate First Order Features + // Calculate First Order Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("first-order")) { MITK_INFO << "Start calculating first order statistics...."; mitk::GIFFirstOrderStatistics::Pointer firstOrderCalculator = mitk::GIFFirstOrderStatistics::New(); auto localResults = firstOrderCalculator->CalculateFeatures(image, mask); stats.insert(stats.end(), localResults.begin(), localResults.end()); MITK_INFO << "Finished calculating first order statistics...."; } //////////////////////////////////////////////////////////////// - // CAlculate Volume based Features + // Calculate Volume based Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("volume")) { MITK_INFO << "Start calculating volumetric ...."; mitk::GIFVolumetricStatistics::Pointer volCalculator = mitk::GIFVolumetricStatistics::New(); auto localResults = volCalculator->CalculateFeatures(image, mask); stats.insert(stats.end(), localResults.begin(), localResults.end()); MITK_INFO << "Finished calculating volumetric...."; } //////////////////////////////////////////////////////////////// - // CAlculate Co-occurence Features + // Calculate Co-occurence Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("cooccurence")) { auto ranges = splitDouble(parsedArgs["cooccurence"].ToString(),';'); - for (int i = 0; i < ranges.size(); ++i) + for (std::size_t i = 0; i < ranges.size(); ++i) { MITK_INFO << "Start calculating coocurence with range " << ranges[i] << "...."; mitk::GIFCooccurenceMatrix::Pointer coocCalculator = mitk::GIFCooccurenceMatrix::New(); coocCalculator->SetRange(ranges[i]); coocCalculator->SetDirection(direction); auto localResults = coocCalculator->CalculateFeatures(image, mask); stats.insert(stats.end(), localResults.begin(), localResults.end()); MITK_INFO << "Finished calculating coocurence with range " << ranges[i] << "...."; } } //////////////////////////////////////////////////////////////// - // CAlculate Run-Length Features + // Calculate Run-Length Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("run-length")) { auto ranges = splitDouble(parsedArgs["run-length"].ToString(),';'); - for (int i = 0; i < ranges.size(); ++i) + for (std::size_t i = 0; i < ranges.size(); ++i) { MITK_INFO << "Start calculating run-length with number of bins " << ranges[i] << "...."; mitk::GIFGrayLevelRunLength::Pointer calculator = mitk::GIFGrayLevelRunLength::New(); calculator->SetRange(ranges[i]); auto localResults = calculator->CalculateFeatures(image, mask); stats.insert(stats.end(), localResults.begin(), localResults.end()); MITK_INFO << "Finished calculating run-length with number of bins " << ranges[i] << "...."; } } - for (int i = 0; i < stats.size(); ++i) + for (std::size_t i = 0; i < stats.size(); ++i) { std::cout << stats[i].first << " - " << stats[i].second < +#include + +#define CONVERT_IMAGE(TYPE, DIM) itk::Image::Pointer itkImage = itk::Image::New(); \ + MITK_INFO << "Data Type for Conversion: "<< typeid(TYPE).name(); \ + mitk::CastToItkImage(image, itkImage); \ + mitk::CastToMitkImage(itkImage, outputImage) + +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + + parser.setTitle("Image Type Converter"); + parser.setCategory("Preprocessing Tools"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--","-"); + // Add command line argument names + parser.addArgument("help", "h",mitkCommandLineParser::Bool, "Help:", "Show this help text"); + parser.addArgument("input", "i", mitkCommandLineParser::InputDirectory, "Input file:", "Input file",us::Any(),false); + parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output file:", "Output file", us::Any(), false); + parser.addArgument("type", "t", mitkCommandLineParser::OutputFile, "Type definition:", "Define Scalar data type: int, uint, short, ushort, char, uchar, float, double", us::Any(), false); + + map parsedArgs = parser.parseArguments(argc, argv); + + if (parsedArgs.size()==0) + return EXIT_FAILURE; + + // Show a help message + if ( parsedArgs.count("help") || parsedArgs.count("h")) + { + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + std::string inputName = us::any_cast(parsedArgs["input"]); + std::string outputName = us::any_cast(parsedArgs["output"]); + std::string type = us::any_cast(parsedArgs["type"]); + + mitk::Image::Pointer image = mitk::IOUtil::LoadImage(inputName); + mitk::Image::Pointer outputImage = mitk::Image::New(); + + if (type.compare("int") == 0) + { + CONVERT_IMAGE(int, 3); + } + else if (type.compare("uint") == 0) + { + CONVERT_IMAGE(unsigned int, 3); + } + else if (type.compare("char") == 0) + { + CONVERT_IMAGE(char, 3); + } + else if (type.compare("uchar") == 0) + { + CONVERT_IMAGE(unsigned char, 3); + } + else if (type.compare("short") == 0) + { + CONVERT_IMAGE(short, 3); + } + else if (type.compare("ushort") == 0) + { + CONVERT_IMAGE(unsigned short, 3); + } + else if (type.compare("float") == 0) + { + CONVERT_IMAGE(float, 3); + } + else if (type.compare("double") == 0) + { + CONVERT_IMAGE(double, 3); + } + else + { + CONVERT_IMAGE(double, 3); + } + + + mitk::IOUtil::SaveImage(outputImage, outputName); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLMRNormalization.cpp b/Modules/Classification/CLMiniApps/CLMRNormalization.cpp index 755f5d6ed3..47ad30bf60 100644 --- a/Modules/Classification/CLMiniApps/CLMRNormalization.cpp +++ b/Modules/Classification/CLMiniApps/CLMRNormalization.cpp @@ -1,133 +1,134 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkCLPolyToNrrd_cpp #define mitkCLPolyToNrrd_cpp #include "time.h" #include #include #include #include "mitkCommandLineParser.h" #include "itkImageRegionIterator.h" // MITK #include #include #include #include #include // ITK #include #include typedef itk::Image< double, 3 > FloatImageType; typedef itk::Image< unsigned char, 3 > MaskImageType; int main(int argc, char* argv[]) { MITK_INFO << "Start"; mitkCommandLineParser parser; parser.setArgumentPrefix("--", "-"); // required params parser.addArgument("image", "i", mitkCommandLineParser::InputImage, "Input Image", "Path to the input VTK polydata", us::Any(), false); parser.addArgument("mode", "mode", mitkCommandLineParser::InputImage, "Normalisation mode", "1,2,3: Single Area normalization to Mean, Median, Mode, 4,5,6: Mean, Median, Mode of two regions. ", us::Any(), false); parser.addArgument("mask0", "m0", mitkCommandLineParser::InputImage, "Input Mask", "The median of the area covered by this mask will be set to 0", us::Any(), false); parser.addArgument("mask1", "m1", mitkCommandLineParser::InputImage, "Input Mask", "The median of the area covered by this mask will be set to 1", us::Any(), true); parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output Image", "Target file. The output statistic is appended to this file.", us::Any(), false); // Miniapp Infos parser.setCategory("Classification Tools"); parser.setTitle("MR Normalization Tool"); parser.setDescription("Normalizes a MR image. Sets the Median of the tissue covered by mask 0 to 0 and the median of the area covered by mask 1 to 1."); parser.setContributor("MBI"); std::map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) { return EXIT_FAILURE; } if ( parsedArgs.count("help") || parsedArgs.count("h")) { return EXIT_SUCCESS; } MITK_INFO << "Mode access"; - int mode = 5;//us::any_cast(parsedArgs["mode"]); + int mode =stoi(us::any_cast(parsedArgs["mode"])); + MITK_INFO << "Mode: " << mode; MITK_INFO << "Read images"; mitk::Image::Pointer mask1; mitk::Image::Pointer image = mitk::IOUtil::LoadImage(parsedArgs["image"].ToString()); mitk::Image::Pointer mask0 = mitk::IOUtil::LoadImage(parsedArgs["mask0"].ToString()); if (mode > 3) { mask1 = mitk::IOUtil::LoadImage(parsedArgs["mask1"].ToString()); } mitk::MRNormLinearStatisticBasedFilter::Pointer oneRegion = mitk::MRNormLinearStatisticBasedFilter::New(); mitk::MRNormTwoRegionsBasedFilter::Pointer twoRegion = mitk::MRNormTwoRegionsBasedFilter::New(); mitk::Image::Pointer output; - //oneRegion->SetInput(image); + oneRegion->SetInput(image); + oneRegion->SetMask(mask0); twoRegion->SetInput(image); - //oneRegion->SetMask(mask0); twoRegion->SetMask1(mask0); twoRegion->SetMask2(mask1); switch (mode) { case 1: oneRegion->SetCenterMode(mitk::MRNormLinearStatisticBasedFilter::MEAN); oneRegion->Update(); output=oneRegion->GetOutput(); break; case 2: oneRegion->SetCenterMode(mitk::MRNormLinearStatisticBasedFilter::MEDIAN); oneRegion->Update(); output=oneRegion->GetOutput(); break; case 3: oneRegion->SetCenterMode(mitk::MRNormLinearStatisticBasedFilter::MODE); oneRegion->Update(); output=oneRegion->GetOutput(); break; case 4: twoRegion->SetArea1(mitk::MRNormTwoRegionsBasedFilter::MEAN); twoRegion->SetArea2(mitk::MRNormTwoRegionsBasedFilter::MEAN); twoRegion->Update(); output=twoRegion->GetOutput(); break; case 5: twoRegion->SetArea1(mitk::MRNormTwoRegionsBasedFilter::MEDIAN); twoRegion->SetArea2(mitk::MRNormTwoRegionsBasedFilter::MEDIAN); twoRegion->Update(); output=twoRegion->GetOutput(); break; case 6: twoRegion->SetArea1(mitk::MRNormTwoRegionsBasedFilter::MODE); twoRegion->SetArea2(mitk::MRNormTwoRegionsBasedFilter::MODE); twoRegion->Update(); output=twoRegion->GetOutput(); break; } mitk::IOUtil::SaveImage(output, parsedArgs["output"].ToString()); return 0; } #endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLMultiForestPrediction.cpp b/Modules/Classification/CLMiniApps/CLMultiForestPrediction.cpp new file mode 100644 index 0000000000..38e1c8bbbe --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLMultiForestPrediction.cpp @@ -0,0 +1,253 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ +#ifndef mitkForest_cpp +#define mitkForest_cpp + +#include "time.h" +#include + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +// ----------------------- Forest Handling ---------------------- +#include + + +int main(int argc, char* argv[]) +{ + MITK_INFO << "Starting MITK_Forest Mini-App"; + double startTime = time(0); + + ////////////////////////////////////////////////////////////////////////////// + // Read Console Input Parameter + ////////////////////////////////////////////////////////////////////////////// + ConfigFileReader allConfig(argv[1]); + + bool readFile = true; + std::stringstream ss; + for (int i = 0; i < argc; ++i ) + { + MITK_INFO << "-----"<< argv[i]<<"------"; + if (readFile) + { + if (argv[i][0] == '+') + { + readFile = false; + continue; + } else + { + try + { + allConfig.ReadFile(argv[i]); + } + catch (std::exception &e) + { + MITK_INFO << e.what(); + } + } + } + else + { + std::string input = argv[i]; + std::replace(input.begin(), input.end(),'_',' '); + ss << input << std::endl; + } + } + allConfig.ReadStream(ss); + + try + { + ////////////////////////////////////////////////////////////////////////////// + // General + ////////////////////////////////////////////////////////////////////////////// + int currentRun = allConfig.IntValue("General","Run",0); + int doTraining = allConfig.IntValue("General","Do Training",1); + std::string forestPath = allConfig.Value("General","Forest Path"); + std::string trainingCollectionPath = allConfig.Value("General","Patient Collection"); + std::string testCollectionPath = allConfig.Value("General", "Patient Test Collection", trainingCollectionPath); + + ////////////////////////////////////////////////////////////////////////////// + // Read Default Classification + ////////////////////////////////////////////////////////////////////////////// + std::vector trainPatients = allConfig.Vector("Training Group",currentRun); + std::vector testPatients = allConfig.Vector("Test Group",currentRun); + std::vector modalities = allConfig.Vector("Modalities", 0); + std::vector outputFilter = allConfig.Vector("Output Filter", 0); + std::string trainMask = allConfig.Value("Data","Training Mask"); + std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask"); + std::string testMask = allConfig.Value("Data","Test Mask"); + std::string resultMask = allConfig.Value("Data", "Result Mask"); + std::string resultProb = allConfig.Value("Data", "Result Propability"); + std::string outputFolder = allConfig.Value("General","Output Folder"); + + std::string writeDataFilePath = allConfig.Value("Forest","File to write data to"); + + ////////////////////////////////////////////////////////////////////////////// + // Read Data Forest Parameter + ////////////////////////////////////////////////////////////////////////////// + int testSingleDataset = allConfig.IntValue("Data", "Test Single Dataset",0); + std::string singleDatasetName = allConfig.Value("Data", "Single Dataset Name", "none"); + std::vector forestVector = allConfig.Vector("Forests", 0); + + ////////////////////////////////////////////////////////////////////////////// + // Read Statistic Parameter + ////////////////////////////////////////////////////////////////////////////// + std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file"); + std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file"); + std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file"); + std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV"); + bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0); + std::vector labelGroupA = allConfig.Vector("LabelsA",0); + std::vector labelGroupB = allConfig.Vector("LabelsB",0); + + + std::ofstream timingFile; + timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app); + timingFile << statisticShortFileLabel << ";"; + std::time_t lastTimePoint; + time(&lastTimePoint); + + ////////////////////////////////////////////////////////////////////////////// + // Read Images + ////////////////////////////////////////////////////////////////////////////// + std::vector usedModalities; + for (int i = 0; i < modalities.size(); ++i) + { + usedModalities.push_back(modalities[i]); + } + usedModalities.push_back(trainMask); + usedModalities.push_back(completeTrainMask); + usedModalities.push_back(testMask); + usedModalities.push_back(statisticGoldStandard); + + // vtkSmartPointer colReader = vtkSmartPointer::New(); + mitk::CollectionReader* colReader = new mitk::CollectionReader(); + colReader->AddDataElementIds(trainPatients); + colReader->SetDataItemNames(usedModalities); + + if (testSingleDataset > 0) + { + testPatients.clear(); + testPatients.push_back(singleDatasetName); + } + colReader->ClearDataElementIds(); + colReader->AddDataElementIds(testPatients); + mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath); + + std::time_t now; + time(&now); + double seconds = std::difftime(now, lastTimePoint); + timingFile << seconds << ";"; + time(&lastTimePoint); + + + mitk::VigraRandomForestClassifier::Pointer forest = mitk::VigraRandomForestClassifier::New(); + MITK_INFO << "Convert Test data"; + auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection, modalities, testMask); + + for (int i = 0; i < forestVector.size(); ++i) + { + forest = dynamic_cast(mitk::IOUtil::Load(forestVector[i])[0].GetPointer()); + + time(&now); + seconds = std::difftime(now, lastTimePoint); + MITK_INFO << "Duration for Training: " << seconds; + timingFile << seconds << ";"; + time(&lastTimePoint); + + MITK_INFO << "Predict Test Data"; + auto testDataNewY = forest->Predict(testDataX); + auto testDataNewProb = forest->GetPointWiseProbabilities(); + + auto maxClassValue = testDataNewProb.cols(); + std::vector names; + for (int j = 0; j < maxClassValue; ++j) + { + std::string name = resultProb + std::to_string(j); + names.push_back(name); + } + + mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask); + mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask); + MITK_INFO << "Converted predicted data"; + + time(&now); + seconds = std::difftime(now, lastTimePoint); + timingFile << seconds << ";"; + time(&lastTimePoint); + + ////////////////////////////////////////////////////////////////////////////// + // Save results to folder + ////////////////////////////////////////////////////////////////////////////// + MITK_INFO << "Write Result to HDD"; + mitk::CollectionWriter::ExportCollectionToFolder(testCollection, + outputFolder + "/result_collection.xml", + outputFilter); + + MITK_INFO << "Calculate Statistic...."; + ////////////////////////////////////////////////////////////////////////////// + // Calculate and Print Statistic + ////////////////////////////////////////////////////////////////////////////// + std::ofstream statisticFile; + statisticFile.open(statisticFilePath.c_str(), std::ios::app); + std::ofstream sstatisticFile; + sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app); + + mitk::CollectionStatistic stat; + stat.SetCollection(testCollection); + stat.SetClassCount(5); + stat.SetGoldName(statisticGoldStandard); + stat.SetTestName(resultMask); + stat.SetMaskName(testMask); + mitk::BinaryValueminusOneToIndexMapper* mapper = new mitk::BinaryValueminusOneToIndexMapper; + stat.SetGroundTruthValueToIndexMapper(mapper); + stat.SetTestValueToIndexMapper(mapper); + stat.Update(); + //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel); + stat.Print(statisticFile, sstatisticFile, true, statisticShortFileLabel + "_"+std::to_string(i)); + statisticFile.close(); + delete mapper; + + time(&now); + seconds = std::difftime(now, lastTimePoint); + timingFile << seconds << std::endl; + time(&lastTimePoint); + timingFile.close(); + } + } + catch (std::string s) + { + MITK_INFO << s; + return 0; + } + catch (char* s) + { + MITK_INFO << s; + } + + return 0; +} + +#endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLN4.cpp b/Modules/Classification/CLMiniApps/CLN4.cpp new file mode 100644 index 0000000000..837ab399fd --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLN4.cpp @@ -0,0 +1,112 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include "mitkCommandLineParser.h" +#include "mitkIOUtil.h" +#include +#include "mitkCommandLineParser.h" +#include + +#include + +int main(int argc, char* argv[]) +{ + typedef itk::Image MaskImageType; + typedef itk::Image ImageType; + typedef itk::N4BiasFieldCorrectionImageFilter < ImageType, MaskImageType, ImageType > FilterType; + + mitkCommandLineParser parser; + parser.setTitle("N4 Bias Field Correction"); + parser.setCategory("Classification Command Tools"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + // Add command line argument names + parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Help:", "Show this help text"); + parser.addArgument("input", "i", mitkCommandLineParser::InputDirectory, "Input file:", "Input file", us::Any(), false); + parser.addArgument("mask", "m", mitkCommandLineParser::OutputFile, "Output file:", "Mask file", us::Any(), false); + parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output file:", "Output file", us::Any(), false); + + parser.addArgument("number-of-controllpoints", "noc", mitkCommandLineParser::Int, "Parameter", "The noc for the point grid size defining the B-spline estimate (default 4)", us::Any(), true); + parser.addArgument("number-of-fitting-levels", "nofl", mitkCommandLineParser::Int, "Parameter", "Number of fitting levels for the multi-scale approach (default 1)", us::Any(), true); + parser.addArgument("number-of-histogram-bins", "nofl", mitkCommandLineParser::Int, "Parameter", "number of bins defining the log input intensity histogram (default 200)", us::Any(), true); + parser.addArgument("spline-order", "so", mitkCommandLineParser::Int, "Parameter", "Define the spline order (default 3)", us::Any(), true); + parser.addArgument("winer-filter-noise", "wfn", mitkCommandLineParser::Float, "Parameter", "Noise estimate defining the Wiener filter (default 0.01)", us::Any(), true); + + + map parsedArgs = parser.parseArguments(argc, argv); + + // Show a help message + if (parsedArgs.count("help") || parsedArgs.count("h")) + { + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + MaskImageType::Pointer itkMsk = MaskImageType::New(); + mitk::Image::Pointer img = mitk::IOUtil::LoadImage(parsedArgs["mask"].ToString()); + mitk::CastToItkImage(img, itkMsk); + + ImageType::Pointer itkImage = ImageType::New(); + mitk::Image::Pointer img2 = mitk::IOUtil::LoadImage(parsedArgs["input"].ToString()); + mitk::CastToItkImage(img2, itkImage); + + FilterType::Pointer filter = FilterType::New(); + filter->SetInput(itkImage); + filter->SetMaskImage(itkMsk); + + + + if (parsedArgs.count("number-of-controllpoints") > 0) + { + int variable = us::any_cast(parsedArgs["maximum-iterations"]); + MITK_INFO << "Number of controll points: " << variable; + filter->SetNumberOfControlPoints(variable); + } + if (parsedArgs.count("number-of-fitting-levels") > 0) + { + int variable = us::any_cast(parsedArgs["number-of-fitting-levels"]); + MITK_INFO << "Number of fitting levels: " << variable; + filter->SetNumberOfFittingLevels(variable); + } + if (parsedArgs.count("number-of-histogram-bins") > 0) + { + int variable = us::any_cast(parsedArgs["number-of-histogram-bins"]); + MITK_INFO << "Number of histogram bins: " << variable; + filter->SetNumberOfHistogramBins(variable); + } + if (parsedArgs.count("spline-order") > 0) + { + int variable = us::any_cast(parsedArgs["spline-order"]); + MITK_INFO << "Spline Order " << variable; + filter->SetSplineOrder(variable); + } + if (parsedArgs.count("winer-filter-noise") > 0) + { + float variable = us::any_cast(parsedArgs["winer-filter-noise"]); + MITK_INFO << "Number of histogram bins: " << variable; + filter->SetWienerFilterNoise(variable); + } + + filter->Update(); + auto out = filter->GetOutput(); + mitk::Image::Pointer outImg = mitk::Image::New(); + mitk::CastToMitkImage(out, outImg); + mitk::IOUtil::SaveImage(outImg, parsedArgs["output"].ToString()); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLNrrdToPoly.cpp b/Modules/Classification/CLMiniApps/CLNrrdToPoly.cpp new file mode 100644 index 0000000000..3dead7c22d --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLNrrdToPoly.cpp @@ -0,0 +1,78 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ +#ifndef mitkCLPolyToNrrd_cpp +#define mitkCLPolyToNrrd_cpp + +#include "time.h" +#include +#include + +#include +#include "mitkCommandLineParser.h" + +// VTK +#include +#include +#include + +typedef itk::Image< double, 3 > FloatImageType; +typedef itk::Image< unsigned char, 3 > MaskImageType; + + +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + parser.setArgumentPrefix("--", "-"); + // required params + parser.addArgument("mask", "m", mitkCommandLineParser::InputImage, "Input Mask", "Mask Image that specifies the area over for the statistic, (Values = 1)", us::Any(), false); + parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output text file", "Target file. The output statistic is appended to this file.", us::Any(), false); + + // Miniapp Infos + parser.setCategory("Classification Tools"); + parser.setTitle("Segmentation to Mask"); + parser.setDescription("Estimates a Mesh from a segmentation"); + parser.setContributor("MBI"); + + map parsedArgs = parser.parseArguments(argc, argv); + + if (parsedArgs.size()==0) + { + return EXIT_FAILURE; + } + if ( parsedArgs.count("help") || parsedArgs.count("h")) + { + return EXIT_SUCCESS; + } + + MITK_INFO << "Version: "<< 1.0; + + mitk::Image::Pointer mask = mitk::IOUtil::LoadImage(parsedArgs["mask"].ToString()); + + + vtkSmartPointer image = mask->GetVtkImageData(); + image->SetOrigin(mask->GetGeometry()->GetOrigin()[0], mask->GetGeometry()->GetOrigin()[1], mask->GetGeometry()->GetOrigin()[2]); + vtkSmartPointer mesher = vtkSmartPointer::New(); + mesher->SetInputData(image); + mitk::Surface::Pointer surf = mitk::Surface::New(); + mesher->SetValue(0,0.5); + mesher->Update(); + surf->SetVtkPolyData(mesher->GetOutput()); + mitk::IOUtil::Save(surf, parsedArgs["output"].ToString()); + + return 0; +} + +#endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp b/Modules/Classification/CLMiniApps/CLPurfVoxelClassification.cpp similarity index 96% copy from Modules/Classification/CLMiniApps/CLVoxelClassification.cpp copy to Modules/Classification/CLMiniApps/CLPurfVoxelClassification.cpp index 708c7556f5..89d359991f 100644 --- a/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp +++ b/Modules/Classification/CLMiniApps/CLPurfVoxelClassification.cpp @@ -1,438 +1,447 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkForest_cpp #define mitkForest_cpp #include "time.h" #include #include #include #include #include #include #include #include #include #include #include // ----------------------- Forest Handling ---------------------- //#include #include //#include //#include //#include //#include // ----------------------- Point weighting ---------------------- //#include //#include //#include #include //#include //#include //#include //#include int main(int argc, char* argv[]) { MITK_INFO << "Starting MITK_Forest Mini-App"; double startTime = time(0); ////////////////////////////////////////////////////////////////////////////// // Read Console Input Parameter ////////////////////////////////////////////////////////////////////////////// - ConfigFileReader allConfig(argv[2]); + ConfigFileReader allConfig(argv[1]); bool readFile = true; std::stringstream ss; for (int i = 0; i < argc; ++i ) { MITK_INFO << "-----"<< argv[i]<<"------"; if (readFile) { if (argv[i][0] == '+') { readFile = false; continue; } else { try { allConfig.ReadFile(argv[i]); } catch (std::exception &e) { MITK_INFO << e.what(); } } } else { std::string input = argv[i]; std::replace(input.begin(), input.end(),'_',' '); ss << input << std::endl; } } allConfig.ReadStream(ss); try { ////////////////////////////////////////////////////////////////////////////// // General ////////////////////////////////////////////////////////////////////////////// int currentRun = allConfig.IntValue("General","Run",0); int doTraining = allConfig.IntValue("General","Do Training",1); std::string forestPath = allConfig.Value("General","Forest Path"); std::string trainingCollectionPath = allConfig.Value("General","Patient Collection"); std::string testCollectionPath = trainingCollectionPath; - MITK_INFO << "Training collection: " << trainingCollectionPath; ////////////////////////////////////////////////////////////////////////////// // Read Default Classification ////////////////////////////////////////////////////////////////////////////// std::vector trainPatients = allConfig.Vector("Training Group",currentRun); std::vector testPatients = allConfig.Vector("Test Group",currentRun); std::vector modalities = allConfig.Vector("Modalities",0); std::string trainMask = allConfig.Value("Data","Training Mask"); std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask"); std::string testMask = allConfig.Value("Data","Test Mask"); std::string resultMask = allConfig.Value("Data", "Result Mask"); std::string resultProb = allConfig.Value("Data", "Result Propability"); std::string outputFolder = allConfig.Value("General","Output Folder"); std::string writeDataFilePath = allConfig.Value("Forest","File to write data to"); ////////////////////////////////////////////////////////////////////////////// // Read Forest Parameter ////////////////////////////////////////////////////////////////////////////// int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1); int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255); double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str()); if (samplesPerTree <= 0.0000001) { samplesPerTree = 1.0; } MITK_INFO << "Samples per Tree: " << samplesPerTree; int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1); double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str()); if (trainPrecision <= 0.0000000001) { trainPrecision = 0.0; } double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str()); if (weightLambda <= 0.0000000001) { weightLambda = 0.0; } int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000); int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0); ////////////////////////////////////////////////////////////////////////////// // Read Statistic Parameter ////////////////////////////////////////////////////////////////////////////// std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file"); std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file"); std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file"); std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV"); bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0); std::vector labelGroupA = allConfig.Vector("LabelsA",0); std::vector labelGroupB = allConfig.Vector("LabelsB",0); ////////////////////////////////////////////////////////////////////////////// // Read Special Parameter ////////////////////////////////////////////////////////////////////////////// bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0); bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0); int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0); std::string importanceWeightName = allConfig.Value("Forest","Importance weight name",""); std::ofstream timingFile; timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app); timingFile << statisticShortFileLabel << ";"; std::time_t lastTimePoint; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // Read Images ////////////////////////////////////////////////////////////////////////////// std::vector usedModalities; for (int i = 0; i < modalities.size(); ++i) { usedModalities.push_back(modalities[i]); } usedModalities.push_back(trainMask); usedModalities.push_back(completeTrainMask); usedModalities.push_back(testMask); usedModalities.push_back(statisticGoldStandard); usedModalities.push_back(importanceWeightName); // vtkSmartPointer colReader = vtkSmartPointer::New(); mitk::CollectionReader* colReader = new mitk::CollectionReader(); colReader->AddDataElementIds(trainPatients); colReader->SetDataItemNames(usedModalities); //colReader->SetNames(usedModalities); mitk::DataCollection::Pointer trainCollection; if (doTraining) { trainCollection = colReader->LoadCollection(trainingCollectionPath); } colReader->ClearDataElementIds(); colReader->AddDataElementIds(testPatients); mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath); std::time_t now; time(&now); double seconds = std::difftime(now, lastTimePoint); timingFile << seconds << ";"; time(&lastTimePoint); /* if (writePointsToFile) { MITK_INFO << "Use external weights..."; mitk::ExternalWeighting weightReader; weightReader.SetModalities(modalities); weightReader.SetTestCollection(testCollection); weightReader.SetTrainCollection(trainCollection); weightReader.SetTestMask(testMask); weightReader.SetTrainMask(trainMask); weightReader.SetWeightsName("weights"); weightReader.SetCorrectionFactor(1.0); weightReader.SetWeightFileName(writeDataFilePath); weightReader.WriteData(); return 0; }*/ ////////////////////////////////////////////////////////////////////////////// // If required do Training.... ////////////////////////////////////////////////////////////////////////////// //mitk::DecisionForest forest; mitk::VigraRandomForestClassifier::Pointer forest = mitk::VigraRandomForestClassifier::New(); forest->SetSamplesPerTree(samplesPerTree); forest->SetMinimumSplitNodeSize(minimumSplitNodeSize); forest->SetTreeCount(numberOfTrees); forest->UseSampleWithReplacement(sampleWithReplacement); forest->SetPrecision(trainPrecision); forest->SetMaximumTreeDepth(maximumTreeDepth); forest->SetWeightLambda(weightLambda); // TOOD forest.UseRandomSplit(randomSplit); if (doTraining) { // 0 = LR-Estimation // 1 = KNN-Estimation // 2 = Kliep // 3 = Extern Image // 4 = Zadrozny // 5 = Spectral // 6 = uLSIF auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask); auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask); - //if (useWeightedPoints) - if (false) + if (useWeightedPoints) + //if (false) { MITK_INFO << "Activated Point-based weighting..."; //forest.UseWeightedPoints(true); forest->UsePointWiseWeight(true); //forest.SetWeightName("calculated_weight"); /*if (importanceWeightAlgorithm == 1) { mitk::KNNDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 2) { mitk::KliepDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 3) { forest.SetWeightName(importanceWeightName); } else if (importanceWeightAlgorithm == 4) { mitk::ZadroznyWeighting est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 5) { mitk::SpectralDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 6) { mitk::ULSIFDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else*/ { mitk::LRDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask); forest->SetPointWiseWeight(trainDataW); forest->UsePointWiseWeight(true); } forest->Train(trainDataX, trainDataY); // TODO forest.Save(forestPath); } else { // TODO forest.Load(forestPath); } time(&now); seconds = std::difftime(now, lastTimePoint); timingFile << seconds << ";"; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // If required do Save Forest.... ////////////////////////////////////////////////////////////////////////////// //writer.// (forest); /* auto w = forest->GetTreeWeights(); w(0,0) = 10; forest->SetTreeWeights(w);*/ - mitk::IOUtil::Save(forest,"d:/tmp/forest.forest"); + //mitk::IOUtil::Save(forest,"d:/tmp/forest.forest"); ////////////////////////////////////////////////////////////////////////////// // If required do test ////////////////////////////////////////////////////////////////////////////// auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask); auto testDataNewY = forest->Predict(testDataX); + auto testDataNewProb = forest->GetPointWiseProbabilities(); //MITK_INFO << testDataNewY; + std::vector names; + names.push_back("prob-1"); + names.push_back("prob-2"); + mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask); + mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask); //forest.SetMaskName(testMask); //forest.SetCollection(testCollection); //forest.Test(); //forest.PrintTree(0); time(&now); seconds = std::difftime(now, lastTimePoint); timingFile << seconds << ";"; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // Cost-based analysis ////////////////////////////////////////////////////////////////////////////// // TODO Reactivate //MITK_INFO << "Calculate Cost-based Statistic "; //mitk::CostingStatistic costStat; //costStat.SetCollection(testCollection); //costStat.SetCombinedA("combinedHealty"); //costStat.SetCombinedB("combinedTumor"); //costStat.SetCombinedLabel("combinedLabel"); //costStat.SetMaskName(testMask); ////std::vector labelHealthy; ////labelHealthy.push_back("result_prop_Class-0"); ////labelHealthy.push_back("result_prop_Class-4"); ////std::vector labelTumor; ////labelTumor.push_back("result_prop_Class-1"); ////labelTumor.push_back("result_prop_Class-2"); ////labelTumor.push_back("result_prop_Class-3"); //costStat.SetProbabilitiesA(labelGroupA); //costStat.SetProbabilitiesB(labelGroupB); //std::ofstream costStatisticFile; //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app); //std::ofstream lcostStatisticFile; //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app); //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel); //costStatisticFile.close(); //costStat.CalculateClass(50); ////////////////////////////////////////////////////////////////////////////// // Save results to folder ////////////////////////////////////////////////////////////////////////////// std::vector outputFilter; //outputFilter.push_back(resultMask); //std::vector propNames = forest.GetListOfProbabilityNames(); //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end()); mitk::CollectionWriter::ExportCollectionToFolder(testCollection, outputFolder + "/result_collection.xml", outputFilter); MITK_INFO << "Calculate Statistic...." ; ////////////////////////////////////////////////////////////////////////////// // Calculate and Print Statistic ////////////////////////////////////////////////////////////////////////////// std::ofstream statisticFile; statisticFile.open(statisticFilePath.c_str(), std::ios::app); std::ofstream sstatisticFile; sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app); mitk::CollectionStatistic stat; stat.SetCollection(testCollection); stat.SetClassCount(2); stat.SetGoldName(statisticGoldStandard); stat.SetTestName(resultMask); stat.SetMaskName(testMask); + mitk::BinaryValueminusOneToIndexMapper* mapper = new mitk::BinaryValueminusOneToIndexMapper; + stat.SetGroundTruthValueToIndexMapper(mapper); + stat.SetTestValueToIndexMapper(mapper); stat.Update(); //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel); stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel); statisticFile.close(); + delete mapper; time(&now); seconds = std::difftime(now, lastTimePoint); timingFile << seconds << std::endl; time(&lastTimePoint); timingFile.close(); } catch (std::string s) { MITK_INFO << s; return 0; } catch (char* s) { MITK_INFO << s; } return 0; } #endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLRandomSampling.cpp b/Modules/Classification/CLMiniApps/CLRandomSampling.cpp new file mode 100644 index 0000000000..ea98e3ee26 --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLRandomSampling.cpp @@ -0,0 +1,158 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include "mitkCommandLineParser.h" +#include "mitkIOUtil.h" +#include + +static vector splitDouble(string str, char delimiter) { + vector internal; + stringstream ss(str); // Turn the string into a stream. + string tok; + double val; + while (getline(ss, tok, delimiter)) { + stringstream s2(tok); + s2 >> val; + internal.push_back(val); + } + + return internal; +} + +static vector splitUInt(string str, char delimiter) { + vector internal; + stringstream ss(str); // Turn the string into a stream. + string tok; + unsigned int val; + while (getline(ss, tok, delimiter)) { + stringstream s2(tok); + s2 >> val; + internal.push_back(val); + } + + return internal; +} + +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + + parser.setTitle("Random Sampling"); + parser.setCategory("Classification Command Tools"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + // Add command line argument names + parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Help:", "Show this help text"); + parser.addArgument("input", "i", mitkCommandLineParser::InputDirectory, "Input file:", "Input file", us::Any(), false); + parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output file:", "Output file", us::Any(), false); + + parser.addArgument("single-rate", "sr", mitkCommandLineParser::OutputFile, "Single Acceptance rate for all voxel", "Output file", us::Any(), true); + parser.addArgument("class-rate", "cr", mitkCommandLineParser::OutputFile, "Class-dependend acceptance rate", "Output file", us::Any(), true); + parser.addArgument("single-number", "sn", mitkCommandLineParser::OutputFile, "Single Number of Voxel for each class", "Output file", us::Any(), true); + parser.addArgument("class-number", "cn", mitkCommandLineParser::OutputFile, "Class-dependedn number of voxels ", "Output file", us::Any(), true); + + map parsedArgs = parser.parseArguments(argc, argv); + + if (parsedArgs.size() == 0) + return EXIT_FAILURE; + + // Show a help message + if (parsedArgs.count("help") || parsedArgs.count("h")) + { + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + if (parsedArgs.count("single-rate") + parsedArgs.count("class-rate") + parsedArgs.count("single-number") + parsedArgs.count("class-number") < 1) + { + std::cout << "Please specify the sampling rate or number of voxels to be labeled" << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + if (parsedArgs.count("single-rate") + parsedArgs.count("class-rate") + parsedArgs.count("single-number") + parsedArgs.count("class-number") > 2) + { + std::cout << "Please specify only one way for the sampling rate or number of voxels to be labeled" << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + + std::string inputName = us::any_cast(parsedArgs["input"]); + std::string outputName = us::any_cast(parsedArgs["output"]); + mitk::Image::Pointer image = mitk::IOUtil::LoadImage(inputName); + + mitk::RandomImageSampler::Pointer filter = mitk::RandomImageSampler::New(); + filter->SetInput(image); + + if (parsedArgs.count("single-rate")) + { + filter->SetSamplingMode(mitk::RandomImageSamplerMode::SINGLE_ACCEPTANCE_RATE); + auto rate = splitDouble(parsedArgs["single-rate"].ToString(), ';'); + if (rate.size() != 1) + { + std::cout << "Please specify a single double value for single-rate, for example 0.3." << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + filter->SetAcceptanceRate(rate[0]); + } + + else if (parsedArgs.count("class-rate")) + { + filter->SetSamplingMode(mitk::RandomImageSamplerMode::CLASS_DEPENDEND_ACCEPTANCE_RATE); + auto rate = splitDouble(parsedArgs["class-rate"].ToString(), ';'); + if (rate.size() < 2) + { + std::cout << "Please specify at least two, semicolon separated values for class-rate, for example '0.3;0.2' ." << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + filter->SetAcceptanceRateVector(rate); + } + + else if (parsedArgs.count("single-number")) + { + filter->SetSamplingMode(mitk::RandomImageSamplerMode::SINGLE_NUMBER_OF_ACCEPTANCE); + auto rate = splitUInt(parsedArgs["single-number"].ToString(), ';'); + if (rate.size() != 1) + { + std::cout << "Please specify a single double value for single-number, for example 100." << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + filter->SetNumberOfSamples(rate[0]); + } + + else if (parsedArgs.count("class-number")) + { + filter->SetSamplingMode(mitk::RandomImageSamplerMode::CLASS_DEPENDEND_NUMBER_OF_ACCEPTANCE); + auto rate = splitUInt(parsedArgs["class-number"].ToString(), ';'); + if (rate.size() < 2) + { + std::cout << "Please specify at least two, semicolon separated values for class-number, for example '100;200' ." << std::endl << std::endl; + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + filter->SetNumberOfSamplesVector(rate); + } + filter->Update(); + mitk::IOUtil::SaveImage(filter->GetOutput(), outputName); + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLRemoveEmptyVoxels.cpp b/Modules/Classification/CLMiniApps/CLRemoveEmptyVoxels.cpp new file mode 100644 index 0000000000..2f2a955c07 --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLRemoveEmptyVoxels.cpp @@ -0,0 +1,159 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include "mitkCommandLineParser.h" +#include "mitkIOUtil.h" +#include "mitkImageCast.h" +#include +#include +#include "mitkImageGenerator.h" + +int main(int argc, char* argv[]) +{ + typedef itk::Image ImageType; + typedef itk::Image MaskImageType; + typedef ImageType::Pointer ImagePointerType; + typedef MaskImageType::Pointer MaskImagePointerType; + + typedef itk::ImageRegionConstIterator ConstIteratorType; + typedef itk::ImageRegionConstIterator ConstMaskIteratorType; + typedef itk::ImageRegionIterator IteratorType; + typedef itk::ImageRegionIterator MaskIteratorType; + + mitkCommandLineParser parser; + + parser.setTitle("Remove empty voxels Sampling"); + parser.setCategory("Classification Command Tools"); + parser.setDescription(""); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + // Add command line argument names + parser.addArgument("help", "h", mitkCommandLineParser::Bool, "Help:", "Show this help text"); + parser.addArgument("mask-input", "mi", mitkCommandLineParser::InputDirectory, "Input file:", "Input file", us::Any(), false); + parser.addArgument("mask-output", "mo", mitkCommandLineParser::OutputFile, "Output file:", "Output file", us::Any(), false); + + for (int i = 0; i < 100; ++i) + { + stringstream s1; s1 << i; std::string number = s1.str(); + parser.addArgument("input"+number, "i"+number, mitkCommandLineParser::OutputFile, "Input file", "input file", us::Any(), true); + parser.addArgument("output" + number, "o" + number, mitkCommandLineParser::OutputFile, "Output File", "Output file", us::Any(), true); + } + + map parsedArgs = parser.parseArguments(argc, argv); + + if (parsedArgs.size() == 0) + return EXIT_FAILURE; + + // Show a help message + if (parsedArgs.count("help") || parsedArgs.count("h")) + { + std::cout << parser.helpText(); + return EXIT_SUCCESS; + } + + // Load Mask Image and count number of non-zero voxels + mitk::Image::Pointer mask = mitk::IOUtil::LoadImage(parsedArgs["mask-input"].ToString()); + MaskImagePointerType itkMask = MaskImageType::New(); + mitk::CastToItkImage(mask, itkMask); + ConstMaskIteratorType maskIter(itkMask, itkMask->GetLargestPossibleRegion()); + std::size_t nonZero = 0; + while (!maskIter.IsAtEnd()) + { + if (maskIter.Value() > 0) + { + ++nonZero; + } + ++maskIter; + } + maskIter.GoToBegin(); + + // Create new mask image + auto mitkNewMask = mitk::ImageGenerator::GenerateGradientImage(nonZero, 1, 1, 1, 1, 1); + MaskImagePointerType itkNewMask = MaskImageType::New(); + mitk::CastToItkImage(mitkNewMask, itkNewMask); + MaskIteratorType newMaskIter(itkNewMask, itkNewMask->GetLargestPossibleRegion()); + + // Read additional image + std::vector mitkImagesVector; + std::vector itkImageVector; + std::vector itkOutputImageVector; + std::vector inputIteratorVector; + std::vector outputIteratorVector; + for (int i = 0; i < 100; ++i) + { + stringstream s1; s1 << i; std::string number = s1.str(); + if (parsedArgs.count("input" + number) < 1) + break; + if (parsedArgs.count("output" + number) < 1) + break; + + mitk::Image::Pointer image = mitk::IOUtil::LoadImage(parsedArgs["input"+number].ToString()); + mitkImagesVector.push_back(image); + + ImagePointerType itkImage = ImageType::New(); + mitk::CastToItkImage(image, itkImage); + itkImageVector.push_back(itkImage); + + ConstIteratorType iter(itkImage, itkImage->GetLargestPossibleRegion()); + inputIteratorVector.push_back(iter); + + auto mitkNewImage = mitk::ImageGenerator::GenerateGradientImage(nonZero, 1, 1, 1, 1, 1); + ImagePointerType itkNewOutput = ImageType::New(); + mitk::CastToItkImage(mitkNewImage, itkNewOutput); + IteratorType outputIter(itkNewOutput, itkNewOutput->GetLargestPossibleRegion()); + itkOutputImageVector.push_back(itkNewOutput); + outputIteratorVector.push_back(outputIter); + } + + // Convert non-zero voxels to the new images + while (!maskIter.IsAtEnd()) + { + if (maskIter.Value() > 0) + { + newMaskIter.Set(maskIter.Value()); + ++newMaskIter; + for (int i = 0; i < outputIteratorVector.size(); ++i) + { + outputIteratorVector[i].Set(inputIteratorVector[i].Value()); + ++(outputIteratorVector[i]); + } + } + ++maskIter; + for (int i = 0; i < inputIteratorVector.size(); ++i) + { + ++(inputIteratorVector[i]); + } + + } + + // Save the new images + for (int i = 0; i < outputIteratorVector.size(); ++i) + { + stringstream s1; s1 << i; std::string number = s1.str(); + mitk::Image::Pointer mitkImage = mitk::Image::New(); + mitk::CastToMitkImage(itkOutputImageVector[i], mitkImage); + mitk::IOUtil::SaveImage(mitkImage, parsedArgs["output" + number].ToString()); + } + // Save the new mask + { + mitk::Image::Pointer mitkImage = mitk::Image::New(); + mitk::CastToMitkImage(itkNewMask, mitkImage); + mitk::IOUtil::SaveImage(mitkImage, parsedArgs["mask-output"].ToString()); + } + + return EXIT_SUCCESS; +} \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLResampleImageToReference.cpp b/Modules/Classification/CLMiniApps/CLResampleImageToReference.cpp new file mode 100644 index 0000000000..def14065c3 --- /dev/null +++ b/Modules/Classification/CLMiniApps/CLResampleImageToReference.cpp @@ -0,0 +1,116 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ +#ifndef mitkCLResampleImageToReference_cpp +#define mitkCLResampleImageToReference_cpp + +#include "mitkCommandLineParser.h" +#include +#include +#include +#include +#include +#include + +// ITK +#include "itkLinearInterpolateImageFunction.h" +#include "itkWindowedSincInterpolateImageFunction.h" +#include "itkNearestNeighborInterpolateImageFunction.h" +#include "itkIdentityTransform.h" +#include "itkResampleImageFilter.h" + + +template +void +ResampleImageToReferenceFunction(itk::Image* itkReference, mitk::Image::Pointer moving, std::string ergPath) +{ + typedef itk::Image InputImageType; + + // Identify Transform + typedef itk::IdentityTransform T_Transform; + typename T_Transform::Pointer _pTransform = T_Transform::New(); + _pTransform->SetIdentity(); + + typedef itk::WindowedSincInterpolateImageFunction< InputImageType, VImageDimension> WindowedSincInterpolatorType; + typename WindowedSincInterpolatorType::Pointer sinc_interpolator = WindowedSincInterpolatorType::New(); + + typedef itk::LinearInterpolateImageFunction< InputImageType> LinearInterpolateImageFunctionType; + typename LinearInterpolateImageFunctionType::Pointer lin_interpolator = LinearInterpolateImageFunctionType::New(); + + typedef itk::NearestNeighborInterpolateImageFunction< InputImageType> NearestNeighborInterpolateImageFunctionType; + typename NearestNeighborInterpolateImageFunctionType::Pointer nn_interpolator = NearestNeighborInterpolateImageFunctionType::New(); + + typename InputImageType::Pointer itkMoving = InputImageType::New(); + mitk::CastToItkImage(moving,itkMoving); + typedef itk::ResampleImageFilter ResampleFilterType; + + typename ResampleFilterType::Pointer resampler = ResampleFilterType::New(); + resampler->SetInput(itkMoving); + resampler->SetReferenceImage( itkReference ); + resampler->UseReferenceImageOn(); + resampler->SetTransform(_pTransform); + //if ( sincInterpol) + // resampler->SetInterpolator(sinc_interpolator); + //else + resampler->SetInterpolator(lin_interpolator); + + resampler->Update(); + + // Convert back to mitk + mitk::Image::Pointer result = mitk::Image::New(); + result->InitializeByItk(resampler->GetOutput()); + GrabItkImageMemory(resampler->GetOutput(), result); + MITK_INFO << "writing result to: " << ergPath; + mitk::IOUtil::SaveImage(result, ergPath); + //return result; +} + +int main(int argc, char* argv[]) +{ + mitkCommandLineParser parser; + parser.setArgumentPrefix("--", "-"); + // required params + parser.addArgument("fix", "f", mitkCommandLineParser::InputImage, "Input Image", "Path to the input VTK polydata", us::Any(), false); + parser.addArgument("moving", "m", mitkCommandLineParser::OutputFile, "Output text file", "Target file. The output statistic is appended to this file.", us::Any(), false); + parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Extension", "File extension. Default is .nii.gz", us::Any(), false); + + // Miniapp Infos + parser.setCategory("Classification Tools"); + parser.setTitle("Resample Image To Reference"); + parser.setDescription("Resamples an image (moving) to an given image (fix) without additional registration."); + parser.setContributor("MBI"); + + map parsedArgs = parser.parseArguments(argc, argv); + + if (parsedArgs.size() == 0) + { + return EXIT_FAILURE; + } + if (parsedArgs.count("help") || parsedArgs.count("h")) + { + return EXIT_SUCCESS; + } + + mitk::Image::Pointer fix = mitk::IOUtil::LoadImage(parsedArgs["fix"].ToString()); + mitk::Image::Pointer moving = mitk::IOUtil::LoadImage(parsedArgs["moving"].ToString()); + mitk::Image::Pointer erg = mitk::Image::New(); + + AccessByItk_2(fix, ResampleImageToReferenceFunction, moving, parsedArgs["output"].ToString()); + +} + + + +#endif diff --git a/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp b/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp index 708c7556f5..9b747e5934 100644 --- a/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp +++ b/Modules/Classification/CLMiniApps/CLVoxelClassification.cpp @@ -1,438 +1,482 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkForest_cpp #define mitkForest_cpp #include "time.h" #include #include #include #include #include #include #include #include #include #include #include // ----------------------- Forest Handling ---------------------- //#include #include //#include //#include //#include //#include // ----------------------- Point weighting ---------------------- //#include //#include //#include #include //#include //#include //#include //#include int main(int argc, char* argv[]) { MITK_INFO << "Starting MITK_Forest Mini-App"; double startTime = time(0); ////////////////////////////////////////////////////////////////////////////// // Read Console Input Parameter ////////////////////////////////////////////////////////////////////////////// - ConfigFileReader allConfig(argv[2]); + ConfigFileReader allConfig(argv[1]); bool readFile = true; std::stringstream ss; for (int i = 0; i < argc; ++i ) { MITK_INFO << "-----"<< argv[i]<<"------"; if (readFile) { if (argv[i][0] == '+') { readFile = false; continue; } else { try { allConfig.ReadFile(argv[i]); } catch (std::exception &e) { MITK_INFO << e.what(); } } } else { std::string input = argv[i]; std::replace(input.begin(), input.end(),'_',' '); ss << input << std::endl; } } allConfig.ReadStream(ss); try { ////////////////////////////////////////////////////////////////////////////// // General ////////////////////////////////////////////////////////////////////////////// int currentRun = allConfig.IntValue("General","Run",0); int doTraining = allConfig.IntValue("General","Do Training",1); std::string forestPath = allConfig.Value("General","Forest Path"); std::string trainingCollectionPath = allConfig.Value("General","Patient Collection"); - std::string testCollectionPath = trainingCollectionPath; - MITK_INFO << "Training collection: " << trainingCollectionPath; + std::string testCollectionPath = allConfig.Value("General", "Patient Test Collection", trainingCollectionPath); ////////////////////////////////////////////////////////////////////////////// // Read Default Classification ////////////////////////////////////////////////////////////////////////////// std::vector trainPatients = allConfig.Vector("Training Group",currentRun); std::vector testPatients = allConfig.Vector("Test Group",currentRun); - std::vector modalities = allConfig.Vector("Modalities",0); + std::vector modalities = allConfig.Vector("Modalities", 0); + std::vector outputFilter = allConfig.Vector("Output Filter", 0); std::string trainMask = allConfig.Value("Data","Training Mask"); std::string completeTrainMask = allConfig.Value("Data","Complete Training Mask"); std::string testMask = allConfig.Value("Data","Test Mask"); std::string resultMask = allConfig.Value("Data", "Result Mask"); std::string resultProb = allConfig.Value("Data", "Result Propability"); std::string outputFolder = allConfig.Value("General","Output Folder"); std::string writeDataFilePath = allConfig.Value("Forest","File to write data to"); + ////////////////////////////////////////////////////////////////////////////// + // Read Data Forest Parameter + ////////////////////////////////////////////////////////////////////////////// + int testSingleDataset = allConfig.IntValue("Data", "Test Single Dataset",0); + std::string singleDatasetName = allConfig.Value("Data", "Single Dataset Name", "none"); + int trainSingleDataset = allConfig.IntValue("Data", "Train Single Dataset", 0); + std::string singleTrainDatasetName = allConfig.Value("Data", "Train Single Dataset Name", "none"); + ////////////////////////////////////////////////////////////////////////////// // Read Forest Parameter ////////////////////////////////////////////////////////////////////////////// int minimumSplitNodeSize = allConfig.IntValue("Forest", "Minimum split node size",1); int numberOfTrees = allConfig.IntValue("Forest", "Number of Trees",255); double samplesPerTree = atof(allConfig.Value("Forest", "Samples per Tree").c_str()); if (samplesPerTree <= 0.0000001) { samplesPerTree = 1.0; } MITK_INFO << "Samples per Tree: " << samplesPerTree; int sampleWithReplacement = allConfig.IntValue("Forest", "Sample with replacement",1); double trainPrecision = atof(allConfig.Value("Forest", "Precision").c_str()); if (trainPrecision <= 0.0000000001) { trainPrecision = 0.0; } double weightLambda = atof(allConfig.Value("Forest", "Weight Lambda").c_str()); if (weightLambda <= 0.0000000001) { weightLambda = 0.0; } int maximumTreeDepth = allConfig.IntValue("Forest", "Maximum Tree Depth",10000); int randomSplit = allConfig.IntValue("Forest","Use RandomSplit",0); ////////////////////////////////////////////////////////////////////////////// // Read Statistic Parameter ////////////////////////////////////////////////////////////////////////////// std::string statisticFilePath = allConfig.Value("Evaluation", "Statistic output file"); std::string statisticShortFilePath = allConfig.Value("Evaluation", "Statistic short output file"); std::string statisticShortFileLabel = allConfig.Value("Evaluation", "Index for short file"); std::string statisticGoldStandard = allConfig.Value("Evaluation", "Gold Standard Name","GTV"); bool statisticWithHeader = allConfig.IntValue("Evaluation", "Write header in short file",0); std::vector labelGroupA = allConfig.Vector("LabelsA",0); std::vector labelGroupB = allConfig.Vector("LabelsB",0); ////////////////////////////////////////////////////////////////////////////// // Read Special Parameter ////////////////////////////////////////////////////////////////////////////// bool useWeightedPoints = allConfig.IntValue("Forest", "Use point-based weighting",0); bool writePointsToFile = allConfig.IntValue("Forest", "Write points to file",0); int importanceWeightAlgorithm = allConfig.IntValue("Forest","Importance weight Algorithm",0); std::string importanceWeightName = allConfig.Value("Forest","Importance weight name",""); std::ofstream timingFile; timingFile.open((statisticFilePath + ".timing").c_str(), std::ios::app); timingFile << statisticShortFileLabel << ";"; std::time_t lastTimePoint; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // Read Images ////////////////////////////////////////////////////////////////////////////// std::vector usedModalities; for (int i = 0; i < modalities.size(); ++i) { usedModalities.push_back(modalities[i]); } usedModalities.push_back(trainMask); usedModalities.push_back(completeTrainMask); usedModalities.push_back(testMask); usedModalities.push_back(statisticGoldStandard); usedModalities.push_back(importanceWeightName); - // vtkSmartPointer colReader = vtkSmartPointer::New(); + if (trainSingleDataset > 0) + { + trainPatients.clear(); + trainPatients.push_back(singleTrainDatasetName); + } + mitk::CollectionReader* colReader = new mitk::CollectionReader(); colReader->AddDataElementIds(trainPatients); colReader->SetDataItemNames(usedModalities); //colReader->SetNames(usedModalities); mitk::DataCollection::Pointer trainCollection; if (doTraining) { trainCollection = colReader->LoadCollection(trainingCollectionPath); } + + if (testSingleDataset > 0) + { + testPatients.clear(); + testPatients.push_back(singleDatasetName); + } colReader->ClearDataElementIds(); colReader->AddDataElementIds(testPatients); mitk::DataCollection::Pointer testCollection = colReader->LoadCollection(testCollectionPath); std::time_t now; time(&now); double seconds = std::difftime(now, lastTimePoint); timingFile << seconds << ";"; time(&lastTimePoint); /* if (writePointsToFile) { MITK_INFO << "Use external weights..."; mitk::ExternalWeighting weightReader; weightReader.SetModalities(modalities); weightReader.SetTestCollection(testCollection); weightReader.SetTrainCollection(trainCollection); weightReader.SetTestMask(testMask); weightReader.SetTrainMask(trainMask); weightReader.SetWeightsName("weights"); weightReader.SetCorrectionFactor(1.0); weightReader.SetWeightFileName(writeDataFilePath); weightReader.WriteData(); return 0; }*/ ////////////////////////////////////////////////////////////////////////////// // If required do Training.... ////////////////////////////////////////////////////////////////////////////// //mitk::DecisionForest forest; mitk::VigraRandomForestClassifier::Pointer forest = mitk::VigraRandomForestClassifier::New(); forest->SetSamplesPerTree(samplesPerTree); forest->SetMinimumSplitNodeSize(minimumSplitNodeSize); forest->SetTreeCount(numberOfTrees); forest->UseSampleWithReplacement(sampleWithReplacement); forest->SetPrecision(trainPrecision); forest->SetMaximumTreeDepth(maximumTreeDepth); forest->SetWeightLambda(weightLambda); // TOOD forest.UseRandomSplit(randomSplit); if (doTraining) { // 0 = LR-Estimation // 1 = KNN-Estimation // 2 = Kliep // 3 = Extern Image // 4 = Zadrozny // 5 = Spectral // 6 = uLSIF auto trainDataX = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, modalities, trainMask); auto trainDataY = mitk::DCUtilities::DC3dDToMatrixXi(trainCollection, trainMask, trainMask); - //if (useWeightedPoints) - if (false) + if (useWeightedPoints) + //if (false) { MITK_INFO << "Activated Point-based weighting..."; //forest.UseWeightedPoints(true); forest->UsePointWiseWeight(true); //forest.SetWeightName("calculated_weight"); /*if (importanceWeightAlgorithm == 1) { mitk::KNNDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 2) { mitk::KliepDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 3) { forest.SetWeightName(importanceWeightName); } else if (importanceWeightAlgorithm == 4) { mitk::ZadroznyWeighting est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 5) { mitk::SpectralDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else if (importanceWeightAlgorithm == 6) { mitk::ULSIFDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } else*/ { mitk::LRDensityEstimation est; est.SetCollection(trainCollection); est.SetTrainMask(trainMask); est.SetTestMask(testMask); est.SetModalities(modalities); est.SetWeightName("calculated_weight"); est.Update(); } auto trainDataW = mitk::DCUtilities::DC3dDToMatrixXd(trainCollection, "calculated_weight", trainMask); forest->SetPointWiseWeight(trainDataW); forest->UsePointWiseWeight(true); } + MITK_INFO << "Start training the forest"; forest->Train(trainDataX, trainDataY); - // TODO forest.Save(forestPath); + + MITK_INFO << "Save Forest"; + mitk::IOUtil::Save(forest, forestPath); } else { - // TODO forest.Load(forestPath); + forest = dynamic_cast(mitk::IOUtil::Load(forestPath)[0].GetPointer());// TODO forest.Load(forestPath); } time(&now); seconds = std::difftime(now, lastTimePoint); + MITK_INFO << "Duration for Training: " << seconds; timingFile << seconds << ";"; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // If required do Save Forest.... ////////////////////////////////////////////////////////////////////////////// //writer.// (forest); /* auto w = forest->GetTreeWeights(); w(0,0) = 10; forest->SetTreeWeights(w);*/ - mitk::IOUtil::Save(forest,"d:/tmp/forest.forest"); ////////////////////////////////////////////////////////////////////////////// // If required do test ////////////////////////////////////////////////////////////////////////////// + MITK_INFO << "Convert Test data"; auto testDataX = mitk::DCUtilities::DC3dDToMatrixXd(testCollection,modalities, testMask); + + MITK_INFO << "Predict Test Data"; auto testDataNewY = forest->Predict(testDataX); + auto testDataNewProb = forest->GetPointWiseProbabilities(); //MITK_INFO << testDataNewY; + auto maxClassValue = testDataNewProb.cols(); + std::vector names; + for (int i = 0; i < maxClassValue; ++i) + { + std::string name = resultProb + std::to_string(i); + MITK_INFO << name; + names.push_back(name); + } + //names.push_back("prob-1"); + //names.push_back("prob-2"); + mitk::DCUtilities::MatrixToDC3d(testDataNewY, testCollection, resultMask, testMask); + mitk::DCUtilities::MatrixToDC3d(testDataNewProb, testCollection, names, testMask); + MITK_INFO << "Converted predicted data"; //forest.SetMaskName(testMask); //forest.SetCollection(testCollection); //forest.Test(); //forest.PrintTree(0); time(&now); seconds = std::difftime(now, lastTimePoint); timingFile << seconds << ";"; time(&lastTimePoint); ////////////////////////////////////////////////////////////////////////////// // Cost-based analysis ////////////////////////////////////////////////////////////////////////////// // TODO Reactivate //MITK_INFO << "Calculate Cost-based Statistic "; //mitk::CostingStatistic costStat; //costStat.SetCollection(testCollection); //costStat.SetCombinedA("combinedHealty"); //costStat.SetCombinedB("combinedTumor"); //costStat.SetCombinedLabel("combinedLabel"); //costStat.SetMaskName(testMask); ////std::vector labelHealthy; ////labelHealthy.push_back("result_prop_Class-0"); ////labelHealthy.push_back("result_prop_Class-4"); ////std::vector labelTumor; ////labelTumor.push_back("result_prop_Class-1"); ////labelTumor.push_back("result_prop_Class-2"); ////labelTumor.push_back("result_prop_Class-3"); //costStat.SetProbabilitiesA(labelGroupA); //costStat.SetProbabilitiesB(labelGroupB); //std::ofstream costStatisticFile; //costStatisticFile.open((statisticFilePath + ".cost").c_str(), std::ios::app); //std::ofstream lcostStatisticFile; //lcostStatisticFile.open((statisticFilePath + ".longcost").c_str(), std::ios::app); //costStat.WriteStatistic(lcostStatisticFile,costStatisticFile,2.5,statisticShortFileLabel); //costStatisticFile.close(); //costStat.CalculateClass(50); ////////////////////////////////////////////////////////////////////////////// // Save results to folder ////////////////////////////////////////////////////////////////////////////// - std::vector outputFilter; + ////std::vector outputFilter; //outputFilter.push_back(resultMask); //std::vector propNames = forest.GetListOfProbabilityNames(); //outputFilter.insert(outputFilter.begin(), propNames.begin(), propNames.end()); + MITK_INFO << "Write Result to HDD"; mitk::CollectionWriter::ExportCollectionToFolder(testCollection, outputFolder + "/result_collection.xml", outputFilter); MITK_INFO << "Calculate Statistic...." ; ////////////////////////////////////////////////////////////////////////////// // Calculate and Print Statistic ////////////////////////////////////////////////////////////////////////////// std::ofstream statisticFile; statisticFile.open(statisticFilePath.c_str(), std::ios::app); std::ofstream sstatisticFile; sstatisticFile.open(statisticShortFilePath.c_str(), std::ios::app); mitk::CollectionStatistic stat; stat.SetCollection(testCollection); - stat.SetClassCount(2); + stat.SetClassCount(5); stat.SetGoldName(statisticGoldStandard); stat.SetTestName(resultMask); stat.SetMaskName(testMask); + mitk::BinaryValueminusOneToIndexMapper* mapper = new mitk::BinaryValueminusOneToIndexMapper; + stat.SetGroundTruthValueToIndexMapper(mapper); + stat.SetTestValueToIndexMapper(mapper); stat.Update(); //stat.Print(statisticFile,sstatisticFile,statisticWithHeader, statisticShortFileLabel); stat.Print(statisticFile,sstatisticFile,true, statisticShortFileLabel); statisticFile.close(); + delete mapper; time(&now); seconds = std::difftime(now, lastTimePoint); timingFile << seconds << std::endl; time(&lastTimePoint); timingFile.close(); } catch (std::string s) { MITK_INFO << s; return 0; } catch (char* s) { MITK_INFO << s; } return 0; } #endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CLVoxelFeatures.cpp b/Modules/Classification/CLMiniApps/CLVoxelFeatures.cpp index 02fa596994..dc18b6c60f 100644 --- a/Modules/Classification/CLMiniApps/CLVoxelFeatures.cpp +++ b/Modules/Classification/CLMiniApps/CLVoxelFeatures.cpp @@ -1,328 +1,337 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkCLVoxeFeatures_cpp #define mitkCLVoxeFeatures_cpp #include "time.h" #include #include #include #include #include #include "mitkCommandLineParser.h" #include #include #include "itkDiscreteGaussianImageFilter.h" #include #include "itkHessianRecursiveGaussianImageFilter.h" #include "itkUnaryFunctorImageFilter.h" #include #include "vnl/algo/vnl_symmetric_eigensystem.h" #include static std::vector splitDouble(std::string str, char delimiter) { std::vector internal; std::stringstream ss(str); // Turn the string into a stream. std::string tok; double val; while (std::getline(ss, tok, delimiter)) { std::stringstream s2(tok); s2 >> val; internal.push_back(val); } return internal; } namespace Functor { template class MatrixFirstEigenvalue { public: MatrixFirstEigenvalue() {} virtual ~MatrixFirstEigenvalue() {} int order; inline TOutput operator ()(const TInput& input) { double a,b,c; if (input[0] < 0.01 && input[1] < 0.01 &&input[2] < 0.01 &&input[3] < 0.01 &&input[4] < 0.01 &&input[5] < 0.01) return 0; vnl_symmetric_eigensystem_compute_eigenvals(input[0], input[1],input[2],input[3],input[4],input[5],a,b,c); switch (order) { case 0: return a; case 1: return b; case 2: return c; default: return a; } } bool operator !=(const MatrixFirstEigenvalue) const { return false; } bool operator ==(const MatrixFirstEigenvalue& other) const { return !(*this != other); } }; } template void GaussianFilter(itk::Image* itkImage, double variance, mitk::Image::Pointer &output) { typedef itk::Image ImageType; typedef itk::DiscreteGaussianImageFilter< ImageType, ImageType > GaussFilterType; typename GaussFilterType::Pointer gaussianFilter = GaussFilterType::New(); gaussianFilter->SetInput( itkImage ); gaussianFilter->SetVariance(variance); gaussianFilter->Update(); mitk::CastToMitkImage(gaussianFilter->GetOutput(), output); } template void DifferenceOfGaussFilter(itk::Image* itkImage, double variance, mitk::Image::Pointer &output) { typedef itk::Image ImageType; typedef itk::DiscreteGaussianImageFilter< ImageType, ImageType > GaussFilterType; typedef itk::SubtractImageFilter SubFilterType; typename GaussFilterType::Pointer gaussianFilter1 = GaussFilterType::New(); gaussianFilter1->SetInput( itkImage ); gaussianFilter1->SetVariance(variance); gaussianFilter1->Update(); typename GaussFilterType::Pointer gaussianFilter2 = GaussFilterType::New(); gaussianFilter2->SetInput( itkImage ); gaussianFilter2->SetVariance(variance*0.66*0.66); gaussianFilter2->Update(); typename SubFilterType::Pointer subFilter = SubFilterType::New(); subFilter->SetInput1(gaussianFilter1->GetOutput()); subFilter->SetInput2(gaussianFilter2->GetOutput()); subFilter->Update(); mitk::CastToMitkImage(subFilter->GetOutput(), output); } template void LaplacianOfGaussianFilter(itk::Image* itkImage, double variance, mitk::Image::Pointer &output) { typedef itk::Image ImageType; typedef itk::DiscreteGaussianImageFilter< ImageType, ImageType > GaussFilterType; typedef itk::LaplacianRecursiveGaussianImageFilter LaplacianFilter; typename GaussFilterType::Pointer gaussianFilter = GaussFilterType::New(); gaussianFilter->SetInput( itkImage ); gaussianFilter->SetVariance(variance); gaussianFilter->Update(); typename LaplacianFilter::Pointer laplaceFilter = LaplacianFilter::New(); laplaceFilter->SetInput(gaussianFilter->GetOutput()); laplaceFilter->Update(); mitk::CastToMitkImage(laplaceFilter->GetOutput(), output); } template void HessianOfGaussianFilter(itk::Image* itkImage, double variance, std::vector &out) { typedef itk::Image ImageType; typedef itk::Image FloatImageType; typedef itk::HessianRecursiveGaussianImageFilter HessianFilterType; typedef typename HessianFilterType::OutputImageType VectorImageType; typedef Functor::MatrixFirstEigenvalue DeterminantFunctorType; typedef itk::UnaryFunctorImageFilter DetFilterType; typename HessianFilterType::Pointer hessianFilter = HessianFilterType::New(); hessianFilter->SetInput(itkImage); hessianFilter->SetSigma(std::sqrt(variance)); for (int i = 0; i < VImageDimension; ++i) { typename DetFilterType::Pointer detFilter = DetFilterType::New(); detFilter->SetInput(hessianFilter->GetOutput()); detFilter->GetFunctor().order = i; detFilter->Update(); mitk::CastToMitkImage(detFilter->GetOutput(), out[i]); } } template void LocalHistograms(itk::Image* itkImage, std::vector &out, double offset, double delta) { typedef itk::Image ImageType; typedef itk::Image FloatImageType; typedef itk::MultiHistogramFilter MultiHistogramType; typename MultiHistogramType::Pointer filter = MultiHistogramType::New(); filter->SetInput(itkImage); filter->SetOffset(offset); filter->SetDelta(delta); filter->Update(); - for (int i = 0; i < VImageDimension; ++i) + for (int i = 0; i < 11; ++i) { mitk::Image::Pointer img = mitk::Image::New(); - mitk::CastToMitkImage(filter->GetOutput(), img); + mitk::CastToMitkImage(filter->GetOutput(i), img); out.push_back(img); } } int main(int argc, char* argv[]) { mitkCommandLineParser parser; parser.setArgumentPrefix("--", "-"); // required params parser.addArgument("image", "i", mitkCommandLineParser::InputImage, "Input Image", "Path to the input VTK polydata", us::Any(), false); parser.addArgument("output", "o", mitkCommandLineParser::OutputFile, "Output text file", "Target file. The output statistic is appended to this file.", us::Any(), false); + parser.addArgument("extension", "e", mitkCommandLineParser::OutputFile, "Extension", "File extension. Default is .nii.gz", us::Any(), true); parser.addArgument("gaussian","g",mitkCommandLineParser::String, "Gaussian Filtering of the input images", "Gaussian Filter. Followed by the used variances seperated by ';' ",us::Any()); parser.addArgument("difference-of-gaussian","dog",mitkCommandLineParser::String, "Difference of Gaussian Filtering of the input images", "Difference of Gaussian Filter. Followed by the used variances seperated by ';' ",us::Any()); parser.addArgument("laplace-of-gauss","log",mitkCommandLineParser::String, "Laplacian of Gaussian Filtering", "Laplacian of Gaussian Filter. Followed by the used variances seperated by ';' ",us::Any()); parser.addArgument("hessian-of-gauss","hog",mitkCommandLineParser::String, "Hessian of Gaussian Filtering", "Hessian of Gaussian Filter. Followed by the used variances seperated by ';' ",us::Any()); parser.addArgument("local-histogram", "lh", mitkCommandLineParser::String, "Local Histograms", "Calculate the local histogram based feature. Specify Offset and Delta, for exampel -3;0.6 ", us::Any()); // Miniapp Infos parser.setCategory("Classification Tools"); parser.setTitle("Global Image Feature calculator"); parser.setDescription("Calculates different global statistics for a given segmentation / image combination"); parser.setContributor("MBI"); std::map parsedArgs = parser.parseArguments(argc, argv); if (parsedArgs.size()==0) { return EXIT_FAILURE; } if ( parsedArgs.count("help") || parsedArgs.count("h")) { return EXIT_SUCCESS; } bool useCooc = parsedArgs.count("cooccurence"); mitk::Image::Pointer image = mitk::IOUtil::LoadImage(parsedArgs["image"].ToString()); std::string filename=parsedArgs["output"].ToString(); + std::string extension = ".nii.gz"; + if (parsedArgs.count("extension")) + { + extension = parsedArgs["extension"].ToString(); + } + //////////////////////////////////////////////////////////////// // CAlculate Gaussian Features //////////////////////////////////////////////////////////////// MITK_INFO << "Check for Local Histogram..."; if (parsedArgs.count("local-histogram")) { std::vector outs; auto ranges = splitDouble(parsedArgs["local-histogram"].ToString(), ';'); if (ranges.size() < 2) { MITK_INFO << "Missing Delta and Offset for Local Histogram"; } else { AccessByItk_3(image, LocalHistograms, outs, ranges[0], ranges[1]); for (int i = 0; i < outs.size(); ++i) { - std::string name = filename + "-lh" + us::any_value_to_string(i)+".nii.gz"; + std::string name = filename + "-lh" + us::any_value_to_string(i)+extension; mitk::IOUtil::SaveImage(outs[i], name); } } } //////////////////////////////////////////////////////////////// // CAlculate Gaussian Features //////////////////////////////////////////////////////////////// MITK_INFO << "Check for Gaussian..."; if (parsedArgs.count("gaussian")) { MITK_INFO << "Calculate Gaussian... " << parsedArgs["gaussian"].ToString(); auto ranges = splitDouble(parsedArgs["gaussian"].ToString(),';'); for (int i = 0; i < ranges.size(); ++i) { + MITK_INFO << "Gaussian with sigma: " << ranges[i]; mitk::Image::Pointer output; AccessByItk_2(image, GaussianFilter, ranges[i], output); - std::string name = filename + "-gaussian-" + us::any_value_to_string(ranges[i])+".nii.gz"; + MITK_INFO << "Write output:"; + std::string name = filename + "-gaussian-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(output, name); } } //////////////////////////////////////////////////////////////// // CAlculate Difference of Gaussian Features //////////////////////////////////////////////////////////////// MITK_INFO << "Check for DoG..."; if (parsedArgs.count("difference-of-gaussian")) { MITK_INFO << "Calculate Difference of Gaussian... " << parsedArgs["difference-of-gaussian"].ToString(); auto ranges = splitDouble(parsedArgs["difference-of-gaussian"].ToString(),';'); for (int i = 0; i < ranges.size(); ++i) { mitk::Image::Pointer output; AccessByItk_2(image, DifferenceOfGaussFilter, ranges[i], output); - std::string name = filename + "-dog-" + us::any_value_to_string(ranges[i])+".nii.gz"; + std::string name = filename + "-dog-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(output, name); } } MITK_INFO << "Check for LoG..."; //////////////////////////////////////////////////////////////// // CAlculate Laplacian Of Gauss Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("laplace-of-gauss")) { MITK_INFO << "Calculate LoG... " << parsedArgs["laplace-of-gauss"].ToString(); auto ranges = splitDouble(parsedArgs["laplace-of-gauss"].ToString(),';'); for (int i = 0; i < ranges.size(); ++i) { mitk::Image::Pointer output; AccessByItk_2(image, LaplacianOfGaussianFilter, ranges[i], output); - std::string name = filename + "-log-" + us::any_value_to_string(ranges[i])+".nii.gz"; + std::string name = filename + "-log-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(output, name); } } MITK_INFO << "Check for HoG..."; //////////////////////////////////////////////////////////////// // CAlculate Hessian Of Gauss Features //////////////////////////////////////////////////////////////// if (parsedArgs.count("hessian-of-gauss")) { MITK_INFO << "Calculate HoG... " << parsedArgs["hessian-of-gauss"].ToString(); auto ranges = splitDouble(parsedArgs["hessian-of-gauss"].ToString(),';'); for (int i = 0; i < ranges.size(); ++i) { std::vector outs; outs.push_back(mitk::Image::New()); outs.push_back(mitk::Image::New()); outs.push_back(mitk::Image::New()); AccessByItk_2(image, HessianOfGaussianFilter, ranges[i], outs); - std::string name = filename + "-hog0-" + us::any_value_to_string(ranges[i])+".nii.gz"; + std::string name = filename + "-hog0-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(outs[0], name); - name = filename + "-hog1-" + us::any_value_to_string(ranges[i])+".nii.gz"; + name = filename + "-hog1-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(outs[1], name); - name = filename + "-hog2-" + us::any_value_to_string(ranges[i])+".nii.gz"; + name = filename + "-hog2-" + us::any_value_to_string(ranges[i]) + extension; mitk::IOUtil::SaveImage(outs[2], name); } } return 0; } #endif \ No newline at end of file diff --git a/Modules/Classification/CLMiniApps/CMakeLists.txt b/Modules/Classification/CLMiniApps/CMakeLists.txt index b163d6dcb6..c093d569b4 100644 --- a/Modules/Classification/CLMiniApps/CMakeLists.txt +++ b/Modules/Classification/CLMiniApps/CMakeLists.txt @@ -1,62 +1,135 @@ option(BUILD_ClassificationMiniApps "Build commandline tools for classification" OFF) if(BUILD_ClassificationMiniApps OR MITK_BUILD_ALL_APPS) include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ) # list of miniapps # if an app requires additional dependencies # they are added after a "^^" and separated by "_" set( classificationminiapps RandomForestTraining^^MitkCLVigraRandomForest NativeHeadCTSegmentation^^MitkCLVigraRandomForest ManualSegmentationEvaluation^^MitkCLVigraRandomForest - CLGlobalImageFeatures^^MitkCLUtilities - CLMRNormalization^^MitkCLUtilities_MitkCLMRUtilities - CLStaple^^MitkCLUtilities - CLVoxelFeatures^^MitkCLUtilities - CLDicom2Nrrd^^ - CLPolyToNrrd^^ - CLSimpleVoxelClassification^^MitkDataCollection_MitkCLVigraRandomForest - CLVoxelClassification^^MitkDataCollection_MitkCLImportanceWeighting_MitkCLVigraRandomForest - CLBrainMask^^MitkCLUtilities + CLGlobalImageFeatures^^MitkCore_MitkCLUtilities + CLMRNormalization^^MitkCore_MitkCLUtilities_MitkCLMRUtilities + CLStaple^^MitkCore_MitkCLUtilities + CLVoxelFeatures^^MitkCore_MitkCLUtilities + CLDicom2Nrrd^^MitkCore + CLPolyToNrrd^^MitkCore + CLImageTypeConverter^^MitkCore + CLResampleImageToReference^^MitkCore + CLRandomSampling^^MitkCore_MitkCLUtilities + CLRemoveEmptyVoxels^^MitkCore + CLN4^^MitkCore + CLMultiForestPrediction^^MitkDataCollection_MitkCLVigraRandomForest + CLNrrdToPoly^^MitkCore # RandomForestPrediction^^MitkCLVigraRandomForest ) foreach(classificationminiapps ${classificationminiapps}) # extract mini app name and dependencies string(REPLACE "^^" "\\;" miniapp_info ${classificationminiapps}) set(miniapp_info_list ${miniapp_info}) list(GET miniapp_info_list 0 appname) list(GET miniapp_info_list 1 raw_dependencies) string(REPLACE "_" "\\;" dependencies "${raw_dependencies}") set(dependencies_list ${dependencies}) - mitkFunctionCreateCommandLineApp( - NAME ${appname} - DEPENDS MitkCore MitkCLCore ${dependencies_list} - PACKAGE_DEPENDS Vigra Qt5|Core + mitk_create_executable(${appname} + DEPENDS MitkCore MitkCLCore ${dependencies_list} + PACKAGE_DEPENDS ITK Qt4|QtCore Qt5|Core Vigra + CPP_FILES ${appname}.cpp mitkCommandLineParser.cpp ) + + if(EXECUTABLE_IS_ENABLED) + + # On Linux, create a shell script to start a relocatable application + if(UNIX AND NOT APPLE) + install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) + endif() + + get_target_property(_is_bundle ${EXECUTABLE_TARGET} MACOSX_BUNDLE) + + if(APPLE) + if(_is_bundle) + set(_target_locations ${EXECUTABLE_TARGET}.app) + set(${_target_locations}_qt_plugins_install_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) + set(_bundle_dest_dir ${EXECUTABLE_TARGET}.app/Contents/MacOS) + set(_qt_plugins_for_current_bundle ${EXECUTABLE_TARGET}.app/Contents/MacOS) + set(_qt_conf_install_dirs ${EXECUTABLE_TARGET}.app/Contents/Resources) + install(TARGETS ${EXECUTABLE_TARGET} BUNDLE DESTINATION . ) + else() + if(NOT MACOSX_BUNDLE_NAMES) + set(_qt_conf_install_dirs bin) + set(_target_locations bin/${EXECUTABLE_TARGET}) + set(${_target_locations}_qt_plugins_install_dir bin) + install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) + else() + foreach(bundle_name ${MACOSX_BUNDLE_NAMES}) + list(APPEND _qt_conf_install_dirs ${bundle_name}.app/Contents/Resources) + set(_current_target_location ${bundle_name}.app/Contents/MacOS/${EXECUTABLE_TARGET}) + list(APPEND _target_locations ${_current_target_location}) + set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) + message( " set(${_current_target_location}_qt_plugins_install_dir ${bundle_name}.app/Contents/MacOS) ") + + install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION ${bundle_name}.app/Contents/MacOS/) + endforeach() + endif() + endif() + else() + set(_target_locations bin/${EXECUTABLE_TARGET}${CMAKE_EXECUTABLE_SUFFIX}) + set(${_target_locations}_qt_plugins_install_dir bin) + set(_qt_conf_install_dirs bin) + install(TARGETS ${EXECUTABLE_TARGET} RUNTIME DESTINATION bin) + endif() + endif() endforeach() - # This mini app does not depend on MitkCLCore at all - mitkFunctionCreateCommandLineApp( - NAME CLImageConverter - DEPENDS MitkCore ${dependencies_list} - ) + # This mini app does not depend on mitkDiffusionImaging at all - mitkFunctionCreateCommandLineApp( - NAME CLSurWeighting - DEPENDS MitkCore MitkCLUtilities MitkDataCollection MitkCLImportanceWeighting ${dependencies_list} - ) + #mitk_create_executable(CLGlobalImageFeatures + # DEPENDS MitkCore MitkCLUtilities + # CPP_FILES CLGlobalImageFeatures.cpp mitkCommandLineParser.cpp + #) - mitkFunctionCreateCommandLineApp( - NAME CLImageCropper - DEPENDS MitkCore MitkCLUtilities MitkAlgorithmsExt ${dependencies_list} + mitk_create_executable(CLSimpleVoxelClassification + DEPENDS MitkCore MitkCLCore MitkDataCollection MitkCLVigraRandomForest MitkCommandLine + CPP_FILES CLSimpleVoxelClassification.cpp ) + # This mini app does not depend on mitkDiffusionImaging at all + mitk_create_executable(CLVoxelClassification + DEPENDS MitkCore MitkCLCore MitkDataCollection MitkCLImportanceWeighting MitkCLVigraRandomForest + CPP_FILES CLVoxelClassification.cpp + ) + mitk_create_executable(CLBrainMask + DEPENDS MitkCore MitkCLUtilities + CPP_FILES CLBrainMask.cpp mitkCommandLineParser.cpp + ) + mitk_create_executable(CLImageConverter + DEPENDS MitkCore + CPP_FILES CLImageConverter.cpp mitkCommandLineParser.cpp + ) + mitk_create_executable(CLSurWeighting + DEPENDS MitkCore MitkCLUtilities MitkDataCollection MitkCLImportanceWeighting + CPP_FILES CLSurWeighting.cpp mitkCommandLineParser.cpp + ) + mitk_create_executable(CLImageCropper + DEPENDS MitkCore MitkCLUtilities MitkAlgorithmsExt + CPP_FILES CLImageCropper.cpp mitkCommandLineParser.cpp + ) + + # On Linux, create a shell script to start a relocatable application + if(UNIX AND NOT APPLE) + install(PROGRAMS "${MITK_SOURCE_DIR}/CMake/RunInstalledApp.sh" DESTINATION "." RENAME ${EXECUTABLE_TARGET}.sh) + endif() + + if(EXECUTABLE_IS_ENABLED) + MITK_INSTALL_TARGETS(EXECUTABLES ${EXECUTABLE_TARGET}) + endif() endif() diff --git a/Modules/Classification/CLUtilities/files.cmake b/Modules/Classification/CLUtilities/files.cmake index cfcdfcfe53..35294a897a 100644 --- a/Modules/Classification/CLUtilities/files.cmake +++ b/Modules/Classification/CLUtilities/files.cmake @@ -1,24 +1,25 @@ file(GLOB_RECURSE H_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/include/*") set(CPP_FILES Algorithms/itkLabelSampler.cpp Algorithms/itkSmoothedClassProbabilites.cpp + Algorithms/mitkRandomImageSampler.cpp Features/itkNeighborhoodFunctorImageFilter.cpp Features/itkLineHistogramBasedMassImageFilter.cpp GlobalImageFeatures/mitkGIFCooccurenceMatrix.cpp GlobalImageFeatures/mitkGIFGrayLevelRunLength.cpp GlobalImageFeatures/mitkGIFFirstOrderStatistics.cpp GlobalImageFeatures/mitkGIFVolumetricStatistics.cpp #GlobalImageFeatures/itkEnhancedScalarImageToRunLengthFeaturesFilter.hxx #GlobalImageFeatures/itkEnhancedScalarImageToRunLengthMatrixFilter.hxx #GlobalImageFeatures/itkEnhancedHistogramToRunLengthFeaturesFilter.hxx #GlobalImageFeatures/itkEnhancedHistogramToTextureFeaturesFilter.hxx #GlobalImageFeatures/itkEnhancedScalarImageToTextureFeaturesFilter.hxx mitkCLUtil.cpp ) set( TOOL_FILES ) diff --git a/Modules/Classification/CLUtilities/include/itkEnhancedHistogramToTextureFeaturesFilter.hxx b/Modules/Classification/CLUtilities/include/itkEnhancedHistogramToTextureFeaturesFilter.hxx index 78e9de579f..a257e6735d 100644 --- a/Modules/Classification/CLUtilities/include/itkEnhancedHistogramToTextureFeaturesFilter.hxx +++ b/Modules/Classification/CLUtilities/include/itkEnhancedHistogramToTextureFeaturesFilter.hxx @@ -1,658 +1,656 @@ /*========================================================================= * * Copyright Insight Software Consortium * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0.txt * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * *=========================================================================*/ #ifndef __itkEnhancedHistogramToTextureFeaturesFilter_hxx #define __itkEnhancedHistogramToTextureFeaturesFilter_hxx #include "itkEnhancedHistogramToTextureFeaturesFilter.h" #include "itkNumericTraits.h" #include "vnl/vnl_math.h" #include "itkMath.h" #define itkMacroGLCMFeature(name, id) \ template< typename THistogram > \ const \ typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * \ EnhancedHistogramToTextureFeaturesFilter< THistogram > \ ::Get##name##Output() const \ { \ return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(id) ); \ } \ \ template< typename THistogram > \ typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType \ EnhancedHistogramToTextureFeaturesFilter< THistogram > \ ::Get##name() const \ { \ return this->Get##name##Output()->Get(); \ } namespace itk { namespace Statistics { //constructor template< typename THistogram > EnhancedHistogramToTextureFeaturesFilter< THistogram >::EnhancedHistogramToTextureFeaturesFilter(void) { this->ProcessObject::SetNumberOfRequiredInputs(1); // allocate the data objects for the outputs which are // just decorators real types for ( int i = 0; i < 25; ++i ) { this->ProcessObject::SetNthOutput( i, this->MakeOutput(i) ); } } template< typename THistogram > void EnhancedHistogramToTextureFeaturesFilter< THistogram > ::SetInput(const HistogramType *histogram) { this->ProcessObject::SetNthInput( 0, const_cast< HistogramType * >( histogram ) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::HistogramType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetInput() const { return itkDynamicCastInDebugMode< const HistogramType * >( this->GetPrimaryInput() ); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::DataObjectPointer EnhancedHistogramToTextureFeaturesFilter< THistogram > ::MakeOutput( DataObjectPointerArraySizeType itkNotUsed(idx) ) { return MeasurementObjectType::New().GetPointer(); } template< typename THistogram > void EnhancedHistogramToTextureFeaturesFilter< THistogram >::GenerateData(void) { typedef typename HistogramType::ConstIterator HistogramIterator; const HistogramType *inputHistogram = this->GetInput(); //Normalize the absolute frequencies and populate the relative frequency //container TotalRelativeFrequencyType totalFrequency = static_cast< TotalRelativeFrequencyType >( inputHistogram->GetTotalFrequency() ); m_RelativeFrequencyContainer.clear(); typename HistogramType::SizeValueType binsPerAxis = inputHistogram->GetSize(0); std::vector sumP; std::vector diffP; sumP.resize(2*binsPerAxis,0.0); diffP.resize(binsPerAxis,0.0); double numberOfPixels = 0; for ( HistogramIterator hit = inputHistogram->Begin(); hit != inputHistogram->End(); ++hit ) { AbsoluteFrequencyType frequency = hit.GetFrequency(); RelativeFrequencyType relativeFrequency = (totalFrequency > 0) ? frequency / totalFrequency : 0; m_RelativeFrequencyContainer.push_back(relativeFrequency); IndexType index = inputHistogram->GetIndex( hit.GetInstanceIdentifier() ); sumP[index[0] + index[1]] += relativeFrequency; diffP[std::abs(index[0] - index[1])] += relativeFrequency; //if (index[1] == 0) numberOfPixels += frequency; } // Now get the various means and variances. This is takes two passes // through the histogram. double pixelMean; double marginalMean; double marginalDevSquared; double pixelVariance; this->ComputeMeansAndVariances(pixelMean, marginalMean, marginalDevSquared, pixelVariance); // Finally compute the texture features. Another one pass. MeasurementType energy = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType entropy = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType correlation = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inverseDifferenceMoment = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inertia = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType clusterShade = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType clusterProminence = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType haralickCorrelation = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType autocorrelation = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType contrast = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType dissimilarity = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType maximumProbability = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inverseVariance = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType homogeneity1 = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType clusterTendency = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType variance = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType sumAverage = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType sumEntropy = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType sumVariance = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType diffAverage = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType diffEntropy = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType diffVariance = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inverseDifferenceMomentNormalized = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inverseDifferenceNormalized = NumericTraits< MeasurementType >::ZeroValue(); MeasurementType inverseDifference = NumericTraits< MeasurementType >::ZeroValue(); double pixelVarianceSquared = pixelVariance * pixelVariance; // Variance is only used in correlation. If variance is 0, then // (index[0] - pixelMean) * (index[1] - pixelMean) // should be zero as well. In this case, set the variance to 1. in // order to avoid NaN correlation. if( Math::FloatAlmostEqual( pixelVarianceSquared, 0.0, 4, 2*NumericTraits::epsilon() ) ) { pixelVarianceSquared = 1.; } const double log2 = std::log(2.0); typename RelativeFrequencyContainerType::const_iterator rFreqIterator = m_RelativeFrequencyContainer.begin(); - MITK_INFO << pixelMean << " - " << pixelVariance; - for ( HistogramIterator hit = inputHistogram->Begin(); hit != inputHistogram->End(); ++hit ) { RelativeFrequencyType frequency = *rFreqIterator; ++rFreqIterator; if ( frequency == 0 ) { continue; // no use doing these calculations if we're just multiplying by // zero. } IndexType index = inputHistogram->GetIndex( hit.GetInstanceIdentifier() ); energy += frequency * frequency; entropy -= ( frequency > 0.0001 ) ? frequency *std::log(frequency) / log2:0; correlation += ( ( index[0] - pixelMean ) * ( index[1] - pixelMean ) * frequency ) / pixelVarianceSquared; inverseDifferenceMoment += frequency / ( 1.0 + ( index[0] - index[1] ) * ( index[0] - index[1] ) ); inertia += ( index[0] - index[1] ) * ( index[0] - index[1] ) * frequency; clusterShade += std::pow( ( index[0] - pixelMean ) + ( index[1] - pixelMean ), 3 ) * frequency; clusterProminence += std::pow( ( index[0] - pixelMean ) + ( index[1] - pixelMean ), 4 ) * frequency; haralickCorrelation += index[0] * index[1] * frequency; // New Features added for Aerts compatibility autocorrelation +=index[0] * index[1] * frequency; contrast += (index[0] - index[1]) * (index[0] - index[1]) * frequency; dissimilarity += std::abs(index[0] - index[1]) * frequency; maximumProbability = std::max(maximumProbability, frequency); if (index[0] != index[1]) inverseVariance += frequency / ((index[0] - index[1])*(index[0] - index[1])); homogeneity1 +=frequency / ( 1.0 + std::abs( index[0] - index[1] )); clusterTendency += std::pow( ( index[0] - pixelMean ) + ( index[1] - pixelMean ), 2 ) * frequency; variance += std::pow( ( index[0] - pixelMean ), 2) * frequency; if (numberOfPixels > 0) { inverseDifferenceMomentNormalized += frequency / ( 1.0 + ( index[0] - index[1] ) * ( index[0] - index[1] ) / numberOfPixels / numberOfPixels); inverseDifferenceNormalized += frequency / ( 1.0 + std::abs( index[0] - index[1] ) / numberOfPixels ); } else { inverseDifferenceMomentNormalized = 0; inverseDifferenceNormalized = 0; } inverseDifference += frequency / ( 1.0 + std::abs( index[0] - index[1] ) ); } for (int i = 0; i < (int)sumP.size();++i) { double frequency = sumP[i]; sumAverage += i * frequency; sumEntropy -= ( frequency > 0.0001 ) ? frequency *std::log(frequency) / log2:0; } for (int i = 0; i < (int)sumP.size();++i) { double frequency = sumP[i]; sumVariance += (i-sumAverage)*(i-sumAverage) * frequency; } for (int i = 0; i < (int)diffP.size();++i) { double frequency = diffP[i]; diffAverage += i * frequency; diffEntropy -= ( frequency > 0.0001 ) ? frequency *std::log(frequency) / log2:0; } for (int i = 0; i < (int)diffP.size();++i) { double frequency = diffP[i]; sumVariance += (i-diffAverage)*(i-diffAverage) * frequency; } if (marginalDevSquared > 0) { haralickCorrelation = ( haralickCorrelation - marginalMean * marginalMean ) / marginalDevSquared; } else { haralickCorrelation =0; } MeasurementObjectType *energyOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(0) ); energyOutputObject->Set(energy); MeasurementObjectType *entropyOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(1) ); entropyOutputObject->Set(entropy); MeasurementObjectType *correlationOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(2) ); correlationOutputObject->Set(correlation); MeasurementObjectType *inverseDifferenceMomentOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(3) ); inverseDifferenceMomentOutputObject->Set(inverseDifferenceMoment); MeasurementObjectType *inertiaOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(4) ); inertiaOutputObject->Set(inertia); MeasurementObjectType *clusterShadeOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(5) ); clusterShadeOutputObject->Set(clusterShade); MeasurementObjectType *clusterProminenceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(6) ); clusterProminenceOutputObject->Set(clusterProminence); MeasurementObjectType *haralickCorrelationOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(7) ); haralickCorrelationOutputObject->Set(haralickCorrelation); MeasurementObjectType *autocorrelationOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(8) ); autocorrelationOutputObject->Set(autocorrelation); MeasurementObjectType *contrastOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(9) ); contrastOutputObject->Set(contrast); MeasurementObjectType *dissimilarityOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(10) ); dissimilarityOutputObject->Set(dissimilarity); MeasurementObjectType *maximumProbabilityOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(11) ); maximumProbabilityOutputObject->Set(maximumProbability); MeasurementObjectType *inverseVarianceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(12) ); inverseVarianceOutputObject->Set(inverseVariance); MeasurementObjectType *homogeneity1OutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(13) ); homogeneity1OutputObject->Set(homogeneity1); MeasurementObjectType *clusterTendencyOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(14) ); clusterTendencyOutputObject->Set(clusterTendency); MeasurementObjectType *varianceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(15) ); varianceOutputObject->Set(variance); MeasurementObjectType *sumAverageOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(16) ); sumAverageOutputObject->Set(sumAverage); MeasurementObjectType *sumEntropyOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(17) ); sumEntropyOutputObject->Set(sumEntropy); MeasurementObjectType *sumVarianceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(18) ); sumVarianceOutputObject->Set(sumVariance); MeasurementObjectType *diffAverageOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(19) ); diffAverageOutputObject->Set(diffAverage); MeasurementObjectType *diffEntropyOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(20) ); diffEntropyOutputObject->Set(diffEntropy); MeasurementObjectType *diffVarianceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(21) ); diffVarianceOutputObject->Set(diffVariance); MeasurementObjectType *inverseDifferenceMomentNormalizedOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(22) ); inverseDifferenceMomentNormalizedOutputObject->Set(inverseDifferenceMomentNormalized); MeasurementObjectType *inverseDifferenceNormalizedOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(23) ); inverseDifferenceNormalizedOutputObject->Set(inverseDifferenceNormalized); MeasurementObjectType *inverseDifferenceOutputObject = static_cast< MeasurementObjectType * >( this->ProcessObject::GetOutput(24) ); inverseDifferenceOutputObject->Set(inverseDifference); } template< typename THistogram > void EnhancedHistogramToTextureFeaturesFilter< THistogram >::ComputeMeansAndVariances(double & pixelMean, double & marginalMean, double & marginalDevSquared, double & pixelVariance) { // This function takes two passes through the histogram and two passes through // an array of the same length as a histogram axis. This could probably be // cleverly compressed to one pass, but it's not clear that that's necessary. typedef typename HistogramType::ConstIterator HistogramIterator; const HistogramType *inputHistogram = this->GetInput(); // Initialize everything typename HistogramType::SizeValueType binsPerAxis = inputHistogram->GetSize(0); double *marginalSums = new double[binsPerAxis]; for ( double *ms_It = marginalSums; ms_It < marginalSums + binsPerAxis; ms_It++ ) { *ms_It = 0; } pixelMean = 0; typename RelativeFrequencyContainerType::const_iterator rFreqIterator = m_RelativeFrequencyContainer.begin(); // Ok, now do the first pass through the histogram to get the marginal sums // and compute the pixel mean HistogramIterator hit = inputHistogram->Begin(); while ( hit != inputHistogram->End() ) { RelativeFrequencyType frequency = *rFreqIterator; IndexType index = inputHistogram->GetIndex( hit.GetInstanceIdentifier() ); pixelMean += index[0] * frequency; marginalSums[index[0]] += frequency; ++hit; ++rFreqIterator; } /* Now get the mean and deviaton of the marginal sums. Compute incremental mean and SD, a la Knuth, "The Art of Computer Programming, Volume 2: Seminumerical Algorithms", section 4.2.2. Compute mean and standard deviation using the recurrence relation: M(1) = x(1), M(k) = M(k-1) + (x(k) - M(k-1) ) / k S(1) = 0, S(k) = S(k-1) + (x(k) - M(k-1)) * (x(k) - M(k)) for 2 <= k <= n, then sigma = std::sqrt(S(n) / n) (or divide by n-1 for sample SD instead of population SD). */ marginalMean = marginalSums[0]; marginalDevSquared = 0; for ( unsigned int arrayIndex = 1; arrayIndex < binsPerAxis; arrayIndex++ ) { int k = arrayIndex + 1; double M_k_minus_1 = marginalMean; double S_k_minus_1 = marginalDevSquared; double x_k = marginalSums[arrayIndex]; double M_k = M_k_minus_1 + ( x_k - M_k_minus_1 ) / k; double S_k = S_k_minus_1 + ( x_k - M_k_minus_1 ) * ( x_k - M_k ); marginalMean = M_k; marginalDevSquared = S_k; } marginalDevSquared = marginalDevSquared / binsPerAxis; rFreqIterator = m_RelativeFrequencyContainer.begin(); // OK, now compute the pixel variances. pixelVariance = 0; for ( hit = inputHistogram->Begin(); hit != inputHistogram->End(); ++hit ) { RelativeFrequencyType frequency = *rFreqIterator; IndexType index = inputHistogram->GetIndex( hit.GetInstanceIdentifier() ); pixelVariance += ( index[0] - pixelMean ) * ( index[0] - pixelMean ) * frequency; ++rFreqIterator; } delete[] marginalSums; } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetEnergyOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(0) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetEntropyOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(1) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetCorrelationOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(2) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetInverseDifferenceMomentOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(3) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetInertiaOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(4) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetClusterShadeOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(5) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetClusterProminenceOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(6) ); } template< typename THistogram > const typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementObjectType * EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetHaralickCorrelationOutput() const { return static_cast< const MeasurementObjectType * >( this->ProcessObject::GetOutput(7) ); } itkMacroGLCMFeature(Autocorrelation,8) itkMacroGLCMFeature(Contrast,9) itkMacroGLCMFeature(Dissimilarity,10) itkMacroGLCMFeature(MaximumProbability,11) itkMacroGLCMFeature(InverseVariance,12) itkMacroGLCMFeature(Homogeneity1,13) itkMacroGLCMFeature(ClusterTendency,14) itkMacroGLCMFeature(Variance,15) itkMacroGLCMFeature(SumAverage,16) itkMacroGLCMFeature(SumEntropy,17) itkMacroGLCMFeature(SumVariance,18) itkMacroGLCMFeature(DifferenceAverage,19) itkMacroGLCMFeature(DifferenceEntropy,20) itkMacroGLCMFeature(DifferenceVariance,21) itkMacroGLCMFeature(InverseDifferenceMomentNormalized,22) itkMacroGLCMFeature(InverseDifferenceNormalized,23) itkMacroGLCMFeature(InverseDifference,24) template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetEnergy() const { return this->GetEnergyOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetEntropy() const { return this->GetEntropyOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetCorrelation() const { return this->GetCorrelationOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetInverseDifferenceMoment() const { return this->GetInverseDifferenceMomentOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetInertia() const { return this->GetInertiaOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetClusterShade() const { return this->GetClusterShadeOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetClusterProminence() const { return this->GetClusterProminenceOutput()->Get(); } template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetHaralickCorrelation() const { return this->GetHaralickCorrelationOutput()->Get(); } #define itkMacroGLCMFeatureSwitch(name) \ case name : \ return this->Get##name() template< typename THistogram > typename EnhancedHistogramToTextureFeaturesFilter< THistogram >::MeasurementType EnhancedHistogramToTextureFeaturesFilter< THistogram > ::GetFeature(TextureFeatureName feature) { switch ( feature ) { itkMacroGLCMFeatureSwitch(Energy); itkMacroGLCMFeatureSwitch(Entropy); itkMacroGLCMFeatureSwitch(Correlation); itkMacroGLCMFeatureSwitch(InverseDifferenceMoment); itkMacroGLCMFeatureSwitch(Inertia); itkMacroGLCMFeatureSwitch(ClusterShade); itkMacroGLCMFeatureSwitch(ClusterProminence); itkMacroGLCMFeatureSwitch(HaralickCorrelation); itkMacroGLCMFeatureSwitch(Autocorrelation); itkMacroGLCMFeatureSwitch(Contrast); itkMacroGLCMFeatureSwitch(Dissimilarity); itkMacroGLCMFeatureSwitch(MaximumProbability); itkMacroGLCMFeatureSwitch(InverseVariance); itkMacroGLCMFeatureSwitch(Homogeneity1); itkMacroGLCMFeatureSwitch(ClusterTendency); itkMacroGLCMFeatureSwitch(Variance); itkMacroGLCMFeatureSwitch(SumAverage); itkMacroGLCMFeatureSwitch(SumEntropy); itkMacroGLCMFeatureSwitch(SumVariance); itkMacroGLCMFeatureSwitch(DifferenceAverage); itkMacroGLCMFeatureSwitch(DifferenceEntropy); itkMacroGLCMFeatureSwitch(DifferenceVariance); itkMacroGLCMFeatureSwitch(InverseDifferenceMomentNormalized); itkMacroGLCMFeatureSwitch(InverseDifferenceNormalized); itkMacroGLCMFeatureSwitch(InverseDifference); default: return 0; } } #undef itkMacroGLCMFeatureSwitch template< typename THistogram > void EnhancedHistogramToTextureFeaturesFilter< THistogram > ::PrintSelf(std::ostream & os, Indent indent) const { Superclass::PrintSelf(os, indent); } } // end of namespace Statistics } // end of namespace itk #endif \ No newline at end of file diff --git a/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.cpp b/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.cpp index 95e0e12cee..0b8fb73903 100644 --- a/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.cpp +++ b/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.cpp @@ -1,96 +1,105 @@ #ifndef itkMultiHistogramFilter_cpp #define itkMultiHistogramFilter_cpp #include #include +#include #include template< class TInputImageType, class TOuputImageType> itk::MultiHistogramFilter::MultiHistogramFilter(): m_Offset(-3.0), m_Delta(0.6) { this->SetNumberOfRequiredOutputs(11); this->SetNumberOfRequiredInputs(0); for (int i = 0; i < 11; ++i) { this->SetNthOutput( i, this->MakeOutput(i) ); } } template< class TInputImageType, class TOuputImageType> void - itk::MultiHistogramFilter::GenerateData() +itk::MultiHistogramFilter::BeforeThreadedGenerateData() { - double offset = m_Offset;// -3.0; - double delta = m_Delta;// 0.6; - - typedef itk::NeighborhoodIterator IteratorType; - typedef itk::ConstNeighborhoodIterator ConstIteratorType; - +// MITK_INFO << "Creating output images"; InputImagePointer input = this->GetInput(0); CreateOutputImage(input, this->GetOutput(0)); CreateOutputImage(input, this->GetOutput(1)); CreateOutputImage(input, this->GetOutput(2)); CreateOutputImage(input, this->GetOutput(3)); CreateOutputImage(input, this->GetOutput(4)); CreateOutputImage(input, this->GetOutput(5)); CreateOutputImage(input, this->GetOutput(6)); CreateOutputImage(input, this->GetOutput(7)); CreateOutputImage(input, this->GetOutput(8)); CreateOutputImage(input, this->GetOutput(9)); CreateOutputImage(input, this->GetOutput(10)); +} +template< class TInputImageType, class TOuputImageType> +void +itk::MultiHistogramFilter::ThreadedGenerateData(const OutputImageRegionType & outputRegionForThread, ThreadIdType threadId) +{ + double offset = m_Offset;// -3.0; + double delta = m_Delta;// 0.6; + + typedef itk::ImageRegionIterator IteratorType; + typedef itk::ConstNeighborhoodIterator ConstIteratorType; + + InputImagePointer input = this->GetInput(0); +// MITK_INFO << "Creating output iterator"; typename TInputImageType::SizeType size; size.Fill(5); std::vector iterVector; for (int i = 0; i < 11; ++i) { - IteratorType iter(size, this->GetOutput(i), this->GetOutput(i)->GetLargestPossibleRegion()); + IteratorType iter(this->GetOutput(i), outputRegionForThread); iterVector.push_back(iter); } - ConstIteratorType inputIter( size, input, input->GetLargestPossibleRegion()); + ConstIteratorType inputIter(size, input, outputRegionForThread); while (!inputIter.IsAtEnd()) { for (int i = 0; i < 11; ++i) { - iterVector[i].SetCenterPixel(0); + iterVector[i].Set(0); } for (int i = 0; i < inputIter.Size(); ++i) { double value = inputIter.GetPixel(i); value -= offset; value /= delta; int pos = (int)(value); pos = std::max(0, std::min(10, pos)); - iterVector[pos].SetCenterPixel(iterVector[pos].GetCenterPixel() + 1); + iterVector[pos].Value() += 1;// (iterVector[pos].GetCenterPixel() + 1); } for (int i = 0; i < 11; ++i) { ++(iterVector[i]); } ++inputIter; } } template< class TInputImageType, class TOuputImageType> itk::DataObject::Pointer itk::MultiHistogramFilter::MakeOutput(unsigned int /*idx*/) { DataObject::Pointer output; output = ( TOuputImageType::New() ).GetPointer(); return output.GetPointer(); } template< class TInputImageType, class TOuputImageType> void itk::MultiHistogramFilter::CreateOutputImage(InputImagePointer input, OutputImagePointer output) { output->SetRegions(input->GetLargestPossibleRegion()); output->Allocate(); } -#endif //itkMultiHistogramFilter_cpp \ No newline at end of file +#endif //itkMultiHistogramFilter_cpp diff --git a/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.h b/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.h index 120482e9e1..fc1752806c 100644 --- a/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.h +++ b/Modules/Classification/CLUtilities/include/itkMultiHistogramFilter.h @@ -1,50 +1,53 @@ #ifndef itkMultiHistogramFilter_h #define itkMultiHistogramFilter_h #include "itkImageToImageFilter.h" namespace itk { template class MultiHistogramFilter : public ImageToImageFilter< TInputImageType, TOuputImageType> { public: typedef MultiHistogramFilter Self; typedef ImageToImageFilter< TInputImageType, TOuputImageType > Superclass; typedef SmartPointer< Self > Pointer; typedef typename TInputImageType::ConstPointer InputImagePointer; typedef typename TOuputImageType::Pointer OutputImagePointer; + typedef typename TOuputImageType::RegionType OutputImageRegionType; itkNewMacro (Self); itkTypeMacro(MultiHistogramFilter, ImageToImageFilter); itkSetMacro(Delta, double); itkGetConstMacro(Delta, double); itkSetMacro(Offset, double); itkGetConstMacro(Offset, double); protected: MultiHistogramFilter(); ~MultiHistogramFilter(){}; - virtual void GenerateData(); + virtual void ThreadedGenerateData(const OutputImageRegionType & outputRegionForThread, ThreadIdType threadId); + virtual void BeforeThreadedGenerateData(void); + DataObject::Pointer MakeOutput(unsigned int /*idx*/); void CreateOutputImage(InputImagePointer input, OutputImagePointer output); private: MultiHistogramFilter(const Self &); // purposely not implemented void operator=(const Self &); // purposely not implemented double m_Delta; double m_Offset; }; } #ifndef ITK_MANUAL_INSTANTIATION #include "itkMultiHistogramFilter.cpp" #endif #endif // itkMultiHistogramFilter_h diff --git a/Modules/Classification/CLUtilities/include/mitkRandomImageSampler.h b/Modules/Classification/CLUtilities/include/mitkRandomImageSampler.h new file mode 100644 index 0000000000..b737d78330 --- /dev/null +++ b/Modules/Classification/CLUtilities/include/mitkRandomImageSampler.h @@ -0,0 +1,111 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ +#ifndef __mitkRandomImageSampler_h +#define __mitkRandomImageSampler_h + +#include "MitkCLUtilitiesExports.h" + +//MITK +#include +#include "mitkImageToImageFilter.h" +#include + +namespace mitk +{ + enum RandomImageSamplerMode + { + SINGLE_ACCEPTANCE_RATE, + CLASS_DEPENDEND_ACCEPTANCE_RATE, + SINGLE_NUMBER_OF_ACCEPTANCE, + CLASS_DEPENDEND_NUMBER_OF_ACCEPTANCE + }; + + + class MITKCLUTILITIES_EXPORT RandomImageSampler : public ImageToImageFilter + { + public: + + mitkClassMacro( RandomImageSampler , ImageToImageFilter ); + itkFactorylessNewMacro(Self) + itkCloneMacro(Self) + + itkSetMacro(SamplingMode, RandomImageSamplerMode); + itkGetConstMacro(SamplingMode, RandomImageSamplerMode); + + itkSetMacro(AcceptanceRate, double); + itkGetConstMacro(AcceptanceRate, double); + + itkSetMacro(AcceptanceRateVector, std::vector); + itkGetConstMacro(AcceptanceRateVector, std::vector); + + itkSetMacro(NumberOfSamples, unsigned int); + itkGetConstMacro(NumberOfSamples, unsigned int); + + itkSetMacro(NumberOfSamplesVector, std::vector); + itkGetConstMacro(NumberOfSamplesVector, std::vector); + + private: + /*! + \brief standard constructor + */ + RandomImageSampler(); + /*! + \brief standard destructor + */ + ~RandomImageSampler(); + /*! + \brief Method generating the output information of this filter (e.g. image dimension, image type, etc.). + The interface ImageToImageFilter requires this implementation. Everything is taken from the input image. + */ + virtual void GenerateOutputInformation() override; + /*! + \brief Method generating the output of this filter. Called in the updated process of the pipeline. + This method generates the smoothed output image. + */ + virtual void GenerateData() override; + + /*! + \brief Internal templated method calling the ITK bilteral filter. Here the actual filtering is performed. + */ + template + void ItkImageProcessing(const itk::Image* itkImage); + + /*! + \brief Internal templated method calling the ITK bilteral filter. Here the actual filtering is performed. + */ + template + void ItkImageProcessingClassDependendSampling(const itk::Image* itkImage); + + /*! + \brief Internal templated method calling the ITK bilteral filter. Here the actual filtering is performed. + */ + template + void ItkImageProcessingFixedNumberSampling(const itk::Image* itkImage); + + /*! + \brief Internal templated method calling the ITK bilteral filter. Here the actual filtering is performed. + */ + template + void ItkImageProcessingClassDependendNumberSampling(const itk::Image* itkImage); + + double m_AcceptanceRate; + std::vector m_AcceptanceRateVector; + unsigned int m_NumberOfSamples; + std::vector m_NumberOfSamplesVector; + RandomImageSamplerMode m_SamplingMode; + }; +} //END mitk namespace +#endif diff --git a/Modules/Classification/CLUtilities/src/Algorithms/mitkRandomImageSampler.cpp b/Modules/Classification/CLUtilities/src/Algorithms/mitkRandomImageSampler.cpp new file mode 100644 index 0000000000..feedad12b9 --- /dev/null +++ b/Modules/Classification/CLUtilities/src/Algorithms/mitkRandomImageSampler.cpp @@ -0,0 +1,266 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include "mitkRandomImageSampler.h" +#include +#include "mitkImageAccessByItk.h" +#include "mitkImageCast.h" +#include "itkUnaryFunctorImageFilter.h" +#include +#include "itkImageDuplicator.h" + +mitk::RandomImageSampler::RandomImageSampler() + : m_AcceptanceRate(0.1), m_SamplingMode(RandomImageSamplerMode::SINGLE_ACCEPTANCE_RATE) +{ + //default parameters DomainSigma: 2 , RangeSigma: 50, AutoKernel: true, KernelRadius: 1 +} + +mitk::RandomImageSampler::~RandomImageSampler() +{ +} + +template< class TInput, class TOutput> +class RandomlySampleFunctor +{ +public: + RandomlySampleFunctor() {}; + ~RandomlySampleFunctor() {}; + bool operator!=(const RandomlySampleFunctor &) const + { + return false; + } + bool operator==(const RandomlySampleFunctor & other) const + { + return !(*this != other); + } + inline TOutput operator()(const TInput & A) const + { + if (rand() < RAND_MAX*m_AcceptanceRate) + return A; + else + return TOutput(0); + } + + double m_AcceptanceRate = 0.1; +}; + +template< class TInput, class TOutput> +class RandomlySampleClassDependedFunctor +{ +public: + RandomlySampleClassDependedFunctor() {}; + ~RandomlySampleClassDependedFunctor() {}; + bool operator!=(const RandomlySampleClassDependedFunctor &) const + { + return false; + } + bool operator==(const RandomlySampleClassDependedFunctor & other) const + { + return !(*this != other); + } + inline TOutput operator()(const TInput & A) const + { + std::size_t index = static_cast(A + 0.5); + double samplingRate = 0; + if (index >= 0 && index < m_SamplingRateVector.size()) + { + samplingRate = m_SamplingRateVector[index]; + } + + if (rand() < RAND_MAX*samplingRate) + return A; + else + return TOutput(0); + } + + std::vector m_SamplingRateVector; +}; + +void mitk::RandomImageSampler::GenerateData() +{ + mitk::Image::ConstPointer inputImage = this->GetInput(0); + switch (m_SamplingMode) + { + case SINGLE_ACCEPTANCE_RATE: + AccessByItk(inputImage.GetPointer(), ItkImageProcessing); + break; + case CLASS_DEPENDEND_ACCEPTANCE_RATE : + AccessByItk(inputImage.GetPointer(), ItkImageProcessingClassDependendSampling); + break; + case SINGLE_NUMBER_OF_ACCEPTANCE: + AccessByItk(inputImage.GetPointer(), ItkImageProcessingFixedNumberSampling); + break; + case CLASS_DEPENDEND_NUMBER_OF_ACCEPTANCE: + AccessByItk(inputImage.GetPointer(), ItkImageProcessingClassDependendNumberSampling); + break; + default: + AccessByItk(inputImage.GetPointer(), ItkImageProcessing); + break; + } +} + +template +void mitk::RandomImageSampler::ItkImageProcessing( const itk::Image* itkImage ) +{ + //ITK Image type given from the input image + typedef itk::Image< TPixel, VImageDimension > ItkImageType; + //bilateral filter with same type + typedef RandomlySampleFunctor< typename ItkImageType::PixelType, + typename ItkImageType::PixelType> LocalSampleFunctorType; + typedef itk::UnaryFunctorImageFilter RandomImageSamplerType; + typename RandomImageSamplerType::Pointer RandomImageSampler = RandomImageSamplerType::New(); + RandomImageSampler->SetInput(itkImage); + + LocalSampleFunctorType functor; + functor.m_AcceptanceRate = m_AcceptanceRate; + RandomImageSampler->SetFunctor(functor); + RandomImageSampler->GetFunctor().m_AcceptanceRate = m_AcceptanceRate; + RandomImageSampler->Update(); + + + //get Pointer to output image + mitk::Image::Pointer resultImage = this->GetOutput(); + //write into output image + mitk::CastToMitkImage(RandomImageSampler->GetOutput(), resultImage); +} + +template +void mitk::RandomImageSampler::ItkImageProcessingClassDependendSampling(const itk::Image* itkImage) +{ + //ITK Image type given from the input image + typedef itk::Image< TPixel, VImageDimension > ItkImageType; + //bilateral filter with same type + typedef RandomlySampleClassDependedFunctor< typename ItkImageType::PixelType, + typename ItkImageType::PixelType> LocalSampleFunctorType; + typedef itk::UnaryFunctorImageFilter RandomImageSamplerType; + typename RandomImageSamplerType::Pointer RandomImageSampler = RandomImageSamplerType::New(); + RandomImageSampler->SetInput(itkImage); + + LocalSampleFunctorType functor; + functor.m_SamplingRateVector = m_AcceptanceRateVector; + RandomImageSampler->SetFunctor(functor); + RandomImageSampler->GetFunctor().m_SamplingRateVector = m_AcceptanceRateVector; + RandomImageSampler->Update(); + + + //get Pointer to output image + mitk::Image::Pointer resultImage = this->GetOutput(); + //write into output image + mitk::CastToMitkImage(RandomImageSampler->GetOutput(), resultImage); +} + +template +void mitk::RandomImageSampler::ItkImageProcessingFixedNumberSampling(const itk::Image* itkImage) +{ + //ITK Image type given from the input image + typedef itk::Image< TPixel, VImageDimension > ItkImageType; + typedef itk::ImageDuplicator< ItkImageType > DuplicatorType; + typedef itk::ImageRandomNonRepeatingIteratorWithIndex IteratorType; + + typename DuplicatorType::Pointer duplicator = DuplicatorType::New(); + duplicator->SetInputImage(itkImage); + duplicator->Update(); + typename ItkImageType::Pointer clonedImage = duplicator->GetOutput(); + + //clonedImage->FillBuffer(0); + std::vector counts; + IteratorType iter(clonedImage, clonedImage->GetLargestPossibleRegion()); + iter.SetNumberOfSamples(clonedImage->GetLargestPossibleRegion().GetNumberOfPixels()); + //iter.GoToBegin(); + while (!iter.IsAtEnd()) + { + std::size_t index = static_cast(iter.Value() + 0.5); + while (index >= counts.size()) + { + counts.push_back(0); + } + if (counts[index] < m_NumberOfSamples) + { + //clonedImage->SetPixel(iter.GetIndex(), iter.Value()); + counts[index] += 1; + } + else + { + iter.Set(0.0); + //clonedImage->SetPixel(iter.GetIndex(), 0.0); + } + + ++iter; + } + + //get Pointer to output image + mitk::Image::Pointer resultImage = this->GetOutput(); + //write into output image + mitk::CastToMitkImage(clonedImage, resultImage); +} + +template +void mitk::RandomImageSampler::ItkImageProcessingClassDependendNumberSampling(const itk::Image* itkImage) +{ + //ITK Image type given from the input image + typedef itk::Image< TPixel, VImageDimension > ItkImageType; + typedef itk::ImageDuplicator< ItkImageType > DuplicatorType; + typedef itk::ImageRandomNonRepeatingIteratorWithIndex IteratorType; + + typename DuplicatorType::Pointer duplicator = DuplicatorType::New(); + duplicator->SetInputImage(itkImage); + duplicator->Update(); + typename ItkImageType::Pointer clonedImage = duplicator->GetOutput(); + + std::vector counts; + IteratorType iter(clonedImage, clonedImage->GetLargestPossibleRegion()); + iter.SetNumberOfSamples(clonedImage->GetLargestPossibleRegion().GetNumberOfPixels()); + while (!iter.IsAtEnd()) + { + std::size_t index = static_cast(iter.Value() + 0.5); + if (index < m_NumberOfSamplesVector.size()) + { + while (index >= counts.size()) + { + counts.push_back(0); + } + + if (counts[index] < m_NumberOfSamplesVector[index]) + { + counts[index] += 1; + } + else + { + iter.Set(0.0); + } + } + else + { + iter.Set(0.0); + } + + ++iter; + } + + //get Pointer to output image + mitk::Image::Pointer resultImage = this->GetOutput(); + //write into output image + mitk::CastToMitkImage(clonedImage, resultImage); +} + + +void mitk::RandomImageSampler::GenerateOutputInformation() +{ + mitk::Image::Pointer inputImage = (mitk::Image*) this->GetInput(); + mitk::Image::Pointer output = this->GetOutput(); + itkDebugMacro(<<"GenerateOutputInformation()"); + if(inputImage.IsNull()) return; +} diff --git a/Modules/Classification/CLUtilities/src/GlobalImageFeatures/mitkGIFVolumetricStatistics.cpp b/Modules/Classification/CLUtilities/src/GlobalImageFeatures/mitkGIFVolumetricStatistics.cpp index 662480ae8c..5a2ce2cfa0 100644 --- a/Modules/Classification/CLUtilities/src/GlobalImageFeatures/mitkGIFVolumetricStatistics.cpp +++ b/Modules/Classification/CLUtilities/src/GlobalImageFeatures/mitkGIFVolumetricStatistics.cpp @@ -1,170 +1,173 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #include // MITK #include #include #include // ITK #include #include // VTK #include #include #include // STL #include #include template void CalculateVolumeStatistic(itk::Image* itkImage, mitk::Image::Pointer mask, mitk::GIFVolumetricStatistics::FeatureListType & featureList) { typedef itk::Image ImageType; typedef itk::Image MaskType; typedef itk::LabelStatisticsImageFilter FilterType; typename MaskType::Pointer maskImage = MaskType::New(); mitk::CastToItkImage(mask, maskImage); typename FilterType::Pointer labelStatisticsImageFilter = FilterType::New(); labelStatisticsImageFilter->SetInput( itkImage ); labelStatisticsImageFilter->SetLabelInput(maskImage); labelStatisticsImageFilter->Update(); double volume = labelStatisticsImageFilter->GetCount(1); + double voxelVolume = 1; for (int i = 0; i < (int)(VImageDimension); ++i) { volume *= itkImage->GetSpacing()[i]; + voxelVolume *= itkImage->GetSpacing()[i]; } - featureList.push_back(std::make_pair("Volumetric Features Volume (pixel based)",volume)); + featureList.push_back(std::make_pair("Volumetric Features Volume (pixel based)", volume)); + featureList.push_back(std::make_pair("Volumetric Features Voxel Volume", voxelVolume)); } template void CalculateLargestDiameter(itk::Image* mask, mitk::GIFVolumetricStatistics::FeatureListType & featureList) { typedef itk::Image ImageType; typedef typename ImageType::PointType PointType; typename ImageType::SizeType radius; for (int i=0; i < (int)VImageDimension; ++i) radius[i] = 1; itk::NeighborhoodIterator iterator(radius,mask, mask->GetRequestedRegion()); std::vector borderPoints; while(!iterator.IsAtEnd()) { if (iterator.GetCenterPixel() == 0) { ++iterator; continue; } bool border = false; for (int i = 0; i < (int)(iterator.Size()); ++i) { if (iterator.GetPixel(i) == 0) { border = true; break; } } if (border) { auto centerIndex = iterator.GetIndex(); PointType centerPoint; mask->TransformIndexToPhysicalPoint(centerIndex, centerPoint ); borderPoints.push_back(centerPoint); } ++iterator; } double longestDiameter = 0; unsigned long numberOfBorderPoints = borderPoints.size(); for (int i = 0; i < (int)numberOfBorderPoints; ++i) { auto point = borderPoints[i]; for (int j = i; j < (int)numberOfBorderPoints; ++j) { double newDiameter=point.EuclideanDistanceTo(borderPoints[j]); if (newDiameter > longestDiameter) longestDiameter = newDiameter; } } featureList.push_back(std::make_pair("Volumetric Features Maximum 3D diameter",longestDiameter)); } mitk::GIFVolumetricStatistics::GIFVolumetricStatistics() { } mitk::GIFVolumetricStatistics::FeatureListType mitk::GIFVolumetricStatistics::CalculateFeatures(const Image::Pointer & image, const Image::Pointer &mask) { FeatureListType featureList; AccessByItk_2(image, CalculateVolumeStatistic, mask, featureList); AccessByItk_1(mask, CalculateLargestDiameter, featureList); vtkSmartPointer mesher = vtkSmartPointer::New(); vtkSmartPointer stats = vtkSmartPointer::New(); mesher->SetInputData(mask->GetVtkImageData()); stats->SetInputConnection(mesher->GetOutputPort()); stats->Update(); double pi = vnl_math::pi; double meshVolume = stats->GetVolume(); double meshSurf = stats->GetSurfaceArea(); double pixelVolume = featureList[0].second; double compactness1 = pixelVolume / ( std::sqrt(pi) * std::pow(meshSurf, 2.0/3.0)); double compactness2 = 36*pi*pixelVolume*pixelVolume/meshSurf/meshSurf/meshSurf; double sphericity=std::pow(pi,1/3.0) *std::pow(6*pixelVolume, 2.0/3.0) / meshSurf; double surfaceToVolume = meshSurf / pixelVolume; double sphericalDisproportion = meshSurf / 4 / pi / std::pow(3.0 / 4.0 / pi * pixelVolume, 2.0 / 3.0); featureList.push_back(std::make_pair("Volumetric Features Volume (mesh based)",meshVolume)); featureList.push_back(std::make_pair("Volumetric Features Surface area",meshSurf)); featureList.push_back(std::make_pair("Volumetric Features Surface to volume ratio",surfaceToVolume)); featureList.push_back(std::make_pair("Volumetric Features Sphericity",sphericity)); featureList.push_back(std::make_pair("Volumetric Features Compactness 1",compactness1)); featureList.push_back(std::make_pair("Volumetric Features Compactness 2",compactness2)); featureList.push_back(std::make_pair("Volumetric Features Spherical disproportion",sphericalDisproportion)); return featureList; } mitk::GIFVolumetricStatistics::FeatureNameListType mitk::GIFVolumetricStatistics::GetFeatureNames() { FeatureNameListType featureList; featureList.push_back("Volumetric Features Compactness 1"); featureList.push_back("Volumetric Features Compactness 2"); featureList.push_back("Volumetric Features Sphericity"); featureList.push_back("Volumetric Features Surface to volume ratio"); featureList.push_back("Volumetric Features Surface area"); featureList.push_back("Volumetric Features Volume (mesh based)"); featureList.push_back("Volumetric Features Volume (pixel based)"); featureList.push_back("Volumetric Features Spherical disproportion"); featureList.push_back("Volumetric Features Maximum 3D diameter"); return featureList; } diff --git a/Modules/Classification/CLVigraRandomForest/files.cmake b/Modules/Classification/CLVigraRandomForest/files.cmake index eddacbdbf5..358f417862 100644 --- a/Modules/Classification/CLVigraRandomForest/files.cmake +++ b/Modules/Classification/CLVigraRandomForest/files.cmake @@ -1,22 +1,25 @@ file(GLOB_RECURSE H_FILES RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "${CMAKE_CURRENT_SOURCE_DIR}/include/*") set(CPP_FILES mitkModuleActivator.cpp Classifier/mitkVigraRandomForestClassifier.cpp + Classifier/mitkPURFClassifier.cpp Algorithm/itkHessianMatrixEigenvalueImageFilter.cpp Algorithm/itkStructureTensorEigenvalueImageFilter.cpp + Splitter/mitkAdditionalRFData.cpp Splitter/mitkImpurityLoss.cpp + Splitter/mitkPUImpurityLoss.cpp Splitter/mitkLinearSplitting.cpp Splitter/mitkThresholdSplit.cpp IO/mitkRandomForestIO.cpp IO/mitkVigraRandomForestClassifierSerializer.cpp IO/mitkDummyLsetReader.cpp ) set( TOOL_FILES ) diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkAdditionalRFData.h b/Modules/Classification/CLVigraRandomForest/include/mitkAdditionalRFData.h new file mode 100644 index 0000000000..608579700b --- /dev/null +++ b/Modules/Classification/CLVigraRandomForest/include/mitkAdditionalRFData.h @@ -0,0 +1,33 @@ +#ifndef mitkAdditionalRFData_h +#define mitkAdditionalRFData_h + +#include + + +namespace mitk +{ + class AdditionalRFDataAbstract + { + public: + // This function is necessary to be able to do dynamic casts + virtual void NoFunction() = 0; + virtual ~AdditionalRFDataAbstract() {}; + }; + + class NoRFData : public AdditionalRFDataAbstract + { + public: + virtual void NoFunction() { return; } + virtual ~NoRFData() {}; + }; + + class PURFData : public AdditionalRFDataAbstract + { + public: + vigra::ArrayVector m_Kappa; + virtual void NoFunction(); + virtual ~PURFData() {}; + }; +} + +#endif //mitkAdditionalRFData_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h b/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h index abce530788..f36e5438aa 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h @@ -1,55 +1,57 @@ #ifndef mitkImpurityLoss_h #define mitkImpurityLoss_h #include #include +#include namespace mitk { template , class TWeightContainer = vigra::MultiArrayView<2, double> > class ImpurityLoss { public: typedef TLabelContainer LabelContainerType; typedef TWeightContainer WeightContainerType; template ImpurityLoss(TLabelContainer const &labels, - vigra::ProblemSpec const &ext); + vigra::ProblemSpec const &ext, + AdditionalRFDataAbstract *data); void Reset(); template double Increment(TDataIterator begin, TDataIterator end); template double Decrement(TDataIterator begin, TDataIterator end); template double Init(TArray initCounts); vigra::ArrayVector const& Response(); void UsePointWeights(bool useWeights); bool IsUsingPointWeights(); void SetPointWeights(TWeightContainer weight); WeightContainerType GetPointWeights(); private: bool m_UsePointWeights; TWeightContainer m_PointWeights; //Variable of origin TLabelContainer const& m_Labels; vigra::ArrayVector m_Counts; vigra::ArrayVector m_ClassWeights; double m_TotalCount; TLossFunction m_LossFunction; }; } #include <../src/Splitter/mitkImpurityLoss.cpp> #endif //mitkImpurityLoss_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkLinearSplitting.h b/Modules/Classification/CLVigraRandomForest/include/mitkLinearSplitting.h index 31e4ab7a73..46dbb31299 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkLinearSplitting.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkLinearSplitting.h @@ -1,86 +1,91 @@ #ifndef mitkLinearSplitting_h #define mitkLinearSplitting_h #include #include +#include namespace mitk { template class LinearSplitting { public: typedef typename TLossAccumulator::WeightContainerType TWeightContainer; typedef TWeightContainer WeightContainerType; LinearSplitting(); template LinearSplitting(vigra::ProblemSpec const &ext); void UsePointWeights(bool pointWeight); bool IsUsingPointWeights(); void UseRandomSplit(bool randomSplit); bool IsUsingRandomSplit(); void SetPointWeights(WeightContainerType weight); WeightContainerType GetPointWeights(); + void SetAdditionalData(AdditionalRFDataAbstract* data); + AdditionalRFDataAbstract* GetAdditionalData() const; + template void set_external_parameters(vigra::ProblemSpec const &ext); template void operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse); template double LossOfRegion(TDataSourceLabel const & labels, TDataIterator &begin, TDataIterator &end, TArray const & regionResponse); double GetMinimumLoss() { return m_MinimumLoss; } double GetMinimumThreshold() { return m_MinimumThreshold; } std::ptrdiff_t GetMinimumIndex() { return m_MinimumIndex; } vigra::ArrayVector* GetBestCurrentCounts() { return m_BestCurrentCounts; } private: bool m_UsePointWeights; bool m_UseRandomSplit; WeightContainerType m_PointWeights; // From original code vigra::ArrayVector m_ClassWeights; vigra::ArrayVector m_BestCurrentCounts[2]; double m_MinimumLoss; double m_MinimumThreshold; std::ptrdiff_t m_MinimumIndex; vigra::ProblemSpec<> m_ExtParameter; + AdditionalRFDataAbstract* m_AdditionalData; }; } #include <../src/Splitter/mitkLinearSplitting.cpp> #endif //mitkLinearSplitting_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h b/Modules/Classification/CLVigraRandomForest/include/mitkPUImpurityLoss.h similarity index 65% copy from Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h copy to Modules/Classification/CLVigraRandomForest/include/mitkPUImpurityLoss.h index abce530788..5088e7b0d4 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkImpurityLoss.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkPUImpurityLoss.h @@ -1,55 +1,72 @@ -#ifndef mitkImpurityLoss_h -#define mitkImpurityLoss_h +#ifndef mitkPUImpurityLoss_h +#define mitkPUImpurityLoss_h #include #include +#include namespace mitk { + + template + class PURFProblemSpec : vigra::ProblemSpec + { + public: + vigra::ArrayVector kappa_; // if classes have different importance + }; + + template , class TWeightContainer = vigra::MultiArrayView<2, double> > - class ImpurityLoss + class PUImpurityLoss { public: typedef TLabelContainer LabelContainerType; typedef TWeightContainer WeightContainerType; template - ImpurityLoss(TLabelContainer const &labels, - vigra::ProblemSpec const &ext); + PUImpurityLoss(TLabelContainer const &labels, + vigra::ProblemSpec const &ext, + AdditionalRFDataAbstract *data); void Reset(); + void UpdatePUCounts(); + template double Increment(TDataIterator begin, TDataIterator end); template double Decrement(TDataIterator begin, TDataIterator end); template double Init(TArray initCounts); vigra::ArrayVector const& Response(); void UsePointWeights(bool useWeights); bool IsUsingPointWeights(); void SetPointWeights(TWeightContainer weight); WeightContainerType GetPointWeights(); private: bool m_UsePointWeights; TWeightContainer m_PointWeights; //Variable of origin TLabelContainer const& m_Labels; vigra::ArrayVector m_Counts; + vigra::ArrayVector m_PUCounts; + vigra::ArrayVector m_Kappa; vigra::ArrayVector m_ClassWeights; double m_TotalCount; + double m_PUTotalCount; + int m_ClassCount; TLossFunction m_LossFunction; }; } -#include <../src/Splitter/mitkImpurityLoss.cpp> +#include <../src/Splitter/mitkPUImpurityLoss.cpp> #endif //mitkImpurityLoss_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h b/Modules/Classification/CLVigraRandomForest/include/mitkPURFClassifier.h similarity index 79% copy from Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h copy to Modules/Classification/CLVigraRandomForest/include/mitkPURFClassifier.h index 9eb3a1f270..1a6c02704f 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkPURFClassifier.h @@ -1,93 +1,95 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ -#ifndef mitkVigraRandomForestClassifier_h -#define mitkVigraRandomForestClassifier_h +#ifndef mitkPURFClassifier_h +#define mitkPURFClassifier_h #include #include //#include #include #include namespace mitk { - class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier + class MITKCLVIGRARANDOMFOREST_EXPORT PURFClassifier : public AbstractClassifier { public: - mitkClassMacro(VigraRandomForestClassifier,AbstractClassifier) + mitkClassMacro(PURFClassifier, AbstractClassifier) itkFactorylessNewMacro(Self) itkCloneMacro(Self) - VigraRandomForestClassifier(); + PURFClassifier(); - ~VigraRandomForestClassifier(); + ~PURFClassifier(); void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); - void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); + Eigen::MatrixXi Predict(const Eigen::MatrixXd &X); Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X); bool SupportsPointWiseWeight(); bool SupportsPointWiseProbability(); void ConvertParameter(); + vigra::ArrayVector CalculateKappa(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in); void SetRandomForest(const vigra::RandomForest & rf); const vigra::RandomForest & GetRandomForest() const; void UsePointWiseWeight(bool); void SetMaximumTreeDepth(int); void SetMinimumSplitNodeSize(int); void SetPrecision(double); void SetSamplesPerTree(double); void UseSampleWithReplacement(bool); void SetTreeCount(int); void SetWeightLambda(double); - void SetTreeWeights(Eigen::MatrixXd weights); - void SetTreeWeight(int treeId, double weight); - Eigen::MatrixXd GetTreeWeights() const; - void PrintParameter(std::ostream &str = std::cout); + void SetClassProbabilities(Eigen::VectorXd probabilities); + Eigen::VectorXd GetClassProbabilites(); + private: // *------------------- // * THREADING // *------------------- struct TrainingData; struct PredictionData; struct EigenToVigraTransform; struct Parameter; + vigra::MultiArrayView<2, double> m_Probabilities; Eigen::MatrixXd m_TreeWeights; + Eigen::VectorXd m_ClassProbabilities; Parameter * m_Parameter; vigra::RandomForest m_RandomForest; static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); static ITK_THREAD_RETURN_TYPE PredictCallback(void *); static ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *); static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P); }; } -#endif //mitkVigraRandomForestClassifier_h +#endif //mitkPURFClassifier_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkThresholdSplit.h b/Modules/Classification/CLVigraRandomForest/include/mitkThresholdSplit.h index 643e26e1ec..62c1d99116 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkThresholdSplit.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkThresholdSplit.h @@ -1,81 +1,86 @@ #ifndef mitkThresholdSplit_h #define mitkThresholdSplit_h #include #include +#include namespace mitk { template class ThresholdSplit: public vigra::SplitBase { public: ThresholdSplit(); // ThresholdSplit(const ThresholdSplit & other); void SetFeatureCalculator(TFeatureCalculator processor); TFeatureCalculator GetFeatureCalculator() const; void SetCalculatingFeature(bool calculate); bool IsCalculatingFeature() const; void UsePointBasedWeights(bool weightsOn); bool IsUsingPointBasedWeights() const; void UseRandomSplit(bool split) {m_UseRandomSplit = split;} bool IsUsingRandomSplit() const { return m_UseRandomSplit; } void SetPrecision(double value); double GetPrecision() const; void SetMaximumTreeDepth(int value); virtual int GetMaximumTreeDepth() const; + void SetAdditionalData(AdditionalRFDataAbstract* data); + AdditionalRFDataAbstract* GetAdditionalData() const; + void SetWeights(vigra::MultiArrayView<2, double> weights); vigra::MultiArrayView<2, double> GetWeights() const; // From vigra::ThresholdSplit double minGini() const; int bestSplitColumn() const; double bestSplitThreshold() const; template void set_external_parameters(vigra::ProblemSpec const & in); template int findBestSplit(vigra::MultiArrayView<2, T, C> features, vigra::MultiArrayView<2, T2, C2> labels, Region & region, vigra::ArrayVector& childRegions, Random & randint); double region_gini_; private: // From vigra::ThresholdSplit typedef vigra::SplitBase SB; // splitter parameters (used by copy constructor) bool m_CalculatingFeature; bool m_UseWeights; bool m_UseRandomSplit; double m_Precision; int m_MaximumTreeDepth; TFeatureCalculator m_FeatureCalculator; vigra::MultiArrayView<2, double> m_Weights; // variabels to work with vigra::ArrayVector splitColumns; TColumnDecisionFunctor bgfunc; vigra::ArrayVector min_gini_; vigra::ArrayVector min_indices_; vigra::ArrayVector min_thresholds_; int bestSplitIndex; + AdditionalRFDataAbstract* m_AdditionalData; }; } #include <../src/Splitter/mitkThresholdSplit.cpp> #endif //mitkThresholdSplit_h diff --git a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h index 9eb3a1f270..0dcbd2b35b 100644 --- a/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h +++ b/Modules/Classification/CLVigraRandomForest/include/mitkVigraRandomForestClassifier.h @@ -1,93 +1,94 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ #ifndef mitkVigraRandomForestClassifier_h #define mitkVigraRandomForestClassifier_h #include #include //#include #include #include namespace mitk { class MITKCLVIGRARANDOMFOREST_EXPORT VigraRandomForestClassifier : public AbstractClassifier { public: - mitkClassMacro(VigraRandomForestClassifier,AbstractClassifier) + mitkClassMacro(VigraRandomForestClassifier, AbstractClassifier) itkFactorylessNewMacro(Self) itkCloneMacro(Self) - VigraRandomForestClassifier(); + VigraRandomForestClassifier(); ~VigraRandomForestClassifier(); void Train(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); void OnlineTrain(const Eigen::MatrixXd &X, const Eigen::MatrixXi &Y); Eigen::MatrixXi Predict(const Eigen::MatrixXd &X); Eigen::MatrixXi PredictWeighted(const Eigen::MatrixXd &X); bool SupportsPointWiseWeight(); bool SupportsPointWiseProbability(); void ConvertParameter(); void SetRandomForest(const vigra::RandomForest & rf); const vigra::RandomForest & GetRandomForest() const; void UsePointWiseWeight(bool); void SetMaximumTreeDepth(int); void SetMinimumSplitNodeSize(int); void SetPrecision(double); void SetSamplesPerTree(double); void UseSampleWithReplacement(bool); void SetTreeCount(int); void SetWeightLambda(double); void SetTreeWeights(Eigen::MatrixXd weights); void SetTreeWeight(int treeId, double weight); Eigen::MatrixXd GetTreeWeights() const; void PrintParameter(std::ostream &str = std::cout); private: // *------------------- // * THREADING // *------------------- struct TrainingData; struct PredictionData; struct EigenToVigraTransform; struct Parameter; + vigra::MultiArrayView<2, double> m_Probabilities; Eigen::MatrixXd m_TreeWeights; Parameter * m_Parameter; vigra::RandomForest m_RandomForest; static ITK_THREAD_RETURN_TYPE TrainTreesCallback(void *); static ITK_THREAD_RETURN_TYPE PredictCallback(void *); static ITK_THREAD_RETURN_TYPE PredictWeightedCallback(void *); static void VigraPredictWeighted(PredictionData *data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P); }; } #endif //mitkVigraRandomForestClassifier_h diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkPURFClassifier.cpp similarity index 64% copy from Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp copy to Modules/Classification/CLVigraRandomForest/src/Classifier/mitkPURFClassifier.cpp index 1ee923ba51..1210f113ca 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkPURFClassifier.cpp @@ -1,592 +1,478 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes -#include +#include #include +#include #include #include #include // Vigra includes #include #include // ITK include #include #include #include -typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; +typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultPUSplitType; -struct mitk::VigraRandomForestClassifier::Parameter +struct mitk::PURFClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; -struct mitk::VigraRandomForestClassifier::TrainingData +struct mitk::PURFClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, - const DefaultSplitType & refSplitter, + const DefaultPUSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; - const DefaultSplitType & m_Splitter; + const DefaultPUSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; -struct mitk::VigraRandomForestClassifier::PredictionData +struct mitk::PURFClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb, vigra::MultiArrayView<2, double> refTreeWeights) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb), m_TreeWeights(refTreeWeights) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; vigra::MultiArrayView<2, double> m_TreeWeights; }; -mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() +mitk::PURFClassifier::PURFClassifier() :m_Parameter(nullptr) { - itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); - command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); + itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); + command->SetCallbackFunction(this, &mitk::PURFClassifier::ConvertParameter); this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } -mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() +mitk::PURFClassifier::~PURFClassifier() { } -bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() +void mitk::PURFClassifier::SetClassProbabilities(Eigen::VectorXd probabilities) +{ + m_ClassProbabilities = probabilities; +} + +Eigen::VectorXd mitk::PURFClassifier::GetClassProbabilites() +{ + return m_ClassProbabilities; +} + +bool mitk::PURFClassifier::SupportsPointWiseWeight() { return true; } -bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() +bool mitk::PURFClassifier::SupportsPointWiseProbability() { return true; } -void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) + +vigra::ArrayVector mitk::PURFClassifier::CalculateKappa(const Eigen::MatrixXd & /* X_in */, const Eigen::MatrixXi & Y_in) { - vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); - vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); - m_RandomForest.onlineLearn(X,Y,0,true); + int maximumValue = Y_in.maxCoeff(); + vigra::ArrayVector kappa(maximumValue + 1); + vigra::ArrayVector counts(maximumValue + 1); + for (int i = 0; i < Y_in.rows(); ++i) + { + counts[Y_in(i, 0)] += 1; + } + for (int i = 0; i < maximumValue+1; ++i) + { + if (counts[i] > 0) + { + kappa[i] = counts[0] * m_ClassProbabilities[i] / counts[i] + 1; + } + else + { + kappa[i] = 1; + } + } + return kappa; } -void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) + +void mitk::PURFClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); - DefaultSplitType splitter; + PURFData* purfData = new PURFData; + purfData->m_Kappa = this->CalculateKappa(X_in, Y_in); + + DefaultPUSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); + splitter.SetAdditionalData(purfData); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::unique_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); + delete purfData; } -Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) -{ - // Initialize output Eigen matrices - m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); - m_OutProbability.fill(0); - m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); - m_OutLabel.fill(0); - - // If no weights provided - if(m_TreeWeights.rows() != m_RandomForest.tree_count()) - { - m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); - m_TreeWeights.fill(1); - } - - - vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); - vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); - vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); - vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); - - std::unique_ptr data; - data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); - - itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); - threader->SetSingleMethod(this->PredictCallback,data.get()); - threader->SingleMethodExecute(); - - return m_OutLabel; -} - -Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in) +Eigen::MatrixXi mitk::PURFClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); // If no weights provided if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); m_TreeWeights.fill(1); } - vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::unique_ptr data; - data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); + data.reset(new PredictionData(m_RandomForest, X, Y, P, TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); - threader->SetSingleMethod(this->PredictWeightedCallback,data.get()); + threader->SetSingleMethod(this->PredictCallback, data.get()); threader->SingleMethodExecute(); + m_Probabilities = data->m_Probabilities; return m_OutLabel; } - - -void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) -{ - m_TreeWeights = weights; -} - -Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const -{ - return m_TreeWeights; -} - -ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) +ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process - DefaultSplitType splitter; + DefaultPUSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); + splitter.SetAdditionalData(data->m_Splitter.GetAdditionalData()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return NULL; } -ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) +ITK_THREAD_RETURN_TYPE mitk::PURFClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the last thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) { end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); return NULL; } -ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg) -{ - // Get the ThreadInfoStruct - typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; - ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); - // assigne the thread id - const unsigned int threadId = infoStruct->ThreadID; - - // Get the user defined parameters containing all - // neccesary informations - PredictionData * data = (PredictionData *)(infoStruct->UserData); - unsigned int numberOfRowsToCalculate = 0; - - // Get number of rows to calculate - numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; - - unsigned int start_index = numberOfRowsToCalculate * threadId; - unsigned int end_index = numberOfRowsToCalculate * (threadId+1); - - // the last thread takes the residuals - if(threadId == infoStruct->NumberOfThreads-1) { - end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; - } - - vigra::MultiArrayView<2, double> split_features; - vigra::MultiArrayView<2, int> split_labels; - vigra::MultiArrayView<2, double> split_probability; - { - vigra::TinyVector lowerBound(start_index,0); - vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); - split_features = data->m_Feature.subarray(lowerBound,upperBound); - } - - { - vigra::TinyVector lowerBound(start_index,0); - vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); - split_labels = data->m_Label.subarray(lowerBound,upperBound); - } - - { - vigra::TinyVector lowerBound(start_index,0); - vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); - split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); - } - - VigraPredictWeighted(data, split_features,split_labels,split_probability); - - return NULL; -} - - -void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P) -{ - - int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_; -//#pragma omp parallel for - for(int row=0; row < vigra::rowCount(X); ++row) - { - vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); - - vigra::ArrayVector::const_iterator weights; - - //totalWeight == totalVoteCount! - double totalWeight = 0.0; - - //Let each tree classify... - for(int k=0; km_RandomForest.options_.tree_count_; ++k) - { - //get weights predicted by single tree - weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow); - double numberOfLeafObservations = (*(weights-1)); - - //update votecount. - for(int l=0; lm_RandomForest.ext_param_.class_count_; ++l) - { - // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node. - double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted)); - cur_w = cur_w * data->m_TreeWeights(k,0); - P(row, l) += (int)cur_w; - //every weight in totalWeight. - totalWeight += cur_w; - } - } - - //Normalise votes in each row by total VoteCount (totalWeight - for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l) - { - P(row, l) /= vigra::detail::RequiresExplicitCast::cast(totalWeight); - } - int erg; - int maxCol = 0; - for (int col=0;colm_RandomForest.class_count();++col) - { - if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol)) - maxCol = col; - } - data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg); - Y(row,0) = erg; - } -} - -void mitk::VigraRandomForestClassifier::ConvertParameter() +void mitk::PURFClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults - MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; + MITK_INFO("PURFClassifier") << "Convert Parameter"; if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } -void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) +void mitk::PURFClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { - MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; + MITK_WARN("PURFClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else str << "precision\t\t" << this->m_Parameter->Precision << "\n"; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } -void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) +void mitk::PURFClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } -void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) +void mitk::PURFClassifier::SetMaximumTreeDepth(int val) { this->GetPropertyList()->SetIntProperty("treedepth",val); } -void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) +void mitk::PURFClassifier::SetMinimumSplitNodeSize(int val) { this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } -void mitk::VigraRandomForestClassifier::SetPrecision(double val) +void mitk::PURFClassifier::SetPrecision(double val) { this->GetPropertyList()->SetDoubleProperty("precision",val); } -void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) +void mitk::PURFClassifier::SetSamplesPerTree(double val) { this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } -void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) +void mitk::PURFClassifier::UseSampleWithReplacement(bool val) { this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } -void mitk::VigraRandomForestClassifier::SetTreeCount(int val) +void mitk::PURFClassifier::SetTreeCount(int val) { this->GetPropertyList()->SetIntProperty("treecount",val); } -void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) +void mitk::PURFClassifier::SetWeightLambda(double val) { this->GetPropertyList()->SetDoubleProperty("lambda",val); } -void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) -{ - m_TreeWeights(treeId,0) = weight; -} - -void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) +void mitk::PURFClassifier::SetRandomForest(const vigra::RandomForest & rf) { this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); this->SetTreeCount(rf.options().tree_count_); this->SetSamplesPerTree(rf.options().training_set_proportion_); this->UseSampleWithReplacement(rf.options().sample_with_replacement_); this->m_RandomForest = rf; } -const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const +const vigra::RandomForest & mitk::PURFClassifier::GetRandomForest() const { return this->m_RandomForest; } diff --git a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp index 1ee923ba51..faa1859573 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Classifier/mitkVigraRandomForestClassifier.cpp @@ -1,592 +1,593 @@ /*=================================================================== The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center, Division of Medical and Biological Informatics. All rights reserved. This software is distributed WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See LICENSE.txt or http://www.mitk.org for details. ===================================================================*/ // MITK includes #include #include #include #include #include // Vigra includes #include #include // ITK include #include #include #include typedef mitk::ThresholdSplit >,int,vigra::ClassificationTag> DefaultSplitType; struct mitk::VigraRandomForestClassifier::Parameter { vigra::RF_OptionTag Stratification; bool SampleWithReplacement; bool UseRandomSplit; bool UsePointBasedWeights; int TreeCount; int MinimumSplitNodeSize; int TreeDepth; double Precision; double WeightLambda; double SamplesPerTree; }; struct mitk::VigraRandomForestClassifier::TrainingData { TrainingData(unsigned int numberOfTrees, const vigra::RandomForest & refRF, const DefaultSplitType & refSplitter, const vigra::MultiArrayView<2, double> refFeature, const vigra::MultiArrayView<2, int> refLabel, const Parameter parameter) : m_ClassCount(0), m_NumberOfTrees(numberOfTrees), m_RandomForest(refRF), m_Splitter(refSplitter), m_Feature(refFeature), m_Label(refLabel), m_Parameter(parameter) { m_mutex = itk::FastMutexLock::New(); } vigra::ArrayVector::DecisionTree_t> trees_; int m_ClassCount; unsigned int m_NumberOfTrees; const vigra::RandomForest & m_RandomForest; const DefaultSplitType & m_Splitter; const vigra::MultiArrayView<2, double> m_Feature; const vigra::MultiArrayView<2, int> m_Label; itk::FastMutexLock::Pointer m_mutex; Parameter m_Parameter; }; struct mitk::VigraRandomForestClassifier::PredictionData { PredictionData(const vigra::RandomForest & refRF, const vigra::MultiArrayView<2, double> refFeature, vigra::MultiArrayView<2, int> refLabel, vigra::MultiArrayView<2, double> refProb, vigra::MultiArrayView<2, double> refTreeWeights) : m_RandomForest(refRF), m_Feature(refFeature), m_Label(refLabel), m_Probabilities(refProb), m_TreeWeights(refTreeWeights) { } const vigra::RandomForest & m_RandomForest; const vigra::MultiArrayView<2, double> m_Feature; vigra::MultiArrayView<2, int> m_Label; vigra::MultiArrayView<2, double> m_Probabilities; vigra::MultiArrayView<2, double> m_TreeWeights; }; mitk::VigraRandomForestClassifier::VigraRandomForestClassifier() :m_Parameter(nullptr) { itk::SimpleMemberCommand::Pointer command = itk::SimpleMemberCommand::New(); command->SetCallbackFunction(this, &mitk::VigraRandomForestClassifier::ConvertParameter); this->GetPropertyList()->AddObserver( itk::ModifiedEvent(), command ); } mitk::VigraRandomForestClassifier::~VigraRandomForestClassifier() { } bool mitk::VigraRandomForestClassifier::SupportsPointWiseWeight() { return true; } bool mitk::VigraRandomForestClassifier::SupportsPointWiseProbability() { return true; } void mitk::VigraRandomForestClassifier::OnlineTrain(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.onlineLearn(X,Y,0,true); } void mitk::VigraRandomForestClassifier::Train(const Eigen::MatrixXd & X_in, const Eigen::MatrixXi &Y_in) { this->ConvertParameter(); DefaultSplitType splitter; splitter.UsePointBasedWeights(m_Parameter->UsePointBasedWeights); splitter.UseRandomSplit(m_Parameter->UseRandomSplit); splitter.SetPrecision(m_Parameter->Precision); splitter.SetMaximumTreeDepth(m_Parameter->TreeDepth); // Weights handled as member variable if (m_Parameter->UsePointBasedWeights) { // Set influence of the weight (0 no influenc to 1 max influence) this->m_PointWiseWeight.unaryExpr([this](double t){ return std::pow(t, this->m_Parameter->WeightLambda) ;}); vigra::MultiArrayView<2, double> W(vigra::Shape2(this->m_PointWiseWeight.rows(),this->m_PointWiseWeight.cols()),this->m_PointWiseWeight.data()); splitter.SetWeights(W); } vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(Y_in.rows(),Y_in.cols()),Y_in.data()); m_RandomForest.set_options().tree_count(1); // Number of trees that are calculated; m_RandomForest.set_options().use_stratification(m_Parameter->Stratification); m_RandomForest.set_options().sample_with_replacement(m_Parameter->SampleWithReplacement); m_RandomForest.set_options().samples_per_tree(m_Parameter->SamplesPerTree); m_RandomForest.set_options().min_split_node_size(m_Parameter->MinimumSplitNodeSize); m_RandomForest.learn(X, Y,vigra::rf::visitors::VisitorBase(),splitter); std::unique_ptr data(new TrainingData(m_Parameter->TreeCount,m_RandomForest,splitter,X,Y, *m_Parameter)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->TrainTreesCallback,data.get()); threader->SingleMethodExecute(); // set result trees m_RandomForest.set_options().tree_count(m_Parameter->TreeCount); m_RandomForest.ext_param_.class_count_ = data->m_ClassCount; m_RandomForest.trees_ = data->trees_; // Set Tree Weights to default m_TreeWeights = Eigen::MatrixXd(m_Parameter->TreeCount,1); m_TreeWeights.fill(1.0); } Eigen::MatrixXi mitk::VigraRandomForestClassifier::Predict(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); // If no weights provided if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); m_TreeWeights.fill(1); } vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::unique_ptr data; - data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); + data.reset(new PredictionData(m_RandomForest, X, Y, P, TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); - threader->SetSingleMethod(this->PredictCallback,data.get()); + threader->SetSingleMethod(this->PredictCallback, data.get()); threader->SingleMethodExecute(); + m_Probabilities = data->m_Probabilities; return m_OutLabel; } Eigen::MatrixXi mitk::VigraRandomForestClassifier::PredictWeighted(const Eigen::MatrixXd &X_in) { // Initialize output Eigen matrices m_OutProbability = Eigen::MatrixXd(X_in.rows(),m_RandomForest.class_count()); m_OutProbability.fill(0); m_OutLabel = Eigen::MatrixXi(X_in.rows(),1); m_OutLabel.fill(0); // If no weights provided if(m_TreeWeights.rows() != m_RandomForest.tree_count()) { m_TreeWeights = Eigen::MatrixXd(m_RandomForest.tree_count(),1); m_TreeWeights.fill(1); } vigra::MultiArrayView<2, double> P(vigra::Shape2(m_OutProbability.rows(),m_OutProbability.cols()),m_OutProbability.data()); vigra::MultiArrayView<2, int> Y(vigra::Shape2(m_OutLabel.rows(),m_OutLabel.cols()),m_OutLabel.data()); vigra::MultiArrayView<2, double> X(vigra::Shape2(X_in.rows(),X_in.cols()),X_in.data()); vigra::MultiArrayView<2, double> TW(vigra::Shape2(m_RandomForest.tree_count(),1),m_TreeWeights.data()); std::unique_ptr data; data.reset( new PredictionData(m_RandomForest,X,Y,P,TW)); itk::MultiThreader::Pointer threader = itk::MultiThreader::New(); threader->SetSingleMethod(this->PredictWeightedCallback,data.get()); threader->SingleMethodExecute(); return m_OutLabel; } void mitk::VigraRandomForestClassifier::SetTreeWeights(Eigen::MatrixXd weights) { m_TreeWeights = weights; } Eigen::MatrixXd mitk::VigraRandomForestClassifier::GetTreeWeights() const { return m_TreeWeights; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::TrainTreesCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); TrainingData * data = (TrainingData *)(infoStruct->UserData); unsigned int numberOfTreesToCalculate = 0; // define the number of tress the forest have to calculate numberOfTreesToCalculate = data->m_NumberOfTrees / infoStruct->NumberOfThreads; // the 0th thread takes the residuals if(infoStruct->ThreadID == 0) numberOfTreesToCalculate += data->m_NumberOfTrees % infoStruct->NumberOfThreads; if(numberOfTreesToCalculate != 0){ // Copy the Treestructure defined in userData vigra::RandomForest rf = data->m_RandomForest; // Initialize a splitter for the leraning process DefaultSplitType splitter; splitter.UsePointBasedWeights(data->m_Splitter.IsUsingPointBasedWeights()); splitter.UseRandomSplit(data->m_Splitter.IsUsingRandomSplit()); splitter.SetPrecision(data->m_Splitter.GetPrecision()); splitter.SetMaximumTreeDepth(data->m_Splitter.GetMaximumTreeDepth()); splitter.SetWeights(data->m_Splitter.GetWeights()); rf.trees_.clear(); rf.set_options().tree_count(numberOfTreesToCalculate); rf.set_options().use_stratification(data->m_Parameter.Stratification); rf.set_options().sample_with_replacement(data->m_Parameter.SampleWithReplacement); rf.set_options().samples_per_tree(data->m_Parameter.SamplesPerTree); rf.set_options().min_split_node_size(data->m_Parameter.MinimumSplitNodeSize); rf.learn(data->m_Feature, data->m_Label,vigra::rf::visitors::VisitorBase(),splitter); data->m_mutex->Lock(); for(const auto & tree : rf.trees_) data->trees_.push_back(tree); data->m_ClassCount = rf.class_count(); data->m_mutex->Unlock(); } return NULL; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the last thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) { end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } data->m_RandomForest.predictLabels(split_features,split_labels); data->m_RandomForest.predictProbabilities(split_features, split_probability); return NULL; } ITK_THREAD_RETURN_TYPE mitk::VigraRandomForestClassifier::PredictWeightedCallback(void * arg) { // Get the ThreadInfoStruct typedef itk::MultiThreader::ThreadInfoStruct ThreadInfoType; ThreadInfoType * infoStruct = static_cast< ThreadInfoType * >( arg ); // assigne the thread id const unsigned int threadId = infoStruct->ThreadID; // Get the user defined parameters containing all // neccesary informations PredictionData * data = (PredictionData *)(infoStruct->UserData); unsigned int numberOfRowsToCalculate = 0; // Get number of rows to calculate numberOfRowsToCalculate = data->m_Feature.shape()[0] / infoStruct->NumberOfThreads; unsigned int start_index = numberOfRowsToCalculate * threadId; unsigned int end_index = numberOfRowsToCalculate * (threadId+1); // the last thread takes the residuals if(threadId == infoStruct->NumberOfThreads-1) { end_index += data->m_Feature.shape()[0] % infoStruct->NumberOfThreads; } vigra::MultiArrayView<2, double> split_features; vigra::MultiArrayView<2, int> split_labels; vigra::MultiArrayView<2, double> split_probability; { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Feature.shape(1)); split_features = data->m_Feature.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index, data->m_Label.shape(1)); split_labels = data->m_Label.subarray(lowerBound,upperBound); } { vigra::TinyVector lowerBound(start_index,0); vigra::TinyVector upperBound(end_index,data->m_Probabilities.shape(1)); split_probability = data->m_Probabilities.subarray(lowerBound,upperBound); } VigraPredictWeighted(data, split_features,split_labels,split_probability); return NULL; } void mitk::VigraRandomForestClassifier::VigraPredictWeighted(PredictionData * data, vigra::MultiArrayView<2, double> & X, vigra::MultiArrayView<2, int> & Y, vigra::MultiArrayView<2, double> & P) { int isSampleWeighted = data->m_RandomForest.options_.predict_weighted_; //#pragma omp parallel for for(int row=0; row < vigra::rowCount(X); ++row) { vigra::MultiArrayView<2, double, vigra::StridedArrayTag> currentRow(rowVector(X, row)); vigra::ArrayVector::const_iterator weights; //totalWeight == totalVoteCount! double totalWeight = 0.0; //Let each tree classify... for(int k=0; km_RandomForest.options_.tree_count_; ++k) { //get weights predicted by single tree weights = data->m_RandomForest.trees_[k /*tree_indices_[k]*/].predict(currentRow); double numberOfLeafObservations = (*(weights-1)); //update votecount. for(int l=0; lm_RandomForest.ext_param_.class_count_; ++l) { // Either the original weights are taken or the tree is additional weighted by the number of Observations in the leaf node. double cur_w = weights[l] * (isSampleWeighted * numberOfLeafObservations + (1-isSampleWeighted)); cur_w = cur_w * data->m_TreeWeights(k,0); P(row, l) += (int)cur_w; //every weight in totalWeight. totalWeight += cur_w; } } //Normalise votes in each row by total VoteCount (totalWeight for(int l=0; l< data->m_RandomForest.ext_param_.class_count_; ++l) { P(row, l) /= vigra::detail::RequiresExplicitCast::cast(totalWeight); } int erg; int maxCol = 0; for (int col=0;colm_RandomForest.class_count();++col) { if (data->m_Probabilities(row,col) > data->m_Probabilities(row, maxCol)) maxCol = col; } data->m_RandomForest.ext_param_.to_classlabel(maxCol, erg); Y(row,0) = erg; } } void mitk::VigraRandomForestClassifier::ConvertParameter() { if(this->m_Parameter == nullptr) this->m_Parameter = new Parameter(); // Get the proerty // Some defaults MITK_INFO("VigraRandomForestClassifier") << "Convert Parameter"; if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) this->m_Parameter->UsePointBasedWeights = false; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) this->m_Parameter->UseRandomSplit = false; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) this->m_Parameter->TreeDepth = 20; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) this->m_Parameter->TreeCount = 100; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) this->m_Parameter->MinimumSplitNodeSize = 5; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) this->m_Parameter->Precision = mitk::eps; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) this->m_Parameter->SamplesPerTree = 0.6; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) this->m_Parameter->SampleWithReplacement = true; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) this->m_Parameter->WeightLambda = 1.0; // Not used yet // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) this->m_Parameter->Stratification = vigra::RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::PrintParameter(std::ostream & str) { if(this->m_Parameter == nullptr) { MITK_WARN("VigraRandomForestClassifier") << "Parameters are not initialized. Please call ConvertParameter() first!"; return; } this->ConvertParameter(); // Get the proerty // Some defaults if(!this->GetPropertyList()->Get("usepointbasedweight",this->m_Parameter->UsePointBasedWeights)) str << "usepointbasedweight\tNOT SET (default " << this->m_Parameter->UsePointBasedWeights << ")" << "\n"; else str << "usepointbasedweight\t" << this->m_Parameter->UsePointBasedWeights << "\n"; if(!this->GetPropertyList()->Get("userandomsplit",this->m_Parameter->UseRandomSplit)) str << "userandomsplit\tNOT SET (default " << this->m_Parameter->UseRandomSplit << ")" << "\n"; else str << "userandomsplit\t" << this->m_Parameter->UseRandomSplit << "\n"; if(!this->GetPropertyList()->Get("treedepth",this->m_Parameter->TreeDepth)) str << "treedepth\t\tNOT SET (default " << this->m_Parameter->TreeDepth << ")" << "\n"; else str << "treedepth\t\t" << this->m_Parameter->TreeDepth << "\n"; if(!this->GetPropertyList()->Get("minimalsplitnodesize",this->m_Parameter->MinimumSplitNodeSize)) str << "minimalsplitnodesize\tNOT SET (default " << this->m_Parameter->MinimumSplitNodeSize << ")" << "\n"; else str << "minimalsplitnodesize\t" << this->m_Parameter->MinimumSplitNodeSize << "\n"; if(!this->GetPropertyList()->Get("precision",this->m_Parameter->Precision)) str << "precision\t\tNOT SET (default " << this->m_Parameter->Precision << ")" << "\n"; else str << "precision\t\t" << this->m_Parameter->Precision << "\n"; if(!this->GetPropertyList()->Get("samplespertree",this->m_Parameter->SamplesPerTree)) str << "samplespertree\tNOT SET (default " << this->m_Parameter->SamplesPerTree << ")" << "\n"; else str << "samplespertree\t" << this->m_Parameter->SamplesPerTree << "\n"; if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->SampleWithReplacement)) str << "samplewithreplacement\tNOT SET (default " << this->m_Parameter->SampleWithReplacement << ")" << "\n"; else str << "samplewithreplacement\t" << this->m_Parameter->SampleWithReplacement << "\n"; if(!this->GetPropertyList()->Get("treecount",this->m_Parameter->TreeCount)) str << "treecount\t\tNOT SET (default " << this->m_Parameter->TreeCount << ")" << "\n"; else str << "treecount\t\t" << this->m_Parameter->TreeCount << "\n"; if(!this->GetPropertyList()->Get("lambda",this->m_Parameter->WeightLambda)) str << "lambda\t\tNOT SET (default " << this->m_Parameter->WeightLambda << ")" << "\n"; else str << "lambda\t\t" << this->m_Parameter->WeightLambda << "\n"; // if(!this->GetPropertyList()->Get("samplewithreplacement",this->m_Parameter->Stratification)) // this->m_Parameter->Stratification = vigra:RF_NONE; // no Property given } void mitk::VigraRandomForestClassifier::UsePointWiseWeight(bool val) { mitk::AbstractClassifier::UsePointWiseWeight(val); this->GetPropertyList()->SetBoolProperty("usepointbasedweight",val); } void mitk::VigraRandomForestClassifier::SetMaximumTreeDepth(int val) { this->GetPropertyList()->SetIntProperty("treedepth",val); } void mitk::VigraRandomForestClassifier::SetMinimumSplitNodeSize(int val) { this->GetPropertyList()->SetIntProperty("minimalsplitnodesize",val); } void mitk::VigraRandomForestClassifier::SetPrecision(double val) { this->GetPropertyList()->SetDoubleProperty("precision",val); } void mitk::VigraRandomForestClassifier::SetSamplesPerTree(double val) { this->GetPropertyList()->SetDoubleProperty("samplespertree",val); } void mitk::VigraRandomForestClassifier::UseSampleWithReplacement(bool val) { this->GetPropertyList()->SetBoolProperty("samplewithreplacement",val); } void mitk::VigraRandomForestClassifier::SetTreeCount(int val) { this->GetPropertyList()->SetIntProperty("treecount",val); } void mitk::VigraRandomForestClassifier::SetWeightLambda(double val) { this->GetPropertyList()->SetDoubleProperty("lambda",val); } void mitk::VigraRandomForestClassifier::SetTreeWeight(int treeId, double weight) { m_TreeWeights(treeId,0) = weight; } void mitk::VigraRandomForestClassifier::SetRandomForest(const vigra::RandomForest & rf) { this->SetMaximumTreeDepth(rf.ext_param().max_tree_depth); this->SetMinimumSplitNodeSize(rf.options().min_split_node_size_); this->SetTreeCount(rf.options().tree_count_); this->SetSamplesPerTree(rf.options().training_set_proportion_); this->UseSampleWithReplacement(rf.options().sample_with_replacement_); this->m_RandomForest = rf; } const vigra::RandomForest & mitk::VigraRandomForestClassifier::GetRandomForest() const { return this->m_RandomForest; } diff --git a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkAdditionalRFData.cpp b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkAdditionalRFData.cpp new file mode 100644 index 0000000000..913b6e41d8 --- /dev/null +++ b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkAdditionalRFData.cpp @@ -0,0 +1,6 @@ +#include + +void mitk::PURFData::NoFunction() +{ + return; +} \ No newline at end of file diff --git a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkImpurityLoss.cpp b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkImpurityLoss.cpp index 527a5523f3..add5f035cc 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkImpurityLoss.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkImpurityLoss.cpp @@ -1,111 +1,112 @@ #ifndef mitkImpurityLoss_cpp #define mitkImpurityLoss_cpp #include template template mitk::ImpurityLoss::ImpurityLoss(TLabelContainer const &labels, - vigra::ProblemSpec const &ext) : + vigra::ProblemSpec const &ext, + AdditionalRFDataAbstract * /*data*/) : m_UsePointWeights(false), m_Labels(labels), m_Counts(ext.class_count_, 0.0), m_ClassWeights(ext.class_weights_), m_TotalCount(0.0) { } template void mitk::ImpurityLoss::Reset() { m_Counts.init(0); m_TotalCount = 0.0; } template template double mitk::ImpurityLoss::Increment(TDataIterator begin, TDataIterator end) { for (TDataIterator iter = begin; iter != end; ++iter) { double pointProbability = 1.0; if (m_UsePointWeights) { pointProbability = m_PointWeights(*iter,0); } m_Counts[m_Labels(*iter,0)] += pointProbability; m_TotalCount += pointProbability; } return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount); } template template double mitk::ImpurityLoss::Decrement(TDataIterator begin, TDataIterator end) { for (TDataIterator iter = begin; iter != end; ++iter) { double pointProbability = 1.0; if (m_UsePointWeights) { pointProbability = m_PointWeights(*iter,0); } m_Counts[m_Labels(*iter,0)] -= pointProbability; m_TotalCount -= pointProbability; } return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount); } template template double mitk::ImpurityLoss::Init(TArray initCounts) { Reset(); std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin()); m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0); return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount); } template vigra::ArrayVector const& mitk::ImpurityLoss::Response() { return m_Counts; } template void mitk::ImpurityLoss::UsePointWeights(bool useWeights) { m_UsePointWeights = useWeights; } template bool mitk::ImpurityLoss::IsUsingPointWeights() { return m_UsePointWeights; } template void mitk::ImpurityLoss::SetPointWeights(TWeightContainer weight) { m_PointWeights = weight; } template typename mitk::ImpurityLoss::WeightContainerType mitk::ImpurityLoss::GetPointWeights() { return m_PointWeights; } #endif // mitkImpurityLoss_cpp diff --git a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkLinearSplitting.cpp b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkLinearSplitting.cpp index 4c915bc304..ecac890a56 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkLinearSplitting.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkLinearSplitting.cpp @@ -1,168 +1,184 @@ #ifndef mitkLinearSplitting_cpp #define mitkLinearSplitting_cpp #include +#include template mitk::LinearSplitting::LinearSplitting() : m_UsePointWeights(false), - m_UseRandomSplit(false) + m_UseRandomSplit(false), + m_AdditionalData(nullptr) { } template template mitk::LinearSplitting::LinearSplitting(vigra::ProblemSpec const &ext) : m_UsePointWeights(false), m_UseRandomSplit(false) { set_external_parameters(ext); } +template +void +mitk::LinearSplitting::SetAdditionalData(AdditionalRFDataAbstract* data) +{ + m_AdditionalData = data; +} + +template +mitk::AdditionalRFDataAbstract * +mitk::LinearSplitting::GetAdditionalData() const +{ + return m_AdditionalData; +} + template void mitk::LinearSplitting::UsePointWeights(bool pointWeight) { m_UsePointWeights = pointWeight; } template bool mitk::LinearSplitting::IsUsingPointWeights() { return m_UsePointWeights; } template void mitk::LinearSplitting::UseRandomSplit(bool randomSplit) { m_UseRandomSplit = randomSplit; } template bool mitk::LinearSplitting::IsUsingRandomSplit() { return m_UseRandomSplit; } template void mitk::LinearSplitting::SetPointWeights(WeightContainerType weight) { m_PointWeights = weight; } template typename mitk::LinearSplitting::WeightContainerType mitk::LinearSplitting::GetPointWeights() { return m_PointWeights; } template template void mitk::LinearSplitting::set_external_parameters(vigra::ProblemSpec const &ext) { m_ExtParameter = ext; } template template void mitk::LinearSplitting::operator()(TDataSourceFeature const &column, TDataSourceLabel const &labels, TDataIterator &begin, TDataIterator &end, TArray const ®ionResponse) { typedef TLossAccumulator LineSearchLoss; std::sort(begin, end, vigra::SortSamplesByDimensions(column, 0)); - LineSearchLoss left(labels, m_ExtParameter); - LineSearchLoss right(labels, m_ExtParameter); + LineSearchLoss left(labels, m_ExtParameter, m_AdditionalData); + LineSearchLoss right(labels, m_ExtParameter, m_AdditionalData); if (m_UsePointWeights) { left.UsePointWeights(true); left.SetPointWeights(m_PointWeights); right.UsePointWeights(true); right.SetPointWeights(m_PointWeights); } m_MinimumLoss = right.Init(regionResponse); m_MinimumThreshold = *begin; m_MinimumIndex = 0; vigra::DimensionNotEqual compareNotEqual(column, 0); if (!m_UseRandomSplit) { TDataIterator iter = begin; // Find the next element that are NOT equal with his neightbour! TDataIterator next = std::adjacent_find(iter, end, compareNotEqual); while(next != end) { // Remove or add the current segment are from the LineSearch double rightLoss = right.Decrement(iter, next +1); double leftLoss = left.Increment(iter, next +1); double currentLoss = rightLoss + leftLoss; if (currentLoss < m_MinimumLoss) { m_BestCurrentCounts[0] = left.Response(); m_BestCurrentCounts[1] = right.Response(); m_MinimumLoss = currentLoss; m_MinimumIndex = next - begin + 1; m_MinimumThreshold = (double(column(*next,0)) + double(column(*(next +1), 0)))/2.0; } iter = next + 1; next = std::adjacent_find(iter, end, compareNotEqual); } } else // If Random split is selected, e.g. ExtraTree behaviour { int size = end - begin + 1; srand(time(NULL)); int offset = rand() % size; TDataIterator iter = begin + offset; double rightLoss = right.Decrement(begin, iter+1); double leftLoss = left.Increment(begin, iter+1); double currentLoss = rightLoss + leftLoss; if (currentLoss < m_MinimumLoss) { m_BestCurrentCounts[0] = left.Response(); m_BestCurrentCounts[1] = right.Response(); m_MinimumLoss = currentLoss; m_MinimumIndex = offset + 1; m_MinimumThreshold = (double(column(*iter,0)) + double(column(*(iter+1), 0)))/2.0; } } } template template double mitk::LinearSplitting::LossOfRegion(TDataSourceLabel const & labels, TDataIterator &/*begin*/, TDataIterator &/*end*/, TArray const & regionResponse) { typedef TLossAccumulator LineSearchLoss; - LineSearchLoss regionLoss(labels, m_ExtParameter); + LineSearchLoss regionLoss(labels, m_ExtParameter, m_AdditionalData); if (m_UsePointWeights) { regionLoss.UsePointWeights(true); regionLoss.SetPointWeights(m_PointWeights); } return regionLoss.Init(regionResponse); } #endif //mitkLinearSplitting_cpp diff --git a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkPUImpurityLoss.cpp b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkPUImpurityLoss.cpp new file mode 100644 index 0000000000..3259b742b5 --- /dev/null +++ b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkPUImpurityLoss.cpp @@ -0,0 +1,136 @@ +#ifndef mitkPUImpurityLoss_cpp +#define mitkPUImpurityLoss_cpp + +#include +#include + +template +template +mitk::PUImpurityLoss::PUImpurityLoss(TLabelContainer const &labels, + vigra::ProblemSpec const &ext, + AdditionalRFDataAbstract *data) : + m_UsePointWeights(false), + m_Labels(labels), + //m_Kappa(ext.kappa_), // Not possible due to data type + m_Counts(ext.class_count_, 0.0), + m_PUCounts(ext.class_count_, 0.0), + m_ClassWeights(ext.class_weights_), + m_TotalCount(0.0), + m_PUTotalCount(0.0), + m_ClassCount(ext.class_count_) +{ + mitk::PURFData * purfdata = dynamic_cast (data); + //const PURFProblemSpec *problem = static_cast * > (&ext); + m_Kappa = vigra::ArrayVector(purfdata->m_Kappa); +} + +template +void +mitk::PUImpurityLoss::Reset() +{ + m_Counts.init(0); + m_TotalCount = 0.0; +} + +template +void +mitk::PUImpurityLoss::UpdatePUCounts() +{ + m_PUTotalCount = 0; + for (int i = 1; i < m_ClassCount; ++i) + { + m_PUCounts[i] = m_Kappa[i] * m_Counts[i]; + m_PUTotalCount += m_PUCounts[i]; + } + m_PUCounts[0] = std::max(0.0, m_TotalCount - m_PUTotalCount); + m_PUTotalCount += m_PUCounts[0]; +} + +template +template +double +mitk::PUImpurityLoss::Increment(TDataIterator begin, TDataIterator end) +{ + for (TDataIterator iter = begin; iter != end; ++iter) + { + double pointProbability = 1.0; + if (m_UsePointWeights) + { + pointProbability = m_PointWeights(*iter,0); + } + m_Counts[m_Labels(*iter,0)] += pointProbability; + m_TotalCount += pointProbability; + } + UpdatePUCounts(); + return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount); +} + +template +template +double +mitk::PUImpurityLoss::Decrement(TDataIterator begin, TDataIterator end) +{ + for (TDataIterator iter = begin; iter != end; ++iter) + { + double pointProbability = 1.0; + if (m_UsePointWeights) + { + pointProbability = m_PointWeights(*iter,0); + } + m_Counts[m_Labels(*iter,0)] -= pointProbability; + m_TotalCount -= pointProbability; + } + UpdatePUCounts(); + return m_LossFunction(m_PUCounts, m_ClassWeights, m_PUTotalCount); +} + +template +template +double +mitk::PUImpurityLoss::Init(TArray initCounts) +{ + Reset(); + std::copy(initCounts.begin(), initCounts.end(), m_Counts.begin()); + m_TotalCount = std::accumulate(m_Counts.begin(), m_Counts.end(), 0.0); + return m_LossFunction(m_Counts, m_ClassWeights, m_TotalCount); +} + +template +vigra::ArrayVector const& +mitk::PUImpurityLoss::Response() +{ + return m_Counts; +} + +template +void +mitk::PUImpurityLoss::UsePointWeights(bool useWeights) +{ + m_UsePointWeights = useWeights; +} + +template +bool +mitk::PUImpurityLoss::IsUsingPointWeights() +{ + return m_UsePointWeights; +} + +template +void +mitk::PUImpurityLoss::SetPointWeights(TWeightContainer weight) +{ + m_PointWeights = weight; +} + +template +typename mitk::PUImpurityLoss::WeightContainerType +mitk::PUImpurityLoss::GetPointWeights() +{ + return m_PointWeights; +} + + +#endif // mitkImpurityLoss_cpp + + diff --git a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp index 86a2f635a8..388b3c27cb 100644 --- a/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp +++ b/Modules/Classification/CLVigraRandomForest/src/Splitter/mitkThresholdSplit.cpp @@ -1,298 +1,313 @@ #ifndef mitkThresholdSplit_cpp #define mitkThresholdSplit_cpp #include template mitk::ThresholdSplit::ThresholdSplit() : m_CalculatingFeature(false), m_UseWeights(false), m_UseRandomSplit(false), m_Precision(0.0), - m_MaximumTreeDepth(1000) + m_MaximumTreeDepth(1000), + m_AdditionalData(nullptr) { } //template //mitk::ThresholdSplit::ThresholdSplit(const ThresholdSplit & /*other*/)/*: // m_CalculatingFeature(other.IsCalculatingFeature()), // m_UseWeights(other.IsUsingPointBasedWeights()), // m_UseRandomSplit(other.IsUsingRandomSplit()), // m_Precision(other.GetPrecision()), // m_MaximumTreeDepth(other.GetMaximumTreeDepth()), // m_FeatureCalculator(other.GetFeatureCalculator()), // m_Weights(other.GetWeights())*/ //{ //} +template +void +mitk::ThresholdSplit::SetAdditionalData(AdditionalRFDataAbstract* data) +{ + bgfunc.SetAdditionalData(data); + m_AdditionalData = data; +} + +template +mitk::AdditionalRFDataAbstract * +mitk::ThresholdSplit::GetAdditionalData() const +{ + return m_AdditionalData; +} template void mitk::ThresholdSplit::SetFeatureCalculator(TFeatureCalculator processor) { m_FeatureCalculator = processor; } template TFeatureCalculator mitk::ThresholdSplit::GetFeatureCalculator() const { return m_FeatureCalculator; } template void mitk::ThresholdSplit::SetCalculatingFeature(bool calculate) { m_CalculatingFeature = calculate; } template bool mitk::ThresholdSplit::IsCalculatingFeature() const { return m_CalculatingFeature; } template void mitk::ThresholdSplit::UsePointBasedWeights(bool weightsOn) { m_UseWeights = weightsOn; bgfunc.UsePointWeights(weightsOn); } template bool mitk::ThresholdSplit::IsUsingPointBasedWeights() const { return m_UseWeights; } template void mitk::ThresholdSplit::SetPrecision(double value) { m_Precision = value; } template double mitk::ThresholdSplit::GetPrecision() const { return m_Precision; } template void mitk::ThresholdSplit::SetMaximumTreeDepth(int value) { m_MaximumTreeDepth = value; } template int mitk::ThresholdSplit::GetMaximumTreeDepth() const { return m_MaximumTreeDepth; } template void mitk::ThresholdSplit::SetWeights(vigra::MultiArrayView<2, double> weights) { m_Weights = weights; bgfunc.UsePointWeights(m_UseWeights); bgfunc.SetPointWeights(weights); } template vigra::MultiArrayView<2, double> mitk::ThresholdSplit::GetWeights() const { return m_Weights; } template double mitk::ThresholdSplit::minGini() const { return min_gini_[bestSplitIndex]; } template int mitk::ThresholdSplit::bestSplitColumn() const { return splitColumns[bestSplitIndex]; } template double mitk::ThresholdSplit::bestSplitThreshold() const { return min_thresholds_[bestSplitIndex]; } template template void mitk::ThresholdSplit::set_external_parameters(vigra::ProblemSpec const & in) { SB::set_external_parameters(in); bgfunc.set_external_parameters( SB::ext_param_); int featureCount_ = SB::ext_param_.column_count_; splitColumns.resize(featureCount_); for(int k=0; k template int mitk::ThresholdSplit::findBestSplit(vigra::MultiArrayView<2, T, C> features, vigra::MultiArrayView<2, T2, C2> labels, Region & region, vigra::ArrayVector& childRegions, Random & randint) { typedef typename Region::IndexIterator IndexIteratorType; if (m_CalculatingFeature) { // Do some very fance stuff here!! // This is not so simple as it might look! We need to // remember which feature has been used to be able to // use it for testing again!! // There, no Splitting class is used!! } bgfunc.UsePointWeights(m_UseWeights); bgfunc.UseRandomSplit(m_UseRandomSplit); vigra::detail::Correction::exec(region, labels); // Create initial class count. for(std::size_t i = 0; i < region.classCounts_.size(); ++i) { region.classCounts_[i] = 0; } double regionSum = 0; for (typename Region::IndexIterator iter = region.begin(); iter != region.end(); ++iter) { double probability = 1.0; if (m_UseWeights) { probability = m_Weights(*iter, 0); } region.classCounts_[labels(*iter,0)] += probability; regionSum += probability; } region.classCountsIsValid = true; vigra::ArrayVector vec; // Is pure region? region_gini_ = bgfunc.LossOfRegion(labels, region.begin(), region.end(), region.classCounts()); if (region_gini_ <= m_Precision * regionSum) // Necessary to fix wrong calculation of Gini-Index { return this->makeTerminalNode(features, labels, region, randint); } // Randomize the order of columns for (int i = 0; i < SB::ext_param_.actual_mtry_; ++i) { std::swap(splitColumns[i], splitColumns[i+ randint(features.shape(1) - i)]); } // find the split with the best evaluation value bestSplitIndex = 0; double currentMiniGini = region_gini_; int numberOfTrials = features.shape(1); for (int k = 0; k < numberOfTrials; ++k) { bgfunc(columnVector(features, splitColumns[k]), labels, region.begin(), region.end(), region.classCounts()); min_gini_[k] = bgfunc.GetMinimumLoss(); min_indices_[k] = bgfunc.GetMinimumIndex(); min_thresholds_[k] = bgfunc.GetMinimumThreshold(); // removed classifier test section, because not necessary if (bgfunc.GetMinimumLoss() < currentMiniGini) { currentMiniGini = bgfunc.GetMinimumLoss(); childRegions[0].classCounts() = bgfunc.GetBestCurrentCounts()[0]; childRegions[1].classCounts() = bgfunc.GetBestCurrentCounts()[1]; childRegions[0].classCountsIsValid = true; childRegions[1].classCountsIsValid = true; bestSplitIndex = k; numberOfTrials = SB::ext_param_.actual_mtry_; } } //If only a small improvement, make terminal node... if(vigra::closeAtTolerance(currentMiniGini, region_gini_)) { return this->makeTerminalNode(features, labels, region, randint); } vigra::Node node(SB::t_data, SB::p_data); SB::node_ = node; node.threshold() = min_thresholds_[bestSplitIndex]; node.column() = splitColumns[bestSplitIndex]; // partition the range according to the best dimension vigra::SortSamplesByDimensions > sorter(features, node.column(), node.threshold()); IndexIteratorType bestSplit = std::partition(region.begin(), region.end(), sorter); // Save the ranges of the child stack entries. childRegions[0].setRange( region.begin() , bestSplit ); childRegions[0].rule = region.rule; childRegions[0].rule.push_back(std::make_pair(1, 1.0)); childRegions[1].setRange( bestSplit , region.end() ); childRegions[1].rule = region.rule; childRegions[1].rule.push_back(std::make_pair(1, 1.0)); return vigra::i_ThresholdNode; return 0; } //template //static void UpdateRegionCounts(TRegion & region, TRegionIterator begin, TRegionIterator end, TLabelHolder labels, TWeightsHolder weights) //{ // if(std::accumulate(region.classCounts().begin(), // region.classCounts().end(), 0.0) != region.size()) // { // RandomForestClassCounter< LabelT, // ArrayVector > // counter(labels, region.classCounts()); // std::for_each( region.begin(), region.end(), counter); // region.classCountsIsValid = true; // } //} // //template //static void exec(Region & region, LabelT & labels) //{ // if(std::accumulate(region.classCounts().begin(), // region.classCounts().end(), 0.0) != region.size()) // { // RandomForestClassCounter< LabelT, // ArrayVector > // counter(labels, region.classCounts()); // std::for_each( region.begin(), region.end(), counter); // region.classCountsIsValid = true; // } //} #endif //mitkThresholdSplit_cpp diff --git a/Modules/Classification/DataCollection/Utilities/mitkCollectionStatistic.h b/Modules/Classification/DataCollection/Utilities/mitkCollectionStatistic.h index 780230b91b..18c111f874 100644 --- a/Modules/Classification/DataCollection/Utilities/mitkCollectionStatistic.h +++ b/Modules/Classification/DataCollection/Utilities/mitkCollectionStatistic.h @@ -1,138 +1,147 @@ #ifndef mitkCollectionStatistic_h #define mitkCollectionStatistic_h #include #include #include namespace mitk { struct MITKDATACOLLECTION_EXPORT StatisticData { unsigned int m_TruePositive; unsigned int m_FalsePositive; unsigned int m_TrueNegative; unsigned int m_FalseNegative; double m_DICE; double m_Jaccard; double m_Sensitivity; double m_Specificity; double m_RMSD; StatisticData() : m_TruePositive(0), m_FalsePositive(0), m_TrueNegative(0), m_FalseNegative(0), m_DICE(0), m_Jaccard(0), m_Sensitivity(0), m_Specificity(0), m_RMSD(-1.0) {} }; class ValueToIndexMapper { public: virtual unsigned char operator() (unsigned char value) const = 0; }; +class BinaryValueminusOneToIndexMapper : public virtual ValueToIndexMapper +{ +public: + unsigned char operator() (unsigned char value) const + { + return value-1; + } +}; + class BinaryValueToIndexMapper : public virtual ValueToIndexMapper { public: unsigned char operator() (unsigned char value) const { return value; } }; class MultiClassValueToIndexMapper : public virtual ValueToIndexMapper { public: unsigned char operator() (unsigned char value) const { if (value == 1 || value == 5) return 0; else return 1; } }; class ProgressionValueToIndexMapper : public virtual ValueToIndexMapper { public: unsigned char operator() (unsigned char value) const { if (value == 1 || value == 0) return 0; else return 1; } }; class MITKDATACOLLECTION_EXPORT CollectionStatistic { public: CollectionStatistic(); ~CollectionStatistic(); typedef std::vector DataVector; typedef std::vector MultiDataVector; void SetCollection(DataCollection::Pointer collection); DataCollection::Pointer GetCollection(); void SetClassCount (size_t count); size_t GetClassCount(); void SetGoldName(std::string name); std::string GetGoldName(); void SetTestName(std::string name); std::string GetTestName(); void SetMaskName(std::string name) {m_MaskName = name; } void SetGroundTruthValueToIndexMapper(const ValueToIndexMapper* mapper); const ValueToIndexMapper* GetGroundTruthValueToIndexMapper(void) const; void SetTestValueToIndexMapper(const ValueToIndexMapper* mapper); const ValueToIndexMapper* GetTestValueToIndexMapper(void) const; void Print(std::ostream& out, std::ostream& sout = std::cout, bool withHeader = false, std::string label = "None"); bool Update(); int IsInSameVirtualClass(unsigned char gold, unsigned char test); /** * @brief mitk::CollectionStatistic::GetStatisticData * @param c The class for which to retrieve the statistic data. * @return */ std::vector GetStatisticData(unsigned char c) const; /** * @brief Computes root-mean-square distance of two binary images. */ void ComputeRMSD(); private: size_t m_ClassCount; std::string m_GroundTruthName; std::string m_TestName; std::string m_MaskName; DataCollection::Pointer m_Collection; std::vector m_ConnectionGold; std::vector m_ConnectionTest; std::vector m_ConnectionClass; size_t m_VituralClassCount; MultiDataVector m_ImageClassStatistic; std::vector m_ImageNames; DataVector m_ImageStatistic; StatisticData m_MeanCompleteStatistic; StatisticData m_CompleteStatistic; const ValueToIndexMapper* m_GroundTruthValueToIndexMapper; const ValueToIndexMapper* m_TestValueToIndexMapper; }; } #endif // mitkCollectionStatistic_h diff --git a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt index 5b02b7e213..d214cca3b3 100755 --- a/Modules/DiffusionImaging/MiniApps/CMakeLists.txt +++ b/Modules/DiffusionImaging/MiniApps/CMakeLists.txt @@ -1,57 +1,58 @@ if(BUILD_DiffusionMiniApps OR MITK_BUILD_ALL_APPS) # needed include directories include_directories( ${CMAKE_CURRENT_SOURCE_DIR} ${CMAKE_CURRENT_BINARY_DIR} ) # list of diffusion miniapps # if an app requires additional dependencies # they are added after a "^^" and separated by "_" set( diffusionminiapps NetworkCreation^^MitkFiberTracking_MitkConnectomics NetworkStatistics^^MitkConnectomics Fiberfox^^MitkFiberTracking MultishellMethods^^MitkFiberTracking PeaksAngularError^^MitkFiberTracking PeakExtraction^^MitkFiberTracking FiberExtraction^^MitkFiberTracking FiberProcessing^^MitkFiberTracking FiberDirectionExtraction^^MitkFiberTracking LocalDirectionalFiberPlausibility^^MitkFiberTracking StreamlineTracking^^MitkFiberTracking GibbsTracking^^MitkFiberTracking TractometerMetrics^^MitkFiberTracking FileFormatConverter^^MitkFiberTracking DFTraining^^MitkFiberTracking DFTracking^^MitkFiberTracking + ExtractAllGradients^^ ) foreach(diffusionminiapp ${diffusionminiapps}) # extract mini app name and dependencies string(REPLACE "^^" "\\;" miniapp_info ${diffusionminiapp}) set(miniapp_info_list ${miniapp_info}) list(GET miniapp_info_list 0 appname) list(GET miniapp_info_list 1 raw_dependencies) string(REPLACE "_" "\\;" dependencies "${raw_dependencies}") set(dependencies_list ${dependencies}) mitkFunctionCreateCommandLineApp( NAME ${appname} DEPENDS MitkCore MitkDiffusionCore ${dependencies_list} PACKAGE_DEPENDS ITK ) endforeach() # This mini app does not depend on mitkDiffusionImaging at all mitkFunctionCreateCommandLineApp( NAME Dicom2Nrrd DEPENDS MitkCore ${dependencies_list} ) if(EXECUTABLE_IS_ENABLED) MITK_INSTALL_TARGETS(EXECUTABLES ${EXECUTABLE_TARGET}) endif() endif() diff --git a/Modules/DiffusionImaging/MiniApps/ExtractAllGradients.cpp b/Modules/DiffusionImaging/MiniApps/ExtractAllGradients.cpp new file mode 100644 index 0000000000..b878cb2075 --- /dev/null +++ b/Modules/DiffusionImaging/MiniApps/ExtractAllGradients.cpp @@ -0,0 +1,86 @@ +/*=================================================================== + +The Medical Imaging Interaction Toolkit (MITK) + +Copyright (c) German Cancer Research Center, +Division of Medical and Biological Informatics. +All rights reserved. + +This software is distributed WITHOUT ANY WARRANTY; without +even the implied warranty of MERCHANTABILITY or FITNESS FOR +A PARTICULAR PURPOSE. + +See LICENSE.txt or http://www.mitk.org for details. + +===================================================================*/ + +#include +#include +#include +#include "mitkCommandLineParser.h" +#include + +using namespace mitk; +using namespace std; + +/*! +\brief Copies transformation matrix of one image to another +*/ +int main(int argc, char* argv[]) +{ + typedef itk::VectorImage< short, 3 > ItkDwiType; + + mitkCommandLineParser parser; + + parser.setTitle("Extract all gradients"); + parser.setCategory("Preprocessing Tools"); + parser.setDescription("Extract all gradients from an diffusion image"); + parser.setContributor("MBI"); + + parser.setArgumentPrefix("--", "-"); + parser.addArgument("in", "i", mitkCommandLineParser::InputFile, "Input:", "input image", us::Any(), false); + //parser.addArgument("extension", "e", mitkCommandLineParser::String, "File Extension:", "Extension of the output file", us::Any(), false); + //parser.addArgument("out", "o", mitkCommandLineParser::OutputFile, "Output:", "output image", us::Any(), false); + + map parsedArgs = parser.parseArguments(argc, argv); + if (parsedArgs.size() == 0) + return EXIT_FAILURE; + + MITK_INFO << "Extract parameter"; + // mandatory arguments + /*string inputName = us::any_cast(parsedArgs["in"]); + string extensionName = us::any_cast(parsedArgs["extension"]); + string ouputName = us::any_cast(parsedArgs["out"]);*/ + string in = us::any_cast(parsedArgs["in"]); + string inputName = "E:\\Kollektive\\R02-Lebertumore-Diffusion\\01-Extrahierte-Daten\\" + in + "\\" + in + "-DWI.dwi"; + string extensionName = ".nrrd"; + string ouputName = "E:\\Kollektive\\R02-Lebertumore-Diffusion\\01-Extrahierte-Daten\\" + in + "\\" + in + "-"; + + MITK_INFO << "Load Image: "; + mitk::Image::Pointer image = mitk::IOUtil::LoadImage(inputName); + + //bool isDiffusionImage(mitk::DiffusionPropertyHelper::IsDiffusionWeightedImage(image)); + //if (!isDiffusionImage) + //{ + // MITK_INFO << "Input file is not of type diffusion image"; + // return; + //} + + ItkDwiType::Pointer itkVectorImagePointer = ItkDwiType::New(); + mitk::CastToItkImage(image, itkVectorImagePointer); + + unsigned int channel = 0; + for (unsigned int channel = 0; channel < itkVectorImagePointer->GetVectorLength(); ++channel) + { + itk::ExtractDwiChannelFilter< short >::Pointer filter = itk::ExtractDwiChannelFilter< short >::New(); + filter->SetInput(itkVectorImagePointer); + filter->SetChannelIndex(channel); + filter->Update(); + + mitk::Image::Pointer newImage = mitk::Image::New(); + newImage->InitializeByItk(filter->GetOutput()); + newImage->SetImportChannel(filter->GetOutput()->GetBufferPointer()); + mitk::IOUtil::SaveImage(newImage, ouputName + to_string(channel) + extensionName); + } + return EXIT_SUCCESS; +}