diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.cpp b/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.cpp index 7dbd5d83b2..8df1a4f902 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.cpp +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.cpp @@ -1,219 +1,220 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "mitkSegmentAnythingPythonService.h" #include "mitkIOUtil.h" #include #include #include #include using namespace std::chrono_literals; namespace mitk { const std::string SIGNALCONSTANTS::READY = "READY"; const std::string SIGNALCONSTANTS::KILL = "KILL"; const std::string SIGNALCONSTANTS::CUDA_OUT_OF_MEMORY_ERROR = "CudaOutOfMemoryError"; bool mitk::SegmentAnythingPythonService::IsPythonReady = false; mitk::SegmentAnythingPythonService::Status mitk::SegmentAnythingPythonService::CurrentStatus = - mitk::SegmentAnythingPythonService::Status::READY; + mitk::SegmentAnythingPythonService::Status::OFF; } mitk::SegmentAnythingPythonService::SegmentAnythingPythonService( std::string workingDir, std::string modelType, std::string checkPointPath, unsigned int gpuId) : m_PythonPath(workingDir), m_ModelType(modelType), m_CheckpointPath(checkPointPath), m_GpuId(gpuId) { this->CreateTempDirs(m_PARENT_TEMP_DIR_PATTERN); } mitk::SegmentAnythingPythonService::~SegmentAnythingPythonService() { if (mitk::SegmentAnythingPythonService::IsPythonReady) { this->StopAsyncProcess(); } mitk::SegmentAnythingPythonService::IsPythonReady = false; + mitk::SegmentAnythingPythonService::CurrentStatus = mitk::SegmentAnythingPythonService::Status::OFF; std::filesystem::remove_all(this->GetMitkTempDir()); } -void mitk::SegmentAnythingPythonService::onPythonProcessEvent(itk::Object * /*pCaller*/, - const itk::EventObject &e, - void *) +void mitk::SegmentAnythingPythonService::onPythonProcessEvent(itk::Object*, const itk::EventObject &e, void*) { std::string testCOUT,testCERR; const auto *pEvent = dynamic_cast(&e); if (pEvent) { testCOUT = testCOUT + pEvent->GetOutput(); testCOUT.erase(std::find_if(testCOUT.rbegin(), testCOUT.rend(), [](unsigned char ch) { return !std::isspace(ch);}).base(), testCOUT.end()); // remove trailing whitespaces, if any if (SIGNALCONSTANTS::READY == testCOUT) { mitk::SegmentAnythingPythonService::IsPythonReady = true; mitk::SegmentAnythingPythonService::CurrentStatus = mitk::SegmentAnythingPythonService::Status::READY; } if (SIGNALCONSTANTS::KILL == testCOUT) { mitk::SegmentAnythingPythonService::CurrentStatus = mitk::SegmentAnythingPythonService::Status::KILLED; mitk::SegmentAnythingPythonService::IsPythonReady = false; } if (SIGNALCONSTANTS::CUDA_OUT_OF_MEMORY_ERROR == testCOUT) { mitk::SegmentAnythingPythonService::CurrentStatus = mitk::SegmentAnythingPythonService::Status::CUDAError; mitk::SegmentAnythingPythonService::IsPythonReady = false; } MITK_INFO << testCOUT; } const auto *pErrEvent = dynamic_cast(&e); if (pErrEvent) { testCERR = testCERR + pErrEvent->GetOutput(); mitk::SegmentAnythingPythonService::IsPythonReady = false; MITK_ERROR << testCERR; } } void mitk::SegmentAnythingPythonService::StopAsyncProcess() { std::stringstream controlStream; controlStream << SIGNALCONSTANTS::KILL; this->WriteControlFile(controlStream); m_Future.get(); } void mitk::SegmentAnythingPythonService::StartAsyncProcess() { if (nullptr != m_DaemonExec) { this->StopAsyncProcess(); } if (this->GetMitkTempDir().empty()) { this->CreateTempDirs(m_PARENT_TEMP_DIR_PATTERN); } std::stringstream controlStream; controlStream << SIGNALCONSTANTS::READY; this->WriteControlFile(controlStream); m_DaemonExec = ProcessExecutor::New(); itk::CStyleCommand::Pointer spCommand = itk::CStyleCommand::New(); spCommand->SetCallback(&mitk::SegmentAnythingPythonService::onPythonProcessEvent); m_DaemonExec->AddObserver(ExternalProcessOutputEvent(), spCommand); m_Future = std::async(std::launch::async, &mitk::SegmentAnythingPythonService::start_python_daemon, this); } void mitk::SegmentAnythingPythonService::TransferPointsToProcess(std::stringstream &triggerCSV) { CheckStatus(); std::string triggerFilePath = m_InDir + IOUtil::GetDirectorySeparator() + m_TRIGGER_FILENAME; std::ofstream csvfile; csvfile.open(triggerFilePath, std::ofstream::out | std::ofstream::trunc); csvfile << triggerCSV.rdbuf(); csvfile.close(); } void mitk::SegmentAnythingPythonService::WriteControlFile(std::stringstream &statusStream) { std::string controlFilePath = m_InDir + IOUtil::GetDirectorySeparator() + "control.txt"; std::ofstream controlFile; controlFile.open(controlFilePath, std::ofstream::out | std::ofstream::trunc); controlFile << statusStream.rdbuf(); controlFile.close(); } void mitk::SegmentAnythingPythonService::start_python_daemon() { ProcessExecutor::ArgumentListType args; std::string command = "python"; args.push_back("-u"); args.push_back("run_inference_daemon.py"); args.push_back("--input-folder"); args.push_back(m_InDir); args.push_back("--output-folder"); args.push_back(m_OutDir); args.push_back("--trigger-file"); args.push_back(m_TRIGGER_FILENAME); args.push_back("--model-type"); args.push_back(m_ModelType); args.push_back("--checkpoint"); args.push_back(m_CheckpointPath); try { std::string cudaEnv = "CUDA_VISIBLE_DEVICES=" + std::to_string(m_GpuId); itksys::SystemTools::PutEnv(cudaEnv.c_str()); std::stringstream logStream; for (const auto &arg : args) logStream << arg << " "; logStream << m_PythonPath; MITK_INFO << logStream.str(); m_DaemonExec->Execute(m_PythonPath, command, args); } catch (const mitk::Exception &e) { MITK_ERROR << e.GetDescription(); return; } MITK_INFO << "Python process ended."; } -void mitk::SegmentAnythingPythonService::CheckStatus() +bool mitk::SegmentAnythingPythonService::CheckStatus() { switch (CurrentStatus) { + case mitk::SegmentAnythingPythonService::Status::READY: + return true; case mitk::SegmentAnythingPythonService::Status::CUDAError: mitkThrow() << "Error: Cuda Out of Memory. Change your model type in Preferences and Activate SAM again."; case mitk::SegmentAnythingPythonService::Status::KILLED: mitkThrow() << "Error: Python process is already terminated. Cannot load requested segmentation."; default: - MITK_INFO << "READY"; + return false; } } void mitk::SegmentAnythingPythonService::CreateTempDirs(const std::string &dirPattern) { this->SetMitkTempDir(IOUtil::CreateTemporaryDirectory(dirPattern)); m_InDir = IOUtil::CreateTemporaryDirectory("sam-in-XXXXXX", m_MitkTempDir); m_OutDir = IOUtil::CreateTemporaryDirectory("sam-out-XXXXXX", m_MitkTempDir); } mitk::Image::Pointer mitk::SegmentAnythingPythonService::RetrieveImageFromProcess() { auto outputImagePath = m_OutDir + IOUtil::GetDirectorySeparator() + m_CurrentUId + ".nii.gz"; while (!std::filesystem::exists(outputImagePath)) { this->CheckStatus(); std::this_thread::sleep_for(100ms); } Image::Pointer outputImage = mitk::IOUtil::Load(outputImagePath); return outputImage; } void mitk::SegmentAnythingPythonService::TransferImageToProcess(const Image *inputAtTimeStep, std::string &UId) { std::string inputImagePath = m_InDir + IOUtil::GetDirectorySeparator() + UId + ".nii.gz"; IOUtil::Save(inputAtTimeStep, inputImagePath); m_CurrentUId = UId; } diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.h b/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.h index e8141cbb87..e4979d55a8 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.h +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingPythonService.h @@ -1,99 +1,98 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef mitkSegmentAnythingPythonService_h #define mitkSegmentAnythingPythonService_h #include #include #include #include #include "mitkImage.h" namespace mitk { /** CHANGE THIS --ashis \brief Segment Anything Model interactive 2D tool. \ingroup ToolManagerEtAl \sa mitk::Tool \sa QmitkInteractiveSegmentation */ class MITKSEGMENTATION_EXPORT SegmentAnythingPythonService : public itk::Object { public: enum Status { READY, - PROCESSING, + OFF, KILLED, CUDAError }; SegmentAnythingPythonService(std::string, std::string, std::string, unsigned int); ~SegmentAnythingPythonService(); itkSetMacro(MitkTempDir, std::string); itkGetConstMacro(MitkTempDir, std::string); /** * @brief Static function to print out everything from itk::EventObject. * Used as callback in mitk::ProcessExecutor object. * */ static void onPythonProcessEvent(itk::Object*, const itk::EventObject&, void*); bool static IsPythonReady; static Status CurrentStatus; - static void CheckStatus(); + static bool CheckStatus(); void StartAsyncProcess(); void StopAsyncProcess(); void TransferImageToProcess(const Image*, std::string &UId); //void TransferPointsToProcess(std::vector>&); void TransferPointsToProcess(std::stringstream&); Image::Pointer RetrieveImageFromProcess(); private: void start_python_daemon(); void WriteControlFile(std::stringstream&); void CreateTempDirs(const std::string&); std::string m_MitkTempDir; std::string m_PythonPath; std::string m_ModelType; std::string m_CheckpointPath; std::string m_InDir, m_OutDir; std::string m_CurrentUId; unsigned int m_GpuId = 0; bool m_IsAuto = false; bool m_IsReady = false; const std::string TEMPLATE_FILENAME = "XXXXXX_000_0000.nii.gz"; const std::string m_PARENT_TEMP_DIR_PATTERN = "mitk-sam-XXXXXX"; const std::string m_TRIGGER_FILENAME = "trigger.csv"; std::future m_Future; ProcessExecutor::Pointer m_DaemonExec; }; struct SIGNALCONSTANTS { static const std::string READY; static const std::string KILL; - static const std::string SUCCESS; - static const std::string PROCESSING; + static const std::string OFF; static const std::string CUDA_OUT_OF_MEMORY_ERROR; }; } // namespace #endif diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp index 350ff1bab0..7d2519ab1a 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp @@ -1,363 +1,363 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "mitkSegmentAnythingTool.h" #include "mitkInteractionPositionEvent.h" #include "mitkPointSetShapeProperty.h" #include "mitkProperties.h" #include "mitkToolManager.h" // us #include "mitkIOUtil.h" #include "mitkSegTool2D.h" #include #include #include #include #include #include #include #include using namespace std::chrono_literals; namespace mitk { MITK_TOOL_MACRO(MITKSEGMENTATION_EXPORT, SegmentAnythingTool, "SegmentAnythingTool"); } mitk::SegmentAnythingTool::SegmentAnythingTool() : SegWithPreviewTool(false, "PressMoveReleaseAndPointSetting") { this->ResetsToEmptyPreviewOn(); this->IsTimePointChangeAwareOff(); } const char **mitk::SegmentAnythingTool::GetXPM() const { return nullptr; } const char *mitk::SegmentAnythingTool::GetName() const { return "SAM"; } us::ModuleResource mitk::SegmentAnythingTool::GetIconResource() const { us::Module *module = us::GetModuleContext()->GetModule(); us::ModuleResource resource = module->GetResource("AI.svg"); return resource; } void mitk::SegmentAnythingTool::Activated() { Superclass::Activated(); m_PointSetPositive = mitk::PointSet::New(); // ensure that the seed points are visible for all timepoints. // dynamic_cast(m_PointSet->GetTimeGeometry())->SetStepDuration(std::numeric_limits::max()); m_PointSetNode = mitk::DataNode::New(); m_PointSetNode->SetData(m_PointSetPositive); m_PointSetNode->SetName(std::string(this->GetName()) + "_PointSetPositive"); m_PointSetNode->SetBoolProperty("helper object", true); m_PointSetNode->SetColor(0.0, 1.0, 0.0); m_PointSetNode->SetVisibility(true); m_PointSetNode->SetProperty("Pointset.2D.shape", mitk::PointSetShapeProperty::New(mitk::PointSetShapeProperty::CIRCLE)); m_PointSetNode->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_PointSetNode, this->GetToolManager()->GetWorkingData(0)); m_PointSetNegative = mitk::PointSet::New(); m_PointSetNodeNegative = mitk::DataNode::New(); m_PointSetNodeNegative->SetData(m_PointSetNegative); m_PointSetNodeNegative->SetName(std::string(this->GetName()) + "_PointSetNegative"); m_PointSetNodeNegative->SetBoolProperty("helper object", true); m_PointSetNodeNegative->SetColor(1.0, 0.0, 0.0); m_PointSetNodeNegative->SetVisibility(true); m_PointSetNodeNegative->SetProperty("Pointset.2D.shape", mitk::PointSetShapeProperty::New(mitk::PointSetShapeProperty::CIRCLE)); m_PointSetNodeNegative->SetProperty("Pointset.2D.fill shape", mitk::BoolProperty::New(true)); this->GetDataStorage()->Add(m_PointSetNodeNegative, this->GetToolManager()->GetWorkingData(0)); this->SetLabelTransferScope(LabelTransferScope::AllLabels); this->SetLabelTransferMode(LabelTransferMode::AddLabel); } void mitk::SegmentAnythingTool::Deactivated() { this->ClearSeeds(); GetDataStorage()->Remove(m_PointSetNode); GetDataStorage()->Remove(m_PointSetNodeNegative); m_PointSetNode = nullptr; m_PointSetNodeNegative = nullptr; m_PointSetPositive = nullptr; m_PointSetNegative = nullptr; m_PythonService.reset(); Superclass::Deactivated(); } void mitk::SegmentAnythingTool::ConnectActionsAndFunctions() { CONNECT_FUNCTION("ShiftSecondaryButtonPressed", OnAddNegativePoint); CONNECT_FUNCTION("ShiftPrimaryButtonPressed", OnAddPoint); CONNECT_FUNCTION("DeletePoint", OnDelete); } void mitk::SegmentAnythingTool::InitSAMPythonProcess() { if (nullptr != m_PythonService) { m_PythonService.reset(); } m_PythonService = std::make_unique( this->GetPythonPath(), this->GetModelType(), this->GetCheckpointPath(), this->GetGpuId()); m_PythonService->StartAsyncProcess(); } bool mitk::SegmentAnythingTool::IsPythonReady() const { - return mitk::SegmentAnythingPythonService::IsPythonReady; + return m_PythonService->CheckStatus(); } void mitk::SegmentAnythingTool::OnAddNegativePoint(StateMachineAction *, InteractionEvent *interactionEvent) { if (!this->GetIsReady() || m_PointSetPositive->GetSize() == 0) { return; } if (!this->IsUpdating() && m_PointSetNegative.IsNotNull()) { const auto positionEvent = dynamic_cast(interactionEvent); if (positionEvent != nullptr) { m_PointSetNegative->InsertPoint(m_PointSetCount, positionEvent->GetPositionInWorld()); m_PointSetCount++; this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::OnAddPoint(StateMachineAction*, InteractionEvent* interactionEvent) { if (!this->GetIsReady()) { return; } m_IsGenerateEmbeddings = false; if ((nullptr == this->GetWorkingPlaneGeometry()) || !mitk::Equal(*(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()), *(this->GetWorkingPlaneGeometry()))) { m_IsGenerateEmbeddings = true; this->ClearSeeds(); this->SetWorkingPlaneGeometry(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()->Clone()); } if (!this->IsUpdating() && m_PointSetPositive.IsNotNull()) { const auto positionEvent = dynamic_cast(interactionEvent); if (positionEvent != nullptr) { m_PointSetPositive->InsertPoint(m_PointSetCount, positionEvent->GetPositionInWorld()); m_PointSetCount++; this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::OnDelete(StateMachineAction*, InteractionEvent*) { if (!this->IsUpdating() && m_PointSetPositive.IsNotNull()) { PointSet::Pointer removeSet = m_PointSetPositive; decltype(m_PointSetPositive->GetMaxId().Index()) maxId = 0; if (m_PointSetPositive->GetSize() > 0) { maxId = m_PointSetPositive->GetMaxId().Index(); } if (m_PointSetNegative->GetSize() > 0 && (maxId < m_PointSetNegative->GetMaxId().Index())) { removeSet = m_PointSetNegative; } removeSet->RemovePointAtEnd(0); --m_PointSetCount; this->UpdatePreview(); } } void mitk::SegmentAnythingTool::ClearPicks() { this->ClearSeeds(); this->UpdatePreview(); } bool mitk::SegmentAnythingTool::HasPicks() const { return this->m_PointSetPositive.IsNotNull() && this->m_PointSetPositive->GetSize() > 0; } void mitk::SegmentAnythingTool::ClearSeeds() { if (this->m_PointSetPositive.IsNotNull()) { m_PointSetCount -= m_PointSetPositive->GetSize(); this->m_PointSetPositive = mitk::PointSet::New(); // renew pointset //ensure that the seed points are visible for all timepoints. dynamic_cast(m_PointSetPositive->GetTimeGeometry())->SetStepDuration(std::numeric_limits::max()); this->m_PointSetNode->SetData(this->m_PointSetPositive); } if (this->m_PointSetNegative.IsNotNull()) { m_PointSetCount -= m_PointSetNegative->GetSize(); this->m_PointSetNegative = mitk::PointSet::New(); // renew pointset // ensure that the seed points are visible for all timepoints. dynamic_cast(m_PointSetNegative->GetTimeGeometry())->SetStepDuration(std::numeric_limits::max()); this->m_PointSetNodeNegative->SetData(this->m_PointSetNegative); } } void mitk::SegmentAnythingTool::DoUpdatePreview(const Image *inputAtTimeStep, const Image *oldSegAtTimeStep, LabelSetImage *previewImage, TimeStepType timeStep) { if (nullptr != oldSegAtTimeStep && nullptr != previewImage && m_PointSetPositive.IsNotNull()) { if (this->HasPicks() && nullptr != m_PythonService) { std::string uniquePlaneID = GetHashForCurrentPlane(); try { this->EmitSAMStatusMessageEvent("Prompting SAM..."); m_PythonService->TransferImageToProcess(inputAtTimeStep, uniquePlaneID); auto csvStream = this->GetPointsAsCSVString(inputAtTimeStep->GetGeometry()); m_ProgressCommand->SetProgress(100); m_PythonService->TransferPointsToProcess(csvStream); m_ProgressCommand->SetProgress(150); std::this_thread::sleep_for(100ms); Image::Pointer outputImage = m_PythonService->RetrieveImageFromProcess(); m_ProgressCommand->SetProgress(180); // auto endloading = std::chrono::system_clock::now(); // MITK_INFO << "Loaded image in MITK. Elapsed: " // << std::chrono::duration_cast(endloading- endPython).count(); // mitk::SegTool2D::WriteSliceToVolume(previewImage, this->GetWorkingPlaneGeometry(), outputImage, timeStep, // true); previewImage->InitializeByLabeledImage(outputImage); previewImage->SetGeometry(this->GetWorkingPlaneGeometry()->Clone()); std::filesystem::remove(outputImagePath); this->EmitSAMStatusMessageEvent("Successfully generated segmentation."); } catch (const mitk::Exception &e) { this->EmitSAMStatusMessageEvent(e.GetDescription()); mitkThrow() << e.GetDescription(); } } else if (nullptr != this->GetWorkingPlaneGeometry()) { previewImage->SetGeometry(this->GetWorkingPlaneGeometry()->Clone()); this->ResetPreviewContentAtTimeStep(timeStep); } } } std::string mitk::SegmentAnythingTool::GetHashForCurrentPlane() { mitk::Vector3D normal = this->GetWorkingPlaneGeometry()->GetNormal(); std::stringstream hashstream; hashstream << normal[0] << normal[1] << normal[2]; mitk::Point3D point = m_PointSetPositive->GetPoint(0); for (int i = 0; i < 3; ++i) { if (normal[i] != 0) { hashstream << point[i]; } } size_t hashVal = std::hash{}(hashstream.str()); return std::to_string(hashVal); } std::stringstream mitk::SegmentAnythingTool::GetPointsAsCSVString(const mitk::BaseGeometry *baseGeometry) { MITK_INFO << "No.of points: " << m_PointSetPositive->GetSize(); std::stringstream pointsAndLabels; pointsAndLabels << "Point,Label\n"; mitk::PointSet::PointsIterator pointSetItPos = m_PointSetPositive->Begin(); mitk::PointSet::PointsIterator pointSetItNeg = m_PointSetNegative->Begin(); const char SPACE = ' '; while (pointSetItPos != m_PointSetPositive->End() || pointSetItNeg != m_PointSetNegative->End()) { if (pointSetItPos != m_PointSetPositive->End()) { mitk::Point3D point = pointSetItPos.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); pointsAndLabels << (int)p2D[0] << SPACE << (int)p2D[1] << ",1" << std::endl; } ++pointSetItPos; } if (pointSetItNeg != m_PointSetNegative->End()) { mitk::Point3D point = pointSetItNeg.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); pointsAndLabels << (int)p2D[0] << SPACE << (int)p2D[1] << ",0" << std::endl; } ++pointSetItNeg; } } return pointsAndLabels; } std::vector> mitk::SegmentAnythingTool::GetPointsAsVector(const mitk::BaseGeometry *baseGeometry) { std::vector> clickVec; clickVec.reserve(m_PointSetPositive->GetSize() + m_PointSetNegative->GetSize()); mitk::PointSet::PointsIterator pointSetItPos = m_PointSetPositive->Begin(); mitk::PointSet::PointsIterator pointSetItNeg = m_PointSetNegative->Begin(); while (pointSetItPos != m_PointSetPositive->End() || pointSetItNeg != m_PointSetNegative->End()) { if (pointSetItPos != m_PointSetPositive->End()) { mitk::Point3D point = pointSetItPos.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); clickVec.push_back(std::pair(p2D, "1")); } ++pointSetItPos; } if (pointSetItNeg != m_PointSetNegative->End()) { mitk::Point3D point = pointSetItNeg.Value(); if (baseGeometry->IsInside(point)) { Point2D p2D = Get2DIndicesfrom3DWorld(baseGeometry, point); clickVec.push_back(std::pair(p2D, "0")); } ++pointSetItNeg; } } return clickVec; } mitk::Point2D mitk::SegmentAnythingTool::Get2DIndicesfrom3DWorld(const mitk::BaseGeometry* baseGeometry, mitk::Point3D& point3d) { baseGeometry->WorldToIndex(point3d, point3d); MITK_INFO << point3d[0] << " " << point3d[1] << " " << point3d[2]; // remove Point2D point2D; point2D.SetElement(0, point3d[0]); point2D.SetElement(1, point3d[1]); return point2D; } void mitk::SegmentAnythingTool::EmitSAMStatusMessageEvent(const std::string status) { SAMStatusMessageEvent.Send(status); } diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h index 14b47d4a78..1e6f65692e 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h @@ -1,146 +1,145 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef mitkSegmentAnythingTool_h #define mitkSegmentAnythingTool_h #include "mitkSegWithPreviewTool.h" #include "mitkPointSet.h" #include "mitkProcessExecutor.h" #include "mitkSegmentAnythingPythonService.h" #include #include #include namespace us { class ModuleResource; } namespace mitk { /** CHANGE THIS --ashis \brief Segment Anything Model interactive 2D tool. \ingroup ToolManagerEtAl \sa mitk::Tool \sa QmitkInteractiveSegmentation */ class MITKSEGMENTATION_EXPORT SegmentAnythingTool : public SegWithPreviewTool { public: mitkClassMacro(SegmentAnythingTool, SegWithPreviewTool); itkFactorylessNewMacro(Self); itkCloneMacro(Self); const char **GetXPM() const override; const char *GetName() const override; us::ModuleResource GetIconResource() const override; void Activated() override; void Deactivated() override; /** * Clears all picks and updates the preview. */ void ClearPicks(); bool HasPicks() const; itkSetMacro(MitkTempDir, std::string); itkGetConstMacro(MitkTempDir, std::string); itkSetMacro(PythonPath, std::string); itkGetConstMacro(PythonPath, std::string); itkSetMacro(ModelType, std::string); itkGetConstMacro(ModelType, std::string); itkSetMacro(CheckpointPath, std::string); itkGetConstMacro(CheckpointPath, std::string); itkSetMacro(GpuId, unsigned int); itkGetConstMacro(GpuId, unsigned int); itkSetMacro(IsAuto, bool); itkGetConstMacro(IsAuto, bool); itkBooleanMacro(IsAuto); itkSetMacro(IsReady, bool); itkGetConstMacro(IsReady, bool); itkBooleanMacro(IsReady); void InitSAMPythonProcess(); bool IsPythonReady() const; Message1 SAMStatusMessageEvent; protected: SegmentAnythingTool(); ~SegmentAnythingTool() = default; void ConnectActionsAndFunctions() override; /* * @brief Add point action of StateMachine pattern */ virtual void OnAddPoint(StateMachineAction*, InteractionEvent* interactionEvent); virtual void OnAddNegativePoint(StateMachineAction*, InteractionEvent* interactionEvent); /* * @brief Delete action of StateMachine pattern */ virtual void OnDelete(StateMachineAction*, InteractionEvent*); /* * @brief Clear all seed points and call UpdatePreview to reset the segmentation Preview */ void ClearSeeds(); void DoUpdatePreview(const Image* inputAtTimeStep, const Image* oldSegAtTimeStep, LabelSetImage* previewImage, TimeStepType timeStep) override; std::vector> GetPointsAsVector(const mitk::BaseGeometry*); std::stringstream GetPointsAsCSVString(const mitk::BaseGeometry *baseGeometry); std::string GetHashForCurrentPlane(); void EmitSAMStatusMessageEvent(const std::string); - private: static mitk::Point2D Get2DIndicesfrom3DWorld(const mitk::BaseGeometry*, mitk::Point3D&); //void WriteCSVFile(std::stringstream&); // move to python service //void CreateTempDirs(const std::string& dirPattern); std::string m_MitkTempDir; std::string m_PythonPath; std::string m_ModelType; std::string m_CheckpointPath; std::string m_InDir, m_OutDir, outputImagePath, m_TriggerCSVPath; // rename as per standards unsigned int m_GpuId = 0; PointSet::Pointer m_PointSetPositive; PointSet::Pointer m_PointSetNegative; DataNode::Pointer m_PointSetNode; DataNode::Pointer m_PointSetNodeNegative; bool m_IsGenerateEmbeddings = true; bool m_IsAuto = false; bool m_IsReady = false; const std::string TEMPLATE_FILENAME = "XXXXXX_000_0000.nii.gz"; const std::string m_PARENT_TEMP_DIR_PATTERN = "mitk-sam-XXXXXX"; const std::string m_TRIGGER_FILENAME = "trigger.csv"; int m_PointSetCount = 0; std::unique_ptr m_PythonService; }; } // namespace #endif diff --git a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp index eb63c0e47b..c59bb4f404 100644 --- a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp +++ b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.cpp @@ -1,319 +1,343 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "QmitkSegmentAnythingToolGUI.h" #include "mitkSegmentAnythingTool.h" #include "mitkProcessExecutor.h" #include #include #include #include #include #include #include #include #include MITK_TOOL_GUI_MACRO(MITKSEGMENTATIONUI_EXPORT, QmitkSegmentAnythingToolGUI, "") namespace { mitk::IPreferences *GetPreferences() { auto *preferencesService = mitk::CoreServices::GetPreferencesService(); return preferencesService->GetSystemPreferences()->Node("org.mitk.views.segmentation"); } } QmitkSegmentAnythingToolGUI::QmitkSegmentAnythingToolGUI() : QmitkSegWithPreviewToolGUIBase(true) { // Nvidia-smi command returning zero doesn't always imply lack of GPUs. // Pytorch uses its own libraries to communicate to the GPUs. Hence, only a warning can be given. if (m_GpuLoader.GetGPUCount() == 0) { std::string warning = "WARNING: No GPUs were detected on your machine. The SegmentAnything tool can be very slow."; this->ShowErrorMessage(warning); } m_EnableConfirmSegBtnFnc = [this](bool enabled) { bool result = false; auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { result = enabled && tool->HasPicks(); } return result; }; m_Prefences = GetPreferences(); } QmitkSegmentAnythingToolGUI::~QmitkSegmentAnythingToolGUI() { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SAMStatusMessageEvent -= mitk::MessageDelegate1( this, &QmitkSegmentAnythingToolGUI::StatusMessageListener); } } void QmitkSegmentAnythingToolGUI::InitializeUI(QBoxLayout *mainLayout) { m_Controls.setupUi(this); m_Controls.statusLabel->setTextFormat(Qt::RichText); QString welcomeText; if (m_GpuLoader.GetGPUCount() != 0) { welcomeText = "STATUS: Welcome to Segment Anything tool. You're in luck: " + QString::number(m_GpuLoader.GetGPUCount()) + " GPU(s) were detected."; } else { welcomeText = "STATUS: Welcome to Segment Anything tool. Sorry, " + QString::number(m_GpuLoader.GetGPUCount()) + " GPUs were detected."; } connect(m_Controls.activateButton, SIGNAL(clicked()), this, SLOT(OnActivateBtnClicked())); connect(m_Controls.resetButton, SIGNAL(clicked()), this, SLOT(OnResetPicksClicked())); QIcon arrowIcon = QmitkStyleManager::ThemeIcon( QStringLiteral(":/org_mitk_icons/icons/tango/scalable/actions/media-playback-start.svg")); m_Controls.activateButton->setIcon(arrowIcon); bool isInstalled = this->ValidatePrefences(); if (isInstalled) { m_PythonPath = QString::fromStdString(m_Prefences->Get("sam python path", "")); QString modelType = QString::fromStdString(m_Prefences->Get("sam modeltype", "")); welcomeText += " SAM is already found installed. Model type '" + modelType + "' selected in Preferences."; } else { welcomeText += " SAM tool is not configured correctly. Please go to Preferences (Cntl+P) > Segment Anything to configure and/or install SAM."; } this->EnableAll(isInstalled); this->WriteStatusMessage(welcomeText); this->ShowProgressBar(false); m_Controls.samProgressBar->setMaximum(0); mainLayout->addLayout(m_Controls.verticalLayout); Superclass::InitializeUI(mainLayout); } bool QmitkSegmentAnythingToolGUI::ValidatePrefences() { const QString storageDir = QString::fromStdString(m_Prefences->Get("sam python path", "")); bool isInstalled = QmitkSegmentAnythingToolGUI::IsSAMInstalled(storageDir); std::string modelType = m_Prefences->Get("sam modeltype", ""); std::string path = m_Prefences->Get("sam parent path", ""); return (isInstalled && !modelType.empty() && !path.empty()); } void QmitkSegmentAnythingToolGUI::EnableAll(bool isEnable) { m_Controls.activateButton->setEnabled(isEnable); } void QmitkSegmentAnythingToolGUI::WriteStatusMessage(const QString &message) { m_Controls.statusLabel->setText(message); m_Controls.statusLabel->setStyleSheet("font-weight: bold; color: white"); qApp->processEvents(); } void QmitkSegmentAnythingToolGUI::WriteErrorMessage(const QString &message) { m_Controls.statusLabel->setText(message); m_Controls.statusLabel->setStyleSheet("font-weight: bold; color: red"); qApp->processEvents(); } void QmitkSegmentAnythingToolGUI::ShowErrorMessage(const std::string &message, QMessageBox::Icon icon) { this->setCursor(Qt::ArrowCursor); QMessageBox *messageBox = new QMessageBox(icon, nullptr, message.c_str()); messageBox->exec(); delete messageBox; MITK_WARN << message; } void QmitkSegmentAnythingToolGUI::StatusMessageListener(const std::string &message) { this->WriteStatusMessage(QString::fromStdString(message)); if (message.rfind("Error", 0) == 0) { this->EnableAll(true); } } void QmitkSegmentAnythingToolGUI::OnActivateBtnClicked() { auto tool = this->GetConnectedToolAs(); if (nullptr == tool) { return; } try { this->EnableAll(false); qApp->processEvents(); if (!QmitkSegmentAnythingToolGUI::IsSAMInstalled(m_PythonPath)) { throw std::runtime_error(WARNING_SAM_NOT_FOUND); } tool->SetIsAuto(false); tool->SetPythonPath(m_PythonPath.toStdString()); tool->SetGpuId(m_Prefences->GetInt("sam gpuid", 0)); const QString modelType = QString::fromStdString(m_Prefences->Get("sam modeltype", "")); tool->SetModelType(modelType.toStdString()); this->WriteStatusMessage( QString("STATUS: Checking if model is already downloaded... This might take a while.")); tool->SAMStatusMessageEvent += mitk::MessageDelegate1( this, &QmitkSegmentAnythingToolGUI::StatusMessageListener); if (this->DownloadModel(modelType)) { - this->ActivateSAMDaemon(); - this->WriteStatusMessage(QString("STATUS: Model found. SAM Activated.")); + this->WriteStatusMessage(QString("STATUS: Model found. Activating Segment Anything tool...")); + if (this->ActivateSAMDaemon()) + { + this->WriteStatusMessage(QString("STATUS: Model found. SAM Activated.")); + } + else + { + this->WriteErrorMessage(QString("STATUS: Model found. Couldn't init tool backend.")); + this->EnableAll(true); + } } else { tool->IsReadyOff(); this->WriteStatusMessage(QString("STATUS: Model not found. Starting download...")); } } catch (const std::exception &e) { std::stringstream errorMsg; errorMsg << "STATUS: Error while processing parameters for SAM segmentation. Reason: " << e.what(); this->ShowErrorMessage(errorMsg.str()); this->WriteErrorMessage(QString::fromStdString(errorMsg.str())); m_Controls.activateButton->setEnabled(true); return; } catch (...) { std::string errorMsg = "Unkown error occured while generation SAM segmentation."; this->ShowErrorMessage(errorMsg); m_Controls.activateButton->setEnabled(true); return; } } -void QmitkSegmentAnythingToolGUI::ActivateSAMDaemon() +bool QmitkSegmentAnythingToolGUI::ActivateSAMDaemon() { auto tool = this->GetConnectedToolAs(); if (nullptr == tool) { - return; + return false; } this->ShowProgressBar(true); - tool->InitSAMPythonProcess(); - while (!tool->IsPythonReady()) + qApp->processEvents(); + try { - qApp->processEvents(); + tool->InitSAMPythonProcess(); + while (!tool->IsPythonReady()) + { + qApp->processEvents(); + } + tool->IsReadyOn(); + } + catch (...) + { + tool->IsReadyOff(); } - tool->IsReadyOn(); this->ShowProgressBar(false); + return tool->GetIsReady(); } bool QmitkSegmentAnythingToolGUI::DownloadModel(const QString &modelType) { QUrl url = QmitkSegmentAnythingToolGUI::VALID_MODELS_URL_MAP[modelType]; QString modelFileName = url.fileName(); const QString storageDir = QString::fromStdString(m_Prefences->Get("sam parent path", "")); QString checkPointPath = storageDir + QDir::separator() + modelFileName; if (QFile::exists(checkPointPath)) { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SetCheckpointPath(checkPointPath.toStdString()); } return true; } connect(&m_Manager, SIGNAL(finished(QNetworkReply*)), this, SLOT(FileDownloaded(QNetworkReply*))); QNetworkRequest request(url); m_Manager.get(request); this->ShowProgressBar(true); return false; } void QmitkSegmentAnythingToolGUI::FileDownloaded(QNetworkReply *reply) { const QString storageDir = QString::fromStdString(m_Prefences->Get("sam parent path", "")); const QString &modelFileName = reply->url().fileName(); QFile file(storageDir + QDir::separator() + modelFileName); if (file.open(QIODevice::WriteOnly)) { file.write(reply->readAll()); file.close(); disconnect(&m_Manager, SIGNAL(finished(QNetworkReply *)), this, SLOT(FileDownloaded(QNetworkReply *))); this->WriteStatusMessage(QString("STATUS: Model successfully downloaded. Activating SAM....")); auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->SetCheckpointPath(file.fileName().toStdString()); - this->ActivateSAMDaemon(); - this->WriteStatusMessage(QString("STATUS: Model successfully downloaded. SAM Activated.")); + if (this->ActivateSAMDaemon()) + { + this->WriteStatusMessage(QString("STATUS: Model successfully downloaded. Segment Anything activated.")); + } + else + { + this->WriteErrorMessage(QString("STATUS: Model successfully downloaded. But couldn't init tool backend.")); + this->EnableAll(true); + } } } else { this->WriteErrorMessage("STATUS: Model couldn't be downloaded. SAM not activated."); } this->EnableAll(true); this->ShowProgressBar(false); } void QmitkSegmentAnythingToolGUI::ShowProgressBar(bool enabled) { m_Controls.samProgressBar->setEnabled(enabled); m_Controls.samProgressBar->setVisible(enabled); } bool QmitkSegmentAnythingToolGUI::IsSAMInstalled(const QString &pythonPath) { QString fullPath = pythonPath; bool isPythonExists = false; bool isSamExists = false; #ifdef _WIN32 isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python.exe")); if (!(fullPath.endsWith("Scripts", Qt::CaseInsensitive) || fullPath.endsWith("Scripts/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("Scripts"); isPythonExists = (!isPythonExists) ? QFile::exists(fullPath + QDir::separator() + QString("python.exe")) : isPythonExists; } #else isPythonExists = QFile::exists(fullPath + QDir::separator() + QString("python3")); if (!(fullPath.endsWith("bin", Qt::CaseInsensitive) || fullPath.endsWith("bin/", Qt::CaseInsensitive))) { fullPath += QDir::separator() + QString("bin"); isPythonExists = (!isPythonExists) ? QFile::exists(fullPath + QDir::separator() + QString("python3")) : isPythonExists; } #endif isSamExists = QFile::exists(fullPath + QDir::separator() + QString("run_inference_daemon.py")); bool isExists = isSamExists && isPythonExists; return isExists; } void QmitkSegmentAnythingToolGUI::OnResetPicksClicked() { auto tool = this->GetConnectedToolAs(); if (nullptr != tool) { tool->ClearPicks(); } } diff --git a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h index 513821eeb8..e36f95aa6c 100644 --- a/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h +++ b/Modules/SegmentationUI/Qmitk/QmitkSegmentAnythingToolGUI.h @@ -1,122 +1,122 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef QmitkSegmentAnythingToolGUI_h #define QmitkSegmentAnythingToolGUI_h #include "QmitkSegWithPreviewToolGUIBase.h" #include #include "ui_QmitkSegmentAnythingGUIControls.h" #include "QmitknnUNetGPU.h" #include "QmitkSetupVirtualEnvUtil.h" #include #include #include #include #include /** \ingroup org_mitk_gui_qt_interactivesegmentation_internal \brief GUI for mitk::SegmentAnythingTool. \sa mitk::PickingTool */ class MITKSEGMENTATIONUI_EXPORT QmitkSegmentAnythingToolGUI : public QmitkSegWithPreviewToolGUIBase { Q_OBJECT public: mitkClassMacro(QmitkSegmentAnythingToolGUI, QmitkSegWithPreviewToolGUIBase); itkFactorylessNewMacro(Self); itkCloneMacro(Self); inline static const QMap VALID_MODELS_URL_MAP = { {"vit_b", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth")}, //{"vit_b", QUrl("http://www.google.ru/images/srpr/logo3w.png")}, {"vit_l", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_l_0b3195.pth")}, {"vit_h", QUrl("https://dl.fbaipublicfiles.com/segment_anything/sam_vit_h_4b8939.pth")}}; /** * @brief Checks if SegmentAnything is found inside the selected python virtual environment. * @return bool */ static bool IsSAMInstalled(const QString &); protected slots: /** * @brief Qt Slot */ void OnResetPicksClicked(); /** * @brief Qt Slot */ void OnActivateBtnClicked(); /** * @brief Qt Slot */ void FileDownloaded(QNetworkReply *); protected: QmitkSegmentAnythingToolGUI(); ~QmitkSegmentAnythingToolGUI(); void InitializeUI(QBoxLayout *mainLayout) override; /** * @brief Writes any message in white on the tool pane. */ void WriteStatusMessage(const QString &); /** * @brief Writes any message in red on the tool pane. */ void WriteErrorMessage(const QString &); void StatusMessageListener(const std::string&); /** * @brief Creates a QMessage object and shows on screen. */ void ShowErrorMessage(const std::string &, QMessageBox::Icon = QMessageBox::Critical); /** * @brief Enable (or Disable) GUI elements. */ void EnableAll(bool); /** * */ bool DownloadModel(const QString &); /** * */ void ShowProgressBar(bool); - void ActivateSAMDaemon(); + bool ActivateSAMDaemon(); bool ValidatePrefences(); private: mitk::IPreferences *m_Prefences; QNetworkAccessManager m_Manager; Ui_QmitkSegmentAnythingGUIControls m_Controls; QString m_PythonPath; QmitkGPULoader m_GpuLoader; bool m_FirstPreviewComputation = true; const std::string WARNING_SAM_NOT_FOUND = "SAM is not detected in the selected python environment. Please reinstall SAM."; }; #endif