diff --git a/.ipynb_checkpoints/gif_writer-checkpoint.ipynb b/.ipynb_checkpoints/gif_writer-checkpoint.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..45e4eb0ec91de3013b1ea1b020564012bb9fd7b8 --- /dev/null +++ b/.ipynb_checkpoints/gif_writer-checkpoint.ipynb @@ -0,0 +1,169 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [ + { + "ename": "ModuleNotFoundError", + "evalue": "No module named 'tqdm'", + "output_type": "error", + "traceback": [ + "\u001b[1;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[1;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m\u001b[0m in \u001b[0;36m\u001b[1;34m()\u001b[0m\n\u001b[0;32m 3\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mcv2\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0;32m 4\u001b[0m \u001b[1;32mfrom\u001b[0m \u001b[0mdatetime\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mdatetime\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[1;32m----> 5\u001b[1;33m \u001b[1;32mfrom\u001b[0m \u001b[0mtqdm\u001b[0m \u001b[1;32mimport\u001b[0m \u001b[0mtqdm_notebook\u001b[0m \u001b[1;32mas\u001b[0m \u001b[0mtqdm\u001b[0m\u001b[1;33m\u001b[0m\u001b[0m\n\u001b[0m", + "\u001b[1;31mModuleNotFoundError\u001b[0m: No module named 'tqdm'" + ] + } + ], + "source": [ + "import imageio\n", + "import os\n", + "import cv2\n", + "from datetime import datetime\n", + "from tqdm import tqdm_notebook as tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "in_dir = r'C:\\Users\\Wheeler\\Desktop\\LCL_software\\Experiments\\green_lysis2'\n", + "out_loc = r'out_loc = rC:\\Users\\Wheeler\\Desktop\\LCL_software\\Experiments\\green_lysis2.avi'\n", + "images = []\n", + "clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n", + "files = [os.path.join(in_dir, f) for f in os.listdir(in_dir) if '.tif' in f] \n", + "\n", + "\n", + "def get_time(file_name):\n", + "# t = file_name.split('_')[-2].split('_')[0]\n", + " t=file_name.split('_')[-1].split('.tif')[0]\n", + " t = datetime.strptime(t, '%H.%M.%S.%f')\n", + " t = t.hour*3600 + t.minute*60 + t.second + t.microsecond/10**6\n", + " return t\n", + " \n", + "\n", + "files.sort(key=get_time,reverse = False)\n", + "# for file in files:\n", + "# print('processing...',file.split('/')[-1])\n", + "# img = cv2.imread(file,0)\n", + "# img = cv2.resize(img,(int(img.shape[1]/3),int(img.shape[0]/3)),interpolation = cv2.INTER_CUBIC)\n", + "# img = clahe.apply(img)\n", + "# images.append(img)\n", + "# images = [images[0]] * 2 + images\n", + "# imageio.mimsave(out_loc, images,duration=.2)\n", + "# print('done')" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(822, 1024, 3)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "85c8e11a41554d8ea70ac824e6d02fcd", + "version_major": 2, + "version_minor": 0 + }, + "text/html": [ + "

Failed to display Jupyter Widget of type HBox.

\n", + "

\n", + " If you're reading this message in the Jupyter Notebook or JupyterLab Notebook, it may mean\n", + " that the widgets JavaScript is still loading. If this message persists, it\n", + " likely means that the widgets JavaScript library is either not installed or\n", + " not enabled. See the Jupyter\n", + " Widgets Documentation for setup instructions.\n", + "

\n", + "

\n", + " If you're reading this message in another frontend (for example, a static\n", + " rendering on GitHub or NBViewer),\n", + " it may mean that your frontend doesn't currently support widgets.\n", + "

\n" + ], + "text/plain": [ + "HBox(children=(IntProgress(value=0, max=1278), HTML(value='')))" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "done\n" + ] + } + ], + "source": [ + "img = cv2.imread(files[0],1)\n", + "width = img.shape[1]\n", + "height = img.shape[0]\n", + "print(img.shape)\n", + "use_clahe = True\n", + "fourcc = cv2.VideoWriter_fourcc(*'MJPG') # Be sure to use lower case\n", + "out = cv2.VideoWriter(out_loc, fourcc, 6.0, (int(width/1), int(height/1)))\n", + "clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n", + "for file in tqdm(files):\n", + " img = cv2.imread(file,1)\n", + " img = cv2.resize(img,(int(width/1),int(height/1)))\n", + " if use_clahe == True:\n", + " r,g,b = img[:,:,0],img[:,:,1],img[:,:,2]\n", + " r = clahe.apply(r)\n", + " g = clahe.apply(g)\n", + " b = clahe.apply(b)\n", + " img[:,:,0] = r\n", + " img[:,:,1] = g\n", + " img[:,:,2] = b\n", + " out.write(img)\n", + "out.release()\n", + "print('done')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/LCL_run.py b/LCL_run.py index 7c74af8f0a32a96e128090a3c8913bd6364282fb..39a7e014dd0273d013a3ae652b9ffcd0cb8ebef3 100755 --- a/LCL_run.py +++ b/LCL_run.py @@ -158,6 +158,7 @@ class main_window(QMainWindow): # self.vid.vid_process_signal.connect(self.autofocuser.vid_process_slot) self.vid.vid_process_signal.connect(self.localizer.vid_process_slot) self.qswitch_screenshot_signal.connect(self.screen_shooter.save_qswitch_fire_slot) + self.localizer.qswitch_screenshot_signal.connect(self.screen_shooter.save_qswitch_fire_slot) # self.start_focus_signal.connect(self.autofocuser.autofocus) self.start_localization_signal.connect(self.localizer.localize) # self.autofocuser.position_and_variance_signal.connect(self.plot_variance_and_position) @@ -249,14 +250,12 @@ class main_window(QMainWindow): comment('stage position during qswitch: {}'.format(stage.get_position_slot())) laser.fire_qswitch() - @QtCore.pyqtSlot('PyQt_PyObject','PyQt_PyObject') - def ai_fire_qswitch_slot(self,num_frames,auto_fire): + @QtCore.pyqtSlot('PyQt_PyObject') + def ai_fire_qswitch_slot(self,auto_fire): comment('automated firing from localizer!') - if auto_fire == True: - self.qswitch_screenshot_signal.emit(num_frames) + if auto_fire == True: laser.qswitch_auto() else: - self.qswitch_screenshot_signal.emit(num_frames) laser.fire_qswitch() @QtCore.pyqtSlot() diff --git a/__pycache__/localizer.cpython-36.pyc b/__pycache__/localizer.cpython-36.pyc index 7778402165313d2048b42efab5582a7d356af382..3b057c350607a408311af233f2f4774183a28fe5 100755 Binary files a/__pycache__/localizer.cpython-36.pyc and b/__pycache__/localizer.cpython-36.pyc differ diff --git a/__pycache__/utils.cpython-36.pyc b/__pycache__/utils.cpython-36.pyc index d137e1bcb2d1d8498ba096a52ebbd2513dba7a78..36b7e5f46aa3fc30e1db27727f88c0f315c67a36 100755 Binary files a/__pycache__/utils.cpython-36.pyc and b/__pycache__/utils.cpython-36.pyc differ diff --git a/gif_writer.ipynb b/gif_writer.ipynb new file mode 100755 index 0000000000000000000000000000000000000000..4b46a670b91503317eb5049516b4bbbe7e8bcda9 --- /dev/null +++ b/gif_writer.ipynb @@ -0,0 +1,142 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import imageio\n", + "import os\n", + "import cv2\n", + "from datetime import datetime\n", + "from tqdm import tqdm_notebook as tqdm" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "in_dir = r'C:\\Users\\Wheeler\\Desktop\\LCL_software\\Experiments\\green_lysis2'\n", + "out_loc = r'C:\\Users\\Wheeler\\Desktop\\LCL_software\\Experiments\\green_lysis2.avi'\n", + "images = []\n", + "clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n", + "files = [os.path.join(in_dir, f) for f in os.listdir(in_dir) if '.tif' in f] \n", + "\n", + "\n", + "def get_time(file_name):\n", + "# t = file_name.split('_')[-2].split('_')[0]\n", + " t=file_name.split('_')[-1].split('.tif')[0]\n", + " t = datetime.strptime(t, '%H.%M.%S.%f')\n", + " t = t.hour*3600 + t.minute*60 + t.second + t.microsecond/10**6\n", + " return t\n", + " \n", + "\n", + "files.sort(key=get_time,reverse = False)\n", + "# for file in files:\n", + "# print('processing...',file.split('/')[-1])\n", + "# img = cv2.imread(file,0)\n", + "# img = cv2.resize(img,(int(img.shape[1]/3),int(img.shape[0]/3)),interpolation = cv2.INTER_CUBIC)\n", + "# img = clahe.apply(img)\n", + "# images.append(img)\n", + "# images = [images[0]] * 2 + images\n", + "# imageio.mimsave(out_loc, images,duration=.2)\n", + "# print('done')" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "(822, 1024, 3)\n" + ] + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "b1d0002d2a224d13805cf1c7efc51473", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "A Jupyter Widget" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\n", + "done\n" + ] + } + ], + "source": [ + "img = cv2.imread(files[0],1)\n", + "width = img.shape[1]\n", + "height = img.shape[0]\n", + "print(img.shape)\n", + "use_clahe = False\n", + "fourcc = cv2.VideoWriter_fourcc(*'MJPG') # Be sure to use lower case\n", + "out = cv2.VideoWriter(out_loc, fourcc, 6.0, (int(width/1), int(height/1)))\n", + "clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))\n", + "for file in tqdm(files):\n", + " img = cv2.imread(file,1)\n", + " img = cv2.resize(img,(int(width/1),int(height/1)))\n", + " if use_clahe == True:\n", + " r,g,b = img[:,:,0],img[:,:,1],img[:,:,2]\n", + " r = clahe.apply(r)\n", + " g = clahe.apply(g)\n", + " b = clahe.apply(b)\n", + " img[:,:,0] = r\n", + " img[:,:,1] = g\n", + " img[:,:,2] = b\n", + " out.write(img)\n", + "out.release()\n", + "print('done')\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/localizer.py b/localizer.py index fe3631c7da6f1f0fd3370067049e493939c37260..4a3233c9bdddb86120391ff3a2311ba9bc737adf 100755 --- a/localizer.py +++ b/localizer.py @@ -10,8 +10,13 @@ global graph from PyQt5.QtWidgets import QApplication import pickle from sklearn.preprocessing import StandardScaler +from utils import MeanIoU graph = tf.get_default_graph() +num_classes = 3 +miou_metric = MeanIoU(num_classes) +mean_iou = miou_metric.mean_iou + experiment_folder_location = os.path.join(os.path.dirname(os.path.abspath(__file__)),'models') class Localizer(QtCore.QObject): @@ -19,16 +24,18 @@ class Localizer(QtCore.QObject): get_position_signal = QtCore.pyqtSignal() fire_qswitch_signal = QtCore.pyqtSignal() stop_laser_flash_signal = QtCore.pyqtSignal() - ai_fire_qswitch_signal = QtCore.pyqtSignal('PyQt_PyObject','PyQt_PyObject') + ai_fire_qswitch_signal = QtCore.pyqtSignal('PyQt_PyObject') start_laser_flash_signal = QtCore.pyqtSignal() + qswitch_screenshot_signal = QtCore.pyqtSignal('PyQt_PyObject') # qswitch_screenshot_signal = QtCore.pyqtSignal() def __init__(self, parent = None): - super(Localizer, self).__init__(parent) - self.localizer_model = load_model(os.path.join(experiment_folder_location,'multiclass_localizer14.hdf5')) + super(Localizer, self).__init__(parent) + self.localizer_model = load_model(os.path.join(experiment_folder_location,'multiclass_localizer18_2.hdf5'),custom_objects={'mean_iou': mean_iou}) + # self.localizer_model = load_model(os.path.join(experiment_folder_location,'multiclass_localizer14.hdf5')) + # self.localizer_model = load_model(os.path.join(experiment_folder_location,'binary_localizer6.hdf5')) self.norm = StandardScaler() self.hallucination_img = cv2.imread(os.path.join(experiment_folder_location,'before_qswitch___06_07_2018___11.48.59.274395.tif')) - # self.localizer_model = load_model(os.path.join(experiment_folder_location,'binary_localizer6.hdf5')) self.localizer_model._make_predict_function() self.position = np.zeros((1,2)) self.well_center = np.zeros((1,2)) @@ -140,12 +147,15 @@ class Localizer(QtCore.QObject): if self.lysed_cell_count >= self.cells_to_lyse: self.return_to_original_position(self.well_center) return - time.sleep(2) + time.sleep(.2) self.move_frame(let) + time.sleep(.2) QApplication.processEvents() + time.sleep(.2) self.lyse_all_in_view() self.return_to_original_position(self.well_center) - + + def get_spiral_directions(self,box_size): letters = ['u', 'l', 'd', 'r'] nums = [] @@ -172,8 +182,8 @@ class Localizer(QtCore.QObject): view_center = self.get_stage_position() print('lysing all in view...') self.delay() - # segmented_image = self.get_network_output(self.image,'multi') - segmented_image = self.get_network_output(self.hallucination_img,'multi') + segmented_image = self.get_network_output(self.image,'multi') + # segmented_image = self.get_network_output(self.hallucination_img,'multi') # cv2.imshow('Cell Outlines and Centers',segmented_image) # lyse all cells in view self.lyse_cells(segmented_image,self.cell_type_to_lyse,self.lysis_mode) @@ -234,20 +244,23 @@ class Localizer(QtCore.QObject): contour_start = np.copy(new_center) # now turn on the autofire time.sleep(.1) - self.ai_fire_qswitch_signal.emit(0,True) + self.qswitch_screenshot_signal.emit(2) + self.ai_fire_qswitch_signal.emit(True) # this block is responsible for vectoring around the contour return_vec = np.zeros((2)) for j in range(1,point_number): new_center = contour[j].reshape(2) move_vec = -old_center + new_center - scaled_move_vec = move_vec + scaled_move_vec = move_vec*1.5 return_vec = np.add(return_vec,scaled_move_vec) print(return_vec,return_vec.shape,scaled_move_vec.shape) + self.qswitch_screenshot_signal.emit(1) self.move_to_target(scaled_move_vec,False) old_center = new_center time.sleep(.1) self.lysed_cell_count += 1 - self.ai_fire_qswitch_signal.emit(0,False) + self.qswitch_screenshot_signal.emit(2) + self.ai_fire_qswitch_signal.emit(False) time.sleep(.1) if self.auto_lysis == False: self.stop_laser_flash_signal.emit() @@ -264,7 +277,8 @@ class Localizer(QtCore.QObject): old_center = cell_centers[0] self.move_to_target(old_center-window_center,True) self.delay() - self.ai_fire_qswitch_signal.emit(0,False) + self.qswitch_screenshot_signal.emit(10) + self.ai_fire_qswitch_signal.emit(False) self.delay() self.lysed_cell_count += 1 if self.lysed_cell_count >= self.cells_to_lyse: @@ -273,10 +287,11 @@ class Localizer(QtCore.QObject): return if len(cell_centers) > 1: for i in range(1,len(cell_centers)): + self.qswitch_screenshot_signal.emit(15) self.move_to_target(-old_center + cell_centers[i],False) old_center = cell_centers[i] self.delay() - self.ai_fire_qswitch_signal.emit(0,False) + self.ai_fire_qswitch_signal.emit(False) self.delay() self.lysed_cell_count += 1 if self.auto_lysis == False: diff --git a/models/localizer_continue_learning.py b/models/localizer_continue_learning.py new file mode 100755 index 0000000000000000000000000000000000000000..d664ad424f037a274e9fbd10496e3d86704310fa --- /dev/null +++ b/models/localizer_continue_learning.py @@ -0,0 +1,103 @@ +import sys,pickle +sys.path.insert(0,'/home/hedwar/installs') +from keras.models import load_model +from keras.callbacks import ModelCheckpoint,CSVLogger,TensorBoard +from keras.preprocessing.image import ImageDataGenerator +import tensorflow as tf +import numpy as np + +class MeanIoU(object): + # taken from http://www.davidtvs.com/keras-custom-metrics/ + def __init__(self, num_classes): + super().__init__() + self.num_classes = num_classes + + def mean_iou(self, y_true, y_pred): + # Wraps np_mean_iou method and uses it as a TensorFlow op. + # Takes numpy arrays as its arguments and returns numpy arrays as + # its outputs. + return tf.py_func(self.np_mean_iou, [y_true, y_pred], tf.float32) + + def np_mean_iou(self, y_true, y_pred): + # Compute the confusion matrix to get the number of true positives, + # false positives, and false negatives + # Convert predictions and target from categorical to integer format + target = np.argmax(y_true, axis=-1).ravel() + predicted = np.argmax(y_pred, axis=-1).ravel() + + # Trick from torchnet for bincounting 2 arrays together + # https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py + x = predicted + self.num_classes * target + bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2) + assert bincount_2d.size == self.num_classes**2 + conf = bincount_2d.reshape((self.num_classes, self.num_classes)) + + # Compute the IoU and mean IoU from the confusion matrix + true_positive = np.diag(conf) + false_positive = np.sum(conf, 0) - true_positive + false_negative = np.sum(conf, 1) - true_positive + + # Just in case we get a division by 0, ignore/hide the error and set the value to 0 + with np.errstate(divide='ignore', invalid='ignore'): + iou = true_positive / (true_positive + false_positive + false_negative) + iou[np.isnan(iou)] = 0 + + return np.mean(iou).astype(np.float32) +num_classes = 3 + +miou_metric = MeanIoU(num_classes) +mean_iou = miou_metric.mean_iou + +data_loc = r'/home/hedwar/cell_localization/training_data_normalized.p' + +initial_epoch = 18 +old_save_loc = '/scratch/hedwar/multiclass_localizer18.hdf5' +x_train, x_val, y_train, y_val = pickle.load(open(data_loc,'rb')) +model = load_model(old_save_loc,custom_objects={'mean_iou': mean_iou}) +print('model loaded') + +seed = 1 +image_gen_args = dict(rotation_range = 90., + width_shift_range = 0.05, + height_shift_range = 0.05, + vertical_flip = True, + horizontal_flip = True) + +mask_gen_args = dict(rotation_range = 90., + width_shift_range = 0.05, + height_shift_range = 0.05, + vertical_flip = True, + horizontal_flip = True) + +image_datagen = ImageDataGenerator(**image_gen_args) +mask_datagen = ImageDataGenerator(**mask_gen_args) + +# image_datagen.fit(x_train, seed=seed) +# mask_datagen.fit(y_train, seed=seed) + + +image_generator = image_datagen.flow(x_train, seed=seed, batch_size=8) +mask_generator = mask_datagen.flow(y_train, seed=seed, batch_size=8) + +x_val_generator = image_datagen.flow(x_val, seed=seed, batch_size=100) +y_val_generator = mask_datagen.flow(y_val, seed=seed, batch_size=100) + +train_generator = zip(image_generator, mask_generator) +val_generator = zip(x_val_generator,y_val_generator) + +validation_generator = zip(image_generator, mask_generator) + +csv_logger = CSVLogger('/scratch/hedwar/multiclass_localizer_training18_2.log') +# tbCallBack = TensorBoard(log_dir='/scratch/a/awheeler/hedwar/tensorboard', histogram_freq=0, write_graph=True, write_images=True) +tbCallBack = TensorBoard(log_dir='/scratch/hedwar/tensorboard', histogram_freq=0, write_graph=True, write_images=True) +# save_loc = '/scratch/a/awheeler/hedwar/transpose_multiclass_localizer.hdf5' +save_loc = '/scratch/hedwar/multiclass_localizer18_2.hdf5' +checkpointer = ModelCheckpoint(filepath=save_loc, verbose=1, save_best_only=True) + +model.fit_generator(train_generator, + validation_data = val_generator, + validation_steps=600, + steps_per_epoch=2000, + epochs=1500, + initial_epoch = initial_epoch, + callbacks = [tbCallBack,checkpointer,csv_logger]) \ No newline at end of file diff --git a/models/multiclass_localizer16.hdf5 b/models/multiclass_localizer16.hdf5 new file mode 100755 index 0000000000000000000000000000000000000000..d0b3ff670a740d916be2f497b70611b63594e855 Binary files /dev/null and b/models/multiclass_localizer16.hdf5 differ diff --git a/models/multiclass_localizer18_2.hdf5 b/models/multiclass_localizer18_2.hdf5 new file mode 100755 index 0000000000000000000000000000000000000000..a049dd3afd34887cfb35e6c1029b11c97b0d2080 Binary files /dev/null and b/models/multiclass_localizer18_2.hdf5 differ diff --git a/utils.py b/utils.py index aa44d0acfcbbf96207e8d15c615467590912cedf..0f318dc053d667cc237cf4c13910faf716604ded 100755 --- a/utils.py +++ b/utils.py @@ -6,6 +6,7 @@ import time import numpy as np from PyQt5.QtCore import QThread import threading +import tensorflow as tf def now(): return datetime.datetime.now().strftime('%d_%m_%Y___%H.%M.%S.%f') @@ -81,6 +82,46 @@ class screen_shooter(QtCore.QObject): self.image_title = 'during_qswitch_fire' self.requested_frames += num_frames +class MeanIoU(object): + # taken from http://www.davidtvs.com/keras-custom-metrics/ + def __init__(self, num_classes): + super().__init__() + self.num_classes = num_classes + + def mean_iou(self, y_true, y_pred): + # Wraps np_mean_iou method and uses it as a TensorFlow op. + # Takes numpy arrays as its arguments and returns numpy arrays as + # its outputs. + return tf.py_func(self.np_mean_iou, [y_true, y_pred], tf.float32) + + def np_mean_iou(self, y_true, y_pred): + # Compute the confusion matrix to get the number of true positives, + # false positives, and false negatives + # Convert predictions and target from categorical to integer format + target = np.argmax(y_true, axis=-1).ravel() + predicted = np.argmax(y_pred, axis=-1).ravel() + + # Trick from torchnet for bincounting 2 arrays together + # https://github.com/pytorch/tnt/blob/master/torchnet/meter/confusionmeter.py + x = predicted + self.num_classes * target + bincount_2d = np.bincount(x.astype(np.int32), minlength=self.num_classes**2) + assert bincount_2d.size == self.num_classes**2 + conf = bincount_2d.reshape((self.num_classes, self.num_classes)) + + # Compute the IoU and mean IoU from the confusion matrix + true_positive = np.diag(conf) + false_positive = np.sum(conf, 0) - true_positive + false_negative = np.sum(conf, 1) - true_positive + + # Just in case we get a division by 0, ignore/hide the error and set the value to 0 + with np.errstate(divide='ignore', invalid='ignore'): + iou = true_positive / (true_positive + false_positive + false_negative) + iou[np.isnan(iou)] = 0 + + return np.mean(iou).astype(np.float32) + + + experiment_name = 'experiment_{}'.format(now()) experiment_folder_location = os.path.join(os.path.dirname(os.path.abspath(__file__)),'Experiments',experiment_name) log = logging.getLogger(__name__)