diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp index b97dec1146..21b95bcd85 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.cpp @@ -1,251 +1,325 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #include "mitkSegmentAnythingTool.h" #include "mitkProperties.h" #include "mitkToolManager.h" #include "mitkInteractionPositionEvent.h" // us #include #include #include #include #include "mitkIOUtil.h" #include #include + namespace mitk { MITK_TOOL_MACRO(MITKSEGMENTATION_EXPORT, SegmentAnythingTool, "SegmentAnythingTool"); } mitk::SegmentAnythingTool::SegmentAnythingTool() : SegWithPreviewTool(false, "PressMoveReleaseAndPointSetting") { this->ResetsToEmptyPreviewOn(); this->IsTimePointChangeAwareOff(); } mitk::SegmentAnythingTool::~SegmentAnythingTool() { std::filesystem::remove_all(this->GetMitkTempDir()); } const char **mitk::SegmentAnythingTool::GetXPM() const { return nullptr; } const char *mitk::SegmentAnythingTool::GetName() const { return "SAM"; } us::ModuleResource mitk::SegmentAnythingTool::GetIconResource() const { us::Module *module = us::GetModuleContext()->GetModule(); us::ModuleResource resource = module->GetResource("AI.svg"); return resource; } void mitk::SegmentAnythingTool::Activated() { Superclass::Activated(); m_PointSet = mitk::PointSet::New(); //ensure that the seed points are visible for all timepoints. dynamic_cast(m_PointSet->GetTimeGeometry())->SetStepDuration(std::numeric_limits::max()); m_PointSetNode = mitk::DataNode::New(); m_PointSetNode->SetData(m_PointSet); m_PointSetNode->SetName(std::string(this->GetName()) + "_PointSet"); m_PointSetNode->SetBoolProperty("helper object", true); m_PointSetNode->SetColor(0.0, 1.0, 0.0); m_PointSetNode->SetVisibility(true); this->GetDataStorage()->Add(m_PointSetNode, this->GetToolManager()->GetWorkingData(0)); this->SetLabelTransferMode(LabelTransferMode::AllLabels); } void mitk::SegmentAnythingTool::Deactivated() { this->ClearSeeds(); // remove from data storage and disable interaction GetDataStorage()->Remove(m_PointSetNode); m_PointSetNode = nullptr; m_PointSet = nullptr; Superclass::Deactivated(); } void mitk::SegmentAnythingTool::ConnectActionsAndFunctions() { CONNECT_FUNCTION("ShiftSecondaryButtonPressed", OnAddPoint); CONNECT_FUNCTION("ShiftPrimaryButtonPressed", OnAddPoint); CONNECT_FUNCTION("DeletePoint", OnDelete); } void mitk::SegmentAnythingTool::OnAddPoint(StateMachineAction*, InteractionEvent* interactionEvent) { - this->SetWorkingPlaneGeometry(const_cast( - interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry())); // try to set without const cast - MITK_INFO << "Set W PlaneG"; + m_IsGenerateEmbeddings = false; + if ((nullptr == this->GetWorkingPlaneGeometry()) || !mitk::Equal(*(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()), *(this->GetWorkingPlaneGeometry()))) + { + m_IsGenerateEmbeddings = true; + this->SetWorkingPlaneGeometry(interactionEvent->GetSender()->GetCurrentWorldPlaneGeometry()->Clone()); + } if (!this->IsUpdating() && m_PointSet.IsNotNull()) { const auto positionEvent = dynamic_cast(interactionEvent); if (positionEvent != nullptr) { m_PointSet->InsertPoint(m_PointSet->GetSize(), positionEvent->GetPositionInWorld()); this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::OnDelete(StateMachineAction*, InteractionEvent* /*interactionEvent*/) { if (!this->IsUpdating() && m_PointSet.IsNotNull()) { - // delete last seed point if (this->m_PointSet->GetSize() > 0) { m_PointSet->RemovePointAtEnd(0); - this->UpdatePreview(); } } } void mitk::SegmentAnythingTool::ClearPicks() { this->ClearSeeds(); this->UpdatePreview(); } bool mitk::SegmentAnythingTool::HasPicks() const { return this->m_PointSet.IsNotNull() && this->m_PointSet->GetSize()>0; } void mitk::SegmentAnythingTool::ClearSeeds() { if (this->m_PointSet.IsNotNull()) { - // renew pointset - this->m_PointSet = mitk::PointSet::New(); + this->m_PointSet = mitk::PointSet::New(); // renew pointset //ensure that the seed points are visible for all timepoints. dynamic_cast(m_PointSet->GetTimeGeometry())->SetStepDuration(std::numeric_limits::max()); this->m_PointSetNode->SetData(this->m_PointSet); } } void mitk::SegmentAnythingTool::onPythonProcessEvent(itk::Object * /*pCaller*/, const itk::EventObject &e, void *) { std::string testCOUT; std::string testCERR; const auto *pEvent = dynamic_cast(&e); if (pEvent) { testCOUT = testCOUT + pEvent->GetOutput(); MITK_INFO << testCOUT; } const auto *pErrEvent = dynamic_cast(&e); if (pErrEvent) { testCERR = testCERR + pErrEvent->GetOutput(); MITK_ERROR << testCERR; } } void mitk::SegmentAnythingTool::DoUpdatePreview(const Image* inputAtTimeStep, const Image* oldSegAtTimeStep, LabelSetImage* previewImage, TimeStepType timeStep) { if (nullptr != oldSegAtTimeStep && nullptr != previewImage && m_PointSet.IsNotNull()) { if (this->m_MitkTempDir.empty()) { this->SetMitkTempDir(IOUtil::CreateTemporaryDirectory("mitk-XXXXXX")); } if (this->HasPicks()) { ProcessExecutor::Pointer spExec = ProcessExecutor::New(); itk::CStyleCommand::Pointer spCommand = itk::CStyleCommand::New(); spCommand->SetCallback(&onPythonProcessEvent); spExec->AddObserver(ExternalProcessOutputEvent(), spCommand); - std::string inDir, outDir, inputImagePath, outputImagePath, scriptPath; + std::string inDir, outDir, inputImagePath, pickleFilePath, outputImagePath, scriptPath; inDir = IOUtil::CreateTemporaryDirectory("sam-in-XXXXXX", this->GetMitkTempDir()); std::ofstream tmpStream; - inputImagePath = - IOUtil::CreateTemporaryFile(tmpStream, TEMPLATE_FILENAME, inDir + IOUtil::GetDirectorySeparator()); + inputImagePath = IOUtil::CreateTemporaryFile(tmpStream, TEMPLATE_FILENAME, inDir + IOUtil::GetDirectorySeparator()); tmpStream.close(); std::size_t found = inputImagePath.find_last_of(IOUtil::GetDirectorySeparator()); std::string fileName = inputImagePath.substr(found + 1); std::string token = fileName.substr(0, fileName.find("_")); outDir = IOUtil::CreateTemporaryDirectory("sam-out-XXXXXX", this->GetMitkTempDir()); + pickleFilePath = outDir + IOUtil::GetDirectorySeparator() + "dump.pkl"; + outputImagePath = outDir + IOUtil::GetDirectorySeparator() + token + "_000.nii.gz"; + LabelSetImage::Pointer outputBuffer; //IOUtil::Save(inputAtTimeStep, inputImagePath); - - outputImagePath = outDir + IOUtil::GetDirectorySeparator() + token + "_000.nii.gz"; - //auto pg = this->GetWorkingPlaneGeometry(); - //this->run_generate_embeddings(spExec, inputImagePath, outDir, this->GetGpuId()); + MITK_INFO << "No.of points: " << m_PointSet->GetSize(); + auto point = m_PointSet->GetPoint(0); + MITK_INFO << point[0] << " " << point[1] << " " << point[2]; + Point2D p2D; + p2D.SetElement(0, point[0]); + p2D.SetElement(1, point[1]); + + + this->GetWorkingPlaneGeometry()->WorldToIndex(p2D, p2D); + this->SetPythonPath("C:\\DKFZ\\SAM_work\\sam_env\\Scripts"); + this->SetModelType("vit_b"); + this->SetCheckpointPath("C:\\DKFZ\\SAM_work\\sam_vit_b_01ec64.pth"); + + if (false)//m_IsGenerateEmbeddings) + { + this->run_generate_embeddings( + spExec, inputImagePath, outDir, this->GetModelType(), this->GetCheckpointPath(), this->GetGpuId()); + } + + //run_segmentation_from_points( + // spExec, pickleFilePath, outputImagePath, this->GetModelType(), this->GetCheckpointPath(), this->GetGpuId()); outputImagePath = "C:\\DKFZ\\SAM_work\\test_seg_3d.nii.gz"; Image::Pointer outputImage = IOUtil::Load(outputImagePath); previewImage->InitializeByLabeledImage(outputImage); - //const_cast(this->GetWorkingPlaneGeometry())->ChangeImageGeometryConsideringOriginOffset(true); - previewImage->SetGeometry(const_cast (this->GetWorkingPlaneGeometry())); + previewImage->SetGeometry(this->GetWorkingPlaneGeometry()->Clone()); } } } -void mitk::SegmentAnythingTool::run_generate_embeddings(ProcessExecutor::Pointer spExec, +void mitk::SegmentAnythingTool::run_generate_embeddings(ProcessExecutor* spExec, const std::string& inputImagePath, - const std::string& outputImagePath, - unsigned int gpuId) + const std::string& outputPicklePath, + const std::string& modelType, + const std::string& checkpointPath, + const unsigned int gpuId) { + ProcessExecutor::ArgumentListType args; - std::string command = "TotalSegmentator"; -#if defined(__APPLE__) || defined(_WIN32) - command = "python"; -#endif - args.push_back("-i"); + std::string command = "python"; + + args.push_back("C:\\DKFZ\\SAM_work\\sam-playground\\endpoints\\generate_embedding.py"); + + args.push_back("--input"); args.push_back(inputImagePath); - args.push_back("-o"); + args.push_back("--output"); + args.push_back(outputPicklePath); + + args.push_back("--model-type"); + args.push_back(modelType); + + args.push_back("--checkpoint"); + args.push_back(checkpointPath); + + try + { + std::string cudaEnv = "CUDA_VISIBLE_DEVICES=" + std::to_string(gpuId); + itksys::SystemTools::PutEnv(cudaEnv.c_str()); + + std::stringstream logStream; + for (const auto &arg : args) + logStream << arg << " "; + logStream << this->GetPythonPath(); + MITK_INFO << logStream.str(); + + spExec->Execute(this->GetPythonPath(), command, args); + } + catch (const mitk::Exception &e) + { + MITK_ERROR << e.GetDescription(); + return; + } +} + +void mitk::SegmentAnythingTool::run_segmentation_from_points(ProcessExecutor *spExec, + const std::string &pickleFilePath, + const std::string &outputImagePath, + const std::string &modelType, + const std::string &checkpointPath, + const unsigned int gpuId) +{ + ProcessExecutor::ArgumentListType args; + std::string command = "python"; + + args.push_back("C:\\DKFZ\\SAM_work\\sam-playground\\endpoints\\generate_masks.py"); + + args.push_back("--embedding"); + args.push_back(pickleFilePath); + + args.push_back("--output"); args.push_back(outputImagePath); + args.push_back("--model-type"); + args.push_back(modelType); + + args.push_back("--checkpoint"); + args.push_back(checkpointPath); + + // TODO: add more arguments here-- ashis + try { std::string cudaEnv = "CUDA_VISIBLE_DEVICES=" + std::to_string(gpuId); itksys::SystemTools::PutEnv(cudaEnv.c_str()); std::stringstream logStream; for (const auto &arg : args) logStream << arg << " "; logStream << this->GetPythonPath(); MITK_INFO << logStream.str(); spExec->Execute(this->GetPythonPath(), command, args); } catch (const mitk::Exception &e) { MITK_ERROR << e.GetDescription(); return; } -} \ No newline at end of file +} diff --git a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h index 1a5c9bdf24..1b934f0785 100644 --- a/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h +++ b/Modules/Segmentation/Interactions/mitkSegmentAnythingTool.h @@ -1,115 +1,136 @@ /*============================================================================ The Medical Imaging Interaction Toolkit (MITK) Copyright (c) German Cancer Research Center (DKFZ) All rights reserved. Use of this source code is governed by a 3-clause BSD license that can be found in the LICENSE file. ============================================================================*/ #ifndef mitkSegmentAnythingTool_h #define mitkSegmentAnythingTool_h #include "mitkSegWithPreviewTool.h" #include "mitkPointSet.h" #include "mitkProcessExecutor.h" #include namespace us { class ModuleResource; } namespace mitk { /** CHANGE THIS --ashis \brief Extracts a single region from a segmentation image and creates a new image with same geometry of the input image. The region is extracted in 3D space. This is done by performing region growing within the desired region. Use shift click to add the seed point. \ingroup ToolManagerEtAl \sa mitk::Tool \sa QmitkInteractiveSegmentation */ class MITKSEGMENTATION_EXPORT SegmentAnythingTool : public SegWithPreviewTool { public: mitkClassMacro(SegmentAnythingTool, SegWithPreviewTool); itkFactorylessNewMacro(Self); itkCloneMacro(Self); const char **GetXPM() const override; const char *GetName() const override; us::ModuleResource GetIconResource() const override; void Activated() override; void Deactivated() override; /** * Clears all picks and updates the preview. */ void ClearPicks(); bool HasPicks() const; itkSetMacro(MitkTempDir, std::string); itkGetConstMacro(MitkTempDir, std::string); itkSetMacro(PythonPath, std::string); itkGetConstMacro(PythonPath, std::string); + itkSetMacro(ModelType, std::string); + itkGetConstMacro(ModelType, std::string); + + itkSetMacro(CheckpointPath, std::string); + itkGetConstMacro(CheckpointPath, std::string); + itkSetMacro(GpuId, unsigned int); itkGetConstMacro(GpuId, unsigned int); + itkSetMacro(IsAuto, bool); + itkGetConstMacro(IsAuto, bool); + itkBooleanMacro(IsAuto); + + itkSetMacro(IsReady, bool); + itkGetConstMacro(IsReady, bool); + itkBooleanMacro(IsReady); + /** * @brief Static function to print out everything from itk::EventObject. * Used as callback in mitk::ProcessExecutor object. * */ static void onPythonProcessEvent(itk::Object *, const itk::EventObject &e, void *); protected: - SegmentAnythingTool(); // purposely hidden + SegmentAnythingTool(); ~SegmentAnythingTool() override; void ConnectActionsAndFunctions() override; /// \brief Add point action of StateMachine pattern virtual void OnAddPoint(StateMachineAction*, InteractionEvent* interactionEvent); /// \brief Delete action of StateMachine pattern virtual void OnDelete(StateMachineAction*, InteractionEvent* interactionEvent); /// \brief Clear all seed points. void ClearSeeds(); void DoUpdatePreview(const Image* inputAtTimeStep, const Image* oldSegAtTimeStep, LabelSetImage* previewImage, TimeStepType timeStep) override; private: /** - * @brief Runs Totalsegmentator python process with desired arguments + * @brief Runs SAM python process with desired arguments to generate embeddings for the input image * */ - void run_generate_embeddings(ProcessExecutor::Pointer, const std::string &, const std::string &, unsigned int); + void run_generate_embeddings(ProcessExecutor*, const std::string&, const std::string&, const std::string&, const std::string&, const unsigned int); + + void run_segmentation_from_points(ProcessExecutor *, const std::string &, const std::string &, const std::string &, const std::string &, const unsigned int); + std::string m_MitkTempDir; std::string m_PythonPath; + std::string m_ModelType; + std::string m_CheckpointPath; + unsigned int m_GpuId = 0; PointSet::Pointer m_PointSet; DataNode::Pointer m_PointSetNode; - PlaneGeometry::Pointer m_WorkingPlaneGeometry; - //std::map m_PicklePathDict; + bool m_IsGenerateEmbeddings = true; + bool m_IsAuto = false; + bool m_IsReady = false; const std::string TEMPLATE_FILENAME = "XXXXXX_000_0000.nii.gz"; }; } // namespace #endif