diff --git a/.gitignore b/.gitignore
index 6a7ec10..ba33c01 100644
--- a/.gitignore
+++ b/.gitignore
@@ -1,9 +1,12 @@
+*.pyc
*.pickle
*.ipynb_checkpoints*
-*.pyc
+*.npy
*.pkl
*.log
*.png
*.jpg
-__pycache__/
+__pycache__/*
.idea/**
+
+!/assets/*
diff --git a/.idea/mdt-public.iml b/.idea/mdt-public.iml
deleted file mode 100644
index d0876a7..0000000
--- a/.idea/mdt-public.iml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/misc.xml b/.idea/misc.xml
deleted file mode 100644
index 65531ca..0000000
--- a/.idea/misc.xml
+++ /dev/null
@@ -1,4 +0,0 @@
-
-
-
-
\ No newline at end of file
diff --git a/.idea/modules.xml b/.idea/modules.xml
deleted file mode 100644
index a24a3b7..0000000
--- a/.idea/modules.xml
+++ /dev/null
@@ -1,8 +0,0 @@
-
-
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/vcs.xml b/.idea/vcs.xml
deleted file mode 100644
index 94a25f7..0000000
--- a/.idea/vcs.xml
+++ /dev/null
@@ -1,6 +0,0 @@
-
-
-
-
-
-
\ No newline at end of file
diff --git a/.idea/workspace.xml b/.idea/workspace.xml
deleted file mode 100644
index 2c033af..0000000
--- a/.idea/workspace.xml
+++ /dev/null
@@ -1,48 +0,0 @@
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- 1564881037445
-
-
- 1564881037445
-
-
-
-
\ No newline at end of file
diff --git a/cuda_functions/nms_2D/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_2D/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 08425eb..0000000
Binary files a/cuda_functions/nms_2D/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_2D/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 2eb81da..0000000
Binary files a/cuda_functions/nms_2D/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-35.pyc b/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-35.pyc
deleted file mode 100644
index 1bf0a6c..0000000
Binary files a/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-36.pyc b/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-36.pyc
deleted file mode 100644
index 839361c..0000000
Binary files a/cuda_functions/nms_2D/__pycache__/pth_nms.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index ab74db1..0000000
Binary files a/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 3e87955..0000000
Binary files a/cuda_functions/nms_2D/_ext/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index e535879..0000000
Binary files a/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 7e1a9b1..0000000
Binary files a/cuda_functions/nms_2D/_ext/nms/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_3D/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 1cf1238..0000000
Binary files a/cuda_functions/nms_3D/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_3D/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index e09a2cb..0000000
Binary files a/cuda_functions/nms_3D/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-35.pyc b/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-35.pyc
deleted file mode 100644
index 29a502f..0000000
Binary files a/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-36.pyc b/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-36.pyc
deleted file mode 100644
index 2fa4c5d..0000000
Binary files a/cuda_functions/nms_3D/__pycache__/pth_nms.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 6ee8ff3..0000000
Binary files a/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index f733093..0000000
Binary files a/cuda_functions/nms_3D/_ext/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-35.pyc b/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 10160ab..0000000
Binary files a/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-36.pyc b/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 74019e7..0000000
Binary files a/cuda_functions/nms_3D/_ext/nms/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 6a821bb..0000000
Binary files a/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 385ecda..0000000
Binary files a/cuda_functions/roi_align_2D/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 438fada..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 5611b92..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc b/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc
deleted file mode 100644
index e23974d..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc b/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc
deleted file mode 100644
index ca931d9..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 080f7b4..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 1a5aa20..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 27f3502..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 972175c..0000000
Binary files a/cuda_functions/roi_align_2D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 853e83e..0000000
Binary files a/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 2cdfb29..0000000
Binary files a/cuda_functions/roi_align_3D/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index fa3d8d7..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index cb9081a..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc b/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc
deleted file mode 100644
index 88ce998..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc b/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc
deleted file mode 100644
index 30d30f5..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/__pycache__/crop_and_resize.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index d50935c..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index e2b65f5..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/_ext/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc
deleted file mode 100644
index 93afa7e..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-35.pyc and /dev/null differ
diff --git a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc b/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc
deleted file mode 100644
index 5dd726e..0000000
Binary files a/cuda_functions/roi_align_3D/roi_align/_ext/crop_and_resize/__pycache__/__init__.cpython-36.pyc and /dev/null differ
diff --git a/default_configs.py b/default_configs.py
index 7e57afc..7abc68d 100644
--- a/default_configs.py
+++ b/default_configs.py
@@ -1,137 +1,140 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Default Configurations script. Avoids changing configs of all experiments if general settings are to be changed."""
import os
class DefaultConfigs:
def __init__(self, model, server_env=None, dim=2):
#########################
# I/O #
#########################
self.model = model
self.dim = dim
# int [0 < dataset_size]. select n patients from dataset for prototyping.
self.select_prototype_subset = None
# some default paths.
self.backbone_path = 'models/backbone.py'
self.source_dir = os.path.dirname(os.path.realpath(__file__)) #current dir.
self.input_df_name = 'info_df.pickle'
self.model_path = 'models/{}.py'.format(self.model)
if server_env:
self.source_dir = '/home/jaegerp/code/mamma_code/medicaldetectiontoolkit'
#########################
# Data Loader #
#########################
#random seed for fold_generator and batch_generator.
self.seed = 0
#number of threads for multithreaded batch generation.
self.n_workers = 6
# if True, segmentation losses learn all categories, else only foreground vs. background.
self.class_specific_seg_flag = False
#########################
# Architecture #
#########################
self.weight_decay = 0.0
# nonlinearity to be applied after convs with nonlinearity. one of 'relu' or 'leaky_relu'
self.relu = 'relu'
# if True initializes weights as specified in model script. else use default Pytorch init.
self.custom_init = False
# if True adds high-res decoder levels to feature pyramid: P1 + P0. (e.g. set to true in retina_unet configs)
self.operate_stride1 = False
#########################
# Schedule #
#########################
# number of folds in cross validation.
self.n_cv_splits = 5
# number of probabilistic samples in validation.
self.n_probabilistic_samples = None
#########################
# Testing / Plotting #
#########################
# perform mirroring at test time. (only XY. Z not done to not blow up predictions times).
self.test_aug = True
# if True, test data lies in a separate folder and is not part of the cross validation.
self.hold_out_test_set = False
# if hold_out_test_set provided, ensemble predictions over models of all trained cv-folds.
self.ensemble_folds = False
# color specifications for all box_types in prediction_plot.
self.box_color_palette = {'det': 'b', 'gt': 'r', 'neg_class': 'purple',
'prop': 'w', 'pos_class': 'g', 'pos_anchor': 'c', 'neg_anchor': 'c'}
# scan over confidence score in evaluation to optimize it on the validation set.
self.scan_det_thresh = False
# plots roc-curves / prc-curves in evaluation.
self.plot_stat_curves = False
# evaluates average precision per image and averages over images. instead computing one ap over data set.
self.per_patient_ap = False
# threshold for clustering 2D box predictions to 3D Cubes. Overlap is computed in XY.
self.merge_3D_iou = 0.1
# monitor any value from training.
self.n_monitoring_figures = 1
# dict to assign specific plot_values to monitor_figures > 0. {1: ['class_loss'], 2: ['kl_loss', 'kl_sigmas']}
self.assign_values_to_extra_figure = {}
# save predictions to csv file in experiment dir.
self.save_preds_to_csv = True
+ # select a maximum number of patient cases to test. number or "all" for all
+ self.max_test_patients = "all"
+
#########################
# MRCNN #
#########################
# if True, mask loss is not applied. used for data sets, where no pixel-wise annotations are provided.
self.frcnn_mode = False
# if True, unmolds masks in Mask R-CNN to full-res for plotting/monitoring.
self.return_masks_in_val = False
self.return_masks_in_test = False # needed if doing instance segmentation. evaluation not yet implemented.
# add P6 to Feature Pyramid Network.
self.sixth_pooling = False
# for probabilistic detection
self.n_latent_dims = 0
diff --git a/exec.py b/exec.py
index 271f40c..ae5dcc4 100644
--- a/exec.py
+++ b/exec.py
@@ -1,220 +1,228 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""execution script."""
import argparse
import os
import time
import torch
import utils.exp_utils as utils
from evaluator import Evaluator
from predictor import Predictor
from plotting import plot_batch_prediction
def train(logger):
"""
perform the training routine for a given fold. saves plots and selected parameters to the experiment dir
specified in the configs.
"""
logger.info('performing training in {}D over fold {} on experiment {} with model {}'.format(
cf.dim, cf.fold, cf.exp_dir, cf.model))
net = model.net(cf, logger).cuda()
optimizer = torch.optim.Adam(net.parameters(), lr=cf.learning_rate[0], weight_decay=cf.weight_decay)
model_selector = utils.ModelSelector(cf, logger)
train_evaluator = Evaluator(cf, logger, mode='train')
val_evaluator = Evaluator(cf, logger, mode=cf.val_mode)
starting_epoch = 1
# prepare monitoring
monitor_metrics, TrainingPlot = utils.prepare_monitoring(cf)
if cf.resume_to_checkpoint:
starting_epoch, monitor_metrics = utils.load_checkpoint(cf.resume_to_checkpoint, net, optimizer)
logger.info('resumed to checkpoint {} at epoch {}'.format(cf.resume_to_checkpoint, starting_epoch))
logger.info('loading dataset and initializing batch generators...')
batch_gen = data_loader.get_train_generators(cf, logger)
for epoch in range(starting_epoch, cf.num_epochs + 1):
logger.info('starting training epoch {}'.format(epoch))
for param_group in optimizer.param_groups:
param_group['lr'] = cf.learning_rate[epoch - 1]
start_time = time.time()
net.train()
train_results_list = []
for bix in range(cf.num_train_batches):
batch = next(batch_gen['train'])
tic_fw = time.time()
results_dict = net.train_forward(batch)
tic_bw = time.time()
optimizer.zero_grad()
results_dict['torch_loss'].backward()
optimizer.step()
logger.info('tr. batch {0}/{1} (ep. {2}) fw {3:.3f}s / bw {4:.3f}s / total {5:.3f}s || '
.format(bix + 1, cf.num_train_batches, epoch, tic_bw - tic_fw,
time.time() - tic_bw, time.time() - tic_fw) + results_dict['logger_string'])
train_results_list.append([results_dict['boxes'], batch['pid']])
monitor_metrics['train']['monitor_values'][epoch].append(results_dict['monitor_values'])
_, monitor_metrics['train'] = train_evaluator.evaluate_predictions(train_results_list, monitor_metrics['train'])
train_time = time.time() - start_time
logger.info('starting validation in mode {}.'.format(cf.val_mode))
with torch.no_grad():
net.eval()
if cf.do_validation:
val_results_list = []
val_predictor = Predictor(cf, net, logger, mode='val')
for _ in range(batch_gen['n_val']):
batch = next(batch_gen[cf.val_mode])
if cf.val_mode == 'val_patient':
results_dict = val_predictor.predict_patient(batch)
elif cf.val_mode == 'val_sampling':
results_dict = net.train_forward(batch, is_validation=True)
val_results_list.append([results_dict['boxes'], batch['pid']])
monitor_metrics['val']['monitor_values'][epoch].append(results_dict['monitor_values'])
_, monitor_metrics['val'] = val_evaluator.evaluate_predictions(val_results_list, monitor_metrics['val'])
model_selector.run_model_selection(net, optimizer, monitor_metrics, epoch)
# update monitoring and prediction plots
TrainingPlot.update_and_save(monitor_metrics, epoch)
epoch_time = time.time() - start_time
logger.info('trained epoch {}: took {} sec. ({} train / {} val)'.format(
epoch, epoch_time, train_time, epoch_time-train_time))
batch = next(batch_gen['val_sampling'])
results_dict = net.train_forward(batch, is_validation=True)
logger.info('plotting predictions from validation sampling.')
plot_batch_prediction(batch, results_dict, cf)
def test(logger):
"""
perform testing for a given fold (or hold out set). save stats in evaluator.
"""
logger.info('starting testing model of fold {} in exp {}'.format(cf.fold, cf.exp_dir))
net = model.net(cf, logger).cuda()
test_predictor = Predictor(cf, net, logger, mode='test')
test_evaluator = Evaluator(cf, logger, mode='test')
batch_gen = data_loader.get_test_generator(cf, logger)
test_results_list = test_predictor.predict_test_set(batch_gen, return_results=True)
test_evaluator.evaluate_predictions(test_results_list)
test_evaluator.score_test_df()
if __name__ == '__main__':
parser = argparse.ArgumentParser()
- parser.add_argument('--mode', type=str, default='train_test',
+ parser.add_argument('-m', '--mode', type=str, default='train_test',
help='one out of: train / test / train_test / analysis / create_exp')
- parser.add_argument('--folds', nargs='+', type=int, default=None,
+ parser.add_argument('-f','--folds', nargs='+', type=int, default=None,
help='None runs over all folds in CV. otherwise specify list of folds.')
parser.add_argument('--exp_dir', type=str, default='/path/to/experiment/directory',
help='path to experiment dir. will be created if non existent.')
parser.add_argument('--server_env', default=False, action='store_true',
help='change IO settings to deploy models on a cluster.')
parser.add_argument('--slurm_job_id', type=str, default=None, help='job scheduler info')
parser.add_argument('--use_stored_settings', default=False, action='store_true',
help='load configs from existing exp_dir instead of source dir. always done for testing, '
'but can be set to true to do the same for training. useful in job scheduler environment, '
'where source code might change before the job actually runs.')
parser.add_argument('--resume_to_checkpoint', type=str, default=None,
help='if resuming to checkpoint, the desired fold still needs to be parsed via --folds.')
parser.add_argument('--exp_source', type=str, default='experiments/toy_exp',
help='specifies, from which source experiment to load configs and data_loader.')
+ parser.add_argument('-d', '--dev', default=False, action='store_true', help="development mode: shorten everything")
args = parser.parse_args()
folds = args.folds
resume_to_checkpoint = args.resume_to_checkpoint
if args.mode == 'train' or args.mode == 'train_test':
cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, args.use_stored_settings)
+ if args.dev:
+ folds = [0,1]
+ cf.batch_size, cf.num_epochs, cf.min_save_thresh, cf.save_n_models = 3 if cf.dim==2 else 1, 1, 0, 1
+ cf.num_train_batches, cf.num_val_batches, cf.max_val_patients = 5, 1, 1
+ cf.test_n_epochs = cf.save_n_models
+ cf.max_test_patients = 1
+
cf.slurm_job_id = args.slurm_job_id
model = utils.import_module('model', cf.model_path)
data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
if folds is None:
folds = range(cf.n_cv_splits)
for fold in folds:
cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
cf.fold = fold
cf.resume_to_checkpoint = resume_to_checkpoint
if not os.path.exists(cf.fold_dir):
os.mkdir(cf.fold_dir)
logger = utils.get_logger(cf.fold_dir)
train(logger)
cf.resume_to_checkpoint = None
if args.mode == 'train_test':
test(logger)
elif args.mode == 'test':
cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
cf.slurm_job_id = args.slurm_job_id
model = utils.import_module('model', cf.model_path)
data_loader = utils.import_module('dl', os.path.join(args.exp_source, 'data_loader.py'))
if folds is None:
folds = range(cf.n_cv_splits)
for fold in folds:
cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
logger = utils.get_logger(cf.fold_dir)
cf.fold = fold
test(logger)
# load raw predictions saved by predictor during testing, run aggregation algorithms and evaluation.
elif args.mode == 'analysis':
cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, is_training=False, use_stored_settings=True)
logger = utils.get_logger(cf.exp_dir)
if cf.hold_out_test_set:
cf.folds = args.folds
predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
results_list = predictor.load_saved_predictions(apply_wbc=True)
utils.create_csv_output(results_list, cf, logger)
else:
if folds is None:
folds = range(cf.n_cv_splits)
for fold in folds:
cf.fold_dir = os.path.join(cf.exp_dir, 'fold_{}'.format(fold))
cf.fold = fold
predictor = Predictor(cf, net=None, logger=logger, mode='analysis')
results_list = predictor.load_saved_predictions(apply_wbc=True)
logger.info('starting evaluation...')
evaluator = Evaluator(cf, logger, mode='test')
evaluator.evaluate_predictions(results_list)
evaluator.score_test_df()
# create experiment folder and copy scripts without starting job.
- # usefull for cloud deployment where configs might change before job actually runs.
+ # useful for cloud deployment where configs might change before job actually runs.
elif args.mode == 'create_exp':
cf = utils.prep_exp(args.exp_source, args.exp_dir, args.server_env, use_stored_settings=True)
logger = utils.get_logger(cf.exp_dir)
logger.info('created experiment directory at {}'.format(args.exp_dir))
else:
raise RuntimeError('mode specified in args is not implemented...')
diff --git a/experiments/lidc_exp/__pycache__/configs.cpython-35.pyc b/experiments/lidc_exp/__pycache__/configs.cpython-35.pyc
deleted file mode 100644
index 0f55697..0000000
Binary files a/experiments/lidc_exp/__pycache__/configs.cpython-35.pyc and /dev/null differ
diff --git a/experiments/lidc_exp/__pycache__/configs.cpython-36.pyc b/experiments/lidc_exp/__pycache__/configs.cpython-36.pyc
deleted file mode 100644
index 19e7e83..0000000
Binary files a/experiments/lidc_exp/__pycache__/configs.cpython-36.pyc and /dev/null differ
diff --git a/experiments/lidc_exp/__pycache__/data_loader.cpython-35.pyc b/experiments/lidc_exp/__pycache__/data_loader.cpython-35.pyc
deleted file mode 100644
index 47b52d6..0000000
Binary files a/experiments/lidc_exp/__pycache__/data_loader.cpython-35.pyc and /dev/null differ
diff --git a/experiments/lidc_exp/__pycache__/data_loader.cpython-36.pyc b/experiments/lidc_exp/__pycache__/data_loader.cpython-36.pyc
deleted file mode 100644
index 9ffbba7..0000000
Binary files a/experiments/lidc_exp/__pycache__/data_loader.cpython-36.pyc and /dev/null differ
diff --git a/experiments/lidc_exp/configs.py b/experiments/lidc_exp/configs.py
index a848c3e..67025b3 100644
--- a/experiments/lidc_exp/configs.py
+++ b/experiments/lidc_exp/configs.py
@@ -1,335 +1,335 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import numpy as np
from default_configs import DefaultConfigs
class configs(DefaultConfigs):
def __init__(self, server_env=None):
#########################
# Preprocessing #
#########################
self.root_dir = '/path/to/raw/data'
self.raw_data_dir = '{}/data_nrrd'.format(self.root_dir)
self.pp_dir = '{}/pp_norm'.format(self.root_dir)
self.target_spacing = (0.7, 0.7, 1.25)
#########################
# I/O #
#########################
# one out of [2, 3]. dimension the model operates in.
self.dim = 3
# one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_unet'].
- self.model = 'mrcnn'
+ self.model = 'retina_unet'
DefaultConfigs.__init__(self, self.model, server_env, self.dim)
# int [0 < dataset_size]. select n patients from dataset for prototyping. If None, all data is used.
self.select_prototype_subset = None
# path to preprocessed data.
- self.pp_name = 'pp_norm'
+ self.pp_name = 'lidc_preprocessed_for_G2'
self.input_df_name = 'info_df.pickle'
- self.pp_data_path = '/path/to/preprocessed/data/{}'.format(self.pp_name)
+ self.pp_data_path = '/mnt/HDD2TB/Documents/data/lidc/{}'.format(self.pp_name)
self.pp_test_data_path = self.pp_data_path #change if test_data in separate folder.
# settings for deployment in cloud.
if server_env:
# path to preprocessed data.
self.pp_name = 'pp_fg_slices'
self.crop_name = 'pp_fg_slices_packed'
self.pp_data_path = '/path/to/preprocessed/data/{}/{}'.format(self.pp_name, self.crop_name)
self.pp_test_data_path = self.pp_data_path
self.select_prototype_subset = None
#########################
# Data Loader #
#########################
# select modalities from preprocessed data
self.channels = [0]
self.n_channels = len(self.channels)
# patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
self.pre_crop_size_2D = [300, 300]
self.patch_size_2D = [288, 288]
self.pre_crop_size_3D = [156, 156, 96]
self.patch_size_3D = [128, 128, 64]
self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
# ratio of free sampled batch elements before class balancing is triggered
# (>0 to include "empty"/background patches.)
self.batch_sample_slack = 0.2
# set 2D network to operate in 3D images.
self.merge_2D_to_3D_preds = True
# feed +/- n neighbouring slices into channel dimension. set to None for no context.
self.n_3D_context = None
if self.n_3D_context is not None and self.dim == 2:
self.n_channels *= (self.n_3D_context * 2 + 1)
#########################
# Architecture #
#########################
self.start_filts = 48 if self.dim == 2 else 18
self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
self.norm = None # one of None, 'instance_norm', 'batch_norm'
self.weight_decay = 0
# one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
self.weight_init = None
#########################
# Schedule / Selection #
#########################
self.num_epochs = 100
self.num_train_batches = 200 if self.dim == 2 else 200
self.batch_size = 20 if self.dim == 2 else 8
self.do_validation = True
# decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
# the former is morge accurate, while the latter is faster (depending on volume size)
self.val_mode = 'val_sampling' # one of 'val_sampling' , 'val_patient'
if self.val_mode == 'val_patient':
self.max_val_patients = 50 # if 'None' iterates over entire val_set once.
if self.val_mode == 'val_sampling':
self.num_val_batches = 50
#########################
# Testing / Plotting #
#########################
# set the top-n-epochs to be saved for temporal averaging in testing.
self.save_n_models = 5
self.test_n_epochs = 5
# set a minimum epoch number for saving in case of instabilities in the first phase of training.
self.min_save_thresh = 0 if self.dim == 2 else 0
self.report_score_level = ['patient', 'rois'] # choose list from 'patient', 'rois'
self.class_dict = {1: 'benign', 2: 'malignant'} # 0 is background.
self.patient_class_of_interest = 2 # patient metrics are only plotted for one class.
self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring.
self.model_selection_criteria = ['malignant_ap', 'benign_ap'] # criteria to average over for saving epochs.
self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation.
# threshold for clustering predictions together (wcs = weighted cluster scoring).
# needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold).
# if too high, preds of the same object are separate clusters.
self.wcs_iou = 1e-5
self.plot_prediction_histograms = True
self.plot_stat_curves = False
#########################
# Data Augmentation #
#########################
self.da_kwargs={
'do_elastic_deform': True,
'alpha':(0., 1500.),
'sigma':(30., 50.),
'do_rotation':True,
'angle_x': (0., 2 * np.pi),
'angle_y': (0., 0),
'angle_z': (0., 0),
'do_scale': True,
'scale':(0.8, 1.1),
'random_crop':False,
'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
'border_mode_data': 'constant',
'border_cval_data': 0,
'order_data': 1
}
if self.dim == 3:
self.da_kwargs['do_elastic_deform'] = False
self.da_kwargs['angle_x'] = (0, 0.0)
self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
self.da_kwargs['angle_z'] = (0., 2 * np.pi)
#########################
# Add model specifics #
#########################
{'detection_unet': self.add_det_unet_configs,
'mrcnn': self.add_mrcnn_configs,
'ufrcnn': self.add_mrcnn_configs,
'retina_net': self.add_mrcnn_configs,
'retina_unet': self.add_mrcnn_configs,
}[self.model]()
def add_det_unet_configs(self):
self.learning_rate = [1e-4] * self.num_epochs
# aggregation from pixel perdiction to object scores (connected component). One of ['max', 'median']
self.aggregation_operation = 'max'
# max number of roi candidates to identify per batch element and class.
self.n_roi_candidates = 10 if self.dim == 2 else 30
# loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
self.seg_loss_mode = 'dice_wce'
# if <1, false positive predictions in foreground are penalized less.
self.fp_dice_weight = 1 if self.dim == 2 else 1
self.wce_weights = [1, 1, 1]
self.detection_min_confidence = self.min_det_thresh
# if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic).
self.class_specific_seg_flag = True
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
self.head_classes = self.num_seg_classes
def add_mrcnn_configs(self):
# learning rate is a list with one entry per epoch.
self.learning_rate = [1e-4] * self.num_epochs
# disable the re-sampling of mask proposals to original size for speed-up.
# since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
# mask-outputs are optional.
self.return_masks_in_val = True
self.return_masks_in_test = False
# set number of proposal boxes to plot after each epoch.
self.n_plot_rpn_props = 5 if self.dim == 2 else 30
# number of classes for head networks: n_foreground_classes + 1 (background)
self.head_classes = 3
# seg_classes hier refers to the first stage classifier (RPN)
self.num_seg_classes = 2 # foreground vs. background
# feature map strides per pyramid level are inferred from architecture.
self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
# anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
# per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
# choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
self.pyramid_levels = [0, 1, 2, 3]
# number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
self.n_rpn_features = 512 if self.dim == 2 else 128
# anchor ratios and strides per position in feature maps.
self.rpn_anchor_ratios = [0.5, 1, 2]
self.rpn_anchor_stride = 1
# Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION
self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
# loss sampling settings.
self.rpn_train_anchors_per_image = 6 #per batch element
self.train_rois_per_image = 6 #per batch element
self.roi_positive_ratio = 0.5
self.anchor_matching_iou = 0.7
# factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
# poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
self.shem_poolsize = 10
self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1], 0, self.patch_size_3D[2]])
self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1],
self.patch_size_3D[2], self.patch_size_3D[2]])
if self.dim == 2:
self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
self.bbox_std_dev = self.bbox_std_dev[:4]
self.window = self.window[:4]
self.scale = self.scale[:4]
# pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
self.pre_nms_limit = 3000 if self.dim == 2 else 6000
# n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
# since proposals of the entire batch are forwarded through second stage in as one "batch".
self.roi_chunk_size = 2500 if self.dim == 2 else 600
self.post_nms_rois_training = 500 if self.dim == 2 else 75
self.post_nms_rois_inference = 500
# Final selection of detections (refine_detections)
self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class.
self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster.
self.model_min_confidence = 0.1
if self.dim == 2:
self.backbone_shapes = np.array(
[[int(np.ceil(self.patch_size[0] / stride)),
int(np.ceil(self.patch_size[1] / stride))]
for stride in self.backbone_strides['xy']])
else:
self.backbone_shapes = np.array(
[[int(np.ceil(self.patch_size[0] / stride)),
int(np.ceil(self.patch_size[1] / stride)),
int(np.ceil(self.patch_size[2] / stride_z))]
for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
)])
if self.model == 'ufrcnn':
self.operate_stride1 = True
self.class_specific_seg_flag = True
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
self.frcnn_mode = True
if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector':
# implement extra anchor-scales according to retina-net publication.
self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
self.rpn_anchor_scales['xy']]
self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
self.rpn_anchor_scales['z']]
self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
self.n_rpn_features = 256 if self.dim == 2 else 64
# pre-selection of detections for NMS-speedup. per entire batch.
self.pre_nms_limit = 10000 if self.dim == 2 else 50000
# anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
self.anchor_matching_iou = 0.5
# if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic).
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
if self.model == 'retina_unet':
self.operate_stride1 = True
diff --git a/experiments/lidc_exp/data_loader.py b/experiments/lidc_exp/data_loader.py
index 0c05a20..2e64f34 100644
--- a/experiments/lidc_exp/data_loader.py
+++ b/experiments/lidc_exp/data_loader.py
@@ -1,455 +1,461 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
'''
Example Data Loader for the LIDC data set. This dataloader expects preprocessed data in .npy or .npz files per patient and
a pandas dataframe in the same directory containing the meta-info e.g. file paths, labels, foregound slice-ids.
'''
import numpy as np
import os
from collections import OrderedDict
import pandas as pd
import pickle
import time
import subprocess
import utils.dataloader_utils as dutils
# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading import SingleThreadedAugmenter
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
def get_train_generators(cf, logger):
"""
wrapper function for creating the training batch generator pipeline. returns the train/val generators.
selects patients according to cv folds (generated by first run/fold of experiment):
splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
If cf.hold_out_test_set is True, adds the test split to the training data.
"""
all_data = load_dataset(cf, logger)
all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()])
if not cf.created_fold_id_pickle:
fg = dutils.fold_generator(seed=cf.seed, n_splits=cf.n_cv_splits, len_data=len(all_pids_list)).get_fold_names()
with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'wb') as handle:
pickle.dump(fg, handle)
cf.created_fold_id_pickle = True
else:
with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
fg = pickle.load(handle)
train_ix, val_ix, test_ix, _ = fg[cf.fold]
train_pids = [all_pids_list[ix] for ix in train_ix]
val_pids = [all_pids_list[ix] for ix in val_ix]
if cf.hold_out_test_set:
train_pids += [all_pids_list[ix] for ix in test_ix]
train_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in train_pids)}
val_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in val_pids)}
logger.info("data set loaded with: {} train / {} val / {} test patients".format(len(train_ix), len(val_ix), len(test_ix)))
batch_gen = {}
batch_gen['train'] = create_data_gen_pipeline(train_data, cf=cf, is_training=True)
batch_gen['val_sampling'] = create_data_gen_pipeline(val_data, cf=cf, is_training=False)
if cf.val_mode == 'val_patient':
batch_gen['val_patient'] = PatientBatchIterator(val_data, cf=cf)
- batch_gen['n_val'] = len(val_ix) if cf.max_val_patients is None else cf.max_val_patients
+ batch_gen['n_val'] = len(val_ix) if cf.max_val_patients is None else min(len(val_ix), cf.max_val_patients)
else:
batch_gen['n_val'] = cf.num_val_batches
return batch_gen
def get_test_generator(cf, logger):
"""
wrapper function for creating the test batch generator pipeline.
selects patients according to cv folds (generated by first run/fold of experiment)
If cf.hold_out_test_set is True, gets the data from an external folder instead.
"""
if cf.hold_out_test_set:
- cf.pp_data_path = cf.pp_test_data_path
+ pp_name = cf.pp_test_name
test_ix = None
else:
+ pp_name = None
with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
fold_list = pickle.load(handle)
_, _, test_ix, _ = fold_list[cf.fold]
# warnings.warn('WARNING: using validation set for testing!!!')
- test_data = load_dataset(cf, logger, test_ix)
+ test_data = load_dataset(cf, logger, test_ix, pp_data_path=cf.pp_test_data_path, pp_name=pp_name)
logger.info("data set loaded with: {} test patients".format(len(test_ix)))
batch_gen = {}
batch_gen['test'] = PatientBatchIterator(test_data, cf=cf)
- batch_gen['n_test'] = len(test_ix)
+ batch_gen['n_test'] = len(test_ix) if cf.max_test_patients=="all" else \
+ min(cf.max_test_patients, len(test_ix))
return batch_gen
-def load_dataset(cf, logger, subset_ixs=None):
+def load_dataset(cf, logger, subset_ixs=None, pp_data_path=None, pp_name=None):
"""
loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory.
:param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds.
:return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as
individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed
numpy arrays to be loaded during batch-generation
"""
+ if pp_data_path is None:
+ pp_data_path = cf.pp_data_path
+ if pp_name is None:
+ pp_name = cf.pp_name
if cf.server_env:
copy_data = True
- target_dir = os.path.join('/ssd', cf.slurm_job_id, cf.pp_name, cf.crop_name)
+ target_dir = os.path.join('/ssd', cf.slurm_job_id, pp_name, cf.crop_name)
if not os.path.exists(target_dir):
- cf.data_source_dir = cf.pp_data_path
+ cf.data_source_dir = pp_data_path
os.makedirs(target_dir)
subprocess.call('rsync -av {} {}'.format(
os.path.join(cf.data_source_dir, cf.input_df_name), os.path.join(target_dir, cf.input_df_name)), shell=True)
logger.info('created target dir and info df at {}'.format(os.path.join(target_dir, cf.input_df_name)))
elif subset_ixs is None:
copy_data = False
- cf.pp_data_path = target_dir
+ pp_data_path = target_dir
- p_df = pd.read_pickle(os.path.join(cf.pp_data_path, cf.input_df_name))
+ p_df = pd.read_pickle(os.path.join(pp_data_path, cf.input_df_name))
if cf.select_prototype_subset is not None:
prototype_pids = p_df.pid.tolist()[:cf.select_prototype_subset]
p_df = p_df[p_df.pid.isin(prototype_pids)]
logger.warning('WARNING: using prototyping data subset!!!')
if subset_ixs is not None:
subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs]
p_df = p_df[p_df.pid.isin(subset_pids)]
logger.info('subset: selected {} instances from df'.format(len(p_df)))
if cf.server_env:
if copy_data:
copy_and_unpack_data(logger, p_df.pid.tolist(), cf.fold_dir, cf.data_source_dir, target_dir)
class_targets = p_df['class_target'].tolist()
pids = p_df.pid.tolist()
- imgs = [os.path.join(cf.pp_data_path, '{}_img.npy'.format(pid)) for pid in pids]
- segs = [os.path.join(cf.pp_data_path,'{}_rois.npy'.format(pid)) for pid in pids]
+ imgs = [os.path.join(pp_data_path, '{}_img.npy'.format(pid)) for pid in pids]
+ segs = [os.path.join(pp_data_path,'{}_rois.npy'.format(pid)) for pid in pids]
data = OrderedDict()
for ix, pid in enumerate(pids):
# for the experiment conducted here, malignancy scores are binarized: (benign: 1-2, malignant: 3-5)
targets = [1 if ii >= 3 else 0 for ii in class_targets[ix]]
data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid, 'class_target': targets}
data[pid]['fg_slices'] = p_df['fg_slices'].tolist()[ix]
return data
def create_data_gen_pipeline(patient_data, cf, is_training=True):
"""
create mutli-threaded train/val/test batch generation and augmentation pipeline.
:param patient_data: dictionary containing one dictionary per patient in the train/test subset.
:param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
:return: multithreaded_generator
"""
# create instance of batch generator as first element in pipeline.
data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)
# add transformations to pipeline.
my_transforms = []
if is_training:
mirror_transform = Mirror(axes=np.arange(cf.dim))
my_transforms.append(mirror_transform)
spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
class BatchGenerator(SlimDataLoaderBase):
"""
creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
Actual patch_size is obtained after data augmentation.
:param data: data dictionary as provided by 'load_dataset'.
:param batch_size: number of patients to sample for the batch
:return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
"""
def __init__(self, data, batch_size, cf):
super(BatchGenerator, self).__init__(data, batch_size)
self.cf = cf
self.crop_margin = np.array(self.cf.patch_size)/8. #min distance of ROI center to edge of cropped_patch.
self.p_fg = 0.5
def generate_train_batch(self):
batch_data, batch_segs, batch_pids, batch_targets, batch_patient_labels = [], [], [], [], []
class_targets_list = [v['class_target'] for (k, v) in self._data.items()]
if self.cf.head_classes > 2:
# samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
batch_ixs = dutils.get_class_balanced_patients(
class_targets_list, self.batch_size, self.cf.head_classes - 1, slack_factor=self.cf.batch_sample_slack)
else:
batch_ixs = np.random.choice(len(class_targets_list), self.batch_size)
patients = list(self._data.items())
for b in batch_ixs:
patient = patients[b][1]
data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] # (c, y, x, z)
seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
batch_pids.append(patient['pid'])
batch_targets.append(patient['class_target'])
if self.cf.dim == 2:
# draw random slice from patient while oversampling slices containing foreground objects with p_fg.
if len(patient['fg_slices']) > 0:
fg_prob = self.p_fg / len(patient['fg_slices'])
bg_prob = (1 - self.p_fg) / (data.shape[3] - len(patient['fg_slices']))
slices_prob = [fg_prob if ix in patient['fg_slices'] else bg_prob for ix in range(data.shape[3])]
slice_id = np.random.choice(data.shape[3], p=slices_prob)
else:
slice_id = np.random.choice(data.shape[3])
# if set to not None, add neighbouring slices to each selected slice in channel dimension.
if self.cf.n_3D_context is not None:
padded_data = dutils.pad_nd_image(data[0], [(data.shape[-1] + (self.cf.n_3D_context*2))], mode='constant')
padded_slice_id = slice_id + self.cf.n_3D_context
data = (np.concatenate([padded_data[..., ii][np.newaxis] for ii in range(
padded_slice_id - self.cf.n_3D_context, padded_slice_id + self.cf.n_3D_context + 1)], axis=0))
else:
data = data[..., slice_id]
seg = seg[..., slice_id]
# pad data if smaller than pre_crop_size.
if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.cf.pre_crop_size)]):
new_shape = [np.max([data.shape[dim + 1], ps]) for dim, ps in enumerate(self.cf.pre_crop_size)]
data = dutils.pad_nd_image(data, new_shape, mode='constant')
seg = dutils.pad_nd_image(seg, new_shape, mode='constant')
# crop patches of size pre_crop_size, while sampling patches containing foreground with p_fg.
crop_dims = [dim for dim, ps in enumerate(self.cf.pre_crop_size) if data.shape[dim + 1] > ps]
if len(crop_dims) > 0:
fg_prob_sample = np.random.rand(1)
# with p_fg: sample random pixel from random ROI and shift center by random value.
if fg_prob_sample < self.p_fg and np.sum(seg) > 0:
seg_ixs = np.argwhere(seg == np.random.choice(np.unique(seg)[1:], 1))
roi_anchor_pixel = seg_ixs[np.random.choice(seg_ixs.shape[0], 1)][0]
assert seg[tuple(roi_anchor_pixel)] > 0
# sample the patch center coords. constrained by edges of images - pre_crop_size /2. And by
# distance to the desired ROI < patch_size /2.
# (here final patch size to account for center_crop after data augmentation).
sample_seg_center = {}
for ii in crop_dims:
low = np.max((self.cf.pre_crop_size[ii]//2, roi_anchor_pixel[ii] - (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
high = np.min((data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2,
roi_anchor_pixel[ii] + (self.cf.patch_size[ii]//2 - self.crop_margin[ii])))
# happens if lesion on the edge of the image. dont care about roi anymore,
# just make sure pre-crop is inside image.
if low >= high:
low = data.shape[ii + 1] // 2 - (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
high = data.shape[ii + 1] // 2 + (data.shape[ii + 1] // 2 - self.cf.pre_crop_size[ii] // 2)
sample_seg_center[ii] = np.random.randint(low=low, high=high)
else:
# not guaranteed to be empty. probability of emptiness depends on the data.
sample_seg_center = {ii: np.random.randint(low=self.cf.pre_crop_size[ii]//2,
high=data.shape[ii + 1] - self.cf.pre_crop_size[ii]//2) for ii in crop_dims}
for ii in crop_dims:
min_crop = int(sample_seg_center[ii] - self.cf.pre_crop_size[ii] // 2)
max_crop = int(sample_seg_center[ii] + self.cf.pre_crop_size[ii] // 2)
data = np.take(data, indices=range(min_crop, max_crop), axis=ii + 1)
seg = np.take(seg, indices=range(min_crop, max_crop), axis=ii)
batch_data.append(data)
batch_segs.append(seg[np.newaxis])
data = np.array(batch_data)
seg = np.array(batch_segs).astype(np.uint8)
class_target = np.array(batch_targets)
return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target}
class PatientBatchIterator(SlimDataLoaderBase):
"""
creates a test generator that iterates over entire given dataset returning 1 patient per batch.
Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
if willing to accept speed-loss during training.
:return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
batch_size = n_2D_patches in 2D .
"""
def __init__(self, data, cf): #threads in augmenter
super(PatientBatchIterator, self).__init__(data, 0)
self.cf = cf
self.patient_ix = 0
self.dataset_pids = [v['pid'] for (k, v) in data.items()]
self.patch_size = cf.patch_size
if len(self.patch_size) == 2:
self.patch_size = self.patch_size + [1]
def generate_train_batch(self):
pid = self.dataset_pids[self.patient_ix]
patient = self._data[pid]
data = np.transpose(np.load(patient['data'], mmap_mode='r'), axes=(1, 2, 0))[np.newaxis] # (c, y, x, z)
seg = np.transpose(np.load(patient['seg'], mmap_mode='r'), axes=(1, 2, 0))
batch_class_targets = np.array([patient['class_target']])
# pad data if smaller than patch_size seen during training.
if np.any([data.shape[dim + 1] < ps for dim, ps in enumerate(self.patch_size)]):
new_shape = [data.shape[0]] + [np.max([data.shape[dim + 1], self.patch_size[dim]]) for dim, ps in enumerate(self.patch_size)]
data = dutils.pad_nd_image(data, new_shape) # use 'return_slicer' to crop image back to original shape.
seg = dutils.pad_nd_image(seg, new_shape)
# get 3D targets for evaluation, even if network operates in 2D. 2D predictions will be merged to 3D in predictor.
if self.cf.dim == 3 or self.cf.merge_2D_to_3D_preds:
out_data = data[np.newaxis]
out_seg = seg[np.newaxis, np.newaxis]
out_targets = batch_class_targets
batch_3D = {'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid}
converter = ConvertSegToBoundingBoxCoordinates(dim=3, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
batch_3D = converter(**batch_3D)
batch_3D.update({'patient_bb_target': batch_3D['bb_target'],
'patient_roi_labels': batch_3D['roi_labels'],
'original_img_shape': out_data.shape})
if self.cf.dim == 2:
out_data = np.transpose(data, axes=(3, 0, 1, 2)) # (z, c, x, y )
out_seg = np.transpose(seg, axes=(2, 0, 1))[:, np.newaxis]
out_targets = np.array(np.repeat(batch_class_targets, out_data.shape[0], axis=0))
# if set to not None, add neighbouring slices to each selected slice in channel dimension.
if self.cf.n_3D_context is not None:
slice_range = range(self.cf.n_3D_context, out_data.shape[0] + self.cf.n_3D_context)
out_data = np.pad(out_data, ((self.cf.n_3D_context, self.cf.n_3D_context), (0, 0), (0, 0), (0, 0)), 'constant', constant_values=0)
out_data = np.array(
[np.concatenate([out_data[ii] for ii in range(
slice_id - self.cf.n_3D_context, slice_id + self.cf.n_3D_context + 1)], axis=0) for slice_id in
slice_range])
batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': out_targets, 'pid': pid}
converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
batch_2D = converter(**batch_2D)
if self.cf.merge_2D_to_3D_preds:
batch_2D.update({'patient_bb_target': batch_3D['patient_bb_target'],
'patient_roi_labels': batch_3D['patient_roi_labels'],
'original_img_shape': out_data.shape})
else:
batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
'patient_roi_labels': batch_2D['roi_labels'],
'original_img_shape': out_data.shape})
out_batch = batch_3D if self.cf.dim == 3 else batch_2D
patient_batch = out_batch
# crop patient-volume to patches of patch_size used during training. stack patches up in batch dimension.
# in this case, 2D is treated as a special case of 3D with patch_size[z] = 1.
if np.any([data.shape[dim + 1] > self.patch_size[dim] for dim in range(3)]):
patch_crop_coords_list = dutils.get_patch_crop_coords(data[0], self.patch_size)
new_img_batch, new_seg_batch, new_class_targets_batch = [], [], []
for cix, c in enumerate(patch_crop_coords_list):
seg_patch = seg[c[0]:c[1], c[2]: c[3], c[4]:c[5]]
new_seg_batch.append(seg_patch)
# if set to not None, add neighbouring slices to each selected slice in channel dimension.
# correct patch_crop coordinates by added slices of 3D context.
if self.cf.dim == 2 and self.cf.n_3D_context is not None:
tmp_c_5 = c[5] + (self.cf.n_3D_context * 2)
if cix == 0:
data = np.pad(data, ((0, 0), (0, 0), (0, 0), (self.cf.n_3D_context, self.cf.n_3D_context)), 'constant', constant_values=0)
else:
tmp_c_5 = c[5]
new_img_batch.append(data[:, c[0]:c[1], c[2]:c[3], c[4]:tmp_c_5])
data = np.array(new_img_batch) # (n_patches, c, x, y, z)
seg = np.array(new_seg_batch)[:, np.newaxis] # (n_patches, 1, x, y, z)
batch_class_targets = np.repeat(batch_class_targets, len(patch_crop_coords_list), axis=0)
if self.cf.dim == 2:
if self.cf.n_3D_context is not None:
data = np.transpose(data[:, 0], axes=(0, 3, 1, 2))
else:
# all patches have z dimension 1 (slices). discard dimension
data = data[..., 0]
seg = seg[..., 0]
patch_batch = {'data': data, 'seg': seg, 'class_target': batch_class_targets, 'pid': pid}
patch_batch['patch_crop_coords'] = np.array(patch_crop_coords_list)
patch_batch['patient_bb_target'] = patient_batch['patient_bb_target']
patch_batch['patient_roi_labels'] = patient_batch['patient_roi_labels']
patch_batch['original_img_shape'] = patient_batch['original_img_shape']
converter = ConvertSegToBoundingBoxCoordinates(self.cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
patch_batch = converter(**patch_batch)
out_batch = patch_batch
self.patient_ix += 1
if self.patient_ix == len(self.dataset_pids):
self.patient_ix = 0
return out_batch
def copy_and_unpack_data(logger, pids, fold_dir, source_dir, target_dir):
start_time = time.time()
with open(os.path.join(fold_dir, 'file_list.txt'), 'w') as handle:
for pid in pids:
handle.write('{}_img.npz\n'.format(pid))
handle.write('{}_rois.npz\n'.format(pid))
subprocess.call('rsync -av --files-from {} {} {}'.format(os.path.join(fold_dir, 'file_list.txt'),
source_dir, target_dir), shell=True)
dutils.unpack_dataset(target_dir)
copied_files = os.listdir(target_dir)
logger.info("copying and unpacking data set finsihed : {} files in target dir: {}. took {} sec".format(
len(copied_files), target_dir, np.round(time.time() - start_time, 0)))
diff --git a/experiments/toy_exp/__pycache__/configs.cpython-35.pyc b/experiments/toy_exp/__pycache__/configs.cpython-35.pyc
deleted file mode 100644
index 6171c47..0000000
Binary files a/experiments/toy_exp/__pycache__/configs.cpython-35.pyc and /dev/null differ
diff --git a/experiments/toy_exp/__pycache__/configs.cpython-36.pyc b/experiments/toy_exp/__pycache__/configs.cpython-36.pyc
deleted file mode 100644
index 2b334d9..0000000
Binary files a/experiments/toy_exp/__pycache__/configs.cpython-36.pyc and /dev/null differ
diff --git a/experiments/toy_exp/__pycache__/data_loader.cpython-35.pyc b/experiments/toy_exp/__pycache__/data_loader.cpython-35.pyc
deleted file mode 100644
index 7347c22..0000000
Binary files a/experiments/toy_exp/__pycache__/data_loader.cpython-35.pyc and /dev/null differ
diff --git a/experiments/toy_exp/__pycache__/data_loader.cpython-36.pyc b/experiments/toy_exp/__pycache__/data_loader.cpython-36.pyc
deleted file mode 100644
index ab4d287..0000000
Binary files a/experiments/toy_exp/__pycache__/data_loader.cpython-36.pyc and /dev/null differ
diff --git a/experiments/toy_exp/configs.py b/experiments/toy_exp/configs.py
index d414d76..36d6ee7 100644
--- a/experiments/toy_exp/configs.py
+++ b/experiments/toy_exp/configs.py
@@ -1,345 +1,344 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import sys
import os
sys.path.append(os.path.dirname(os.path.realpath(__file__)))
import numpy as np
from default_configs import DefaultConfigs
class configs(DefaultConfigs):
def __init__(self, server_env=None):
#########################
# Preprocessing #
#########################
- self.root_dir = '/path/to/data'
+ self.root_dir = '/mnt/HDD2TB/Documents/data/mdt_toy'
#########################
# I/O #
#########################
# one out of [2, 3]. dimension the model operates in.
self.dim = 2
# one out of ['mrcnn', 'retina_net', 'retina_unet', 'detection_unet', 'ufrcnn', 'detection_unet'].
- self.model = 'ufrcnn'
+ self.model = 'retina_unet'
DefaultConfigs.__init__(self, self.model, server_env, self.dim)
# int [0 < dataset_size]. select n patients from dataset for prototyping.
self.select_prototype_subset = None
self.hold_out_test_set = True
self.n_train_data = 1000
# choose one of the 3 toy experiments described in https://arxiv.org/pdf/1811.08661.pdf
# one of ['donuts_shape', 'donuts_pattern', 'circles_scale'].
toy_mode = 'donuts_shape'
# path to preprocessed data.
self.input_df_name = 'info_df.pickle'
self.pp_name = os.path.join(toy_mode, 'train')
self.pp_data_path = os.path.join(self.root_dir, self.pp_name)
self.pp_test_name = os.path.join(toy_mode, 'test')
self.pp_test_data_path = os.path.join(self.root_dir, self.pp_test_name)
# settings for deployment in cloud.
if server_env:
# path to preprocessed data.
pp_root_dir = '/path/to/data'
self.pp_name = os.path.join(toy_mode, 'train')
self.pp_data_path = os.path.join(pp_root_dir, self.pp_name)
self.pp_test_name = os.path.join(toy_mode, 'test')
self.pp_test_data_path = os.path.join(pp_root_dir, self.pp_test_name)
self.select_prototype_subset = None
#########################
# Data Loader #
#########################
# select modalities from preprocessed data
self.channels = [0]
self.n_channels = len(self.channels)
# patch_size to be used for training. pre_crop_size is the patch_size before data augmentation.
self.pre_crop_size_2D = [320, 320]
self.patch_size_2D = [320, 320]
self.patch_size = self.patch_size_2D if self.dim == 2 else self.patch_size_3D
self.pre_crop_size = self.pre_crop_size_2D if self.dim == 2 else self.pre_crop_size_3D
# ratio of free sampled batch elements before class balancing is triggered
# (>0 to include "empty"/background patches.)
self.batch_sample_slack = 0.2
# set 2D network to operate in 3D images.
self.merge_2D_to_3D_preds = False
# feed +/- n neighbouring slices into channel dimension. set to None for no context.
self.n_3D_context = None
if self.n_3D_context is not None and self.dim == 2:
self.n_channels *= (self.n_3D_context * 2 + 1)
#########################
# Architecture #
#########################
self.start_filts = 48 if self.dim == 2 else 18
self.end_filts = self.start_filts * 4 if self.dim == 2 else self.start_filts * 2
self.res_architecture = 'resnet50' # 'resnet101' , 'resnet50'
self.norm = None # one of None, 'instance_norm', 'batch_norm'
self.weight_decay = 0
# one of 'xavier_uniform', 'xavier_normal', or 'kaiming_normal', None (=default = 'kaiming_uniform')
self.weight_init = None
#########################
# Schedule / Selection #
#########################
self.num_epochs = 100
self.num_train_batches = 200 if self.dim == 2 else 200
self.batch_size = 20 if self.dim == 2 else 8
self.do_validation = True
# decide whether to validate on entire patient volumes (like testing) or sampled patches (like training)
# the former is morge accurate, while the latter is faster (depending on volume size)
self.val_mode = 'val_patient' # one of 'val_sampling' , 'val_patient'
if self.val_mode == 'val_patient':
self.max_val_patients = None # if 'None' iterates over entire val_set once.
if self.val_mode == 'val_sampling':
self.num_val_batches = 50
#########################
# Testing / Plotting #
#########################
# set the top-n-epochs to be saved for temporal averaging in testing.
self.save_n_models = 5
self.test_n_epochs = 5
# set a minimum epoch number for saving in case of instabilities in the first phase of training.
self.min_save_thresh = 0 if self.dim == 2 else 0
self.report_score_level = ['patient', 'rois'] # choose list from 'patient', 'rois'
self.class_dict = {1: 'benign', 2: 'malignant'} # 0 is background.
self.patient_class_of_interest = 2 # patient metrics are only plotted for one class.
self.ap_match_ious = [0.1] # list of ious to be evaluated for ap-scoring.
self.model_selection_criteria = ['benign_ap', 'malignant_ap'] # criteria to average over for saving epochs.
self.min_det_thresh = 0.1 # minimum confidence value to select predictions for evaluation.
# threshold for clustering predictions together (wcs = weighted cluster scoring).
# needs to be >= the expected overlap of predictions coming from one model (typically NMS threshold).
# if too high, preds of the same object are separate clusters.
self.wcs_iou = 1e-5
self.plot_prediction_histograms = True
self.plot_stat_curves = False
#########################
# Data Augmentation #
#########################
self.da_kwargs={
'do_elastic_deform': True,
'alpha':(0., 1500.),
'sigma':(30., 50.),
'do_rotation':True,
'angle_x': (0., 2 * np.pi),
'angle_y': (0., 0),
'angle_z': (0., 0),
'do_scale': True,
'scale':(0.8, 1.1),
'random_crop':False,
'rand_crop_dist': (self.patch_size[0] / 2. - 3, self.patch_size[1] / 2. - 3),
'border_mode_data': 'constant',
'border_cval_data': 0,
'order_data': 1
}
if self.dim == 3:
self.da_kwargs['do_elastic_deform'] = False
self.da_kwargs['angle_x'] = (0, 0.0)
self.da_kwargs['angle_y'] = (0, 0.0) #must be 0!!
self.da_kwargs['angle_z'] = (0., 2 * np.pi)
#########################
# Add model specifics #
#########################
{'detection_unet': self.add_det_unet_configs,
'mrcnn': self.add_mrcnn_configs,
'ufrcnn': self.add_mrcnn_configs,
'ufrcnn_surrounding': self.add_mrcnn_configs,
'retina_net': self.add_mrcnn_configs,
'retina_unet': self.add_mrcnn_configs,
'prob_detector': self.add_mrcnn_configs,
}[self.model]()
def add_det_unet_configs(self):
self.learning_rate = [1e-4] * self.num_epochs
# aggregation from pixel perdiction to object scores (connected component). One of ['max', 'median']
self.aggregation_operation = 'max'
# max number of roi candidates to identify per image (slice in 2D, volume in 3D)
self.n_roi_candidates = 3 if self.dim == 2 else 8
# loss mode: either weighted cross entropy ('wce'), batch-wise dice loss ('dice), or the sum of both ('dice_wce')
self.seg_loss_mode = 'dice_wce'
# if <1, false positive predictions in foreground are penalized less.
self.fp_dice_weight = 1 if self.dim == 2 else 1
self.wce_weights = [1, 1, 1]
self.detection_min_confidence = self.min_det_thresh
# if 'True', loss distinguishes all classes, else only foreground vs. background (class agnostic).
self.class_specific_seg_flag = True
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
self.head_classes = self.num_seg_classes
def add_mrcnn_configs(self):
# learning rate is a list with one entry per epoch.
self.learning_rate = [1e-4] * self.num_epochs
# disable mask head loss. (e.g. if no pixelwise annotations available)
self.frcnn_mode = False
# disable the re-sampling of mask proposals to original size for speed-up.
# since evaluation is detection-driven (box-matching) and not instance segmentation-driven (iou-matching),
# mask-outputs are optional.
self.return_masks_in_val = True
self.return_masks_in_test = False
# set number of proposal boxes to plot after each epoch.
self.n_plot_rpn_props = 5 if self.dim == 2 else 30
# number of classes for head networks: n_foreground_classes + 1 (background)
self.head_classes = 3
# seg_classes hier refers to the first stage classifier (RPN)
self.num_seg_classes = 2 # foreground vs. background
# feature map strides per pyramid level are inferred from architecture.
self.backbone_strides = {'xy': [4, 8, 16, 32], 'z': [1, 2, 4, 8]}
# anchor scales are chosen according to expected object sizes in data set. Default uses only one anchor scale
# per pyramid level. (outer list are pyramid levels (corresponding to BACKBONE_STRIDES), inner list are scales per level.)
self.rpn_anchor_scales = {'xy': [[8], [16], [32], [64]], 'z': [[2], [4], [8], [16]]}
# choose which pyramid levels to extract features from: P2: 0, P3: 1, P4: 2, P5: 3.
self.pyramid_levels = [0, 1, 2, 3]
# number of feature maps in rpn. typically lowered in 3D to save gpu-memory.
self.n_rpn_features = 512 if self.dim == 2 else 128
# anchor ratios and strides per position in feature maps.
self.rpn_anchor_ratios = [0.5, 1, 2]
self.rpn_anchor_stride = 1
# Threshold for first stage (RPN) non-maximum suppression (NMS): LOWER == HARDER SELECTION
self.rpn_nms_threshold = 0.7 if self.dim == 2 else 0.7
# loss sampling settings.
self.rpn_train_anchors_per_image = 2 #per batch element
self.train_rois_per_image = 2 #per batch element
self.roi_positive_ratio = 0.5
self.anchor_matching_iou = 0.7
# factor of top-k candidates to draw from per negative sample (stochastic-hard-example-mining).
# poolsize to draw top-k candidates from will be shem_poolsize * n_negative_samples.
self.shem_poolsize = 10
self.pool_size = (7, 7) if self.dim == 2 else (7, 7, 3)
self.mask_pool_size = (14, 14) if self.dim == 2 else (14, 14, 5)
self.mask_shape = (28, 28) if self.dim == 2 else (28, 28, 10)
self.rpn_bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
self.bbox_std_dev = np.array([0.1, 0.1, 0.1, 0.2, 0.2, 0.2])
self.window = np.array([0, 0, self.patch_size[0], self.patch_size[1]])
self.scale = np.array([self.patch_size[0], self.patch_size[1], self.patch_size[0], self.patch_size[1]])
if self.dim == 2:
self.rpn_bbox_std_dev = self.rpn_bbox_std_dev[:4]
self.bbox_std_dev = self.bbox_std_dev[:4]
self.window = self.window[:4]
self.scale = self.scale[:4]
# pre-selection in proposal-layer (stage 1) for NMS-speedup. applied per batch element.
self.pre_nms_limit = 3000 if self.dim == 2 else 6000
# n_proposals to be selected after NMS per batch element. too high numbers blow up memory if "detect_while_training" is True,
# since proposals of the entire batch are forwarded through second stage in as one "batch".
self.roi_chunk_size = 800 if self.dim == 2 else 600
self.post_nms_rois_training = 500 if self.dim == 2 else 75
self.post_nms_rois_inference = 500
# Final selection of detections (refine_detections)
self.model_max_instances_per_batch_element = 10 if self.dim == 2 else 30 # per batch element and class.
self.detection_nms_threshold = 1e-5 # needs to be > 0, otherwise all predictions are one cluster.
self.model_min_confidence = 0.1
if self.dim == 2:
self.backbone_shapes = np.array(
[[int(np.ceil(self.patch_size[0] / stride)),
int(np.ceil(self.patch_size[1] / stride))]
for stride in self.backbone_strides['xy']])
else:
self.backbone_shapes = np.array(
[[int(np.ceil(self.patch_size[0] / stride)),
int(np.ceil(self.patch_size[1] / stride)),
int(np.ceil(self.patch_size[2] / stride_z))]
for stride, stride_z in zip(self.backbone_strides['xy'], self.backbone_strides['z']
)])
if self.model == 'ufrcnn':
self.operate_stride1 = True
self.class_specific_seg_flag = True
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
self.frcnn_mode = True
if self.model == 'retina_net' or self.model == 'retina_unet' or self.model == 'prob_detector':
# implement extra anchor-scales according to retina-net publication.
self.rpn_anchor_scales['xy'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
self.rpn_anchor_scales['xy']]
self.rpn_anchor_scales['z'] = [[ii[0], ii[0] * (2 ** (1 / 3)), ii[0] * (2 ** (2 / 3))] for ii in
self.rpn_anchor_scales['z']]
self.n_anchors_per_pos = len(self.rpn_anchor_ratios) * 3
self.n_rpn_features = 256 if self.dim == 2 else 64
# pre-selection of detections for NMS-speedup. per entire batch.
self.pre_nms_limit = 10000 if self.dim == 2 else 50000
# anchor matching iou is lower than in Mask R-CNN according to https://arxiv.org/abs/1708.02002
self.anchor_matching_iou = 0.5
# if 'True', seg loss distinguishes all classes, else only foreground vs. background (class agnostic).
self.num_seg_classes = 3 if self.class_specific_seg_flag else 2
if self.model == 'retina_unet':
self.operate_stride1 = True
- self.class_specific_seg_flag = True
diff --git a/experiments/toy_exp/data_loader.py b/experiments/toy_exp/data_loader.py
index 0f8d717..3a7062c 100644
--- a/experiments/toy_exp/data_loader.py
+++ b/experiments/toy_exp/data_loader.py
@@ -1,282 +1,305 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import os
from collections import OrderedDict
import pandas as pd
import pickle
import time
import subprocess
import utils.dataloader_utils as dutils
# batch generator tools from https://github.com/MIC-DKFZ/batchgenerators
from batchgenerators.dataloading.data_loader import SlimDataLoaderBase
from batchgenerators.transforms.spatial_transforms import MirrorTransform as Mirror
from batchgenerators.transforms.abstract_transforms import Compose
from batchgenerators.dataloading.multi_threaded_augmenter import MultiThreadedAugmenter
from batchgenerators.dataloading import SingleThreadedAugmenter
from batchgenerators.transforms.spatial_transforms import SpatialTransform
from batchgenerators.transforms.crop_and_pad_transforms import CenterCropTransform
from batchgenerators.transforms.utility_transforms import ConvertSegToBoundingBoxCoordinates
def get_train_generators(cf, logger):
"""
wrapper function for creating the training batch generator pipeline. returns the train/val generators.
selects patients according to cv folds (generated by first run/fold of experiment):
splits the data into n-folds, where 1 split is used for val, 1 split for testing and the rest for training. (inner loop test set)
If cf.hold_out_test_set is True, adds the test split to the training data.
"""
all_data = load_dataset(cf, logger)
all_pids_list = np.unique([v['pid'] for (k, v) in all_data.items()])
train_pids = all_pids_list[:cf.n_train_data]
val_pids = all_pids_list[1000:1500]
train_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in train_pids)}
val_data = {k: v for (k, v) in all_data.items() if any(p == v['pid'] for p in val_pids)}
logger.info("data set loaded with: {} train / {} val patients".format(len(train_pids), len(val_pids)))
batch_gen = {}
batch_gen['train'] = create_data_gen_pipeline(train_data, cf=cf, do_aug=False)
batch_gen['val_sampling'] = create_data_gen_pipeline(val_data, cf=cf, do_aug=False)
if cf.val_mode == 'val_patient':
batch_gen['val_patient'] = PatientBatchIterator(val_data, cf=cf)
- batch_gen['n_val'] = len(val_pids) if cf.max_val_patients is None else cf.max_val_patients
+ batch_gen['n_val'] = len(val_pids) if cf.max_val_patients is None else min(len(val_pids), cf.max_val_patients)
else:
batch_gen['n_val'] = cf.num_val_batches
return batch_gen
def get_test_generator(cf, logger):
"""
wrapper function for creating the test batch generator pipeline.
selects patients according to cv folds (generated by first run/fold of experiment)
If cf.hold_out_test_set is True, gets the data from an external folder instead.
"""
if cf.hold_out_test_set:
- cf.pp_data_path = cf.pp_test_data_path
- cf.pp_name = cf.pp_test_name
+ pp_name = cf.pp_test_name
test_ix = None
else:
+ pp_name = None
with open(os.path.join(cf.exp_dir, 'fold_ids.pickle'), 'rb') as handle:
fold_list = pickle.load(handle)
_, _, test_ix, _ = fold_list[cf.fold]
# warnings.warn('WARNING: using validation set for testing!!!')
- test_data = load_dataset(cf, logger, test_ix)
- logger.info("data set loaded with: {} test patients from {}".format(len(test_data.keys()), cf.pp_data_path))
+ test_data = load_dataset(cf, logger, test_ix, pp_data_path=cf.pp_test_data_path, pp_name=pp_name)
+ logger.info("data set loaded with: {} test patients from {}".format(len(test_data.keys()), cf.pp_test_data_path))
batch_gen = {}
batch_gen['test'] = PatientBatchIterator(test_data, cf=cf)
- batch_gen['n_test'] = len(test_data.keys())
+ batch_gen['n_test'] = len(test_data.keys()) if cf.max_test_patients=="all" else \
+ min(cf.max_test_patients, len(test_data.keys()))
+
return batch_gen
-def load_dataset(cf, logger, subset_ixs=None):
+def load_dataset(cf, logger, subset_ixs=None, pp_data_path=None, pp_name=None):
"""
loads the dataset. if deployed in cloud also copies and unpacks the data to the working directory.
:param subset_ixs: subset indices to be loaded from the dataset. used e.g. for testing to only load the test folds.
:return: data: dictionary with one entry per patient (in this case per patient-breast, since they are treated as
individual images for training) each entry is a dictionary containing respective meta-info as well as paths to the preprocessed
numpy arrays to be loaded during batch-generation
"""
+ if pp_data_path is None:
+ pp_data_path = cf.pp_data_path
+ if pp_name is None:
+ pp_name = cf.pp_name
if cf.server_env:
copy_data = True
- target_dir = os.path.join('/ssd', cf.slurm_job_id, cf.pp_name)
+ target_dir = os.path.join('/ssd', cf.slurm_job_id, pp_name)
if not os.path.exists(target_dir):
- cf.data_source_dir = cf.pp_data_path
+ cf.data_source_dir = pp_data_path
os.makedirs(target_dir)
subprocess.call('rsync -av {} {}'.format(
os.path.join(cf.data_source_dir, cf.input_df_name), os.path.join(target_dir, cf.input_df_name)), shell=True)
logger.info('created target dir and info df at {}'.format(os.path.join(target_dir, cf.input_df_name)))
elif subset_ixs is None:
copy_data = False
- cf.pp_data_path = target_dir
+ pp_data_path = target_dir
- p_df = pd.read_pickle(os.path.join(cf.pp_data_path, cf.input_df_name))
+ p_df = pd.read_pickle(os.path.join(pp_data_path, cf.input_df_name))
if subset_ixs is not None:
subset_pids = [np.unique(p_df.pid.tolist())[ix] for ix in subset_ixs]
p_df = p_df[p_df.pid.isin(subset_pids)]
logger.info('subset: selected {} instances from df'.format(len(p_df)))
if cf.server_env:
if copy_data:
copy_and_unpack_data(logger, p_df.pid.tolist(), cf.fold_dir, cf.data_source_dir, target_dir)
class_targets = p_df['class_id'].tolist()
pids = p_df.pid.tolist()
- imgs = [os.path.join(cf.pp_data_path, '{}.npy'.format(pid)) for pid in pids]
- segs = [os.path.join(cf.pp_data_path,'{}.npy'.format(pid)) for pid in pids]
+ imgs = [os.path.join(pp_data_path, '{}.npy'.format(pid)) for pid in pids]
+ segs = [os.path.join(pp_data_path,'{}.npy'.format(pid)) for pid in pids]
data = OrderedDict()
for ix, pid in enumerate(pids):
data[pid] = {'data': imgs[ix], 'seg': segs[ix], 'pid': pid, 'class_target': [class_targets[ix]]}
return data
def create_data_gen_pipeline(patient_data, cf, do_aug=True):
"""
create mutli-threaded train/val/test batch generation and augmentation pipeline.
:param patient_data: dictionary containing one dictionary per patient in the train/test subset.
:param is_training: (optional) whether to perform data augmentation (training) or not (validation/testing)
:return: multithreaded_generator
"""
# create instance of batch generator as first element in pipeline.
data_gen = BatchGenerator(patient_data, batch_size=cf.batch_size, cf=cf)
# add transformations to pipeline.
my_transforms = []
if do_aug:
mirror_transform = Mirror(axes=np.arange(2, cf.dim+2, 1))
my_transforms.append(mirror_transform)
spatial_transform = SpatialTransform(patch_size=cf.patch_size[:cf.dim],
patch_center_dist_from_border=cf.da_kwargs['rand_crop_dist'],
do_elastic_deform=cf.da_kwargs['do_elastic_deform'],
alpha=cf.da_kwargs['alpha'], sigma=cf.da_kwargs['sigma'],
do_rotation=cf.da_kwargs['do_rotation'], angle_x=cf.da_kwargs['angle_x'],
angle_y=cf.da_kwargs['angle_y'], angle_z=cf.da_kwargs['angle_z'],
do_scale=cf.da_kwargs['do_scale'], scale=cf.da_kwargs['scale'],
random_crop=cf.da_kwargs['random_crop'])
my_transforms.append(spatial_transform)
else:
my_transforms.append(CenterCropTransform(crop_size=cf.patch_size[:cf.dim]))
my_transforms.append(ConvertSegToBoundingBoxCoordinates(cf.dim, get_rois_from_seg_flag=False, class_specific_seg_flag=cf.class_specific_seg_flag))
all_transforms = Compose(my_transforms)
# multithreaded_generator = SingleThreadedAugmenter(data_gen, all_transforms)
multithreaded_generator = MultiThreadedAugmenter(data_gen, all_transforms, num_processes=cf.n_workers, seeds=range(cf.n_workers))
return multithreaded_generator
class BatchGenerator(SlimDataLoaderBase):
"""
creates the training/validation batch generator. Samples n_batch_size patients (draws a slice from each patient if 2D)
from the data set while maintaining foreground-class balance. Returned patches are cropped/padded to pre_crop_size.
Actual patch_size is obtained after data augmentation.
:param data: data dictionary as provided by 'load_dataset'.
:param batch_size: number of patients to sample for the batch
:return dictionary containing the batch data (b, c, x, y, (z)) / seg (b, 1, x, y, (z)) / pids / class_target
"""
def __init__(self, data, batch_size, cf):
super(BatchGenerator, self).__init__(data, batch_size)
self.cf = cf
def generate_train_batch(self):
batch_data, batch_segs, batch_pids, batch_targets = [], [], [], []
class_targets_list = [v['class_target'] for (k, v) in self._data.items()]
#samples patients towards equilibrium of foreground classes on a roi-level (after randomly sampling the ratio "batch_sample_slack).
batch_ixs = dutils.get_class_balanced_patients(
class_targets_list, self.batch_size, self.cf.head_classes - 1, slack_factor=self.cf.batch_sample_slack)
patients = list(self._data.items())
for b in batch_ixs:
patient = patients[b][1]
all_data = np.load(patient['data'], mmap_mode='r')
data = all_data[0]
seg = all_data[1].astype('uint8')
batch_pids.append(patient['pid'])
batch_targets.append(patient['class_target'])
batch_data.append(data[np.newaxis])
batch_segs.append(seg[np.newaxis])
data = np.array(batch_data)
seg = np.array(batch_segs).astype(np.uint8)
class_target = np.array(batch_targets)
return {'data': data, 'seg': seg, 'pid': batch_pids, 'class_target': class_target}
class PatientBatchIterator(SlimDataLoaderBase):
"""
creates a test generator that iterates over entire given dataset returning 1 patient per batch.
Can be used for monitoring if cf.val_mode = 'patient_val' for a monitoring closer to actualy evaluation (done in 3D),
if willing to accept speed-loss during training.
:return: out_batch: dictionary containing one patient with batch_size = n_3D_patches in 3D or
batch_size = n_2D_patches in 2D .
"""
def __init__(self, data, cf): #threads in augmenter
super(PatientBatchIterator, self).__init__(data, 0)
self.cf = cf
self.patient_ix = 0
self.dataset_pids = [v['pid'] for (k, v) in data.items()]
self.patch_size = cf.patch_size
if len(self.patch_size) == 2:
self.patch_size = self.patch_size + [1]
def generate_train_batch(self):
-
pid = self.dataset_pids[self.patient_ix]
patient = self._data[pid]
all_data = np.load(patient['data'], mmap_mode='r')
data = all_data[0]
seg = all_data[1].astype('uint8')
batch_class_targets = np.array([patient['class_target']])
out_data = data[None, None]
out_seg = seg[None, None]
print('check patient data loader', out_data.shape, out_seg.shape)
batch_2D = {'data': out_data, 'seg': out_seg, 'class_target': batch_class_targets, 'pid': pid}
converter = ConvertSegToBoundingBoxCoordinates(dim=2, get_rois_from_seg_flag=False, class_specific_seg_flag=self.cf.class_specific_seg_flag)
batch_2D = converter(**batch_2D)
batch_2D.update({'patient_bb_target': batch_2D['bb_target'],
'patient_roi_labels': batch_2D['roi_labels'],
'original_img_shape': out_data.shape})
self.patient_ix += 1
if self.patient_ix == len(self.dataset_pids):
self.patient_ix = 0
return batch_2D
def copy_and_unpack_data(logger, pids, fold_dir, source_dir, target_dir):
start_time = time.time()
with open(os.path.join(fold_dir, 'file_list.txt'), 'w') as handle:
for pid in pids:
handle.write('{}.npy\n'.format(pid))
subprocess.call('rsync -av --files-from {} {} {}'.format(os.path.join(fold_dir, 'file_list.txt'),
source_dir, target_dir), shell=True)
# dutils.unpack_dataset(target_dir)
copied_files = os.listdir(target_dir)
logger.info("copying and unpacking data set finsihed : {} files in target dir: {}. took {} sec".format(
len(copied_files), target_dir, np.round(time.time() - start_time, 0)))
+if __name__=="__main__":
+ import utils.exp_utils as utils
+ from .configs import Configs
+
+ total_stime = time.time()
+
+
+ cf = Configs()
+ logger = utils.get_logger(0)
+ batch_gen = get_train_generators(cf, logger)
+
+ train_batch = next(batch_gen["train"])
+
+
+ mins, secs = divmod((time.time() - total_stime), 60)
+ h, mins = divmod(mins, 60)
+ t = "{:d}h:{:02d}m:{:02d}s".format(int(h), int(mins), int(secs))
+ print("{} total runtime: {}".format(os.path.split(__file__)[1], t))
\ No newline at end of file
diff --git a/models/detection_unet.py b/models/detection_unet.py
index 0e58fdd..db1025a 100644
--- a/models/detection_unet.py
+++ b/models/detection_unet.py
@@ -1,214 +1,214 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Unet-like Backbone architecture, with non-parametric heuristics for box detection on semantic segmentation outputs.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from scipy.ndimage.measurements import label as lb
import numpy as np
import utils.exp_utils as utils
import utils.model_utils as mutils
class net(nn.Module):
def __init__(self, cf, logger):
super(net, self).__init__()
self.cf = cf
self.logger = logger
backbone = utils.import_module('bbone', cf.backbone_path)
conv = mutils.NDConvGenerator(cf.dim)
# set operate_stride1=True to generate a unet-like FPN.)
self.fpn = backbone.FPN(cf, conv, operate_stride1=True).cuda()
self.conv_final = conv(cf.end_filts, cf.num_seg_classes, ks=1, pad=0, norm=cf.norm, relu=None)
if self.cf.weight_init is not None:
logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
mutils.initialize_weights(self)
else:
logger.info("using default pytorch weight init")
def forward(self, x):
"""
forward pass of network.
:param x: input image. shape (b, c, y, x, (z))
:return: seg_logits: shape (b, n_classes, y, x, (z))
:return: out_box_coords: list over n_classes. elements are arrays(b, n_rois, (y1, x1, y2, x2, (z1), (z2)))
:return: out_max_scores: list over n_classes. elements are arrays(b, n_rois)
"""
out_features = self.fpn(x)[0]
seg_logits = self.conv_final(out_features)
out_box_coords, out_max_scores = [], []
smax = F.softmax(seg_logits, dim=1).detach().cpu().data.numpy()
for cl in range(1, len(self.cf.class_dict.keys()) + 1):
max_scores = [[] for _ in range(x.shape[0])]
hard_mask = np.copy(smax).argmax(1)
hard_mask[hard_mask != cl] = 0
hard_mask[hard_mask == cl] = 1
# perform connected component analysis on argmaxed predictions,
# draw boxes around components and return coordinates.
box_coords, rois = get_coords(hard_mask, self.cf.n_roi_candidates, self.cf.dim)
# for each object, choose the highest softmax score (in the respective class)
# of all pixels in the component as object score.
for bix, broi in enumerate(rois):
for nix, nroi in enumerate(broi):
component_score = np.max(smax[bix, cl][nroi > 0]) if self.cf.aggregation_operation == 'max' \
else np.median(smax[bix, cl][nroi > 0])
max_scores[bix].append(component_score)
out_box_coords.append(box_coords)
out_max_scores.append(max_scores)
return seg_logits, out_box_coords, out_max_scores
def train_forward(self, batch, **kwargs):
"""
train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
for processing, computes losses, and stores outputs in a dictionary.
:param batch: dictionary containing 'data', 'seg', etc.
:param kwargs:
:return: results_dict: dictionary with keys:
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
'monitor_values': dict of values to be monitored.
"""
img = batch['data']
seg = batch['seg']
var_img = torch.FloatTensor(img).cuda()
var_seg = torch.FloatTensor(seg).cuda().long()
var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(seg, self.cf.num_seg_classes)).cuda()
results_dict = {}
seg_logits, box_coords, max_scores = self.forward(var_img)
results_dict['boxes'] = [[] for _ in range(img.shape[0])]
for cix in range(len(self.cf.class_dict.keys())):
for bix in range(img.shape[0]):
for rix in range(len(max_scores[cix][bix])):
if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
'box_score': max_scores[cix][bix][rix],
'box_pred_class_id': cix + 1, # add 0 for background.
'box_type': 'det'})
for bix in range(img.shape[0]):
for tix in range(len(batch['bb_target'][bix])):
results_dict['boxes'][bix].append({'box_coords': batch['bb_target'][bix][tix],
'box_label': batch['roi_labels'][bix][tix],
'box_type': 'gt'})
# compute segmentation loss as either weighted cross entropy, dice loss, or the sum of both.
loss = torch.FloatTensor([0]).cuda()
if self.cf.seg_loss_mode == 'dice' or self.cf.seg_loss_mode == 'dice_wce':
loss += 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1), var_seg_ohe,
false_positive_weight=float(self.cf.fp_dice_weight))
if self.cf.seg_loss_mode == 'wce' or self.cf.seg_loss_mode == 'dice_wce':
loss += F.cross_entropy(seg_logits, var_seg[:, 0], weight=torch.tensor(self.cf.wce_weights).float().cuda())
results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis]
results_dict['torch_loss'] = loss
- results_dict['monitor_extra_values'] = {'loss': loss.item()}
+ results_dict['monitor_values'] = {'loss': loss.item()}
results_dict['logger_string'] = "loss: {0:.2f}".format(loss.item())
return results_dict
def test_forward(self, batch, **kwargs):
"""
test method. wrapper around forward pass of network without usage of any ground truth information.
prepares input data for processing and stores outputs in a dictionary.
:param batch: dictionary containing 'data'
:param kwargs:
:return: results_dict: dictionary with keys:
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, n_classes]
"""
img = batch['data']
var_img = torch.FloatTensor(img).cuda()
seg_logits, box_coords, max_scores = self.forward(var_img)
results_dict = {}
results_dict['boxes'] = [[] for _ in range(img.shape[0])]
for cix in range(len(self.cf.class_dict.keys())):
for bix in range(img.shape[0]):
for rix in range(len(max_scores[cix][bix])):
if max_scores[cix][bix][rix] > self.cf.detection_min_confidence:
results_dict['boxes'][bix].append({'box_coords': np.copy(box_coords[cix][bix][rix]),
'box_score': max_scores[cix][bix][rix],
'box_pred_class_id': cix + 1, # add 0 for background.
'box_type': 'det'})
results_dict['seg_preds'] = np.argmax(F.softmax(seg_logits, 1).cpu().data.numpy(), 1)[:, np.newaxis].astype('uint8')
return results_dict
def get_coords(binary_mask, n_components, dim):
"""
loops over batch to perform connected component analysis on binary input mask. computes box coordiantes around
n_components - biggest components (rois).
:param binary_mask: (b, y, x, (z)). binary mask for one specific foreground class.
:param n_components: int. number of components to extract per batch element and class.
:return: coords (b, n, (y1, x1, y2, x2, (z1), (z2))
:return: batch_components (b, n, (y1, x1, y2, x2, (z1), (z2))
"""
binary_mask = binary_mask.astype('uint8')
batch_coords = []
batch_components = []
for ix, b in enumerate(binary_mask):
clusters, n_cands = lb(b) # peforms connected component analysis.
uniques, counts = np.unique(clusters, return_counts=True)
# only keep n_components largest components.
keep_uniques = uniques[1:][np.argsort(counts[1:])[::-1]][:n_components]
# separate clusters and concat.
p_components = np.array([(clusters == ii) * 1 for ii in keep_uniques])
p_coords = []
if p_components.shape[0] > 0:
for roi in p_components:
mask_ixs = np.argwhere(roi != 0)
# get coordinates around component.
roi_coords = [np.min(mask_ixs[:, 0]) - 1, np.min(mask_ixs[:, 1]) - 1, np.max(mask_ixs[:, 0]) + 1,
np.max(mask_ixs[:, 1]) + 1]
if dim == 3:
roi_coords += [np.min(mask_ixs[:, 2]), np.max(mask_ixs[:, 2])+1]
p_coords.append(roi_coords)
p_coords = np.array(p_coords)
# clip coords.
p_coords[p_coords < 0] = 0
p_coords[:, :4][p_coords[:, :4] > binary_mask.shape[-2]] = binary_mask.shape[-2]
if dim == 3:
p_coords[:, 4:][p_coords[:, 4:] > binary_mask.shape[-1]] = binary_mask.shape[-1]
batch_coords.append(p_coords)
batch_components.append(p_components)
return batch_coords, batch_components
diff --git a/models/retina_unet.py b/models/retina_unet.py
index ae43d19..1eec628 100644
--- a/models/retina_unet.py
+++ b/models/retina_unet.py
@@ -1,513 +1,513 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""
Retina Net. According to https://arxiv.org/abs/1708.02002
Retina U-Net. According to https://arxiv.org/abs/1811.08661
"""
import utils.model_utils as mutils
import utils.exp_utils as utils
import sys
sys.path.append('../')
from cuda_functions.nms_2D.pth_nms import nms_gpu as nms_2D
from cuda_functions.nms_3D.pth_nms import nms_gpu as nms_3D
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
############################################################
# Network Heads
############################################################
class Classifier(nn.Module):
def __init__(self, cf, conv):
"""
Builds the classifier sub-network.
"""
super(Classifier, self).__init__()
self.dim = conv.dim
self.n_classes = cf.head_classes
n_input_channels = cf.end_filts
n_features = cf.n_rpn_features
n_output_channels = cf.n_anchors_per_pos * cf.head_classes
anchor_stride = cf.rpn_anchor_stride
self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride, pad=1, relu=None)
def forward(self, x):
"""
:param x: input feature map (b, in_c, y, x, (z))
:return: class_logits (b, n_anchors, n_classes)
"""
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
x = self.conv_4(x)
class_logits = self.conv_final(x)
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
class_logits = class_logits.permute(*axes)
class_logits = class_logits.contiguous()
class_logits = class_logits.view(x.size()[0], -1, self.n_classes)
return [class_logits]
class BBRegressor(nn.Module):
def __init__(self, cf, conv):
"""
Builds the bb-regression sub-network.
"""
super(BBRegressor, self).__init__()
self.dim = conv.dim
n_input_channels = cf.end_filts
n_features = cf.n_rpn_features
n_output_channels = cf.n_anchors_per_pos * self.dim * 2
anchor_stride = cf.rpn_anchor_stride
self.conv_1 = conv(n_input_channels, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_2 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_3 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_4 = conv(n_features, n_features, ks=3, stride=anchor_stride, pad=1, relu=cf.relu)
self.conv_final = conv(n_features, n_output_channels, ks=3, stride=anchor_stride,
pad=1, relu=None)
def forward(self, x):
"""
:param x: input feature map (b, in_c, y, x, (z))
:return: bb_logits (b, n_anchors, dim * 2)
"""
x = self.conv_1(x)
x = self.conv_2(x)
x = self.conv_3(x)
x = self.conv_4(x)
bb_logits = self.conv_final(x)
axes = (0, 2, 3, 1) if self.dim == 2 else (0, 2, 3, 4, 1)
bb_logits = bb_logits.permute(*axes)
bb_logits = bb_logits.contiguous()
bb_logits = bb_logits.view(x.size()[0], -1, self.dim * 2)
return [bb_logits]
############################################################
# Loss Functions
############################################################
def compute_class_loss(anchor_matches, class_pred_logits, shem_poolsize=20):
"""
:param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors.
:param class_pred_logits: (n_anchors, n_classes). logits from classifier sub-network.
:param shem_poolsize: int. factor of top-k candidates to draw from per negative sample (online-hard-example-mining).
:return: loss: torch tensor.
:return: np_neg_ix: 1D array containing indices of the neg_roi_logits, which have been sampled for training.
"""
# Positive and Negative anchors contribute to the loss,
# but neutral anchors (match value = 0) don't.
pos_indices = torch.nonzero(anchor_matches > 0)
neg_indices = torch.nonzero(anchor_matches == -1)
# get positive samples and calucalte loss.
if 0 not in pos_indices.size():
pos_indices = pos_indices.squeeze(1)
roi_logits_pos = class_pred_logits[pos_indices]
targets_pos = anchor_matches[pos_indices]
pos_loss = F.cross_entropy(roi_logits_pos, targets_pos.long())
else:
pos_loss = torch.FloatTensor([0]).cuda()
# get negative samples, such that the amount matches the number of positive samples, but at least 1.
# get high scoring negatives by applying online-hard-example-mining.
if 0 not in neg_indices.size():
neg_indices = neg_indices.squeeze(1)
roi_logits_neg = class_pred_logits[neg_indices]
negative_count = np.max((1, pos_indices.size()[0]))
roi_probs_neg = F.softmax(roi_logits_neg, dim=1)
neg_ix = mutils.shem(roi_probs_neg, negative_count, shem_poolsize)
neg_loss = F.cross_entropy(roi_logits_neg[neg_ix], torch.LongTensor([0] * neg_ix.shape[0]).cuda())
# return the indices of negative samples, which contributed to the loss (for monitoring plots).
np_neg_ix = neg_ix.cpu().data.numpy()
else:
neg_loss = torch.FloatTensor([0]).cuda()
np_neg_ix = np.array([]).astype('int32')
loss = (pos_loss + neg_loss) / 2
return loss, np_neg_ix
def compute_bbox_loss(target_deltas, pred_deltas, anchor_matches):
"""
:param target_deltas: (b, n_positive_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd)))).
Uses 0 padding to fill in unsed bbox deltas.
:param pred_deltas: predicted deltas from bbox regression head. (b, n_anchors, (dy, dx, (dz), log(dh), log(dw), (log(dd))))
:param anchor_matches: (n_anchors). [-1, 0, class_id] for negative, neutral, and positive matched anchors.
:return: loss: torch 1D tensor.
"""
if 0 not in torch.nonzero(anchor_matches > 0).size():
indices = torch.nonzero(anchor_matches > 0).squeeze(1)
# Pick bbox deltas that contribute to the loss
pred_deltas = pred_deltas[indices]
# Trim target bounding box deltas to the same length as pred_deltas.
target_deltas = target_deltas[:pred_deltas.size()[0], :]
# Smooth L1 loss
loss = F.smooth_l1_loss(pred_deltas, target_deltas)
else:
loss = torch.FloatTensor([0]).cuda()
return loss
############################################################
# Output Handler
############################################################
def refine_detections(anchors, probs, deltas, batch_ixs, cf):
"""
Refine classified proposals, filter overlaps and return final
detections. n_proposals here is typically a very large number: batch_size * n_anchors.
This function is hence optimized on trimming down n_proposals.
:param anchors: (n_anchors, 2 * dim)
:param probs: (n_proposals, n_classes) softmax probabilities for all rois as predicted by classifier head.
:param deltas: (n_proposals, n_classes, 2 * dim) box refinement deltas as predicted by bbox regressor head.
:param batch_ixs: (n_proposals) batch element assignemnt info for re-allocation.
:return: result: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score))
"""
anchors = anchors.repeat(len(np.unique(batch_ixs)), 1)
# flatten foreground probabilities, sort and trim down to highest confidences by pre_nms limit.
fg_probs = probs[:, 1:].contiguous()
flat_probs, flat_probs_order = fg_probs.view(-1).sort(descending=True)
keep_ix = flat_probs_order[:cf.pre_nms_limit]
# reshape indices to 2D index array with shape like fg_probs.
keep_arr = torch.cat(((keep_ix / fg_probs.shape[1]).unsqueeze(1), (keep_ix % fg_probs.shape[1]).unsqueeze(1)), 1)
pre_nms_scores = flat_probs[:cf.pre_nms_limit]
pre_nms_class_ids = keep_arr[:, 1] + 1 # add background again.
pre_nms_batch_ixs = batch_ixs[keep_arr[:, 0]]
pre_nms_anchors = anchors[keep_arr[:, 0]]
pre_nms_deltas = deltas[keep_arr[:, 0]]
keep = torch.arange(pre_nms_scores.size()[0]).long().cuda()
# apply bounding box deltas. re-scale to image coordinates.
std_dev = torch.from_numpy(np.reshape(cf.rpn_bbox_std_dev, [1, cf.dim * 2])).float().cuda()
scale = torch.from_numpy(cf.scale).float().cuda()
refined_rois = mutils.apply_box_deltas_2D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale \
if cf.dim == 2 else mutils.apply_box_deltas_3D(pre_nms_anchors / scale, pre_nms_deltas * std_dev) * scale
# round and cast to int since we're deadling with pixels now
refined_rois = mutils.clip_to_window(cf.window, refined_rois)
pre_nms_rois = torch.round(refined_rois)
for j, b in enumerate(mutils.unique1d(pre_nms_batch_ixs)):
bixs = torch.nonzero(pre_nms_batch_ixs == b)[:, 0]
bix_class_ids = pre_nms_class_ids[bixs]
bix_rois = pre_nms_rois[bixs]
bix_scores = pre_nms_scores[bixs]
for i, class_id in enumerate(mutils.unique1d(bix_class_ids)):
ixs = torch.nonzero(bix_class_ids == class_id)[:, 0]
# nms expects boxes sorted by score.
ix_rois = bix_rois[ixs]
ix_scores = bix_scores[ixs]
ix_scores, order = ix_scores.sort(descending=True)
ix_rois = ix_rois[order, :]
ix_scores = ix_scores
if cf.dim == 2:
class_keep = nms_2D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold)
else:
class_keep = nms_3D(torch.cat((ix_rois, ix_scores.unsqueeze(1)), dim=1), cf.detection_nms_threshold)
# map indices back.
class_keep = keep[bixs[ixs[order[class_keep]]]]
# merge indices over classes for current batch element
b_keep = class_keep if i == 0 else mutils.unique1d(torch.cat((b_keep, class_keep)))
# only keep top-k boxes of current batch-element.
top_ids = pre_nms_scores[b_keep].sort(descending=True)[1][:cf.model_max_instances_per_batch_element]
b_keep = b_keep[top_ids]
# merge indices over batch elements.
batch_keep = b_keep if j == 0 else mutils.unique1d(torch.cat((batch_keep, b_keep)))
keep = batch_keep
# arrange output.
result = torch.cat((pre_nms_rois[keep],
pre_nms_batch_ixs[keep].unsqueeze(1).float(),
pre_nms_class_ids[keep].unsqueeze(1).float(),
pre_nms_scores[keep].unsqueeze(1)), dim=1)
return result
def get_results(cf, img_shape, detections, seg_logits, box_results_list=None):
"""
Restores batch dimension of merged detections, unmolds detections, creates and fills results dict.
:param img_shape:
:param detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
:param box_results_list: None or list of output boxes for monitoring/plotting.
each element is a list of boxes per batch element.
:return: results_dict: dictionary with keys:
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for
retina_unet and dummy array for retina_net.
"""
detections = detections.cpu().data.numpy()
batch_ixs = detections[:, cf.dim*2]
detections = [detections[batch_ixs == ix] for ix in range(img_shape[0])]
# for test_forward, where no previous list exists.
if box_results_list is None:
box_results_list = [[] for _ in range(img_shape[0])]
for ix in range(img_shape[0]):
if 0 not in detections[ix].shape:
boxes = detections[ix][:, :2 * cf.dim].astype(np.int32)
class_ids = detections[ix][:, 2 * cf.dim + 1].astype(np.int32)
scores = detections[ix][:, 2 * cf.dim + 2]
# Filter out detections with zero area. Often only happens in early
# stages of training when the network weights are still a bit random.
if cf.dim == 2:
exclude_ix = np.where((boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) <= 0)[0]
else:
exclude_ix = np.where(
(boxes[:, 2] - boxes[:, 0]) * (boxes[:, 3] - boxes[:, 1]) * (boxes[:, 5] - boxes[:, 4]) <= 0)[0]
if exclude_ix.shape[0] > 0:
boxes = np.delete(boxes, exclude_ix, axis=0)
class_ids = np.delete(class_ids, exclude_ix, axis=0)
scores = np.delete(scores, exclude_ix, axis=0)
if 0 not in boxes.shape:
for ix2, score in enumerate(scores):
if score >= cf.model_min_confidence:
box_results_list[ix].append({'box_coords': boxes[ix2],
'box_score': score,
'box_type': 'det',
'box_pred_class_id': class_ids[ix2]})
results_dict = {'boxes': box_results_list}
if seg_logits is None:
# output dummy segmentation for retina_net.
results_dict['seg_preds'] = np.zeros(img_shape)[:, 0][:, np.newaxis]
else:
# output label maps for retina_unet.
results_dict['seg_preds'] = F.softmax(seg_logits, 1).argmax(1).cpu().data.numpy()[:, np.newaxis].astype('uint8')
return results_dict
############################################################
# Retina (U-)Net Class
############################################################
class net(nn.Module):
def __init__(self, cf, logger):
super(net, self).__init__()
self.cf = cf
self.logger = logger
self.build()
if self.cf.weight_init is not None:
logger.info("using pytorch weight init of type {}".format(self.cf.weight_init))
mutils.initialize_weights(self)
else:
logger.info("using default pytorch weight init")
def build(self):
"""
Build Retina Net architecture.
"""
# Image size must be dividable by 2 multiple times.
h, w = self.cf.patch_size[:2]
if h / 2 ** 5 != int(h / 2 ** 5) or w / 2 ** 5 != int(w / 2 ** 5):
raise Exception("Image size must be dividable by 2 at least 5 times "
"to avoid fractions when downscaling and upscaling."
"For example, use 256, 320, 384, 448, 512, ... etc. ")
# instanciate abstract multi dimensional conv class and backbone model.
conv = mutils.NDConvGenerator(self.cf.dim)
backbone = utils.import_module('bbone', self.cf.backbone_path)
# build Anchors, FPN, Classifier / Bbox-Regressor -head
self.np_anchors = mutils.generate_pyramid_anchors(self.logger, self.cf)
self.anchors = torch.from_numpy(self.np_anchors).float().cuda()
self.Fpn = backbone.FPN(self.cf, conv, operate_stride1=self.cf.operate_stride1)
self.Classifier = Classifier(self.cf, conv)
self.BBRegressor = BBRegressor(self.cf, conv)
- self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=False, relu=None)
+ self.final_conv = conv(self.cf.end_filts, self.cf.num_seg_classes, ks=1, pad=0, norm=None, relu=None)
def train_forward(self, batch, **kwargs):
"""
train method (also used for validation monitoring). wrapper around forward pass of network. prepares input data
for processing, computes losses, and stores outputs in a dictionary.
:param batch: dictionary containing 'data', 'seg', etc.
:return: results_dict: dictionary with keys:
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
'seg_preds': pixelwise segmentation output (b, c, y, x, (z)) with values [0, .., n_classes].
'monitor_values': dict of values to be monitored.
"""
img = batch['data']
gt_class_ids = batch['roi_labels']
gt_boxes = batch['bb_target']
var_seg_ohe = torch.FloatTensor(mutils.get_one_hot_encoding(batch['seg'], self.cf.num_seg_classes)).cuda()
var_seg = torch.LongTensor(batch['seg']).cuda()
img = torch.from_numpy(img).float().cuda()
batch_class_loss = torch.FloatTensor([0]).cuda()
batch_bbox_loss = torch.FloatTensor([0]).cuda()
# list of output boxes for monitoring/plotting. each element is a list of boxes per batch element.
box_results_list = [[] for _ in range(img.shape[0])]
detections, class_logits, pred_deltas, seg_logits = self.forward(img)
# loop over batch
for b in range(img.shape[0]):
# add gt boxes to results dict for monitoring.
if len(gt_boxes[b]) > 0:
for ix in range(len(gt_boxes[b])):
box_results_list[b].append({'box_coords': batch['bb_target'][b][ix],
'box_label': batch['roi_labels'][b][ix], 'box_type': 'gt'})
# match gt boxes with anchors to generate targets.
anchor_class_match, anchor_target_deltas = mutils.gt_anchor_matching(
self.cf, self.np_anchors, gt_boxes[b], gt_class_ids[b])
# add positive anchors used for loss to results_dict for monitoring.
pos_anchors = mutils.clip_boxes_numpy(
self.np_anchors[np.argwhere(anchor_class_match > 0)][:, 0], img.shape[2:])
for p in pos_anchors:
box_results_list[b].append({'box_coords': p, 'box_type': 'pos_anchor'})
else:
anchor_class_match = np.array([-1]*self.np_anchors.shape[0])
anchor_target_deltas = np.array([0])
anchor_class_match = torch.from_numpy(anchor_class_match).cuda()
anchor_target_deltas = torch.from_numpy(anchor_target_deltas).float().cuda()
# compute losses.
class_loss, neg_anchor_ix = compute_class_loss(anchor_class_match, class_logits[b])
bbox_loss = compute_bbox_loss(anchor_target_deltas, pred_deltas[b], anchor_class_match)
# add negative anchors used for loss to results_dict for monitoring.
neg_anchors = mutils.clip_boxes_numpy(
self.np_anchors[np.argwhere(anchor_class_match == -1)][0, neg_anchor_ix], img.shape[2:])
for n in neg_anchors:
box_results_list[b].append({'box_coords': n, 'box_type': 'neg_anchor'})
batch_class_loss += class_loss / img.shape[0]
batch_bbox_loss += bbox_loss / img.shape[0]
results_dict = get_results(self.cf, img.shape, detections, seg_logits, box_results_list)
seg_loss_dice = 1 - mutils.batch_dice(F.softmax(seg_logits, dim=1),var_seg_ohe)
seg_loss_ce = F.cross_entropy(seg_logits, var_seg[:, 0])
loss = batch_class_loss + batch_bbox_loss + (seg_loss_dice + seg_loss_ce) / 2
results_dict['torch_loss'] = loss
results_dict['monitor_values'] = {'loss': loss.item(), 'class_loss': batch_class_loss.item()}
results_dict['logger_string'] = \
"loss: {0:.2f}, class: {1:.2f}, bbox: {2:.2f}, seg dice: {3:.3f}, seg ce: {4:.3f}, mean pix. pr.: {5:.5f}"\
.format(loss.item(), batch_class_loss.item(), batch_bbox_loss.item(), seg_loss_dice.item(),
seg_loss_ce.item(), np.mean(results_dict['seg_preds']))
return results_dict
def test_forward(self, batch, **kwargs):
"""
test method. wrapper around forward pass of network without usage of any ground truth information.
prepares input data for processing and stores outputs in a dictionary.
:param batch: dictionary containing 'data'
:return: results_dict: dictionary with keys:
'boxes': list over batch elements. each batch element is a list of boxes. each box is a dictionary:
[[{box_0}, ... {box_n}], [{box_0}, ... {box_n}], ...]
'seg_preds': pixel-wise class predictions (b, 1, y, x, (z)) with values [0, ..., n_classes] for
retina_unet and dummy array for retina_net.
"""
img = batch['data']
img = torch.from_numpy(img).float().cuda()
detections, _, _, seg_logits = self.forward(img)
results_dict = get_results(self.cf, img.shape, detections, seg_logits)
return results_dict
def forward(self, img):
"""
forward pass of the model.
:param img: input img (b, c, y, x, (z)).
:return: rpn_pred_logits: (b, n_anchors, 2)
:return: rpn_pred_deltas: (b, n_anchors, (y, x, (z), log(h), log(w), (log(d))))
:return: batch_proposal_boxes: (b, n_proposals, (y1, x1, y2, x2, (z1), (z2), batch_ix)) only for monitoring/plotting.
:return: detections: (n_final_detections, (y1, x1, y2, x2, (z1), (z2), batch_ix, pred_class_id, pred_score)
:return: detection_masks: (n_final_detections, n_classes, y, x, (z)) raw molded masks as returned by mask-head.
"""
# Feature extraction
fpn_outs = self.Fpn(img)
seg_logits = self.final_conv(fpn_outs[0])
selected_fmaps = [fpn_outs[i + 1] for i in self.cf.pyramid_levels]
# Loop through pyramid layers
class_layer_outputs, bb_reg_layer_outputs = [], [] # list of lists
for p in selected_fmaps:
class_layer_outputs.append(self.Classifier(p))
bb_reg_layer_outputs.append(self.BBRegressor(p))
# Concatenate layer outputs
# Convert from list of lists of level outputs to list of lists
# of outputs across levels.
# e.g. [[a1, b1, c1], [a2, b2, c2]] => [[a1, a2], [b1, b2], [c1, c2]]
class_logits = list(zip(*class_layer_outputs))
class_logits = [torch.cat(list(o), dim=1) for o in class_logits][0]
bb_outputs = list(zip(*bb_reg_layer_outputs))
bb_outputs = [torch.cat(list(o), dim=1) for o in bb_outputs][0]
# merge batch_dimension and store info in batch_ixs for re-allocation.
batch_ixs = torch.arange(class_logits.shape[0]).unsqueeze(1).repeat(1, class_logits.shape[1]).view(-1).cuda()
flat_class_softmax = F.softmax(class_logits.view(-1, class_logits.shape[-1]), 1)
flat_bb_outputs = bb_outputs.view(-1, bb_outputs.shape[-1])
detections = refine_detections(self.anchors, flat_class_softmax, flat_bb_outputs, batch_ixs, self.cf)
return detections, class_logits, bb_outputs, seg_logits
diff --git a/utils/__pycache__/dataloader_utils.cpython-35.pyc b/utils/__pycache__/dataloader_utils.cpython-35.pyc
deleted file mode 100644
index 7f3ab5b..0000000
Binary files a/utils/__pycache__/dataloader_utils.cpython-35.pyc and /dev/null differ
diff --git a/utils/__pycache__/dataloader_utils.cpython-36.pyc b/utils/__pycache__/dataloader_utils.cpython-36.pyc
deleted file mode 100644
index 8d657a5..0000000
Binary files a/utils/__pycache__/dataloader_utils.cpython-36.pyc and /dev/null differ
diff --git a/utils/__pycache__/exp_utils.cpython-35.pyc b/utils/__pycache__/exp_utils.cpython-35.pyc
deleted file mode 100644
index e1d8a6c..0000000
Binary files a/utils/__pycache__/exp_utils.cpython-35.pyc and /dev/null differ
diff --git a/utils/__pycache__/exp_utils.cpython-36.pyc b/utils/__pycache__/exp_utils.cpython-36.pyc
deleted file mode 100644
index a3a7f35..0000000
Binary files a/utils/__pycache__/exp_utils.cpython-36.pyc and /dev/null differ
diff --git a/utils/__pycache__/model_utils.cpython-35.pyc b/utils/__pycache__/model_utils.cpython-35.pyc
deleted file mode 100644
index 661a944..0000000
Binary files a/utils/__pycache__/model_utils.cpython-35.pyc and /dev/null differ
diff --git a/utils/__pycache__/model_utils.cpython-36.pyc b/utils/__pycache__/model_utils.cpython-36.pyc
deleted file mode 100644
index 5d5d3f1..0000000
Binary files a/utils/__pycache__/model_utils.cpython-36.pyc and /dev/null differ
diff --git a/utils/__pycache__/mrcnn_utils.cpython-36.pyc b/utils/__pycache__/mrcnn_utils.cpython-36.pyc
deleted file mode 100644
index 479538d..0000000
Binary files a/utils/__pycache__/mrcnn_utils.cpython-36.pyc and /dev/null differ
diff --git a/utils/exp_utils.py b/utils/exp_utils.py
index 51bbc05..3c7bf41 100644
--- a/utils/exp_utils.py
+++ b/utils/exp_utils.py
@@ -1,349 +1,346 @@
#!/usr/bin/env python
# Copyright 2018 Division of Medical Image Computing, German Cancer Research Center (DKFZ).
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
import numpy as np
import logging
import subprocess
import os
import torch
from collections import OrderedDict
import plotting
import sys
import importlib.util
import pandas as pd
import pickle
def get_logger(exp_dir):
"""
creates logger instance. writing out info to file and to terminal.
:param exp_dir: experiment directory, where exec.log file is stored.
:return: logger instance.
"""
logger = logging.getLogger('medicaldetectiontoolkit')
logger.setLevel(logging.DEBUG)
log_file = exp_dir + '/exec.log'
hdlr = logging.FileHandler(log_file)
print('Logging to {}'.format(log_file))
logger.addHandler(hdlr)
logger.addHandler(ColorHandler())
logger.propagate = False
return logger
def prep_exp(dataset_path, exp_path, server_env, use_stored_settings=True, is_training=True):
"""
I/O handling, creating of experiment folder structure. Also creates a snapshot of configs/model scripts and copies them to the exp_dir.
This way the exp_dir contains all info needed to conduct an experiment, independent to changes in actual source code. Thus, training/inference of this experiment can be started at anytime. Therefore, the model script is copied back to the source code dir as tmp_model (tmp_backbone).
Provides robust structure for cloud deployment.
:param dataset_path: path to source code for specific data set. (e.g. medicaldetectiontoolkit/lidc_exp)
:param exp_path: path to experiment directory.
:param server_env: boolean flag. pass to configs script for cloud deployment.
:param use_stored_settings: boolean flag. When starting training: If True, starts training from snapshot in existing experiment directory, else creates experiment directory on the fly using configs/model scripts from source code.
:param is_training: boolean flag. distinguishes train vs. inference mode.
:return:
"""
if is_training:
# the first process of an experiment creates the directories and copies the config to exp_path.
if not os.path.exists(exp_path):
os.mkdir(exp_path)
os.mkdir(os.path.join(exp_path, 'plots'))
subprocess.call('cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True)
subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True)
if use_stored_settings:
subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True)
cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
cf = cf_file.configs(server_env)
# only the first process copies the model selcted in configs to exp_path.
if not os.path.isfile(os.path.join(exp_path, 'model.py')):
subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True)
subprocess.call('cp {} {}'.format(os.path.join(cf.backbone_path), os.path.join(exp_path, 'backbone.py')), shell=True)
# copy the snapshot model scripts from exp_dir back to the source_dir as tmp_model / tmp_backbone.
tmp_model_path = os.path.join(cf.source_dir, 'models', 'tmp_model.py')
tmp_backbone_path = os.path.join(cf.source_dir, 'models', 'tmp_backbone.py')
subprocess.call('cp {} {}'.format(os.path.join(exp_path, 'model.py'), tmp_model_path), shell=True)
subprocess.call('cp {} {}'.format(os.path.join(exp_path, 'backbone.py'), tmp_backbone_path), shell=True)
cf.model_path = tmp_model_path
cf.backbone_path = tmp_backbone_path
else:
# run training with source code info and copy snapshot of model to exp_dir for later testing (overwrite scripts if exp_dir already exists.)
cf_file = import_module('cf', os.path.join(dataset_path, 'configs.py'))
cf = cf_file.configs(server_env)
subprocess.call('cp {} {}'.format(cf.model_path, os.path.join(exp_path, 'model.py')), shell=True)
subprocess.call('cp {} {}'.format(cf.backbone_path, os.path.join(exp_path, 'backbone.py')), shell=True)
subprocess.call('cp {} {}'.format('default_configs.py', os.path.join(exp_path, 'default_configs.py')), shell=True)
subprocess.call('cp {} {}'.format(os.path.join(dataset_path, 'configs.py'), os.path.join(exp_path, 'configs.py')), shell=True)
else:
- # for testing copy the snapshot model scripts from exp_dir back to the source_dir as tmp_model / tmp_backbone.
+ # for testing, copy the snapshot model scripts from exp_dir back to the source_dir as tmp_model / tmp_backbone.
cf_file = import_module('cf', os.path.join(exp_path, 'configs.py'))
cf = cf_file.configs(server_env)
- if cf.hold_out_test_set:
- cf.pp_data_path = cf.pp_test_data_path
- cf.pp_name = cf.pp_test_name
tmp_model_path = os.path.join(cf.source_dir, 'models', 'tmp_model.py')
tmp_backbone_path = os.path.join(cf.source_dir, 'models', 'tmp_backbone.py')
subprocess.call('cp {} {}'.format(os.path.join(exp_path, 'model.py'), tmp_model_path), shell=True)
subprocess.call('cp {} {}'.format(os.path.join(exp_path, 'backbone.py'), tmp_backbone_path), shell=True)
cf.model_path = tmp_model_path
cf.backbone_path = tmp_backbone_path
cf.exp_dir = exp_path
cf.test_dir = os.path.join(cf.exp_dir, 'test')
cf.plot_dir = os.path.join(cf.exp_dir, 'plots')
cf.experiment_name = exp_path.split("/")[-1]
cf.server_env = server_env
cf.created_fold_id_pickle = False
return cf
def import_module(name, path):
"""
correct way of importing a module dynamically in python 3.
:param name: name given to module instance.
:param path: path to module.
:return: module: returned module instance.
"""
spec = importlib.util.spec_from_file_location(name, path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
return module
class ModelSelector:
'''
saves a checkpoint after each epoch as 'last_state' (can be loaded to continue interrupted training).
saves the top-k (k=cf.save_n_models) ranked epochs. In inference, predictions of multiple epochs can be ensembled to improve performance.
'''
def __init__(self, cf, logger):
self.cf = cf
self.saved_epochs = [-1] * cf.save_n_models
self.logger = logger
def run_model_selection(self, net, optimizer, monitor_metrics, epoch):
# take the mean over all selection criteria in each epoch
non_nan_scores = np.mean(np.array([[0 if ii is None else ii for ii in monitor_metrics['val'][sc]] for sc in self.cf.model_selection_criteria]), 0)
epochs_scores = [ii for ii in non_nan_scores[1:]]
# ranking of epochs according to model_selection_criterion
epoch_ranking = np.argsort(epochs_scores)[::-1] + 1 #epochs start at 1
# if set in configs, epochs < min_save_thresh are discarded from saving process.
epoch_ranking = epoch_ranking[epoch_ranking >= self.cf.min_save_thresh]
# check if current epoch is among the top-k epchs.
if epoch in epoch_ranking[:self.cf.save_n_models]:
save_dir = os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(epoch))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(net.state_dict(), os.path.join(save_dir, 'params.pth'))
with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle:
pickle.dump(monitor_metrics, handle)
# save epoch_ranking to keep info for inference.
np.save(os.path.join(self.cf.fold_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
self.logger.info(
"saving current epoch {} at rank {}".format(epoch, np.argwhere(epoch_ranking == epoch)))
# delete params of the epoch that just fell out of the top-k epochs.
for se in [int(ii.split('_')[0]) for ii in os.listdir(self.cf.fold_dir) if 'best_checkpoint' in ii]:
if se in epoch_ranking[self.cf.save_n_models:]:
subprocess.call('rm -rf {}'.format(os.path.join(self.cf.fold_dir, '{}_best_checkpoint'.format(se))), shell=True)
self.logger.info('deleting epoch {} at rank {}'.format(se, np.argwhere(epoch_ranking == se)))
state = {
'epoch': epoch,
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
}
# save checkpoint of current epoch.
save_dir = os.path.join(self.cf.fold_dir, 'last_checkpoint'.format(epoch))
if not os.path.exists(save_dir):
os.mkdir(save_dir)
torch.save(state, os.path.join(save_dir, 'params.pth'))
np.save(os.path.join(save_dir, 'epoch_ranking'), epoch_ranking[:self.cf.save_n_models])
with open(os.path.join(save_dir, 'monitor_metrics.pickle'), 'wb') as handle:
pickle.dump(monitor_metrics, handle)
def load_checkpoint(checkpoint_path, net, optimizer):
checkpoint_params = torch.load(os.path.join(checkpoint_path, 'params.pth'))
net.load_state_dict(checkpoint_params['state_dict'])
optimizer.load_state_dict(checkpoint_params['optimizer'])
with open(os.path.join(checkpoint_path, 'monitor_metrics.pickle'), 'rb') as handle:
monitor_metrics = pickle.load(handle)
starting_epoch = checkpoint_params['epoch'] + 1
return starting_epoch, monitor_metrics
def prepare_monitoring(cf):
"""
creates dictionaries, where train/val metrics are stored.
"""
metrics = {}
# first entry for loss dict accounts for epoch starting at 1.
metrics['train'] = OrderedDict()
metrics['val'] = OrderedDict()
metric_classes = []
if 'rois' in cf.report_score_level:
metric_classes.extend([v for k, v in cf.class_dict.items()])
if 'patient' in cf.report_score_level:
metric_classes.extend(['patient'])
for cl in metric_classes:
metrics['train'][cl + '_ap'] = [None]
metrics['val'][cl + '_ap'] = [None]
if cl == 'patient':
metrics['train'][cl + '_auc'] = [None]
metrics['val'][cl + '_auc'] = [None]
metrics['train']['monitor_values'] = [[] for _ in range(cf.num_epochs + 1)]
metrics['val']['monitor_values'] = [[] for _ in range(cf.num_epochs + 1)]
# generate isntance of monitor plot class.
TrainingPlot = plotting.TrainingPlot_2Panel(cf)
return metrics, TrainingPlot
def create_csv_output(results_list, cf, logger):
"""
Write out test set predictions to .csv file. output format is one line per prediction:
PatientID | PredictionID | [y1 x1 y2 x2 (z1) (z2)] | score | pred_classID
Note, that prediction coordinates correspond to images as loaded for training/testing and need to be adapted when
plotted over raw data (before preprocessing/resampling).
:param results_list: [[patient_results, patient_id], [patient_results, patient_id], ...]
"""
logger.info('creating csv output file at {}'.format(os.path.join(cf.exp_dir, 'results.csv')))
predictions_df = pd.DataFrame(columns = ['patientID', 'predictionID', 'coords', 'score', 'pred_classID'])
for r in results_list:
pid = r[1]
#optionally load resampling info from preprocessing to match output predictions with raw data.
#with open(os.path.join(cf.exp_dir, 'test_resampling_info', pid), 'rb') as handle:
# resampling_info = pickle.load(handle)
for bix, box in enumerate(r[0][0]):
assert box['box_type'] == 'det', box['box_type']
coords = box['box_coords']
score = box['box_score']
pred_class_id = box['box_pred_class_id']
out_coords = []
if score >= cf.min_det_thresh:
out_coords.append(coords[0]) #* resampling_info['scale'][0])
out_coords.append(coords[1]) #* resampling_info['scale'][1])
out_coords.append(coords[2]) #* resampling_info['scale'][0])
out_coords.append(coords[3]) #* resampling_info['scale'][1])
if len(coords) > 4:
out_coords.append(coords[4]) #* resampling_info['scale'][2] + resampling_info['z_crop'])
out_coords.append(coords[5]) #* resampling_info['scale'][2] + resampling_info['z_crop'])
predictions_df.loc[len(predictions_df)] = [pid, bix, out_coords, score, pred_class_id]
try:
fold = cf.fold
except:
fold = 'hold_out'
predictions_df.to_csv(os.path.join(cf.exp_dir, 'results_{}.csv'.format(fold)), index=False)
class _AnsiColorizer(object):
"""
A colorizer is an object that loosely wraps around a stream, allowing
callers to write text to the stream in a particular color.
Colorizer classes must implement C{supported()} and C{write(text, color)}.
"""
_colors = dict(black=30, red=31, green=32, yellow=33,
blue=34, magenta=35, cyan=36, white=37, default=39)
def __init__(self, stream):
self.stream = stream
@classmethod
def supported(cls, stream=sys.stdout):
"""
A class method that returns True if the current platform supports
coloring terminal output using this method. Returns False otherwise.
"""
if not stream.isatty():
return False # auto color only on TTYs
try:
import curses
except ImportError:
return False
else:
try:
try:
return curses.tigetnum("colors") > 2
except curses.error:
curses.setupterm()
return curses.tigetnum("colors") > 2
except:
raise
# guess false in case of error
return False
def write(self, text, color):
"""
Write the given text to the stream in the given color.
@param text: Text to be written to the stream.
@param color: A string label for a color. e.g. 'red', 'white'.
"""
color = self._colors[color]
self.stream.write('\x1b[%sm%s\x1b[0m' % (color, text))
class ColorHandler(logging.StreamHandler):
def __init__(self, stream=sys.stdout):
super(ColorHandler, self).__init__(_AnsiColorizer(stream))
def emit(self, record):
msg_colors = {
logging.DEBUG: "green",
logging.INFO: "default",
logging.WARNING: "red",
logging.ERROR: "red"
}
color = msg_colors.get(record.levelno, "blue")
self.stream.write(record.msg + "\n", color)