diff --git a/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.cpp b/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.cpp index a1b0c7e01a..dd5c3c05dc 100644 --- a/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.cpp +++ b/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.cpp @@ -1,278 +1,331 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ // MITK #include "mitkTotalSegmentatorTool.h" #include "mitkIOUtil.h" +#include #include #include // us #include #include #include #include #include namespace mitk { MITK_TOOL_MACRO(MITKSEGMENTATION_EXPORT, TotalSegmentatorTool, "Total Segmentator"); } mitk::TotalSegmentatorTool::~TotalSegmentatorTool() { itksys::SystemTools::RemoveADirectory(this->GetMitkTempDir()); } mitk::TotalSegmentatorTool::TotalSegmentatorTool() { this->IsTimePointChangeAwareOff(); } void mitk::TotalSegmentatorTool::Activated() { Superclass::Activated(); this->SetLabelTransferMode(LabelTransferMode::AllLabels); } const char **mitk::TotalSegmentatorTool::GetXPM() const { return nullptr; } us::ModuleResource mitk::TotalSegmentatorTool::GetIconResource() const { us::Module *module = us::GetModuleContext()->GetModule(); us::ModuleResource resource = module->GetResource("AI.svg"); return resource; } const char *mitk::TotalSegmentatorTool::GetName() const { return "TotalSegmentator"; } void mitk::TotalSegmentatorTool::onPythonProcessEvent(itk::Object * /*pCaller*/, const itk::EventObject &e, void *) { std::string testCOUT; std::string testCERR; const auto *pEvent = dynamic_cast(&e); if (pEvent) { testCOUT = testCOUT + pEvent->GetOutput(); MITK_INFO << testCOUT; } const auto *pErrEvent = dynamic_cast(&e); if (pErrEvent) { testCERR = testCERR + pErrEvent->GetOutput(); MITK_ERROR << testCERR; } } void mitk::TotalSegmentatorTool::DoUpdatePreview(const Image *inputAtTimeStep, const Image * /*oldSegAtTimeStep*/, LabelSetImage *previewImage, TimeStepType timeStep) { if (this->m_MitkTempDir.empty()) { this->SetMitkTempDir(IOUtil::CreateTemporaryDirectory("mitk-XXXXXX")); } if (m_LabelMapTotal.empty()) { this->ParseLabelNames(this->GetLabelMapPath()); } ProcessExecutor::Pointer spExec = ProcessExecutor::New(); itk::CStyleCommand::Pointer spCommand = itk::CStyleCommand::New(); spCommand->SetCallback(&onPythonProcessEvent); spExec->AddObserver(ExternalProcessOutputEvent(), spCommand); std::string inDir, outDir, inputImagePath, outputImagePath, scriptPath; inDir = IOUtil::CreateTemporaryDirectory("totalseg-in-XXXXXX", this->GetMitkTempDir()); std::ofstream tmpStream; inputImagePath = IOUtil::CreateTemporaryFile(tmpStream, TEMPLATE_FILENAME, inDir + IOUtil::GetDirectorySeparator()); tmpStream.close(); std::size_t found = inputImagePath.find_last_of(IOUtil::GetDirectorySeparator()); std::string fileName = inputImagePath.substr(found + 1); std::string token = fileName.substr(0, fileName.find("_")); outDir = IOUtil::CreateTemporaryDirectory("totalseg-out-XXXXXX", this->GetMitkTempDir()); outputImagePath = outDir + IOUtil::GetDirectorySeparator() + token + "_000.nii.gz"; - ProcessExecutor::ArgumentListType args; + mitk::LabelSetImage::Pointer outputBuffer; IOUtil::Save(inputAtTimeStep, inputImagePath); std::string &outArg = outputImagePath; bool isSubTask = false; if (this->GetSubTask() != DEFAULT_TOTAL_TASK) { isSubTask = true; outputImagePath = outDir + IOUtil::GetDirectorySeparator() + this->GetSubTask() + ".nii.gz"; outArg = outDir; } this->run_totalsegmentator( spExec, inputImagePath, outArg, this->GetFast(), !isSubTask, this->GetGpuId(), DEFAULT_TOTAL_TASK); if (isSubTask) - { + { // Run total segmentator again this->run_totalsegmentator( spExec, inputImagePath, outArg, !isSubTask, !isSubTask, this->GetGpuId(), this->GetSubTask()); + // Construct Label Id map + std::vector files = SUBTASKS_MAP.at(this->GetSubTask()); + std::map labelMapSubtask; + int labelId = 1; + for (auto const& file : files) + { + std::string labelName = file.substr(0, file.find('.')); + labelMapSubtask[labelId] = labelName; + labelId++; + } + // Agglomerate individual mask files into one multi-label image. + std::for_each(files.begin(), + files.end(), + [&](std::string &fileName) { fileName = (outDir + IOUtil::GetDirectorySeparator() + fileName); }); + outputBuffer = AgglomerateLabelFiles(files, inputAtTimeStep->GetDimensions(), inputAtTimeStep->GetGeometry()); + // Assign label names to the agglomerated LabelSetImage + this->MapLabelsToSegmentation(outputBuffer, labelMapSubtask); } + else + { + Image::Pointer outputImage = IOUtil::Load(outputImagePath); + outputBuffer = mitk::LabelSetImage::New(); + outputBuffer->InitializeByLabeledImage(outputImage); + outputBuffer->SetGeometry(inputAtTimeStep->GetGeometry()); + this->MapLabelsToSegmentation(outputBuffer, m_LabelMapTotal); + } + this->TransferLabelSetImageContent(outputBuffer, previewImage, timeStep); +} - Image::Pointer outputImage = IOUtil::Load(outputImagePath); - auto outputBuffer = mitk::LabelSetImage::New(); - outputBuffer->InitializeByLabeledImage(outputImage); - outputBuffer->SetGeometry(inputAtTimeStep->GetGeometry()); - this->MapLabelsToSegmentation(outputBuffer, m_LabelMapTotal); - TransferLabelSetImageContent(outputBuffer, previewImage, timeStep); +mitk::LabelSetImage::Pointer mitk::TotalSegmentatorTool::AgglomerateLabelFiles(std::vector &filePaths, + unsigned int *dimensions, + mitk::BaseGeometry *geometry) +{ + int labelId = 1; + auto aggloLabelImage = mitk::LabelSetImage::New(); + auto initImage = mitk::Image::New(); + initImage->Initialize(mitk::MakeScalarPixelType(), 3, dimensions); + aggloLabelImage->Initialize(initImage); + aggloLabelImage->SetGeometry(geometry); + mitk::LabelSet::Pointer newlayer = mitk::LabelSet::New(); + newlayer->SetLayer(0); + aggloLabelImage->AddLayer(newlayer); + + for (auto const &outputImagePath : filePaths) + { + auto label = mitk::Label::New(); + label->SetName("object-" + std::to_string(labelId)); + label->SetValue(labelId); //TODO: set color + aggloLabelImage->GetActiveLabelSet()->AddLabel(label); + + Image::Pointer outputImage = IOUtil::Load(outputImagePath); + auto source = mitk::LabelSetImage::New(); + source->InitializeByLabeledImage(outputImage); + source->SetGeometry(geometry); + + auto labelSet = aggloLabelImage->GetActiveLabelSet(); + mitk::TransferLabelContent(source, aggloLabelImage, labelSet, 0, 0, false, {{1, labelId}}); + labelId++; + } + return aggloLabelImage; } void mitk::TotalSegmentatorTool::run_totalsegmentator(ProcessExecutor::Pointer spExec, const std::string &inputImagePath, const std::string &outputImagePath, bool isFast, bool isMultiLabel, unsigned int gpuId, const std::string &subTask) { ProcessExecutor::ArgumentListType args; std::string command = "TotalSegmentator"; #if defined(__APPLE__) || defined(_WIN32) command = "python"; #endif args.clear(); #ifdef _WIN32 std::string ending = "Scripts"; if (0 == this->GetPythonPath().compare(this->GetPythonPath().length() - ending.length(), ending.length(), ending)) { args.push_back("TotalSegmentator"); } else { args.push_back("Scripts/TotalSegmentator"); } #endif #if defined(__APPLE__) args.push_back("TotalSegmentator"); #endif args.push_back("-i"); args.push_back(inputImagePath); args.push_back("-o"); args.push_back(outputImagePath); if (subTask != DEFAULT_TOTAL_TASK) { args.push_back("-ta"); args.push_back(subTask); } if (isMultiLabel) { args.push_back("--ml"); } if (isFast) { args.push_back("--fast"); } try { std::string cudaEnv = "CUDA_VISIBLE_DEVICES=" + std::to_string(gpuId); itksys::SystemTools::PutEnv(cudaEnv.c_str()); for (auto &arg : args) MITK_INFO << arg; MITK_INFO << this->GetPythonPath(); spExec->Execute(this->GetPythonPath(), command, args); } catch (const mitk::Exception &e) { MITK_ERROR << e.GetDescription(); return; } } void mitk::TotalSegmentatorTool::ParseLabelNames(const std::string &fileName) { std::fstream newfile; newfile.open(fileName, ios::in); std::stringstream buffer; if (newfile.is_open()) { int line = 0; std::string temp; while (std::getline(newfile, temp)) { if (line > 1 && line < 106) { buffer << temp; } ++line; } } std::string key, val; while (std::getline(std::getline(buffer, key, ':'), val, ',')) { m_LabelMapTotal[std::stoi(key)] = val; } } void mitk::TotalSegmentatorTool::MapLabelsToSegmentation(mitk::LabelSetImage::Pointer outputBuffer, std::map &labelMap) { for (auto const &[key, val] : labelMap) { - mitk::Label *labelptr = outputBuffer->GetLabel(key, 0); + mitk::Label *labelptr = outputBuffer->GetActiveLabelSet()->GetLabel(key); if (nullptr != labelptr) { labelptr->SetName(val); } } } std::string mitk::TotalSegmentatorTool::GetLabelMapPath() { std::string pythonFileName; std::filesystem::path pathToLabelMap(this->GetPythonPath()); pathToLabelMap = pathToLabelMap.parent_path(); #ifdef _WIN32 pythonFileName = pathToLabelMap.string() + "/Lib/site-packages/totalsegmentator/map_to_binary.py"; #else pathToLabelMap.append("lib"); for (auto const &dir_entry : std::filesystem::directory_iterator{pathToLabelMap}) { if (dir_entry.is_directory()) { auto dirName = dir_entry.path().filename().string(); if (dirName.rfind("python", 0) == 0) { pathToLabelMap.append(dir_entry.path().filename().string()); break; } } } pythonFileName = pathToLabelMap.string() + "/site-packages/totalsegmentator/map_to_binary.py"; #endif return pythonFileName; } diff --git a/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.h b/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.h index 770ec99c0b..d507d443fc 100644 --- a/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.h +++ b/Modules/Segmentation/Interactions/mitkTotalSegmentatorTool.h @@ -1,130 +1,150 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef MITKTOTALSEGMENTATORTOOL_H #define MITKTOTALSEGMENTATORTOOL_H #include "mitkSegWithPreviewTool.h" #include #include "mitkProcessExecutor.h" namespace us { class ModuleResource; } namespace mitk { /** \brief TotalSegmentator segmentation tool. \ingroup Interaction \ingroup ToolManagerEtAl \warning Only to be instantiated by mitk::ToolManager. */ class MITKSEGMENTATION_EXPORT TotalSegmentatorTool : public SegWithPreviewTool { public: mitkClassMacro(TotalSegmentatorTool, SegWithPreviewTool); itkFactorylessNewMacro(Self); itkCloneMacro(Self); const char *GetName() const override; const char **GetXPM() const override; us::ModuleResource GetIconResource() const override; void Activated() override; itkSetMacro(MitkTempDir, std::string); itkGetConstMacro(MitkTempDir, std::string); itkSetMacro(SubTask, std::string); itkGetConstMacro(SubTask, std::string); itkSetMacro(PythonPath, std::string); itkGetConstMacro(PythonPath, std::string); itkSetMacro(GpuId, unsigned int); itkGetConstMacro(GpuId, unsigned int); itkSetMacro(Fast, bool); itkGetConstMacro(Fast, bool); itkBooleanMacro(Fast); /** * @brief Static function to print out everything from itk::EventObject. * Used as callback in mitk::ProcessExecutor object. * */ static void onPythonProcessEvent(itk::Object *, const itk::EventObject &e, void *); protected: TotalSegmentatorTool(); ~TotalSegmentatorTool(); /** * @brief Overriden method from the tool manager to execute the segmentation * Implementation: * 1. Creates temp directory, if not done already. * 2. Parses Label names from map_to_binary.py for using later on. * 3. Calls "run_totalsegmentator" method. * 4. Expects an output image to be saved in the temporary directory by the python proces. Loads it as * LabelSetImage and sets to previewImage. * * @param inputAtTimeStep * @param oldSegAtTimeStep * @param previewImage * @param timeStep */ void DoUpdatePreview(const Image* inputAtTimeStep, const Image* oldSegAtTimeStep, LabelSetImage* previewImage, TimeStepType timeStep) override; private: /** * @brief Runs Totalsegmentator python process with desired arguments * */ void run_totalsegmentator(ProcessExecutor::Pointer, const std::string&, const std::string&, bool, bool, unsigned int, const std::string&); /** * @brief Applies the m_LabelMapTotal lookup table on the output segmentation LabelSetImage. * */ void MapLabelsToSegmentation(mitk::LabelSetImage::Pointer, std::map&); /** * @brief Parses map_to_binary.py file to extract label ids and names * and stores as a map for reference in m_LabelMapTotal * * @param filePath */ - void ParseLabelNames(const std::string &); + void ParseLabelNames(const std::string&); /** * @brief Get the Label Map Path from the virtual environment location * * @return std::string */ std::string GetLabelMapPath(); + /** + * @brief Agglomerate many individual mask image files into one multi-label LabelSetImage in the + * given filePath order. + * + * @param filePaths + * @param dimension + * @param geometry + * @return LabelSetImage::Pointer + */ + LabelSetImage::Pointer AgglomerateLabelFiles(std::vector& filePaths, unsigned int* dimension, mitk::BaseGeometry* geometry); + std::string m_MitkTempDir; std::string m_PythonPath; std::string m_SubTask = "total"; unsigned int m_GpuId = 0; std::map m_LabelMapTotal; bool m_Fast = true; const std::string TEMPLATE_FILENAME = "XXXXXX_000_0000.nii.gz"; const std::string DEFAULT_TOTAL_TASK = "total"; + const std::unordered_map> SUBTASKS_MAP = + { + {"body", { "body.nii.gz", "body_trunc.nii.gz", "body_extremities.nii.gz", "skin.nii.gz"}}, + {"hip_implant", {"hip_implant.nii.gz"}}, + {"cerebral_bleed", {"intracerebral_hemorrhage.nii.gz"}}, + {"coronary_arteries", {"coronary_arteries.nii.gz"}}, + {"lung_vessels", {"lung_vessels.nii.gz", "lung_trachea_bronchia.nii.gz"}}, + {"pleural_pericard_effusion", {"pleural_effusion.nii.gz", "pericardial_effusion.nii.gz"}} + }; }; // class } // namespace #endif diff --git a/Modules/SegmentationUI/Qmitk/QmitkTotalSegmentatorToolGUI.h b/Modules/SegmentationUI/Qmitk/QmitkTotalSegmentatorToolGUI.h index 05842ed268..d26fdf5bee 100644 --- a/Modules/SegmentationUI/Qmitk/QmitkTotalSegmentatorToolGUI.h +++ b/Modules/SegmentationUI/Qmitk/QmitkTotalSegmentatorToolGUI.h @@ -1,168 +1,176 @@ #ifndef QmitkTotalSegmentatorToolGUI_h_Included #define QmitkTotalSegmentatorToolGUI_h_Included #include "QmitkMultiLabelSegWithPreviewToolGUIBase.h" #include "QmitkSetupVirtualEnvUtil.h" #include "QmitknnUNetGPU.h" #include "ui_QmitkTotalSegmentatorGUIControls.h" #include #include #include #include /** * @brief Installer class for TotalSegmentator Tool. * Class specifies the virtual environment name, install version, packages required to pip install * and implements SetupVirtualEnv method. * */ class QmitkTotalSegmentatorToolInstaller : public QmitkSetupVirtualEnvUtil { public: const QString VENV_NAME = ".totalsegmentator"; const QString TOTALSEGMENTATOR_VERSION = "1.5.3"; const std::vector PACKAGES = {QString("Totalsegmentator==") + TOTALSEGMENTATOR_VERSION, QString("scipy==1.9.1")}; const QString STORAGE_DIR; inline QmitkTotalSegmentatorToolInstaller( const QString baseDir = QStandardPaths::writableLocation(QStandardPaths::GenericDataLocation) + QDir::separator() + qApp->organizationName() + QDir::separator()) : QmitkSetupVirtualEnvUtil(baseDir), STORAGE_DIR(baseDir){}; bool SetupVirtualEnv(const QString &) override; QString GetVirtualEnvPath() override; }; /** \ingroup org_mitk_gui_qt_interactivesegmentation_internal \brief GUI for mitk::TotalSegmentatorTool. \sa mitk:: */ class MITKSEGMENTATIONUI_EXPORT QmitkTotalSegmentatorToolGUI : public QmitkMultiLabelSegWithPreviewToolGUIBase { Q_OBJECT public: mitkClassMacro(QmitkTotalSegmentatorToolGUI, QmitkMultiLabelSegWithPreviewToolGUIBase); itkFactorylessNewMacro(Self); itkCloneMacro(Self); protected slots: /** * @brief Qt Slot */ void OnPreviewBtnClicked(); /** * @brief Qt Slot */ void OnPythonPathChanged(const QString &); /** * @brief Qt Slot */ QString OnSystemPythonChanged(const QString &); /** * @brief Qt Slot */ void OnInstallBtnClicked(); /** * @brief Qt Slot */ void OnOverrideChecked(int); /** * @brief Qt Slot */ void OnClearInstall(); protected: QmitkTotalSegmentatorToolGUI(); ~QmitkTotalSegmentatorToolGUI() = default; void ConnectNewTool(mitk::SegWithPreviewTool *newTool) override; void InitializeUI(QBoxLayout *mainLayout) override; /** * @brief Enable (or Disable) GUI elements. */ void EnableAll(bool); /** * @brief Searches and parses paths of python virtual enviroments * from predefined lookout locations */ void AutoParsePythonPaths(); /** * @brief Checks if TotalSegmentator command is valid in the selected python virtual environment. * * @return bool */ bool IsTotalSegmentatorInstalled(const QString &); /** * @brief Creates a QMessage object and shows on screen. */ void ShowErrorMessage(const std::string &, QMessageBox::Icon = QMessageBox::Critical); /** * @brief Writes any message in white on the tool pane. */ void WriteStatusMessage(const QString &); /** * @brief Writes any message in red on the tool pane. */ void WriteErrorMessage(const QString &); /** * @brief Adds GPU information to the gpu combo box. * In case, there aren't any GPUs avaialble, the combo box will be * rendered editable. */ void SetGPUInfo(); /** * @brief Returns GPU id of the selected GPU from the Combo box. * * @return unsigned int */ unsigned int FetchSelectedGPUFromUI() const; /** * @brief Get the virtual env path from UI combobox removing any * extra special characters. * * @return QString */ QString GetPythonPathFromUI(const QString &) const; /** * @brief Get the Exact Python Path for any OS * from the virtual environment path. * @return QString */ QString GetExactPythonPath(const QString &) const; /** * @brief For storing values like Python path across sessions. */ QSettings m_Settings; QString m_PythonPath; QmitkGPULoader m_GpuLoader; Ui_QmitkTotalSegmentatorToolGUIControls m_Controls; bool m_FirstPreviewComputation = true; bool m_IsInstalled = false; EnableConfirmSegBtnFunctionType m_SuperclassEnableConfirmSegBtnFnc; const std::string WARNING_TOTALSEG_NOT_FOUND = "TotalSegmentator is not detected in the selected python environment.Please select a valid " "python environment or install TotalSegmentator."; - const QStringList VALID_TASKS = {"total", "cerebral_bleed", "hip_implant", "coronary_arteries"}; + const QStringList VALID_TASKS = { + "total", + "cerebral_bleed", + "hip_implant", + "coronary_arteries", + "body", + "lung_vessels", + "pleural_pericard_effusion" + }; QmitkTotalSegmentatorToolInstaller m_Installer; }; #endif