diff --git a/Modules/Segmentation/Interactions/mitknnUnetTool.cpp b/Modules/Segmentation/Interactions/mitknnUnetTool.cpp index 5c08b44a0a..99e388a13a 100644 --- a/Modules/Segmentation/Interactions/mitknnUnetTool.cpp +++ b/Modules/Segmentation/Interactions/mitknnUnetTool.cpp @@ -1,320 +1,325 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "mitknnUnetTool.h" #include "mitkIOUtil.h" #include "mitkProcessExecutor.h" #include #include #include #include #include #include #include namespace mitk { MITK_TOOL_MACRO(MITKSEGMENTATION_EXPORT, nnUNetTool, "nnUNet tool"); } mitk::nnUNetTool::nnUNetTool() { this->SetMitkTempDir(IOUtil::CreateTemporaryDirectory("mitk-XXXXXX")); } mitk::nnUNetTool::~nnUNetTool() { itksys::SystemTools::RemoveADirectory(this->GetMitkTempDir()); } void mitk::nnUNetTool::Activated() { Superclass::Activated(); } void mitk::nnUNetTool::UpdateCleanUp() { // This overriden method is intentionally left out for setting later upon demand // in the `RenderOutputBuffer` method. } void mitk::nnUNetTool::RenderOutputBuffer() { if (this->m_OutputBuffer != nullptr) { Superclass::SetNodeProperties(this->m_OutputBuffer); this->ClearOutputBuffer(); try { if (nullptr != this->GetPreviewSegmentationNode()) { this->GetPreviewSegmentationNode()->SetVisibility(!this->GetSelectedLabels().empty()); } if (this->GetSelectedLabels().empty()) { this->ResetPreviewNode(); } } catch (const mitk::Exception &e) { MITK_INFO << e.GetDescription(); } } } void mitk::nnUNetTool::SetNodeProperties(LabelSetImage::Pointer segmentation) { // This overriden method doesn't set node properties. Intentionally left out for setting later upon demand // in the `RenderOutputBuffer` method. this->m_OutputBuffer = segmentation; } mitk::LabelSetImage::Pointer mitk::nnUNetTool::GetOutputBuffer() { return this->m_OutputBuffer; } void mitk::nnUNetTool::ClearOutputBuffer() { this->m_OutputBuffer = nullptr; } us::ModuleResource mitk::nnUNetTool::GetIconResource() const { us::Module *module = us::GetModuleContext()->GetModule(); us::ModuleResource resource = module->GetResource("Watershed_48x48.png"); return resource; } const char **mitk::nnUNetTool::GetXPM() const { return nullptr; } const char *mitk::nnUNetTool::GetName() const { return "nnUNet"; } +mitk::DataStorage *mitk::nnUNetTool::GetDataStorage() +{ + return this->GetToolManager()->GetDataStorage(); +} + namespace { void onPythonProcessEvent(itk::Object * /*pCaller*/, const itk::EventObject &e, void *) { std::string testCOUT; std::string testCERR; const auto *pEvent = dynamic_cast(&e); if (pEvent) { testCOUT = testCOUT + pEvent->GetOutput(); MITK_INFO << testCOUT; } const auto *pErrEvent = dynamic_cast(&e); if (pErrEvent) { testCERR = testCERR + pErrEvent->GetOutput(); MITK_ERROR << testCERR; } } } // namespace mitk::LabelSetImage::Pointer mitk::nnUNetTool::ComputeMLPreview(const Image *inputAtTimeStep, TimeStepType /*timeStep*/) { Image::Pointer _inputAtTimeStep = inputAtTimeStep->Clone(); std::string inDir, outDir, inputImagePath, outputImagePath, scriptPath; std::string templateFilename = "XXXXXX_000_0000.nii.gz"; ProcessExecutor::Pointer spExec = ProcessExecutor::New(); itk::CStyleCommand::Pointer spCommand = itk::CStyleCommand::New(); spCommand->SetCallback(&onPythonProcessEvent); spExec->AddObserver(ExternalProcessOutputEvent(), spCommand); ProcessExecutor::ArgumentListType args; inDir = IOUtil::CreateTemporaryDirectory("nnunet-in-XXXXXX", this->GetMitkTempDir()); std::ofstream tmpStream; inputImagePath = IOUtil::CreateTemporaryFile(tmpStream, templateFilename, inDir + IOUtil::GetDirectorySeparator()); tmpStream.close(); std::size_t found = inputImagePath.find_last_of(IOUtil::GetDirectorySeparator()); std::string fileName = inputImagePath.substr(found + 1); std::string token = fileName.substr(0, fileName.find("_")); if (this->GetNoPip()) { scriptPath = this->GetnnUNetDirectory() + IOUtil::GetDirectorySeparator() + "nnunet" + IOUtil::GetDirectorySeparator() + "inference" + IOUtil::GetDirectorySeparator() + "predict_simple.py"; } try { IOUtil::Save(_inputAtTimeStep.GetPointer(), inputImagePath); if (this->GetMultiModal()) { for (size_t i = 0; i < this->m_OtherModalPaths.size(); ++i) { std::string inModalFile = this->m_OtherModalPaths[i]; std::string outModalFile = inDir + IOUtil::GetDirectorySeparator() + token + "_000_000" + std::to_string(i + 1) + ".nii.gz"; std::ifstream src(inModalFile, std::ios::binary); std::ofstream dst(outModalFile, std::ios::binary); dst << src.rdbuf(); dst.close(); src.close(); } } } catch (const mitk::Exception &e) { /* Can't throw mitk exception to the caller. Refer: T28691 */ MITK_ERROR << e.GetDescription(); return nullptr; } // Code calls external process std::string command = "nnUNet_predict"; if (this->GetNoPip()) { #ifdef _WIN32 command = "python"; #else command = "python3"; #endif } for (ModelParams &modelparam : m_ParamQ) { outDir = IOUtil::CreateTemporaryDirectory("nnunet-out-XXXXXX", this->GetMitkTempDir()); outputImagePath = outDir + IOUtil::GetDirectorySeparator() + token + "_000.nii.gz"; modelparam.outputDir = outDir; args.clear(); if (this->GetNoPip()) { args.push_back(scriptPath); } args.push_back("-i"); args.push_back(inDir); args.push_back("-o"); args.push_back(outDir); args.push_back("-t"); args.push_back(modelparam.task); if (modelparam.model.find("cascade") != std::string::npos) { args.push_back("-ctr"); } else { args.push_back("-tr"); } args.push_back(modelparam.trainer); args.push_back("-m"); args.push_back(modelparam.model); args.push_back("-p"); args.push_back(modelparam.planId); if (!modelparam.folds.empty()) { args.push_back("-f"); for (auto fold : modelparam.folds) { args.push_back(fold); } } // args.push_back("--all_in_gpu"); // args.push_back(this->GetAllInGPU() ? std::string("True") : std::string("False")); // args.push_back("--num_threads_preprocessing"); // args.push_back(std::to_string(this->GetPreprocessingThreads())); args.push_back("--num_threads_nifti_save"); args.push_back("1"); // fixing to 1 if (!this->GetMirror()) { args.push_back("--disable_tta"); } if (!this->GetMixedPrecision()) { args.push_back("--disable_mixed_precision"); } if (this->GetEnsemble() && !this->GetPostProcessingJsonDirectory().empty()) { args.push_back("--save_npz"); } try { std::string resultsFolderEnv = "RESULTS_FOLDER=" + this->GetModelDirectory(); itksys::SystemTools::PutEnv(resultsFolderEnv.c_str()); std::string cudaEnv = "CUDA_VISIBLE_DEVICES=" + std::to_string(this->GetGpuId()); itksys::SystemTools::PutEnv(cudaEnv.c_str()); spExec->Execute(this->GetPythonPath(), command, args); } catch (const mitk::Exception &e) { /* Can't throw mitk exception to the caller. Refer: T28691 */ MITK_ERROR << e.GetDescription(); return nullptr; } } if (this->GetEnsemble() && !this->GetPostProcessingJsonDirectory().empty()) { args.clear(); command = "nnUNet_ensemble"; outDir = IOUtil::CreateTemporaryDirectory("nnunet-ensemble-out-XXXXXX", this->GetMitkTempDir()); outputImagePath = outDir + IOUtil::GetDirectorySeparator() + token + "_000.nii.gz"; args.push_back("-f"); for (ModelParams &modelparam : m_ParamQ) { args.push_back(modelparam.outputDir); } args.push_back("-o"); args.push_back(outDir); args.push_back("-pp"); args.push_back(this->GetPostProcessingJsonDirectory()); spExec->Execute(this->GetPythonPath(), command, args); } try { LabelSetImage::Pointer resultImage = LabelSetImage::New(); Image::Pointer outputImage = IOUtil::Load(outputImagePath); resultImage->InitializeByLabeledImage(outputImage); resultImage->SetGeometry(_inputAtTimeStep->GetGeometry()); return resultImage; } catch (const mitk::Exception &e) { /* Can't throw mitk exception to the caller. Refer: T28691 */ MITK_ERROR << e.GetDescription(); return nullptr; } } diff --git a/Modules/Segmentation/Interactions/mitknnUnetTool.h b/Modules/Segmentation/Interactions/mitknnUnetTool.h index 519f3ea016..0e0d103c87 100644 --- a/Modules/Segmentation/Interactions/mitknnUnetTool.h +++ b/Modules/Segmentation/Interactions/mitknnUnetTool.h @@ -1,191 +1,197 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef mitknnUnetTool_h_Included #define mitknnUnetTool_h_Included #include "mitkAutoMLSegmentationWithPreviewTool.h" #include "mitkCommon.h" +#include "mitkToolManager.h" #include #include namespace us { class ModuleResource; } namespace mitk { /** * @brief nnUNet parameter request object holding all model parameters for input. * Also holds output temporary directory path. */ struct ModelParams { std::string task; std::vector folds; std::string model; std::string trainer; std::string planId; std::string outputDir; }; /** \brief nnUNet segmentation tool. \ingroup Interaction \ingroup ToolManagerEtAl \warning Only to be instantiated by mitk::ToolManager. */ class MITKSEGMENTATION_EXPORT nnUNetTool : public AutoMLSegmentationWithPreviewTool { public: mitkClassMacro(nnUNetTool, AutoMLSegmentationWithPreviewTool); itkFactorylessNewMacro(Self); itkCloneMacro(Self); const char **GetXPM() const override; const char *GetName() const override; us::ModuleResource GetIconResource() const override; void Activated() override; itkSetMacro(nnUNetDirectory, std::string); itkGetConstMacro(nnUNetDirectory, std::string); itkSetMacro(ModelDirectory, std::string); itkGetConstMacro(ModelDirectory, std::string); itkSetMacro(PythonPath, std::string); itkGetConstMacro(PythonPath, std::string); itkSetMacro(MitkTempDir, std::string); itkGetConstMacro(MitkTempDir, std::string); itkSetMacro(PostProcessingJsonDirectory, std::string); itkGetConstMacro(PostProcessingJsonDirectory, std::string); /* itkSetMacro(UseGPU, bool); itkGetConstMacro(UseGPU, bool); itkBooleanMacro(UseGPU); itkSetMacro(AllInGPU, bool); itkGetConstMacro(AllInGPU, bool); itkBooleanMacro(AllInGPU); */ itkSetMacro(MixedPrecision, bool); itkGetConstMacro(MixedPrecision, bool); itkBooleanMacro(MixedPrecision); itkSetMacro(Mirror, bool); itkGetConstMacro(Mirror, bool); itkBooleanMacro(Mirror); itkSetMacro(MultiModal, bool); itkGetConstMacro(MultiModal, bool); itkBooleanMacro(MultiModal); itkSetMacro(NoPip, bool); itkGetConstMacro(NoPip, bool); itkBooleanMacro(NoPip); itkSetMacro(Ensemble, bool); itkGetConstMacro(Ensemble, bool); itkBooleanMacro(Ensemble); itkSetMacro(GpuId, unsigned int); itkGetConstMacro(GpuId, unsigned int); /** * @brief vector of ModelParams. * Size > 1 only for ensemble prediction. */ std::vector m_ParamQ; /** * @brief Holds paths to other input image modalities. * */ std::vector m_OtherModalPaths; /** * @brief Renders the output LabelSetImage. * To called in the main thread. */ void RenderOutputBuffer(); /** * @brief Get the Output Buffer object * * @return LabelSetImage::Pointer */ LabelSetImage::Pointer GetOutputBuffer(); /** * @brief Sets the outputBuffer to nullptr * */ void ClearOutputBuffer(); + /** + * @brief Returns the DataStorage from the ToolManager + */ + mitk::DataStorage *GetDataStorage(); + protected: /** * @brief Construct a new nnUNet Tool object and temp directory. * */ nnUNetTool(); /** * @brief Destroy the nnUNet Tool object and deletes the temp directory. * */ ~nnUNetTool(); /** * @brief Overriden method from the tool manager to execute the segmentation * Implementation: * 1. Saves the inputAtTimeStep in a temporary directory. * 2. Copies other modalities, renames and saves in the temporary directory, if required. * 3. Sets RESULTS_FOLDER and CUDA_VISIBLE_DEVICES variables in the environment. * 3. Iterates through the parameter queue (m_ParamQ) and executes "nnUNet_predict" command with the parameters * 4. Expects an output image to be saved in the temporary directory by the python proces. Loads it as * LabelSetImage and returns. * * @param inputAtTimeStep * @param timeStep * @return LabelSetImage::Pointer */ LabelSetImage::Pointer ComputeMLPreview(const Image *inputAtTimeStep, TimeStepType timeStep) override; void UpdateCleanUp() override; void SetNodeProperties(LabelSetImage::Pointer) override; private: std::string m_MitkTempDir; std::string m_nnUNetDirectory; std::string m_ModelDirectory; std::string m_PythonPath; std::string m_PostProcessingJsonDirectory; //bool m_UseGPU; kept for future //bool m_AllInGPU; bool m_MixedPrecision; bool m_Mirror; bool m_NoPip; bool m_MultiModal; bool m_Ensemble = false; LabelSetImage::Pointer m_OutputBuffer; unsigned int m_GpuId; }; } // namespace mitk #endif diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp index 741e1c3f9d..01a0ea51e8 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp @@ -1,269 +1,268 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "QmitknnUNetToolGUI.h" #include "mitknnUnetTool.h" #include -#include #include MITK_TOOL_GUI_MACRO(MITKSEGMENTATIONUI_EXPORT, QmitknnUNetToolGUI, "") QmitknnUNetToolGUI::QmitknnUNetToolGUI() : QmitkAutoMLSegmentationToolGUIBase() { // Nvidia-smi command returning zero doesn't alway mean lack of GPUs. // Pytorch uses its own libraries to communicate to the GPUs. Hence, only a warning can be given. if (m_GpuLoader.GetGPUCount() == 0) { - std::string warning= "WARNING: No GPUs were detected on your machine. The nnUNet tool might not work."; + std::string warning = "WARNING: No GPUs were detected on your machine. The nnUNet tool might not work."; ShowErrorMessage(warning); } m_SegmentationThread = new QThread(this); m_Worker = new nnUNetSegmentationWorker; m_Worker->moveToThread(m_SegmentationThread); + + // define predicates for multi modal data selection combobox + auto imageType = mitk::TNodePredicateDataType::New(); + auto labelSetImageType = mitk::NodePredicateNot::New(mitk::TNodePredicateDataType::New()); + this->m_MultiModalPredicate = mitk::NodePredicateAnd::New(imageType, labelSetImageType).GetPointer(); } QmitknnUNetToolGUI::~QmitknnUNetToolGUI() { this->m_SegmentationThread->quit(); this->m_SegmentationThread->wait(); } void QmitknnUNetToolGUI::ConnectNewTool(mitk::AutoSegmentationWithPreviewTool *newTool) { Superclass::ConnectNewTool(newTool); newTool->IsTimePointChangeAwareOff(); } void QmitknnUNetToolGUI::InitializeUI(QBoxLayout *mainLayout) { m_Controls.setupUi(this); #ifndef _WIN32 m_Controls.pythonEnvComboBox->addItem("/usr/bin"); #endif m_Controls.pythonEnvComboBox->addItem("Select"); AutoParsePythonPaths(); connect(m_Controls.previewButton, SIGNAL(clicked()), this, SLOT(OnSettingsAccepted())); connect(m_Controls.modeldirectoryBox, SIGNAL(directoryChanged(const QString &)), this, SLOT(OnDirectoryChanged(const QString &))); connect( m_Controls.modelBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnModelChanged(const QString &))); connect(m_Controls.taskBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnTaskChanged(const QString &))); connect( m_Controls.plannerBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnTrainerChanged(const QString &))); connect(m_Controls.nopipBox, SIGNAL(stateChanged(int)), this, SLOT(OnCheckBoxChanged(int))); connect(m_Controls.multiModalBox, SIGNAL(stateChanged(int)), this, SLOT(OnCheckBoxChanged(int))); connect(m_Controls.multiModalSpinBox, SIGNAL(valueChanged(int)), this, SLOT(OnModalitiesNumberChanged(int))); + connect(m_Controls.posSpinBox, SIGNAL(valueChanged(int)), this, SLOT(OnModalPositionChanged(int))); connect(m_Controls.pythonEnvComboBox, #if QT_VERSION >= 0x050F00 // 5.15 SIGNAL(textActivated(const QString &)), #elif QT_VERSION >= 0x050C00 // 5.12 SIGNAL(activated(const QString &)), #endif this, SLOT(OnPythonPathChanged(const QString &))); connect(this, &QmitknnUNetToolGUI::Operate, m_Worker, &nnUNetSegmentationWorker::DoWork); connect(m_Worker, &nnUNetSegmentationWorker::Finished, this, &QmitknnUNetToolGUI::SegmentationResultHandler); connect(m_Worker, &nnUNetSegmentationWorker::Failed, this, &QmitknnUNetToolGUI::SegmentationProcessFailed); connect(m_SegmentationThread, &QThread::finished, m_Worker, &QObject::deleteLater); m_Controls.codedirectoryBox->setVisible(false); m_Controls.nnUnetdirLabel->setVisible(false); m_Controls.multiModalSpinBox->setVisible(false); m_Controls.multiModalSpinLabel->setVisible(false); + m_Controls.posSpinBoxLabel->setVisible(false); + m_Controls.posSpinBox->setVisible(false); m_Controls.statusLabel->setTextFormat(Qt::RichText); m_Controls.statusLabel->setText("STATUS: Welcome to nnUNet. " + QString::number(m_GpuLoader.GetGPUCount()) + " GPUs were detected."); if (m_GpuLoader.GetGPUCount() != 0) { m_Controls.gpuSpinBox->setMaximum(m_GpuLoader.GetGPUCount() - 1); } mainLayout->addLayout(m_Controls.verticalLayout); Superclass::InitializeUI(mainLayout); m_UI_ROWS = m_Controls.advancedSettingsLayout->rowCount(); // Must do. Row count is correct only here. } void QmitknnUNetToolGUI::OnSettingsAccepted() { - auto tool = this->GetConnectedToolAs(); + mitk::nnUNetTool::Pointer tool = this->GetConnectedToolAs(); if (nullptr != tool) { try { QString modelName = m_Controls.modelBox->currentText(); QString taskName = m_Controls.taskBox->currentText(); bool isNoPip = m_Controls.nopipBox->isChecked(); QString pythonPathTextItem = m_Controls.pythonEnvComboBox->currentText(); QString pythonPath = pythonPathTextItem.mid(pythonPathTextItem.indexOf(" ") + 1); #ifdef _WIN32 if (!isNoPip && !(pythonPath.endsWith("Scripts", Qt::CaseInsensitive) || pythonPath.endsWith("Scripts/", Qt::CaseInsensitive))) { pythonPath += QDir::separator() + QString("Scripts"); } #else if (!(pythonPath.endsWith("bin", Qt::CaseInsensitive) || pythonPath.endsWith("bin/", Qt::CaseInsensitive))) { pythonPath += QDir::separator() + QString("bin"); } #endif std::string nnUNetDirectory; if (isNoPip) { nnUNetDirectory = m_Controls.codedirectoryBox->directory().toStdString(); } else if (!IsNNUNetInstalled(pythonPath)) { throw std::runtime_error("nnUNet is not detected in the selected python environment. Please select a valid " "python environment or install nnUNet."); } QString trainerPlanner = m_Controls.trainerBox->currentText(); QString splitterString = "__"; tool->EnsembleOff(); if (modelName.startsWith("ensemble", Qt::CaseInsensitive)) { QString ppJsonFile = QDir::cleanPath(m_ModelDirectory + QDir::separator() + modelName + QDir::separator() + taskName + QDir::separator() + trainerPlanner + QDir::separator() + "postprocessing.json"); if (QFile(ppJsonFile).exists()) { tool->EnsembleOn(); tool->SetPostProcessingJsonDirectory(ppJsonFile.toStdString()); splitterString = "--"; } } QStringList trainerSplitParts = trainerPlanner.split(splitterString, QString::SplitBehavior::SkipEmptyParts); std::vector requestQ; if (tool->GetEnsemble()) { foreach (QString modelSet, trainerSplitParts) { modelSet.remove("ensemble_", Qt::CaseInsensitive); QStringList splitParts = modelSet.split("__", QString::SplitBehavior::SkipEmptyParts); QString modelName = splitParts.first(); QString trainer = splitParts.at(1); QString planId = splitParts.at(2); auto testfold = std::vector(1, "1"); mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, testfold); requestQ.push_back(modelObject); } } else { QString trainer = trainerSplitParts.first(); QString planId = trainerSplitParts.last(); std::vector fetchedFolds = FetchSelectedFoldsFromUI(); mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, fetchedFolds); requestQ.push_back(modelObject); } tool->m_ParamQ.clear(); tool->m_ParamQ = requestQ; tool->SetnnUNetDirectory(nnUNetDirectory); tool->SetPythonPath(pythonPath.toStdString()); tool->SetModelDirectory(m_ModelDirectory.left(m_ModelDirectory.lastIndexOf(QDir::separator())).toStdString()); // checkboxes tool->SetMirror(m_Controls.mirrorBox->isChecked()); tool->SetMixedPrecision(m_Controls.mixedPrecisionBox->isChecked()); tool->SetNoPip(isNoPip); tool->SetMultiModal(m_Controls.multiModalBox->isChecked()); // Spinboxes tool->SetGpuId(static_cast(m_Controls.gpuSpinBox->value())); // Multi-Modal tool->MultiModalOff(); if (m_Controls.multiModalBox->isChecked()) { tool->m_OtherModalPaths.clear(); tool->m_OtherModalPaths = FetchMultiModalPathsFromUI(); tool->MultiModalOn(); } if (!m_SegmentationThread->isRunning()) { MITK_DEBUG << "Starting thread..."; m_SegmentationThread->start(); } m_Controls.statusLabel->setText("STATUS: Starting Segmentation task... This might take a while."); emit Operate(tool); } catch (const std::exception &e) { - this->setCursor(Qt::ArrowCursor); - std::stringstream stream; - stream << "Error while processing parameters for nnUNet segmentation. Reason: " << e.what(); - QMessageBox *messageBox = new QMessageBox(QMessageBox::Critical, nullptr, stream.str().c_str()); - messageBox->exec(); - delete messageBox; - MITK_ERROR << stream.str(); + std::stringstream errorMsg; + errorMsg << "Error while processing parameters for nnUNet segmentation. Reason: " << e.what(); + ShowErrorMessage(errorMsg.str()); return; } catch (...) { - this->setCursor(Qt::ArrowCursor); - std::stringstream stream; - stream << "Unkown error occured while generation nnUNet segmentation."; - QMessageBox *messageBox = new QMessageBox(QMessageBox::Critical, nullptr, stream.str().c_str()); - messageBox->exec(); - delete messageBox; - MITK_ERROR << stream.str(); + std::string errorMsg = "Unkown error occured while generation nnUNet segmentation."; + ShowErrorMessage(errorMsg); return; } } } -std::vector QmitknnUNetToolGUI::FetchMultiModalPathsFromUI() + +std::vector QmitknnUNetToolGUI::FetchMultiModalPathsFromUI() //Needs to REWRITE { std::vector paths; - if (m_Controls.multiModalBox->isChecked() && !m_ModalPaths.empty()) + /* if (m_Controls.multiModalBox->isChecked() && !m_Modalities.empty()) { - for (auto modality : m_ModalPaths) + for (auto modality : m_Modalities) { paths.push_back(modality->currentPath().toStdString()); } - } + }*/ return paths; } bool QmitknnUNetToolGUI::IsNNUNetInstalled(const QString &pythonPath) { QString fullPath = pythonPath; #ifdef _WIN32 if (!(fullPath.endsWith("Scripts", Qt::CaseInsensitive) || fullPath.endsWith("Scripts/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("Scripts"); } #else if (!(fullPath.endsWith("bin", Qt::CaseInsensitive) || fullPath.endsWith("bin/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("bin"); } #endif return QFile::exists(fullPath + QDir::separator() + QString("nnUNet_predict")); } -void QmitknnUNetToolGUI::ShowErrorMessage(std::string &message) +void QmitknnUNetToolGUI::ShowErrorMessage(std::string &message, QMessageBox::Icon icon) { this->setCursor(Qt::ArrowCursor); - QMessageBox *messageBox = new QMessageBox(QMessageBox::Critical, nullptr, message.c_str()); + QMessageBox *messageBox = new QMessageBox(icon, nullptr, message.c_str()); messageBox->exec(); delete messageBox; MITK_WARN << message; } diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h index ecd51ce75a..4227cc675e 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h @@ -1,196 +1,204 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file.s ============================================================================*/ #ifndef QmitknnUNetToolGUI_h_Included #define QmitknnUNetToolGUI_h_Included #include "QmitkAutoMLSegmentationToolGUIBase.h" #include "QmitknnUNetGPU.h" #include "QmitknnUNetWorker.h" #include "mitknnUnetTool.h" #include "ui_QmitknnUNetToolGUIControls.h" #include #include #include +#include #include +#include class MITKSEGMENTATIONUI_EXPORT QmitknnUNetToolGUI : public QmitkAutoMLSegmentationToolGUIBase { Q_OBJECT public: mitkClassMacro(QmitknnUNetToolGUI, QmitkAutoMLSegmentationToolGUIBase); itkFactorylessNewMacro(Self); itkCloneMacro(Self); protected slots: /** * @brief Qt slot * */ void OnSettingsAccepted(); /** * @brief Qt slot * */ void OnDirectoryChanged(const QString &); /** * @brief Qt slot * */ void OnModelChanged(const QString &); /** * @brief Qt slot * */ void OnTaskChanged(const QString &); /** * @brief Qt slot * */ void OnTrainerChanged(const QString &); /** * @brief Qt slot * */ void OnCheckBoxChanged(int); /** * @brief Qthread slot to capture failures from thread worker and * shows error message * */ void SegmentationProcessFailed(); /** * @brief Qthread to capture sucessfull nnUNet segmentation. * Further, renders the LabelSet image */ void SegmentationResultHandler(mitk::nnUNetTool *); /** * @brief Qt Slot * */ void OnModalitiesNumberChanged(int); /** * @brief Qt Slot * */ void OnPythonPathChanged(const QString &); + /** + * @brief Qt Slot + * + */ + void OnModalPositionChanged(int); + signals: /** * @brief signal for starting the segmentation which is caught by a worker thread. */ void Operate(mitk::nnUNetTool *); protected: QmitknnUNetToolGUI(); ~QmitknnUNetToolGUI(); void ConnectNewTool(mitk::AutoSegmentationWithPreviewTool *newTool) override; void InitializeUI(QBoxLayout *mainLayout) override; void EnableWidgets(bool enabled) override; private: /** * @brief Creates a QMessage object and shows on screen. */ - void ShowErrorMessage(std::string &); + void ShowErrorMessage(std::string &, QMessageBox::Icon = QMessageBox::Critical); /** * @brief Searches and parses paths of python virtual enviroments * from predefined lookout locations */ void AutoParsePythonPaths(); /** * @brief Clears all combo boxes * Any new combo box added in the future can be featured here for clearance. * */ void ClearAllComboBoxes(); /** * @brief Checks if nnUNet_predict command is valid in the selected python virtual environment. * * @return bool */ bool IsNNUNetInstalled(const QString &); /** * @brief Mapper function to map QString entries from UI to ModelParam attributes. * * @return mitk::ModelParams */ mitk::ModelParams MapToRequest( const QString &, const QString &, const QString &, const QString &, const std::vector &); /** * @brief Returns checked fold names from the ctk-Checkable-ComboBox. * * @return std::vector */ std::vector FetchSelectedFoldsFromUI(); /** * @brief Returns all paths from the dynamically generated ctk-path-line-edit boxes. * * @return std::vector */ std::vector FetchMultiModalPathsFromUI(); /** * @brief Template function to fetch all folders inside a given path. * The type can be any of stl or Qt containers which supports push_back call. * * @tparam T * @return T */ template static T FetchFoldersFromDir(const QString &); /** * @brief Stores path of the model director (RESULTS_FOLDER appended by "nnUNet"). * */ QString m_ModelDirectory; Ui_QmitknnUNetToolGUIControls m_Controls; QThread *m_SegmentationThread; nnUNetSegmentationWorker *m_Worker; QmitkGPULoader m_GpuLoader; /** * @brief Stores all dynamically added ctk-path-line-edit UI elements. * */ - std::vector m_ModalPaths; + std::vector m_Modalities; - std::vector m_ModalOrder; + mitk::NodePredicateBase::Pointer m_MultiModalPredicate; /** * @brief Stores row count of the "advancedSettingsLayout" layout element. This value helps dynamically add * ctk-path-line-edit UI elements at the right place. Forced to initialize in the InitializeUI method since there is * no guarantee of retrieving exact row count anywhere else. * */ int m_UI_ROWS; }; #endif diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUIControls.ui b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUIControls.ui index d195dafb36..52ef7bdb8c 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUIControls.ui +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUIControls.ui @@ -1,448 +1,430 @@ QmitknnUNetToolGUIControls 0 0 192 352 0 0 100 0 100000 100000 QmitknnUNetToolWidget 0 0 0 0 0 0 nnUNet Results Folder: Configuration: 0 0 Task: 0 0 Trainer Plan: 0 0 Plan: Fold: - - - - - 0 - 0 - - - - Multi-Modal: - - - - - - - 0 0 No. of Extra Modalities: + + + + + 0 + 0 + + + + Multi-Modal: + + + + + + + - + + 0 0 5 Advanced true true Qt::AlignRight 6 0 0 0 No Pip 0 0 Mixed Precision true 0 0 GPU Id: 0 0 Enable Mirroring true 0 0 Python Path: /usr/bin - - - - - - 0 - 0 - - - - Available Tasks: - - - - - - - - + 0 0 nnUNet Path: 0 0 100000 16777215 Preview 0 0 ctkDirectoryButton QWidget
ctkDirectoryButton.h
1
ctkComboBox QComboBox
ctkComboBox.h
1
ctkCheckableComboBox QComboBox
ctkCheckableComboBox.h
1
ctkCheckBox QCheckBox
ctkCheckBox.h
1
ctkCollapsibleGroupBox QGroupBox
ctkCollapsibleGroupBox.h
1
diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolSlots.cpp b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolSlots.cpp index e69a15b970..5a5536dc3b 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolSlots.cpp +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolSlots.cpp @@ -1,286 +1,324 @@ #include "QmitknnUNetToolGUI.h" #include #include #include #include void QmitknnUNetToolGUI::EnableWidgets(bool enabled) { Superclass::EnableWidgets(enabled); m_Controls.previewButton->setEnabled(false); } void QmitknnUNetToolGUI::ClearAllComboBoxes() { m_Controls.modelBox->clear(); m_Controls.taskBox->clear(); m_Controls.foldBox->clear(); m_Controls.trainerBox->clear(); } template T QmitknnUNetToolGUI::FetchFoldersFromDir(const QString &path) { T folders; for (QDirIterator it(path, QDir::AllDirs, QDirIterator::NoIteratorFlags); it.hasNext();) { it.next(); if (!it.fileName().startsWith('.')) { folders.push_back(it.fileName()); } } return folders; } void QmitknnUNetToolGUI::OnDirectoryChanged(const QString &resultsFolder) { m_Controls.previewButton->setEnabled(false); this->ClearAllComboBoxes(); m_ModelDirectory = resultsFolder + QDir::separator() + "nnUNet"; auto models = FetchFoldersFromDir(m_ModelDirectory); QStringList validlist; // valid list of models supported by nnUNet validlist << "2d" << "3d_lowres" << "3d_fullres" << "3d_cascade_fullres" << "ensembles"; std::for_each(models.begin(), models.end(), [this, validlist](QString model) { if (validlist.contains(model, Qt::CaseInsensitive)) m_Controls.modelBox->addItem(model); }); } void QmitknnUNetToolGUI::OnModelChanged(const QString &text) { QString updatedPath(QDir::cleanPath(m_ModelDirectory + QDir::separator() + text)); m_Controls.taskBox->clear(); auto datasets = FetchFoldersFromDir(updatedPath); std::for_each(datasets.begin(), datasets.end(), [this](QString dataset) { m_Controls.taskBox->addItem(dataset); }); } void QmitknnUNetToolGUI::OnTaskChanged(const QString &text) { QString updatedPath = QDir::cleanPath(m_ModelDirectory + QDir::separator() + m_Controls.modelBox->currentText() + QDir::separator() + text); m_Controls.trainerBox->clear(); auto trainerPlanners = FetchFoldersFromDir(updatedPath); QStringList trainers, planners; foreach (QString trainerPlanner, trainerPlanners) { trainers << trainerPlanner.split("__", QString::SplitBehavior::SkipEmptyParts).first(); planners << trainerPlanner.split("__", QString::SplitBehavior::SkipEmptyParts).last(); } trainers.removeDuplicates(); planners.removeDuplicates(); std::for_each(trainers.begin(), trainers.end(), [this](QString trainer) { m_Controls.trainerBox->addItem(trainer); }); std::for_each(planners.begin(), planners.end(), [this](QString planner) { m_Controls.plannerBox->addItem(planner); }); } void QmitknnUNetToolGUI::OnTrainerChanged(const QString &trainerSelected) { m_Controls.foldBox->clear(); if (m_Controls.modelBox->currentText() != "ensembles") { QString updatedPath(QDir::cleanPath(m_ModelDirectory + QDir::separator() + m_Controls.modelBox->currentText() + QDir::separator() + m_Controls.taskBox->currentText() + QDir::separator() + m_Controls.trainerBox->currentText() + "__" + trainerSelected)); auto folds = FetchFoldersFromDir(updatedPath); std::for_each(folds.begin(), folds.end(), [this](QString fold) { if (fold.startsWith("fold_", Qt::CaseInsensitive)) // imposed by nnUNet m_Controls.foldBox->addItem(fold); }); if (m_Controls.foldBox->count() != 0) { m_Controls.previewButton->setEnabled(true); } } else { m_Controls.previewButton->setEnabled(true); } } void QmitknnUNetToolGUI::OnPythonPathChanged(const QString &pyEnv) { if (pyEnv == QString("Select")) { QString path = QFileDialog::getExistingDirectory(m_Controls.pythonEnvComboBox->parentWidget(), "Python Path", "dir"); if (!path.isEmpty()) { m_Controls.pythonEnvComboBox->insertItem(0, path); m_Controls.pythonEnvComboBox->setCurrentIndex(0); } } else if (!IsNNUNetInstalled(pyEnv)) { std::string warning = "WARNING: nnUNet is not detected on the Python environment you selected. Please select another " "environment or create one. For more info refer https://github.com/MIC-DKFZ/nnUNet"; ShowErrorMessage(warning); } } void QmitknnUNetToolGUI::OnCheckBoxChanged(int state) { bool visibility = false; if (state == Qt::Checked) { visibility = true; } ctkCheckBox *box = qobject_cast(sender()); if (box != nullptr) { if (box->objectName() == QString("nopipBox")) { m_Controls.codedirectoryBox->setVisible(visibility); m_Controls.nnUnetdirLabel->setVisible(visibility); } else if (box->objectName() == QString("multiModalBox")) { m_Controls.multiModalSpinLabel->setVisible(visibility); m_Controls.multiModalSpinBox->setVisible(visibility); + m_Controls.posSpinBoxLabel->setVisible(visibility); + m_Controls.posSpinBox->setVisible(visibility); if (!visibility) { OnModalitiesNumberChanged(0); m_Controls.multiModalSpinBox->setValue(0); + m_Controls.posSpinBox->setMaximum(0); } else { - ctkPathLineEdit *multiModalPath = new ctkPathLineEdit(this); - QSpinBox *multiModalOrderBox = new QSpinBox(this); - multiModalPath->setObjectName(QString("multiModalPath" + QString::number(0))); - multiModalPath->setCurrentPath("default_loaded_image"); - multiModalPath->setDisabled(true); - m_Controls.advancedSettingsLayout->addWidget( - multiModalPath, this->m_UI_ROWS + m_ModalPaths.size() + 1, 1, 1, 3); - m_Controls.advancedSettingsLayout->addWidget(multiModalOrderBox, this->m_UI_ROWS + m_ModalPaths.size() + 1, 0); - m_ModalPaths.push_back(multiModalPath); - m_ModalOrder.push_back(multiModalOrderBox); - m_UI_ROWS += 1; + QmitkDataStorageComboBox *defaultImage = new QmitkDataStorageComboBox(this, true); + defaultImage->setObjectName(QString("multiModal_" + QString::number(0))); + mitk::nnUNetTool::Pointer tool = this->GetConnectedToolAs(); + defaultImage->SetDataStorage(tool->GetDataStorage()); + defaultImage->SetSelectedNode(tool->GetDataStorage()->GetNode()); + defaultImage->setDisabled(true); + m_Controls.advancedSettingsLayout->addWidget(defaultImage, this->m_UI_ROWS + m_Modalities.size() + 1, 1, 1, 3); + m_Modalities.push_back(defaultImage); + m_Controls.posSpinBox->setMaximum(this->m_Modalities.size() - 1); + m_UI_ROWS++; } } } } void QmitknnUNetToolGUI::OnModalitiesNumberChanged(int num) { - while (num > static_cast(this->m_ModalPaths.size())) + while (num > static_cast(this->m_Modalities.size()-1)) { - ctkPathLineEdit *multiModalPath = new ctkPathLineEdit(this); - QSpinBox *multiModalOrderBox = new QSpinBox(this); - multiModalPath->setObjectName(QString("multiModalPath" + QString::number(m_ModalPaths.size() + 1))); - m_Controls.advancedSettingsLayout->addWidget(multiModalPath, this->m_UI_ROWS + m_ModalPaths.size() + 1, 1, 1, 3); - m_Controls.advancedSettingsLayout->addWidget(multiModalOrderBox, this->m_UI_ROWS + m_ModalPaths.size() + 1, 0); - m_ModalPaths.push_back(multiModalPath); - m_ModalOrder.push_back(multiModalOrderBox); + QmitkDataStorageComboBox *multiModalBox = new QmitkDataStorageComboBox(this, true); + mitk::nnUNetTool::Pointer tool = this->GetConnectedToolAs(); + multiModalBox->SetDataStorage(tool->GetDataStorage()); + multiModalBox->SetPredicate(this->m_MultiModalPredicate); + multiModalBox->setObjectName(QString("multiModal_" + QString::number(m_Modalities.size() + 1))); + m_Controls.advancedSettingsLayout->addWidget(multiModalBox, this->m_UI_ROWS + m_Modalities.size() + 1, 1, 1, 3); + m_Modalities.push_back(multiModalBox); } - while (num < static_cast(this->m_ModalPaths.size() - 1) && !m_ModalPaths.empty()) + while (num < static_cast(this->m_Modalities.size()-1) && !m_Modalities.empty()) { - ctkPathLineEdit *child = m_ModalPaths.back(); + QmitkDataStorageComboBox *child = m_Modalities.back(); + if (child->objectName() == "multiModal_0") + { + std::iter_swap(this->m_Modalities.end() - 2, this->m_Modalities.end()-1); + child = m_Modalities.back(); + } delete child; // delete the layout item - m_ModalPaths.pop_back(); - auto *_child = m_ModalOrder.back(); - delete _child; // delete the layout item - m_ModalOrder.pop_back(); + m_Modalities.pop_back(); } + m_Controls.posSpinBox->setMaximum(this->m_Modalities.size()-1); m_Controls.advancedSettingsLayout->update(); } + +void QmitknnUNetToolGUI::OnModalPositionChanged(int posIdx) +{ + if (posIdx < static_cast(this->m_Modalities.size())) + { + int currPos = 0; + bool stopCheck = false; + // for-loop clears all widgets from the QGridLayout and also, finds the position of loaded-image widget. + for (QmitkDataStorageComboBox *multiModalBox : this->m_Modalities) + { + m_Controls.advancedSettingsLayout->removeWidget(multiModalBox); + multiModalBox->setParent(nullptr); + if (multiModalBox->objectName() != "multiModal_0" && !stopCheck) + { + currPos++; + } + else + { + stopCheck = true; + } + } + // moving the loaded-image widget to the required position + std::iter_swap(this->m_Modalities.begin() + currPos, this->m_Modalities.begin() + posIdx); + // re-adding all widgets in the order + for (int i = 0; i < static_cast(this->m_Modalities.size()); ++i) + { + QmitkDataStorageComboBox *multiModalBox = this->m_Modalities[i]; + m_Controls.advancedSettingsLayout->addWidget(multiModalBox, this->m_UI_ROWS + i + 1, 1, 1, 3); + } + m_Controls.advancedSettingsLayout->update(); + } +} + void QmitknnUNetToolGUI::AutoParsePythonPaths() { QString homeDir = QDir::homePath(); std::vector searchDirs; #ifdef _WIN32 searchDirs.push_back(QString("C:") + QDir::separator() + QString("ProgramData") + QDir::separator() + QString("anaconda3")); #else // Add search locations for possible standard python paths here searchDirs.push_back(homeDir + QDir::separator() + "environments"); searchDirs.push_back(homeDir + QDir::separator() + "anaconda3"); searchDirs.push_back(homeDir + QDir::separator() + "miniconda3"); searchDirs.push_back(homeDir + QDir::separator() + "opt" + QDir::separator() + "miniconda3"); searchDirs.push_back(homeDir + QDir::separator() + "opt" + QDir::separator() + "anaconda3"); #endif for (QString searchDir : searchDirs) { if (searchDir.endsWith("anaconda3", Qt::CaseInsensitive)) { if (QDir(searchDir).exists()) { m_Controls.pythonEnvComboBox->insertItem(0, "(base): " + searchDir); searchDir.append((QDir::separator() + QString("envs"))); } } for (QDirIterator subIt(searchDir, QDir::AllDirs, QDirIterator::NoIteratorFlags); subIt.hasNext();) { subIt.next(); QString envName = subIt.fileName(); if (!envName.startsWith('.')) // Filter out irrelevent hidden folders, if any. { m_Controls.pythonEnvComboBox->insertItem(0, "(" + envName + "): " + subIt.filePath()); } } } m_Controls.pythonEnvComboBox->setCurrentIndex(-1); } std::vector QmitknnUNetToolGUI::FetchSelectedFoldsFromUI() { std::vector folds; if (!(m_Controls.foldBox->allChecked() || m_Controls.foldBox->noneChecked())) { QModelIndexList foldList = m_Controls.foldBox->checkedIndexes(); foreach (QModelIndex index, foldList) { QString foldQString = m_Controls.foldBox->itemText(index.row()).split("_", QString::SplitBehavior::SkipEmptyParts).last(); folds.push_back(foldQString.toStdString()); } } return folds; } mitk::ModelParams QmitknnUNetToolGUI::MapToRequest(const QString &modelName, const QString &taskName, const QString &trainer, const QString &planId, const std::vector &folds) { mitk::ModelParams requestObject; requestObject.model = modelName.toStdString(); requestObject.trainer = trainer.toStdString(); requestObject.planId = planId.toStdString(); requestObject.task = taskName.toStdString(); requestObject.folds = folds; return requestObject; } void QmitknnUNetToolGUI::SegmentationProcessFailed() { m_Controls.statusLabel->setText( "STATUS: Error in the segmentation process. No resulting segmentation can be loaded."); this->setCursor(Qt::ArrowCursor); std::stringstream stream; stream << "Error in the segmentation process. No resulting segmentation can be loaded."; QMessageBox *messageBox = new QMessageBox(QMessageBox::Critical, nullptr, stream.str().c_str()); messageBox->exec(); delete messageBox; MITK_ERROR << stream.str(); } void QmitknnUNetToolGUI::SegmentationResultHandler(mitk::nnUNetTool *tool) { MITK_INFO << "Finished slot"; tool->RenderOutputBuffer(); this->SetLabelSetPreview(tool->GetMLPreview()); tool->IsTimePointChangeAwareOn(); m_Controls.statusLabel->setText("STATUS: Segmentation task finished successfully."); }