diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp index ef984f966c..69d8878fcc 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.cpp @@ -1,272 +1,418 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "QmitknnUNetToolGUI.h" #include "mitknnUnetTool.h" #include #include MITK_TOOL_GUI_MACRO(MITKSEGMENTATIONUI_EXPORT, QmitknnUNetToolGUI, "") QmitknnUNetToolGUI::QmitknnUNetToolGUI() : QmitkAutoMLSegmentationToolGUIBase() { // Nvidia-smi command returning zero doesn't alway mean lack of GPUs. // Pytorch uses its own libraries to communicate to the GPUs. Hence, only a warning can be given. if (m_GpuLoader.GetGPUCount() == 0) { std::string warning = "WARNING: No GPUs were detected on your machine. The nnUNet tool might not work."; ShowErrorMessage(warning); } m_SegmentationThread = new QThread(this); m_Worker = new nnUNetSegmentationWorker; m_Worker->moveToThread(m_SegmentationThread); // define predicates for multi modal data selection combobox auto imageType = mitk::TNodePredicateDataType::New(); auto labelSetImageType = mitk::NodePredicateNot::New(mitk::TNodePredicateDataType::New()); this->m_MultiModalPredicate = mitk::NodePredicateAnd::New(imageType, labelSetImageType).GetPointer(); } QmitknnUNetToolGUI::~QmitknnUNetToolGUI() { this->m_SegmentationThread->quit(); this->m_SegmentationThread->wait(); } void QmitknnUNetToolGUI::ConnectNewTool(mitk::AutoSegmentationWithPreviewTool *newTool) { Superclass::ConnectNewTool(newTool); newTool->IsTimePointChangeAwareOff(); } void QmitknnUNetToolGUI::InitializeUI(QBoxLayout *mainLayout) { m_Controls.setupUi(this); #ifndef _WIN32 m_Controls.pythonEnvComboBox->addItem("/usr/bin"); #endif m_Controls.pythonEnvComboBox->addItem("Select"); AutoParsePythonPaths(); - connect(m_Controls.previewButton, SIGNAL(clicked()), this, SLOT(OnSettingsAccepted())); + connect(m_Controls.previewButton, SIGNAL(clicked()), this, SLOT(OnPreviewRequested())); connect(m_Controls.modeldirectoryBox, SIGNAL(directoryChanged(const QString &)), this, SLOT(OnDirectoryChanged(const QString &))); connect( m_Controls.modelBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnModelChanged(const QString &))); connect(m_Controls.taskBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnTaskChanged(const QString &))); connect( m_Controls.plannerBox, SIGNAL(currentTextChanged(const QString &)), this, SLOT(OnTrainerChanged(const QString &))); connect(m_Controls.nopipBox, SIGNAL(stateChanged(int)), this, SLOT(OnCheckBoxChanged(int))); connect(m_Controls.multiModalBox, SIGNAL(stateChanged(int)), this, SLOT(OnCheckBoxChanged(int))); connect(m_Controls.multiModalSpinBox, SIGNAL(valueChanged(int)), this, SLOT(OnModalitiesNumberChanged(int))); connect(m_Controls.posSpinBox, SIGNAL(valueChanged(int)), this, SLOT(OnModalPositionChanged(int))); connect(m_Controls.pythonEnvComboBox, #if QT_VERSION >= 0x050F00 // 5.15 SIGNAL(textActivated(const QString &)), #elif QT_VERSION >= 0x050C00 // 5.12 SIGNAL(activated(const QString &)), #endif this, SLOT(OnPythonPathChanged(const QString &))); connect(this, &QmitknnUNetToolGUI::Operate, m_Worker, &nnUNetSegmentationWorker::DoWork); connect(m_Worker, &nnUNetSegmentationWorker::Finished, this, &QmitknnUNetToolGUI::SegmentationResultHandler); connect(m_Worker, &nnUNetSegmentationWorker::Failed, this, &QmitknnUNetToolGUI::SegmentationProcessFailed); connect(m_SegmentationThread, &QThread::finished, m_Worker, &QObject::deleteLater); m_Controls.codedirectoryBox->setVisible(false); m_Controls.nnUnetdirLabel->setVisible(false); m_Controls.multiModalSpinBox->setVisible(false); m_Controls.multiModalSpinLabel->setVisible(false); m_Controls.posSpinBoxLabel->setVisible(false); m_Controls.posSpinBox->setVisible(false); m_Controls.statusLabel->setTextFormat(Qt::RichText); m_Controls.statusLabel->setText("STATUS: Welcome to nnUNet. " + QString::number(m_GpuLoader.GetGPUCount()) + " GPUs were detected."); if (m_GpuLoader.GetGPUCount() != 0) { m_Controls.gpuSpinBox->setMaximum(m_GpuLoader.GetGPUCount() - 1); } mainLayout->addLayout(m_Controls.verticalLayout); Superclass::InitializeUI(mainLayout); m_UI_ROWS = m_Controls.advancedSettingsLayout->rowCount(); // Must do. Row count is correct only here. } + +void QmitknnUNetToolGUI::OnPreviewRequested() +{ + mitk::nnUNetTool::Pointer tool = this->GetConnectedToolAs(); + if (nullptr != tool) + { + try + { + QString modelName = m_Controls.modelBox->currentText(); + if (modelName.startsWith("ensemble", Qt::CaseInsensitive)) + { + ProcessEnsembleModelsParams(tool); + } + else + { + ProcessModelParams(tool); + } + QString pythonPathTextItem = m_Controls.pythonEnvComboBox->currentText(); + QString pythonPath = pythonPathTextItem.mid(pythonPathTextItem.indexOf(" ") + 1); + bool isNoPip = m_Controls.nopipBox->isChecked(); +#ifdef _WIN32 + if (!isNoPip && !(pythonPath.endsWith("Scripts", Qt::CaseInsensitive) || + pythonPath.endsWith("Scripts/", Qt::CaseInsensitive))) + { + pythonPath += QDir::separator() + QString("Scripts"); + } +#else + if (!(pythonPath.endsWith("bin", Qt::CaseInsensitive) || pythonPath.endsWith("bin/", Qt::CaseInsensitive))) + { + pythonPath += QDir::separator() + QString("bin"); + } +#endif + std::string nnUNetDirectory; + if (isNoPip) + { + nnUNetDirectory = m_Controls.codedirectoryBox->directory().toStdString(); + } + else if (!IsNNUNetInstalled(pythonPath)) + { + // throw std::runtime_error("nnUNet is not detected in the selected python environment. Please select a valid " + // "python environment or install nnUNet."); + } + + tool->SetnnUNetDirectory(nnUNetDirectory); + tool->SetPythonPath(pythonPath.toStdString()); + tool->SetModelDirectory(m_ModelDirectory.left(m_ModelDirectory.lastIndexOf(QDir::separator())).toStdString()); + // checkboxes + tool->SetMirror(m_Controls.mirrorBox->isChecked()); + tool->SetMixedPrecision(m_Controls.mixedPrecisionBox->isChecked()); + tool->SetNoPip(isNoPip); + tool->SetMultiModal(m_Controls.multiModalBox->isChecked()); + // Spinboxes + tool->SetGpuId(static_cast(m_Controls.gpuSpinBox->value())); + // Multi-Modal + tool->MultiModalOff(); + if (m_Controls.multiModalBox->isChecked()) + { + tool->m_OtherModalPaths.clear(); + tool->m_OtherModalPaths = FetchMultiModalImagesFromUI(); + tool->MultiModalOn(); + } + if (!m_SegmentationThread->isRunning()) + { + MITK_DEBUG << "Starting thread..."; + m_SegmentationThread->start(); + } + m_Controls.statusLabel->setText("STATUS: Starting Segmentation task... This might take a while."); + emit Operate(tool); + } + catch (const std::exception &e) + { + std::stringstream errorMsg; + errorMsg << "Error while processing parameters for nnUNet segmentation. Reason: " << e.what(); + ShowErrorMessage(errorMsg.str()); + return; + } + catch (...) + { + std::string errorMsg = "Unkown error occured while generation nnUNet segmentation."; + ShowErrorMessage(errorMsg); + return; + } + } +} + + +/* void QmitknnUNetToolGUI::OnSettingsAccepted() { mitk::nnUNetTool::Pointer tool = this->GetConnectedToolAs(); if (nullptr != tool) { try { QString modelName = m_Controls.modelBox->currentText(); QString taskName = m_Controls.taskBox->currentText(); bool isNoPip = m_Controls.nopipBox->isChecked(); QString pythonPathTextItem = m_Controls.pythonEnvComboBox->currentText(); QString pythonPath = pythonPathTextItem.mid(pythonPathTextItem.indexOf(" ") + 1); #ifdef _WIN32 if (!isNoPip && !(pythonPath.endsWith("Scripts", Qt::CaseInsensitive) || pythonPath.endsWith("Scripts/", Qt::CaseInsensitive))) { pythonPath += QDir::separator() + QString("Scripts"); } #else if (!(pythonPath.endsWith("bin", Qt::CaseInsensitive) || pythonPath.endsWith("bin/", Qt::CaseInsensitive))) { pythonPath += QDir::separator() + QString("bin"); } #endif std::string nnUNetDirectory; if (isNoPip) { nnUNetDirectory = m_Controls.codedirectoryBox->directory().toStdString(); } else if (!IsNNUNetInstalled(pythonPath)) { //throw std::runtime_error("nnUNet is not detected in the selected python environment. Please select a valid " // "python environment or install nnUNet."); } QString trainerPlanner = m_Controls.trainerBox->currentText(); QString splitterString = "__"; tool->EnsembleOff(); if (modelName.startsWith("ensemble", Qt::CaseInsensitive)) { QString ppJsonFile = QDir::cleanPath(m_ModelDirectory + QDir::separator() + modelName + QDir::separator() + taskName + QDir::separator() + trainerPlanner + QDir::separator() + "postprocessing.json"); if (QFile(ppJsonFile).exists()) { tool->EnsembleOn(); tool->SetPostProcessingJsonDirectory(ppJsonFile.toStdString()); splitterString = "--"; } } QStringList trainerSplitParts = trainerPlanner.split(splitterString, QString::SplitBehavior::SkipEmptyParts); std::vector requestQ; if (tool->GetEnsemble()) { foreach (QString modelSet, trainerSplitParts) { modelSet.remove("ensemble_", Qt::CaseInsensitive); QStringList splitParts = modelSet.split("__", QString::SplitBehavior::SkipEmptyParts); QString modelName = splitParts.first(); QString trainer = splitParts.at(1); QString planId = splitParts.at(2); auto testfold = std::vector(1, "1"); mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, testfold); requestQ.push_back(modelObject); } } else { QString trainer = m_Controls.trainerBox->currentText();//trainerSplitParts.first(); QString planId = m_Controls.plannerBox->currentText(); //trainerSplitParts.last(); std::vector fetchedFolds = FetchSelectedFoldsFromUI(); mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, fetchedFolds); requestQ.push_back(modelObject); } tool->m_ParamQ.clear(); tool->m_ParamQ = requestQ; tool->SetnnUNetDirectory(nnUNetDirectory); tool->SetPythonPath(pythonPath.toStdString()); tool->SetModelDirectory(m_ModelDirectory.left(m_ModelDirectory.lastIndexOf(QDir::separator())).toStdString()); // checkboxes tool->SetMirror(m_Controls.mirrorBox->isChecked()); tool->SetMixedPrecision(m_Controls.mixedPrecisionBox->isChecked()); tool->SetNoPip(isNoPip); tool->SetMultiModal(m_Controls.multiModalBox->isChecked()); // Spinboxes tool->SetGpuId(static_cast(m_Controls.gpuSpinBox->value())); // Multi-Modal tool->MultiModalOff(); if (m_Controls.multiModalBox->isChecked()) { tool->m_OtherModalPaths.clear(); tool->m_OtherModalPaths = FetchMultiModalImagesFromUI(); tool->MultiModalOn(); } if (!m_SegmentationThread->isRunning()) { MITK_DEBUG << "Starting thread..."; m_SegmentationThread->start(); } m_Controls.statusLabel->setText("STATUS: Starting Segmentation task... This might take a while."); emit Operate(tool); } catch (const std::exception &e) { std::stringstream errorMsg; errorMsg << "Error while processing parameters for nnUNet segmentation. Reason: " << e.what(); ShowErrorMessage(errorMsg.str()); return; } catch (...) { std::string errorMsg = "Unkown error occured while generation nnUNet segmentation."; ShowErrorMessage(errorMsg); return; } } -} +}*/ std::vector QmitknnUNetToolGUI::FetchMultiModalImagesFromUI() { std::vector paths; if (m_Controls.multiModalBox->isChecked() && !m_Modalities.empty()) { for (QmitkDataStorageComboBox *modality : m_Modalities) { if (modality->objectName() != "multiModal_0") { paths.push_back(dynamic_cast(modality->GetSelectedNode()->GetData())); } } } return paths; } bool QmitknnUNetToolGUI::IsNNUNetInstalled(const QString &pythonPath) { QString fullPath = pythonPath; #ifdef _WIN32 if (!(fullPath.endsWith("Scripts", Qt::CaseInsensitive) || fullPath.endsWith("Scripts/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("Scripts"); } #else if (!(fullPath.endsWith("bin", Qt::CaseInsensitive) || fullPath.endsWith("bin/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("bin"); } #endif fullPath = fullPath.mid(fullPath.indexOf(" ") + 1); return QFile::exists(fullPath + QDir::separator() + QString("nnUNet_predict")); } void QmitknnUNetToolGUI::ShowErrorMessage(const std::string &message, QMessageBox::Icon icon) { this->setCursor(Qt::ArrowCursor); QMessageBox *messageBox = new QMessageBox(icon, nullptr, message.c_str()); messageBox->exec(); delete messageBox; MITK_WARN << message; } + +void QmitknnUNetToolGUI::ProcessEnsembleModelsParams(mitk::nnUNetTool::Pointer tool) +{ + QString taskName = m_Controls.taskBox->currentText(); + std::vector requestQ; + QStringList ppDirFolderName; + ppDirFolderName << "ensemble_"; + if (m_EnsembleParams[0]->modelBox->currentText() == m_EnsembleParams[1]->modelBox->currentText()) + { + for (std::unique_ptr &layout : m_EnsembleParams) + { + QString modelName = layout->modelBox->currentText(); + ppDirFolderName << modelName; + ppDirFolderName << "__"; + QString trainer = layout->trainerBox->currentText(); + ppDirFolderName << trainer; + ppDirFolderName << "__"; + QString planId = layout->plannerBox->currentText(); + ppDirFolderName << planId; + ppDirFolderName << "--"; + auto testfold = std::vector(1, "1"); + mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, testfold); + requestQ.push_back(modelObject); + } + tool->EnsembleOn(); + QString ppJsonFile = QDir::cleanPath(m_ModelDirectory + QDir::separator() + "ensembles" + QDir::separator() + + taskName + QDir::separator() + ppDirFolderName.join(QString("")) + + QDir::separator() + "postprocessing.json"); + if (QFile(ppJsonFile).exists()) + { + tool->SetPostProcessingJsonDirectory(ppJsonFile.toStdString()); + } + else + { + // warning message + } + tool->m_ParamQ.clear(); + tool->m_ParamQ = requestQ; + } + else + { + throw std::runtime_error("Both models you have selected for ensembling are the same."); + } +} + +void QmitknnUNetToolGUI::ProcessModelParams(mitk::nnUNetTool::Pointer tool) +{ + tool->EnsembleOff(); + std::vector requestQ; + QString modelName = m_Controls.modelBox->currentText(); + QString taskName = m_Controls.taskBox->currentText(); + QString trainer = m_Controls.trainerBox->currentText(); + QString planId = m_Controls.plannerBox->currentText(); + std::vector fetchedFolds = FetchSelectedFoldsFromUI(); + mitk::ModelParams modelObject = MapToRequest(modelName, taskName, trainer, planId, fetchedFolds); + requestQ.push_back(modelObject); + tool->m_ParamQ.clear(); + tool->m_ParamQ = requestQ; +} diff --git a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h index c98ccf5e6e..c9e01f4c30 100644 --- a/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h +++ b/Modules/SegmentationUI/Qmitk/QmitknnUNetToolGUI.h @@ -1,213 +1,218 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file.s ============================================================================*/ #ifndef QmitknnUNetToolGUI_h_Included #define QmitknnUNetToolGUI_h_Included #include "QmitkAutoMLSegmentationToolGUIBase.h" #include "QmitknnUNetGPU.h" #include "QmitknnUNetWorker.h" #include "mitknnUnetTool.h" #include "ui_QmitknnUNetToolGUIControls.h" #include #include #include #include #include #include #include class MITKSEGMENTATIONUI_EXPORT QmitknnUNetToolGUI : public QmitkAutoMLSegmentationToolGUIBase { Q_OBJECT public: mitkClassMacro(QmitknnUNetToolGUI, QmitkAutoMLSegmentationToolGUIBase); itkFactorylessNewMacro(Self); itkCloneMacro(Self); protected slots: + /** * @brief Qt slot * */ - void OnSettingsAccepted(); + void OnPreviewRequested(); /** * @brief Qt slot * */ void OnDirectoryChanged(const QString &); /** * @brief Qt slot * */ void OnModelChanged(const QString &); /** * @brief Qt slot * */ void OnTaskChanged(const QString &); /** * @brief Qt slot * */ void OnTrainerChanged(const QString &); /** * @brief Qt slot * */ void OnCheckBoxChanged(int); /** * @brief Qthread slot to capture failures from thread worker and * shows error message * */ void SegmentationProcessFailed(); /** * @brief Qthread to capture sucessfull nnUNet segmentation. * Further, renders the LabelSet image */ void SegmentationResultHandler(mitk::nnUNetTool *); /** * @brief Qt Slot * */ void OnModalitiesNumberChanged(int); /** * @brief Qt Slot * */ void OnPythonPathChanged(const QString &); /** * @brief Qt Slot * */ void OnModalPositionChanged(int); signals: /** * @brief signal for starting the segmentation which is caught by a worker thread. */ void Operate(mitk::nnUNetTool *); protected: QmitknnUNetToolGUI(); ~QmitknnUNetToolGUI(); void ConnectNewTool(mitk::AutoSegmentationWithPreviewTool *newTool) override; void InitializeUI(QBoxLayout *mainLayout) override; void EnableWidgets(bool enabled) override; private: + void ProcessEnsembleModelsParams(mitk::nnUNetTool::Pointer); + + void ProcessModelParams(mitk::nnUNetTool::Pointer); + /** * @brief Creates and renders QmitknnUNetTaskParamsUITemplate layout for ensemble input. */ void ShowEnsembleLayout(bool visible = true); /** * @brief Creates a QMessage object and shows on screen. */ void ShowErrorMessage(const std::string &, QMessageBox::Icon = QMessageBox::Critical); /** * @brief Searches and parses paths of python virtual enviroments * from predefined lookout locations */ void AutoParsePythonPaths(); /** * @brief Clears all combo boxes * Any new combo box added in the future can be featured here for clearance. * */ void ClearAllComboBoxes(); /** * @brief Checks if nnUNet_predict command is valid in the selected python virtual environment. * * @return bool */ bool IsNNUNetInstalled(const QString &); /** * @brief Mapper function to map QString entries from UI to ModelParam attributes. * * @return mitk::ModelParams */ mitk::ModelParams MapToRequest( const QString &, const QString &, const QString &, const QString &, const std::vector &); /** * @brief Returns checked fold names from the ctk-Checkable-ComboBox. * * @return std::vector */ std::vector FetchSelectedFoldsFromUI(); /** * @brief Returns all paths from the dynamically generated ctk-path-line-edit boxes. * * @return std::vector */ std::vector FetchMultiModalImagesFromUI(); /** * @brief Template function to fetch all folders inside a given path. * The type can be any of stl or Qt containers which supports push_back call. * * @tparam T * @return T */ template static T FetchFoldersFromDir(const QString &); /** * @brief Stores path of the model director (RESULTS_FOLDER appended by "nnUNet"). * */ QString m_ModelDirectory; Ui_QmitknnUNetToolGUIControls m_Controls; QThread *m_SegmentationThread; nnUNetSegmentationWorker *m_Worker; QmitkGPULoader m_GpuLoader; /** * @brief Stores all dynamically added ctk-path-line-edit UI elements. * */ std::vector m_Modalities; // TODO: fix memory leak std::vector> m_EnsembleParams; mitk::NodePredicateBase::Pointer m_MultiModalPredicate; /** * @brief Stores row count of the "advancedSettingsLayout" layout element. This value helps dynamically add * ctk-path-line-edit UI elements at the right place. Forced to initialize in the InitializeUI method since there is * no guarantee of retrieving exact row count anywhere else. * */ int m_UI_ROWS; }; #endif