diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp index 996d29a1c0..559440366b 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp @@ -1,365 +1,366 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "mitkSegmentAnythingTool.h" #include "mitkInteractionPositionEvent.h" #include "mitkPointSetShapeProperty.h" #include "mitkProperties.h" #include "mitkToolManager.h" // us #include "mitkSegTool2D.h" #include #include #include #include #include #include using namespace std::chrono_literals; namespace mitk { MITK_TOOL_MACRO(MITKSEGMENTATION_EXPORT, SegmentAnythingTool, "SegmentAnythingTool"); } mitk::SegmentAnythingTool::SegmentAnythingTool() : SegWithPreviewTool(true, "PressMoveReleaseAndPointSetting") { this->ResetsToEmptyPreviewOn(); this->IsTimePointChangeAwareOff(); this->KeepActiveAfterAcceptOn(); } const char **mitk::SegmentAnythingTool::GetXPM() const { return nullptr; } const char *mitk::SegmentAnythingTool::GetName() const { return "Segment Anything"; } us::ModuleResource mitk::SegmentAnythingTool::GetIconResource() const { us::Module *module = us::GetModuleContext()->GetModule(); us::ModuleResource resource = module->GetResource("AI.svg"); return resource; } void mitk::SegmentAnythingTool::Activated() { Superclass::Activated(); m_PointSetPositive = mitk::PointSet::New(); m_PointSetNodePositive = mitk::DataNode::New(); m_PointSetNodePositive->SetData(m_PointSetPositive); m_PointSetNodePositive->SetName(std::string(this->GetName()) + "_PointSetPositive"); m_PointSetNodePositive->SetBoolProperty("helper object", true); m_PointSetNodePositive->SetColor(0.0, 1.0, 0.0); m_PointSetNodePositive->SetVisibility(true); m_PointSetNodePositive->SetProperty("Pointset.2D.shape", mitk::PointSetShapeProperty::New(mitk::PointSetShapeProperty::CIRCLE)); m_PointSetNodePositive->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_PointSetNodePositive, this->GetToolManager()->GetWorkingData(0)); m_PointSetNegative = mitk::PointSet::New(); m_PointSetNodeNegative = mitk::DataNode::New(); m_PointSetNodeNegative->SetData(m_PointSetNegative); m_PointSetNodeNegative->SetName(std::string(this->GetName()) + "_PointSetNegative"); m_PointSetNodeNegative->SetBoolProperty("helper object", true); m_PointSetNodeNegative->SetColor(1.0, 0.0, 0.0); m_PointSetNodeNegative->SetVisibility(true); m_PointSetNodeNegative->SetProperty("Pointset.2D.shape", mitk::PointSetShapeProperty::New(mitk::PointSetShapeProperty::CIRCLE)); m_PointSetNodeNegative->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_PointSetNodeNegative, this->GetToolManager()->GetWorkingData(0)); this->SetLabelTransferScope(LabelTransferScope::ActiveLabel); this->SetLabelTransferMode(LabelTransferMode::MapLabel); } void mitk::SegmentAnythingTool::Deactivated() { this->ClearSeeds(); GetDataStorage()->Remove(m_PointSetNodePositive); GetDataStorage()->Remove(m_PointSetNodeNegative); m_PointSetNodePositive = nullptr; m_PointSetNodeNegative = nullptr; m_PointSetPositive = nullptr; m_PointSetNegative = nullptr; m_PythonService.reset(); Superclass::Deactivated(); } void mitk::SegmentAnythingTool::ConnectActionsAndFunctions() { CONNECT_FUNCTION("ShiftSecondaryButtonPressed", OnAddNegativePoint); CONNECT_FUNCTION("ShiftPrimaryButtonPressed", OnAddPoint); CONNECT_FUNCTION("DeletePoint", OnDelete); } void mitk::SegmentAnythingTool::InitSAMPythonProcess() { if (nullptr != m_PythonService) { m_PythonService.reset(); } m_PythonService = std::make_unique( this->GetPythonPath(), this->GetModelType(), this->GetCheckpointPath(), this->GetGpuId()); m_PythonService->StartAsyncProcess(); } bool mitk::SegmentAnythingTool::IsPythonReady() const { return m_PythonService->CheckStatus(); } void mitk::SegmentAnythingTool::OnAddNegativePoint(StateMachineAction *, InteractionEvent *interactionEvent) { if (!this->GetIsReady() || m_PointSetPositive->GetSize() == 0) { return; } if (!this->IsUpdating() && m_PointSetNegative.IsNotNull()) { const auto positionEvent = dynamic_cast(interactionEvent); if (positionEvent != nullptr) { m_PointSetNegative->InsertPoint(m_PointSetCount, positionEvent->GetPositionInWorld()); m_PointSetCount++; this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::OnAddPoint(StateMachineAction *, InteractionEvent *interactionEvent) { if (!this->GetIsReady()) { return; } m_IsGenerateEmbeddings = false; if ((nullptr == this->GetWorkingPlaneGeometry()) || !mitk::Equal(*(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()), *(this->GetWorkingPlaneGeometry()))) { m_IsGenerateEmbeddings = true; this->ClearSeeds(); this->SetWorkingPlaneGeometry(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()->Clone()); } if (!this->IsUpdating() && m_PointSetPositive.IsNotNull()) { const auto positionEvent = dynamic_cast(interactionEvent); if (positionEvent != nullptr) { m_PointSetPositive->InsertPoint(m_PointSetCount, positionEvent->GetPositionInWorld()); m_PointSetCount++; this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::OnDelete(StateMachineAction *, InteractionEvent *) { if (!this->IsUpdating() && m_PointSetPositive.IsNotNull()) { PointSet::Pointer removeSet = m_PointSetPositive; decltype(m_PointSetPositive->GetMaxId().Index()) maxId = 0; if (m_PointSetPositive->GetSize() > 0) { maxId = m_PointSetPositive->GetMaxId().Index(); } if (m_PointSetNegative->GetSize() > 0 && (maxId < m_PointSetNegative->GetMaxId().Index())) { removeSet = m_PointSetNegative; } removeSet->RemovePointAtEnd(0); --m_PointSetCount; this->UpdatePreview(); } } void mitk::SegmentAnythingTool::ClearPicks() { this->ClearSeeds(); this->UpdatePreview(); } bool mitk::SegmentAnythingTool::HasPicks() const { return this->m_PointSetPositive.IsNotNull() && this->m_PointSetPositive->GetSize() > 0; } void mitk::SegmentAnythingTool::ClearSeeds() { if (this->m_PointSetPositive.IsNotNull()) { m_PointSetCount -= m_PointSetPositive->GetSize(); this->m_PointSetPositive = mitk::PointSet::New(); // renew pointset this->m_PointSetNodePositive->SetData(this->m_PointSetPositive); } if (this->m_PointSetNegative.IsNotNull()) { m_PointSetCount -= m_PointSetNegative->GetSize(); this->m_PointSetNegative = mitk::PointSet::New(); // renew pointset this->m_PointSetNodeNegative->SetData(this->m_PointSetNegative); } } void mitk::SegmentAnythingTool::ConfirmCleanUp() { auto previewImage = this->GetPreviewSegmentation(); for (unsigned int timeStep = 0; timeStep < previewImage->GetTimeSteps(); ++timeStep) { this->ResetPreviewContentAtTimeStep(timeStep); } this->ClearSeeds(); RenderingManager::GetInstance()->RequestUpdateAll(); } void mitk::SegmentAnythingTool::DoUpdatePreview(const Image *inputAtTimeStep, const Image * /*oldSegAtTimeStep*/, LabelSetImage *previewImage, TimeStepType timeStep) { if (nullptr != previewImage && m_PointSetPositive.IsNotNull()) { if (this->HasPicks() && nullptr != m_PythonService) { std::string uniquePlaneID = GetHashForCurrentPlane(); try { this->EmitSAMStatusMessageEvent("Prompting Segment Anything Model..."); m_PythonService->TransferImageToProcess(inputAtTimeStep, uniquePlaneID); auto csvStream = this->GetPointsAsCSVString(inputAtTimeStep->GetGeometry()); m_ProgressCommand->SetProgress(100); m_PythonService->TransferPointsToProcess(csvStream); m_ProgressCommand->SetProgress(150); std::this_thread::sleep_for(100ms); mitk::LabelSetImage::Pointer outputBuffer = m_PythonService->RetrieveImageFromProcess(); m_ProgressCommand->SetProgress(180); mitk::SegTool2D::WriteSliceToVolume(previewImage, this->GetWorkingPlaneGeometry(), outputBuffer.GetPointer(), timeStep, false); this->SetSelectedLabels({MASK_VALUE}); this->EmitSAMStatusMessageEvent("Successfully generated segmentation."); } catch (const mitk::Exception &e) { this->EmitSAMStatusMessageEvent(e.GetDescription()); mitkThrow() << e.GetDescription(); } } else if (nullptr != this->GetWorkingPlaneGeometry()) { this->ResetPreviewContentAtTimeStep(timeStep); + RenderingManager::GetInstance()->ForceImmediateUpdateAll(); } } } std::string mitk::SegmentAnythingTool::GetHashForCurrentPlane() { mitk::Vector3D normal = this->GetWorkingPlaneGeometry()->GetNormal(); std::stringstream hashstream; hashstream << normal[0] << normal[1] << normal[2]; mitk::Point3D point = m_PointSetPositive->GetPoint(0); for (int i = 0; i < 3; ++i) { if (normal[i] != 0) { hashstream << point[i]; } } size_t hashVal = std::hash{}(hashstream.str()); return std::to_string(hashVal); } std::stringstream mitk::SegmentAnythingTool::GetPointsAsCSVString(const mitk::BaseGeometry *baseGeometry) { MITK_INFO << "No.of points: " << m_PointSetPositive->GetSize(); std::stringstream pointsAndLabels; pointsAndLabels << "Point,Label\n"; mitk::PointSet::PointsIterator pointSetItPos = m_PointSetPositive->Begin(); mitk::PointSet::PointsIterator pointSetItNeg = m_PointSetNegative->Begin(); const char SPACE = ' '; while (pointSetItPos != m_PointSetPositive->End() || pointSetItNeg != m_PointSetNegative->End()) { if (pointSetItPos != m_PointSetPositive->End()) { mitk::Point3D point = pointSetItPos.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); pointsAndLabels << (int)p2D[0] << SPACE << (int)p2D[1] << ",1" << std::endl; } ++pointSetItPos; } if (pointSetItNeg != m_PointSetNegative->End()) { mitk::Point3D point = pointSetItNeg.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); pointsAndLabels << (int)p2D[0] << SPACE << (int)p2D[1] << ",0" << std::endl; } ++pointSetItNeg; } } return pointsAndLabels; } std::vector> mitk::SegmentAnythingTool::GetPointsAsVector( const mitk::BaseGeometry *baseGeometry) { std::vector> clickVec; clickVec.reserve(m_PointSetPositive->GetSize() + m_PointSetNegative->GetSize()); mitk::PointSet::PointsIterator pointSetItPos = m_PointSetPositive->Begin(); mitk::PointSet::PointsIterator pointSetItNeg = m_PointSetNegative->Begin(); while (pointSetItPos != m_PointSetPositive->End() || pointSetItNeg != m_PointSetNegative->End()) { if (pointSetItPos != m_PointSetPositive->End()) { mitk::Point3D point = pointSetItPos.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); clickVec.push_back(std::pair(p2D, "1")); } ++pointSetItPos; } if (pointSetItNeg != m_PointSetNegative->End()) { mitk::Point3D point = pointSetItNeg.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); clickVec.push_back(std::pair(p2D, "0")); } ++pointSetItNeg; } } return clickVec; } mitk::Point2D mitk::SegmentAnythingTool::Get2DIndicesfrom3DWorld(const mitk::BaseGeometry *baseGeometry, mitk::Point3D &point3d) { baseGeometry->WorldToIndex(point3d, point3d); MITK_INFO << point3d[0] << " " << point3d[1] << " " << point3d[2]; // remove Point2D point2D; point2D.SetElement(0, point3d[0]); point2D.SetElement(1, point3d[1]); return point2D; } void mitk::SegmentAnythingTool::EmitSAMStatusMessageEvent(const std::string& status) { SAMStatusMessageEvent.Send(status); } diff --git a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp index fcc1edad50..c3bcd3f38d 100644 --- a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp +++ b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp @@ -1,353 +1,353 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "QmitkSegmentAnythingToolGUI.h" #include "mitkSegmentAnythingTool.h" #include "mitkProcessExecutor.h" #include #include #include #include #include #include #include #include #include MITK_TOOL_GUI_MACRO(MITKSEGMENTATIONUI_EXPORT, QmitkSegmentAnythingToolGUI, "") namespace { mitk::IPreferences *GetPreferences() { auto *preferencesService = mitk::CoreServices::GetPreferencesService(); return preferencesService->GetSystemPreferences()->Node("org.mitk.views.segmentation"); } } QmitkSegmentAnythingToolGUI::QmitkSegmentAnythingToolGUI() : QmitkSegWithPreviewToolGUIBase(true) { m_EnableConfirmSegBtnFnc = [this](bool enabled) { bool result = false; auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { result = enabled && tool->HasPicks(); } return result; }; m_Preferences = GetPreferences(); m_Preferences->OnPropertyChanged += mitk::MessageDelegate1( this, &QmitkSegmentAnythingToolGUI::OnPreferenceChangedEvent); } QmitkSegmentAnythingToolGUI::~QmitkSegmentAnythingToolGUI() { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SAMStatusMessageEvent -= mitk::MessageDelegate1( this, &QmitkSegmentAnythingToolGUI::StatusMessageListener); } } void QmitkSegmentAnythingToolGUI::InitializeUI(QBoxLayout *mainLayout) { m_Controls.setupUi(this); m_Controls.statusLabel->setTextFormat(Qt::RichText); QString welcomeText; if (m_GpuLoader.GetGPUCount() != 0) { welcomeText = "STATUS: Welcome to Segment Anything tool. You're in luck: " + QString::number(m_GpuLoader.GetGPUCount()) + " GPU(s) were detected."; } else { welcomeText = "STATUS: Welcome to Segment Anything tool. Sorry, " + QString::number(m_GpuLoader.GetGPUCount()) + " GPUs were detected."; } connect(m_Controls.activateButton, SIGNAL(clicked()), this, SLOT(OnActivateBtnClicked())); connect(m_Controls.resetButton, SIGNAL(clicked()), this, SLOT(OnResetPicksClicked())); QIcon arrowIcon = QmitkStyleManager::ThemeIcon( QStringLiteral(":/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg")); m_Controls.activateButton->setIcon(arrowIcon); bool isInstalled = this->ValidatePrefences(); if (isInstalled) { - m_PythonPath = QString::fromStdString(m_Preferences->Get("sam python path", "")); QString modelType = QString::fromStdString(m_Preferences->Get("sam modeltype", "")); welcomeText += " SAM is already found installed. Model type '" + modelType + "' selected in Preferences."; } else { welcomeText += " SAM tool is not configured correctly. Please go to Preferences (Cntl+P) > Segment Anything to configure and/or install SAM."; } this->EnableAll(isInstalled); this->WriteStatusMessage(welcomeText); this->ShowProgressBar(false); m_Controls.samProgressBar->setMaximum(0); mainLayout->addLayout(m_Controls.verticalLayout); Superclass::InitializeUI(mainLayout); } bool QmitkSegmentAnythingToolGUI::ValidatePrefences() { const QString storageDir = QString::fromStdString(m_Preferences->Get("sam python path", "")); bool isInstalled = QmitkSegmentAnythingToolGUI::IsSAMInstalled(storageDir); std::string modelType = m_Preferences->Get("sam modeltype", ""); std::string path = m_Preferences->Get("sam parent path", ""); return (isInstalled && !modelType.empty() && !path.empty()); } void QmitkSegmentAnythingToolGUI::EnableAll(bool isEnable) { m_Controls.activateButton->setEnabled(isEnable); } void QmitkSegmentAnythingToolGUI::WriteStatusMessage(const QString &message) { m_Controls.statusLabel->setText(message); m_Controls.statusLabel->setStyleSheet("font-weight: bold; color: white"); qApp->processEvents(); } void QmitkSegmentAnythingToolGUI::WriteErrorMessage(const QString &message) { m_Controls.statusLabel->setText(message); m_Controls.statusLabel->setStyleSheet("font-weight: bold; color: red"); qApp->processEvents(); } void QmitkSegmentAnythingToolGUI::ShowErrorMessage(const std::string &message, QMessageBox::Icon icon) { this->setCursor(Qt::ArrowCursor); QMessageBox *messageBox = new QMessageBox(icon, nullptr, message.c_str()); messageBox->exec(); delete messageBox; MITK_WARN << message; } void QmitkSegmentAnythingToolGUI::StatusMessageListener(const std::string &message) { if (message.rfind("Error", 0) == 0) { this->EnableAll(true); this->WriteErrorMessage(QString::fromStdString(message)); } else { this->WriteStatusMessage(QString::fromStdString(message)); } } void QmitkSegmentAnythingToolGUI::OnActivateBtnClicked() { auto tool = this->GetConnectedToolAs(); if (nullptr == tool) { return; } try { this->EnableAll(false); qApp->processEvents(); - if (!QmitkSegmentAnythingToolGUI::IsSAMInstalled(m_PythonPath)) + QString pythonPath = QString::fromStdString(m_Preferences->Get("sam python path", "")); + if (!QmitkSegmentAnythingToolGUI::IsSAMInstalled(pythonPath)) { throw std::runtime_error(WARNING_SAM_NOT_FOUND); } - tool->SetPythonPath(m_PythonPath.toStdString()); + tool->SetPythonPath(pythonPath.toStdString()); tool->SetGpuId(m_Preferences->GetInt("sam gpuid", -1)); const QString modelType = QString::fromStdString(m_Preferences->Get("sam modeltype", "")); tool->SetModelType(modelType.toStdString()); this->WriteStatusMessage( QString("STATUS: Checking if model is already downloaded... This might take a while.")); tool->SAMStatusMessageEvent += mitk::MessageDelegate1( this, &QmitkSegmentAnythingToolGUI::StatusMessageListener); if (this->DownloadModel(modelType)) { this->WriteStatusMessage(QString("STATUS: Model found. Initializing Segment Anything tool...")); if (this->ActivateSAMDaemon()) { this->WriteStatusMessage(QString("STATUS: Model found. Segment Anything tool initialized.")); } else { this->WriteErrorMessage(QString("STATUS: Model found. Couldn't init tool backend.")); this->EnableAll(true); } } else { tool->IsReadyOff(); this->WriteStatusMessage(QString("STATUS: Model type not found. Starting download...")); } } catch (const std::exception &e) { std::stringstream errorMsg; errorMsg << "STATUS: Error while processing parameters for Segment Anything segmentation. Reason: " << e.what(); this->ShowErrorMessage(errorMsg.str()); this->WriteErrorMessage(QString::fromStdString(errorMsg.str())); this->EnableAll(true); return; } catch (...) { std::string errorMsg = "Unkown error occured while generation Segment Anything segmentation."; this->ShowErrorMessage(errorMsg); this->EnableAll(true); return; } } bool QmitkSegmentAnythingToolGUI::ActivateSAMDaemon() { auto tool = this->GetConnectedToolAs(); if (nullptr == tool) { return false; } this->ShowProgressBar(true); qApp->processEvents(); try { tool->InitSAMPythonProcess(); while (!tool->IsPythonReady()) { qApp->processEvents(); } tool->IsReadyOn(); } catch (...) { tool->IsReadyOff(); } this->ShowProgressBar(false); return tool->GetIsReady(); } bool QmitkSegmentAnythingToolGUI::DownloadModel(const QString &modelType) { QUrl url = QmitkSegmentAnythingToolGUI::VALID_MODELS_URL_MAP[modelType]; QString modelFileName = url.fileName(); const QString storageDir = QString::fromStdString(m_Preferences->Get("sam parent path", "")); QString checkPointPath = storageDir + QDir::separator() + modelFileName; if (QFile::exists(checkPointPath)) { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SetCheckpointPath(checkPointPath.toStdString()); } return true; } connect(&m_Manager, SIGNAL(finished(QNetworkReply*)), this, SLOT(FileDownloaded(QNetworkReply*))); QNetworkRequest request(url); m_Manager.get(request); this->ShowProgressBar(true); return false; } void QmitkSegmentAnythingToolGUI::FileDownloaded(QNetworkReply *reply) { const QString storageDir = QString::fromStdString(m_Preferences->Get("sam parent path", "")); const QString &modelFileName = reply->url().fileName(); QFile file(storageDir + QDir::separator() + modelFileName); if (file.open(QIODevice::WriteOnly)) { file.write(reply->readAll()); file.close(); disconnect(&m_Manager, SIGNAL(finished(QNetworkReply *)), this, SLOT(FileDownloaded(QNetworkReply *))); this->WriteStatusMessage(QString("STATUS: Model successfully downloaded. Initializing Segment Anything....")); auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SetCheckpointPath(file.fileName().toStdString()); if (this->ActivateSAMDaemon()) { this->WriteStatusMessage(QString("STATUS: Model successfully downloaded. Segment Anything initialized.")); } else { this->WriteErrorMessage(QString("STATUS: Model successfully downloaded. But, couldn't initialize tool backend.")); this->EnableAll(true); } } } else { this->WriteErrorMessage("STATUS: Model couldn't be downloaded. Segment Anything not initialized."); } this->EnableAll(true); this->ShowProgressBar(false); } void QmitkSegmentAnythingToolGUI::ShowProgressBar(bool enabled) { m_Controls.samProgressBar->setEnabled(enabled); m_Controls.samProgressBar->setVisible(enabled); } bool QmitkSegmentAnythingToolGUI::IsSAMInstalled(const QString &pythonPath) { QString fullPath = pythonPath; bool isPythonExists = false; bool isSamExists = false; #ifdef _WIN32 isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python.exe")); if (!(fullPath.endsWith("Scripts", Qt::CaseInsensitive) || fullPath.endsWith("Scripts/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("Scripts"); isPythonExists = (!isPythonExists) ? QFile::exists(fullPath + QDir::separator() + QString("python.exe")) : isPythonExists; } #else isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python3")); if (!(fullPath.endsWith("bin", Qt::CaseInsensitive) || fullPath.endsWith("bin/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("bin"); isPythonExists = (!isPythonExists) ? QFile::exists(fullPath + QDir::separator() + QString("python3")) : isPythonExists; } #endif isSamExists = QFile::exists(fullPath + QDir::separator() + QString("run_inference_daemon.py")); bool isExists = isSamExists && isPythonExists; return isExists; } void QmitkSegmentAnythingToolGUI::OnResetPicksClicked() { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->ClearPicks(); } } void QmitkSegmentAnythingToolGUI::OnPreferenceChangedEvent(const mitk::IPreferences::ChangeEvent&) { this->EnableAll(true); this->WriteStatusMessage("A Preference change was detected. Please initialize the tool again."); auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->IsReadyOff(); } } diff --git a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h index cc5f660269..aeaf0ba6ea 100644 --- a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h +++ b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h @@ -1,138 +1,137 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef QmitkSegmentAnythingToolGUI_h #define QmitkSegmentAnythingToolGUI_h #include "QmitkSegWithPreviewToolGUIBase.h" #include #include "ui_QmitkSegmentAnythingGUIControls.h" #include "QmitknnUNetGPU.h" #include "QmitkSetupVirtualEnvUtil.h" #include #include #include #include #include /** \ingroup org_mitk_gui_qt_interactivesegmentation_internal \brief GUI for mitk::SegmentAnythingTool. \sa mitk::PickingTool */ class MITKSEGMENTATIONUI_EXPORT QmitkSegmentAnythingToolGUI : public QmitkSegWithPreviewToolGUIBase { Q_OBJECT public: mitkClassMacro(QmitkSegmentAnythingToolGUI, QmitkSegWithPreviewToolGUIBase); itkFactorylessNewMacro(Self); itkCloneMacro(Self); inline static const QMap VALID_MODELS_URL_MAP = { {"vit_b", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")}, {"vit_l", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth")}, {"vit_h", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth")}}; /** * @brief Checks if SegmentAnything is found inside the selected python virtual environment. * @return bool */ static bool IsSAMInstalled(const QString &); protected slots: /** * @brief Qt Slot */ void OnResetPicksClicked(); /** * @brief Qt Slot */ void OnActivateBtnClicked(); /** * @brief Qt Slot */ void FileDownloaded(QNetworkReply*); protected: QmitkSegmentAnythingToolGUI(); ~QmitkSegmentAnythingToolGUI(); void InitializeUI(QBoxLayout* mainLayout) override; /** * @brief Writes any message in white on the tool pane. */ void WriteStatusMessage(const QString&); /** * @brief Writes any message in red on the tool pane. */ void WriteErrorMessage(const QString&); void StatusMessageListener(const std::string&); void OnPreferenceChangedEvent(const mitk::IPreferences::ChangeEvent&); /** * @brief Creates a QMessage object and shows on screen. */ void ShowErrorMessage(const std::string&, QMessageBox::Icon = QMessageBox::Critical); /** * @brief Enable (or Disable) GUI elements. Currently, on the activate button * is affected. */ void EnableAll(bool); /** * @brief Start download process for the given model type. * Download URL is looked from the VALID_MODELS_URL_MAP. * * @return bool */ bool DownloadModel(const QString&); /** * @brief Enable (or Disable) progressbar on GUI * */ void ShowProgressBar(bool); /** * @brief Requests the tool class to spawn the SAM python daemon * process. Waits until the daemon is started. * * @return bool */ bool ActivateSAMDaemon(); /** * @brief Checks if the preferences are correctly set by the user. * * @return bool */ bool ValidatePrefences(); private: mitk::IPreferences* m_Preferences; QNetworkAccessManager m_Manager; Ui_QmitkSegmentAnythingGUIControls m_Controls; - QString m_PythonPath; QmitkGPULoader m_GpuLoader; bool m_FirstPreviewComputation = true; const std::string WARNING_SAM_NOT_FOUND = "SAM is not detected in the selected python environment. Please reinstall SAM."; }; #endif diff --git a/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.cpp b/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.cpp index eab7b24c41..8ecef4e879 100644 --- a/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.cpp +++ b/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.cpp @@ -1,364 +1,369 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "QmitkSegmentAnythingPreferencePage.h" #include #include #include #include #include #include #include #include #include #include namespace { mitk::IPreferences* GetPreferences() { auto* preferencesService = mitk::CoreServices::GetPreferencesService(); return preferencesService->GetSystemPreferences()->Node("org.mitk.views.segmentation"); } } QmitkSegmentAnythingPreferencePage::QmitkSegmentAnythingPreferencePage() : m_Ui(new Ui::QmitkSegmentAnythingPreferencePage), m_Control(nullptr){} QmitkSegmentAnythingPreferencePage::~QmitkSegmentAnythingPreferencePage(){} void QmitkSegmentAnythingPreferencePage::Init(berry::IWorkbench::Pointer){} void QmitkSegmentAnythingPreferencePage::CreateQtControl(QWidget* parent) { m_Control = new QWidget(parent); m_Ui->setupUi(m_Control); #ifndef _WIN32 m_Ui->sysPythonComboBox->addItem("/usr/bin"); #endif this->AutoParsePythonPaths(); m_Ui->sysPythonComboBox->addItem("Select..."); m_Ui->sysPythonComboBox->setCurrentIndex(0); connect(m_Ui->installSAMButton, SIGNAL(clicked()), this, SLOT(OnInstallBtnClicked())); connect(m_Ui->clearSAMButton, SIGNAL(clicked()), this, SLOT(OnClearInstall())); connect(m_Ui->sysPythonComboBox, QOverload::of(&QComboBox::activated), [=](int index) { OnSystemPythonChanged(m_Ui->sysPythonComboBox->itemText(index)); }); QIcon deleteIcon = QmitkStyleManager::ThemeIcon(QStringLiteral(":/org_mitk_icons/icons/awesome/scalable/actions/edit-delete.svg")); m_Ui->clearSAMButton->setIcon(deleteIcon); const QString storageDir = m_Installer.GetVirtualEnvPath(); - m_IsInstalled = QmitkSegmentAnythingToolGUI::IsSAMInstalled(storageDir); + bool isInstalled = QmitkSegmentAnythingToolGUI::IsSAMInstalled(storageDir); QString welcomeText; - if (m_IsInstalled) + if (isInstalled) { m_PythonPath = GetExactPythonPath(storageDir); m_Installer.SetVirtualEnvPath(m_PythonPath); welcomeText += " Segment Anything tool is already found installed."; m_Ui->installSAMButton->setEnabled(false); } else { - welcomeText += " Segment Anything tool not installed. Please click on \"Install SAM\" above."; + welcomeText += " Segment Anything tool not installed. Please click on \"Install SAM\" above. \ + The installation will create a new virtual environment using the System Python selected above."; m_Ui->installSAMButton->setEnabled(true); } this->WriteStatusMessage(welcomeText); m_Ui->samModelTypeComboBox->addItems(QmitkSegmentAnythingToolGUI::VALID_MODELS_URL_MAP.keys()); m_Ui->gpuComboBox->addItem(CPU_ID); this->SetGPUInfo(); this->Update(); this->Update(); } QWidget* QmitkSegmentAnythingPreferencePage::GetQtControl() const { return m_Control; } bool QmitkSegmentAnythingPreferencePage::PerformOk() { auto* prefs = GetPreferences(); prefs->Put("sam parent path", m_Installer.STORAGE_DIR.toStdString()); prefs->Put("sam python path", m_PythonPath.toStdString()); prefs->Put("sam modeltype", m_Ui->samModelTypeComboBox->currentText().toStdString()); prefs->PutInt("sam gpuid", FetchSelectedGPUFromUI()); return true; } void QmitkSegmentAnythingPreferencePage::PerformCancel(){} void QmitkSegmentAnythingPreferencePage::Update() { auto* prefs = GetPreferences(); m_Ui->samModelTypeComboBox->setCurrentText(QString::fromStdString(prefs->Get("sam modeltype", "vit_b"))); int gpuId = prefs->GetInt("sam gpuid", -1); if (gpuId == -1) { m_Ui->gpuComboBox->setCurrentText(CPU_ID); } else if (m_GpuLoader.GetGPUCount() == 0) { m_Ui->gpuComboBox->setCurrentText(QString::number(gpuId)); } else { std::vector specs = m_GpuLoader.GetAllGPUSpecs(); QmitkGPUSpec gpuSpec = specs[gpuId]; m_Ui->gpuComboBox->setCurrentText(QString::number(gpuSpec.id) + ": " + gpuSpec.name + " (" + gpuSpec.memory + ")"); } } QString QmitkSegmentAnythingPreferencePage::OnSystemPythonChanged(const QString &pyEnv) { QString pyPath; if (pyEnv == QString("Select...")) { QString path = QFileDialog::getExistingDirectory(m_Ui->sysPythonComboBox->parentWidget(), "Python Path", "dir"); if (!path.isEmpty()) { this->OnSystemPythonChanged(path); // recall same function for new path validation bool oldState = m_Ui->sysPythonComboBox->blockSignals(true); // block signal firing while inserting item m_Ui->sysPythonComboBox->insertItem(0, path); m_Ui->sysPythonComboBox->setCurrentIndex(0); m_Ui->sysPythonComboBox->blockSignals(oldState); // unblock signal firing after inserting item. Remove this after Qt6 migration } } else { QString uiPyPath = this->GetPythonPathFromUI(pyEnv); pyPath = this->GetExactPythonPath(uiPyPath); } return pyPath; } QString QmitkSegmentAnythingPreferencePage::GetPythonPathFromUI(const QString &pyUI) const { QString fullPath = pyUI; if (-1 != fullPath.indexOf(")")) { fullPath = fullPath.mid(fullPath.indexOf(")") + 2); } return fullPath.simplified(); } QString QmitkSegmentAnythingPreferencePage::GetExactPythonPath(const QString &pyEnv) const { QString fullPath = pyEnv; bool isPythonExists = false; #ifdef _WIN32 isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python.exe")); if (!isPythonExists && !(fullPath.endsWith("Scripts", Qt::CaseInsensitive) || fullPath.endsWith("Scripts/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("Scripts"); isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python.exe")); } #else isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python3")); if (!isPythonExists && !(fullPath.endsWith("bin", Qt::CaseInsensitive) || fullPath.endsWith("bin/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("bin"); isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python3")); } #endif if (!isPythonExists) { fullPath.clear(); } return fullPath; } void QmitkSegmentAnythingPreferencePage::AutoParsePythonPaths() { QString homeDir = QDir::homePath(); std::vector searchDirs; #ifdef _WIN32 searchDirs.push_back(QString("C:") + QDir::separator() + QString("ProgramData") + QDir::separator() + QString("anaconda3")); #else // Add search locations for possible standard python paths here searchDirs.push_back(homeDir + QDir::separator() + "anaconda3"); searchDirs.push_back(homeDir + QDir::separator() + "miniconda3"); searchDirs.push_back(homeDir + QDir::separator() + "opt" + QDir::separator() + "miniconda3"); searchDirs.push_back(homeDir + QDir::separator() + "opt" + QDir::separator() + "anaconda3"); #endif for (QString searchDir : searchDirs) { if (searchDir.endsWith("anaconda3", Qt::CaseInsensitive)) { if (QDir(searchDir).exists()) { m_Ui->sysPythonComboBox->addItem("(base): " + searchDir); searchDir.append((QDir::separator() + QString("envs"))); } } for (QDirIterator subIt(searchDir, QDir::AllDirs, QDirIterator::NoIteratorFlags); subIt.hasNext();) { subIt.next(); QString envName = subIt.fileName(); if (!envName.startsWith('.')) // Filter out irrelevent hidden folders, if any. { m_Ui->sysPythonComboBox->addItem("(" + envName + "): " + subIt.filePath()); } } } } void QmitkSegmentAnythingPreferencePage::SetGPUInfo() { std::vector specs = m_GpuLoader.GetAllGPUSpecs(); for (const QmitkGPUSpec &gpuSpec : specs) { m_Ui->gpuComboBox->addItem(QString::number(gpuSpec.id) + ": " + gpuSpec.name + " (" + gpuSpec.memory + ")"); } if (specs.empty()) { m_Ui->gpuComboBox->setEditable(true); m_Ui->gpuComboBox->addItem(QString::number(0)); m_Ui->gpuComboBox->setValidator(new QIntValidator(0, 999, this)); m_Ui->gpuComboBox->setCurrentIndex(m_Ui->gpuComboBox->findText("cpu")); } else { m_Ui->gpuComboBox->setCurrentIndex(m_Ui->gpuComboBox->count()-1); } } int QmitkSegmentAnythingPreferencePage::FetchSelectedGPUFromUI() const { QString gpuInfo = m_Ui->gpuComboBox->currentText(); if ("cpu" == gpuInfo) { return -1; } else if(m_GpuLoader.GetGPUCount() == 0) { return static_cast(gpuInfo.toInt()); } else { QString gpuId = gpuInfo.split(":", QString::SplitBehavior::SkipEmptyParts).first(); return static_cast(gpuId.toInt()); } } void QmitkSegmentAnythingPreferencePage::OnInstallBtnClicked() { QString systemPython = OnSystemPythonChanged(m_Ui->sysPythonComboBox->currentText()); if (!systemPython.isEmpty()) { this->WriteStatusMessage("STATUS: Installing SAM..."); m_Ui->installSAMButton->setEnabled(false); m_Installer.SetSystemPythonPath(systemPython); - m_IsInstalled = m_Installer.SetupVirtualEnv(m_Installer.VENV_NAME); - if (m_IsInstalled) + bool isInstalled = false; + bool isFinished = m_Installer.SetupVirtualEnv(m_Installer.VENV_NAME); + if (isFinished) + { + isInstalled = QmitkSegmentAnythingToolGUI::IsSAMInstalled(m_Installer.GetVirtualEnvPath()); + } + if (isInstalled) { m_PythonPath = this->GetExactPythonPath(m_Installer.GetVirtualEnvPath()); this->WriteStatusMessage("STATUS: Successfully installed SAM."); - auto *prefs = GetPreferences(); - prefs->Put("sam python path", m_PythonPath.toStdString()); } else { this->WriteErrorMessage("ERROR: Couldn't install SAM."); m_Ui->installSAMButton->setEnabled(true); } } } void QmitkSegmentAnythingPreferencePage::OnClearInstall() { QDir folderPath(m_Installer.GetVirtualEnvPath()); - m_IsInstalled = folderPath.removeRecursively(); - if (m_IsInstalled) + bool isDeleted = folderPath.removeRecursively(); + if (isDeleted) { this->WriteStatusMessage("Deleted SAM installation."); m_Ui->installSAMButton->setEnabled(true); + m_PythonPath.clear(); } else { MITK_ERROR << "The virtual environment couldn't be removed. Please check if you have the required access " "privileges or, some other process is accessing the folders."; } } void QmitkSegmentAnythingPreferencePage::WriteStatusMessage(const QString &message) { m_Ui->samInstallStatusLabel->setText(message); m_Ui->samInstallStatusLabel->setStyleSheet("font-weight: bold; color: white"); qApp->processEvents(); } void QmitkSegmentAnythingPreferencePage::WriteErrorMessage(const QString &message) { m_Ui->samInstallStatusLabel->setText(message); m_Ui->samInstallStatusLabel->setStyleSheet("font-weight: bold; color: red"); qApp->processEvents(); } QString QmitkSAMInstaller::GetVirtualEnvPath() { return STORAGE_DIR + VENV_NAME; } bool QmitkSAMInstaller::SetupVirtualEnv(const QString &venvName) { if (GetSystemPythonPath().isEmpty()) { return false; } QDir folderPath(GetBaseDir()); folderPath.mkdir(venvName); if (!folderPath.cd(venvName)) { return false; // Check if directory creation was successful. } mitk::ProcessExecutor::ArgumentListType args; auto spExec = mitk::ProcessExecutor::New(); auto spCommand = itk::CStyleCommand::New(); spCommand->SetCallback(&PrintProcessEvent); spExec->AddObserver(mitk::ExternalProcessOutputEvent(), spCommand); args.push_back("-m"); args.push_back("venv"); args.push_back(venvName.toStdString()); #ifdef _WIN32 QString pythonFile = GetSystemPythonPath() + QDir::separator() + "python.exe"; QString pythonExeFolder = "Scripts"; #else QString pythonFile = GetSystemPythonPath() + QDir::separator() + "python3"; QString pythonExeFolder = "bin"; #endif spExec->Execute(GetBaseDir().toStdString(), pythonFile.toStdString(), args); // Setup local virtual environment if (folderPath.cd(pythonExeFolder)) { this->SetPythonPath(folderPath.absolutePath()); this->SetPipPath(folderPath.absolutePath()); this->InstallPytorch(); for (auto &package : PACKAGES) { this->PipInstall(package.toStdString(), &PrintProcessEvent); } std::string pythonCode; // python syntax to check if torch is installed with CUDA. pythonCode.append("import torch;"); pythonCode.append("print('Pytorch was installed with CUDA') if torch.cuda.is_available() else print('PyTorch was " "installed WITHOUT CUDA');"); this->ExecutePython(pythonCode, &PrintProcessEvent); return true; } return false; } diff --git a/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.h b/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.h index e51b7cf568..9ee649e9f0 100644 --- a/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.h +++ b/Plugins/org.mitk.gui.qt.segmentation/src/QmitkSegmentAnythingPreferencePage.h @@ -1,114 +1,113 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef QmitkSegmentAnythingPreferencePage_h #define QmitkSegmentAnythingPreferencePage_h #include #include #include #include #include #include class QWidget; namespace Ui { class QmitkSegmentAnythingPreferencePage; } class QmitkSAMInstaller : public QmitkSetupVirtualEnvUtil { public: const QString VENV_NAME = ".sam"; const QString SAM_VERSION = "1.0"; // currently, unused const std::vector PACKAGES = {QString("git+https://github.com/ASHISRAVINDRAN/sam-mitk.git")}; const QString STORAGE_DIR; inline QmitkSAMInstaller( const QString baseDir = QStandardPaths::writableLocation(QStandardPaths::GenericDataLocation) + QDir::separator() + qApp->organizationName() + QDir::separator()) : QmitkSetupVirtualEnvUtil(baseDir), STORAGE_DIR(baseDir){}; bool SetupVirtualEnv(const QString &) override; QString GetVirtualEnvPath() override; }; class QmitkSegmentAnythingPreferencePage : public QObject, public berry::IQtPreferencePage { Q_OBJECT Q_INTERFACES(berry::IPreferencePage) public: QmitkSegmentAnythingPreferencePage(); ~QmitkSegmentAnythingPreferencePage() override; void Init(berry::IWorkbench::Pointer workbench) override; void CreateQtControl(QWidget* parent) override; QWidget* GetQtControl() const override; bool PerformOk() override; void PerformCancel() override; void Update() override; private slots: void OnInstallBtnClicked(); void OnClearInstall(); QString OnSystemPythonChanged(const QString&); protected: /** * @brief Searches and parses paths of python virtual enviroments * from predefined lookout locations */ void AutoParsePythonPaths(); /** * @brief Get the virtual env path from UI combobox removing any * extra special characters. * * @return QString */ QString GetPythonPathFromUI(const QString &) const; /** * @brief Get the Exact Python Path for any OS * from the virtual environment path. * @return QString */ QString GetExactPythonPath(const QString &) const; /** * @brief Adds GPU information to the gpu combo box. * In case, there aren't any GPUs avaialble, the combo box will be * rendered editable. */ void SetGPUInfo(); /** * @brief Returns GPU id of the selected GPU from the Combo box. * @return int */ int FetchSelectedGPUFromUI() const; void WriteStatusMessage(const QString &); void WriteErrorMessage(const QString &); private: Ui::QmitkSegmentAnythingPreferencePage* m_Ui; QmitkSAMInstaller m_Installer; QWidget* m_Control; QmitkGPULoader m_GpuLoader; QString m_PythonPath; - bool m_IsInstalled = false; const QString CPU_ID = "cpu"; }; #endif