Keeping Up With The Latest Techniques

~ brief insights

Keeping Up With The Latest Techniques

Tag Archives: Medical Imaging

Second Annual Data Science Bowl – Part 3 – Automatically Finding the Heart Location in an MRI Image

08 Tuesday Mar 2016

Posted by Colin Priest in Automation, Convolutional Neural Networks, Deep Learning, Image Processing, Kaggle, Machine Learning, Medical Imaging, Python

≈ 10 Comments

Tags

Automation, Deep Learning, Image Processing, Kaggle, Machine Learning, Medical Imaging

My last blog wasn’t so sexy, what with all the data cleansing, and no predictive modelling. But in this blog I do something really cool – I train a machine learning model to find the left ventricle of the heart in an MRI image. And I couldn’t have done it without all of that boring data cleansing. #kaggle @kaggle

Aside from being a really cool thing to do, there is a purpose to this modelling. I want to find the boundaries of the heart chamber, and that is much easier and faster to do when I remove distractions. Once I have found the location of the heart chamber, I can crop the image to a much smaller square.

The input to the model will be a set of images. In order to simply what the model learns, I only gave it training images from sax locations near the centre of the heart.

20160308-image01
20160308-image02
20160308-image03
20160308-image04

The output from the model will be the row number and column number of the centroid of the left ventricle heart chamber (the red dot in the images above).

I had to manually define those centroid locations for a training set of a few hundred of the images. This was laborious and time consuming, even after I automated some of the process. But it needed to be done, because otherwise the machine learning algorithm has no way of knowing what the true answers should be.

Even though I am much more comfortable coding in R than in Python, I used Python for this step because I wanted to use Daniel Nouri‘s nolearn library, which sits above lasagne and theano, and these libraries are not available in R. The convolution neural network architecture was based upon the architecture in Daniel Nouri’s tutorial for the Facial Keypoints Detection competition in Kaggle.

Step 1: Importing of all the Required Libraries

OK, so this part isn’t all that sexy either. But it’s the engine for all the cool modelling that is about to be done.


import numpy as np
import csv
import random
import math
import os
import cv2
import itertools
import math
import matplotlib.pyplot as plt
import pandas as pd
import itertools

from lasagne import layers
from lasagne.updates import nesterov_momentum
from lasagne.nonlinearities import softmax
from lasagne.nonlinearities import sigmoid
from nolearn.lasagne import BatchIterator
from nolearn.lasagne import NeuralNet
from nolearn.lasagne import TrainSplit
from nolearn.lasagne import PrintLayerInfo
from nolearn.lasagne.visualize import plot_loss
from nolearn.lasagne.visualize import plot_conv_weights
from nolearn.lasagne.visualize import plot_conv_activity
from nolearn.lasagne.visualize import plot_occlusion

%pylab inline
from lasagne.layers import DenseLayer
from lasagne.layers import InputLayer
from lasagne.layers import DropoutLayer
from lasagne.layers import Conv2DLayer
from lasagne.layers import MaxPool2DLayer
from lasagne.nonlinearities import softmax
from lasagne.updates import adam
from lasagne.layers import get_all_params
from nolearn.lasagne import NeuralNet
from nolearn.lasagne import TrainSplit
from nolearn.lasagne import objective

import theano
import theano.tensor as T

 

Step 2: Defining the Helper Functions

I used jupyter notebook as the development environment to set up and run my Python scripts. While there’s a lot to like about jupyter, one thing that annoys me is that print commands run in jupyter don’t immediately show text on the screen. But here’s a trick to work around that:


def printQ(s):
 print(s)
 sys.stdout.flush()

Using this helper function instead of the print function results in text immediately appearing in the output. This is particular helpful for progress messages on long training runs.

I like to use R’s expand.grid function, but it isn’t built in to Python. So I wrote my own helper function in Python that mimics the functionality:


def product2(*args, repeat=1):
 # product('ABCD', 'xy') --> Ax Ay Bx By Cx Cy Dx Dy
 # product(range(2), repeat=3) -> 000 001 010 011 100 101 110 111
 pools = [tuple(pool) for pool in args] * repeat
 result = [[]]
 for pool in pools:
 result = [x+[y] for x in result for y in pool]
 for prod in result:
 yield tuple(prod)

def expand_grid(dictionary):
 return pd.DataFrame([row for row in product2(*dictionary.values())],
 columns=dictionary.keys())

This next function reads a cleaned up image, checking that it has the correct aspect ratio, then resizing it to 96 x 96 pixels. The resizing is done to reduce memory usage in my GPU, and to speed up the training and scoring.


def load_image(path):
 # check that the file exists
 if not os.path.isfile(path):
   printQ('BAD PATH: ' + path)
 image = cv2.imread(path, cv2.IMREAD_GRAYSCALE)
 # ensure portrait aspect ratio
 if (image.shape[0] < image.shape[1]):
   image = cv2.transpose(image)
 # resize to 96x96
 resized_image = cv2.resize(image, (96, 96))
 # check that the image isn't empty
 s = sum(sum(resized_image))
 if np.isnan(s):
   print(path)
 return resized_image

My initial models were not performing as well as I would like. So to force the the model to generalise, I added some image transformations (rotation and reflection) to the training data. This required helper functions:


def rotate_image(img, degrees):
 rows,cols = img.shape
 M = cv2.getRotationMatrix2D((cols/2,rows/2),degrees,1)
 dst = cv2.warpAffine(img,M,(cols,rows))
 return dst

def transform_image(img, equalise, gammaAdjust, reflection, rotation):
 if equalise == 1:
   img = cv2.equalizeHist(img)
 if gammaAdjust != 1:
   img = pow(img / 255.0, gammaAdjust) * 255.0
 if reflection == 0 or reflection == 1:
   img = cv2.flip(img, int(reflection))
 if rotation != 0:
   img = rotate_image(img, rotation)
 return img

def transform_xy(x, y, equalise, gammaAdjust, reflection, rotation):
 # if equalise, then no change to x and y
 # if gamma adjustment, then no change to x and y
 # reflection
 if reflection == 0:
   y = 1.0 - y
 if reflection == 1:
   x = 1.0 - x
 if rotation == 180:
   x = 1.0 - x
   y = 1.0 - y
 if rotation != 0 and rotation != 180:
   x1 = x - 0.5
   y1 = y - 0.5
   theta = rotation / 180 * pi
   x2 = x1 * cos(theta) + y1 * sin(theta)
   y2 = -x1 * sin(theta) + y1 * cos(theta)
   x = x2 + 0.5
   y = y2 + 0.5
 return numpy.array([x, y])

# set up the image adjustments
#equalise, gammaAdjust, reflection
dAdjustShort = { 'equalise': [0],
 'gammaAdjust': [1],
 'reflection': [-1],
 'rotation': [0]}
dAdjust = { 'equalise': [0, 1],
 'gammaAdjust': [1, 0.75, 1.5],
 'reflection': [-1, 0, 1],
 'rotation': [0, 180, 3, -3]}
# 'rotation': [0, 180, 3, -3, 10, -10]}
imgAdj = expand_grid(dAdjust)
#imgAdj = expand_grid(dAdjustShort)
print (imgAdj.shape)
# can't have both reflection AND rotation by 180 degrees
imgAdj = imgAdj.query('not (reflection > -1 and rotation == 180)')
# can't have equalise AND gamma adjust
imgAdj = imgAdj.query('not (equalise == 1 and gammaAdjust != 1)')

There are two feature sets that enable the machine learning model to find the heart:

  1. shape in the image – by looking at shape in the image it can find the heart
  2. movement between images at different points of time – the heart is moving but most of the chest is stationary

So I set up the network architecture to use two channels. The first channel is the image, and the second channel is the difference between the image at this point of time and the image 8 time periods in the future.


def load_train_set():
 #numTimes = 8 # how many time periods to use
 plusTime = 5 # gaps between time periods for comparison
 xs = []
 ys = []
 ids = []
 all_y = '/home/colin/data/Second-Annual-Data-Science-Bowl/working/centroids-20160218-R.csv'
 with open(all_y) as f:
   rows = csv.reader(f, delimiter=',', quotechar='"')
   iRow = 1
   for line in rows:
     # first line is column headers, so ignore it
     if iRow > 1:
       # parse the line
       patient = line[0]
       x = float(line[1])
       y = float(line[2])
       sax = int(line[4])
       firstTime = int(line[5])
       if int(patient) % 25 == 0 and firstTime == 1:
         printQ(patient)
       # enhance the training data with rotations, reflections, NOT histogram equalisation
       for index, row in imgAdj.iterrows():
         # append the target values
         xy = transform_xy(x, y, row['equalise'], row['gammaAdjust'], row['reflection'], row['rotation'])
         ys.append(xy.astype('float32').reshape((1, 2)))
         #
         # read the images
         folder = '/home/colin/data/Second-Annual-Data-Science-Bowl/train-cleaned/' + patient + '/study/sax_0' + str(sax) + '/'
         xm = np.zeros([1, 2, 96, 96])
         #
         #
         # current frame
         path = folder + 'image-' + ('%0*d' % (2, firstTime)) + '.png'
         img = load_image(path)
         # transform the image - rotation, reflection etc
         img = transform_image(img, row['equalise'], row['gammaAdjust'], row['reflection'], row['rotation'])
         # get the pixels into the range [-1, 1]
         img = (img / 128.0 - 1.0).astype('float32').reshape((96, 96))
         xm[0, 0, :, :] = img
         #
         #
         # find movement of current frame to future frame
         path = folder + 'image-' + ('%0*d' % (2, firstTime + plusTime)) + '.png'
         img = load_image(path)
         # transform the image - rotation, reflection etc
         img = transform_image(img, row['equalise'], row['gammaAdjust'], row['reflection'], row['rotation'])
         # get the pixels into the range [-1, 1]
         img = (img / 128.0 - 1.0).astype('float32').reshape((96, 96))
         # first time is the complete image at time 1
         # subsequent frames are the differences between frames
         xm[0, 1, :, :] = img - xm[0, 0, :, :]
         xs.append(xm.astype('float32').reshape((1, 2, 96, 96)))
         ids.append(patient)

     iRow = iRow + 1
 return np.vstack(xs), np.vstack(ys), np.vstack(ids)

I used early stopping to help reduce overfitting:


class EarlyStopping(object):
 def __init__(self, patience=100):
   self.patience = patience
   self.best_valid = np.inf
   self.best_valid_epoch = 0
   self.best_weights = None

def __call__(self, nn, train_history):
 current_valid = train_history[-1]['valid_loss']
 current_epoch = train_history[-1]['epoch']
 if current_valid < self.best_valid:
   self.best_valid = current_valid
   self.best_valid_epoch = current_epoch
   self.best_weights = nn.get_all_params_values()
 elif self.best_valid_epoch + self.patience < current_epoch:
   print('Early stopping')
   print('Best valid loss was {:.6f} at epoch {}.'.format(
   self.best_valid, self.best_valid_epoch))
   nn.load_params_from(self.best_weights)
   raise StopIteration()

Step 3: Read the Training Data

As well as reading the training data, I shuffled the order of the data. This allowed me to use batch training.


# read the training data

printQ('reading the training data')
train_x, train_y, train_id = load_train_set()

printQ ('shuffling training rows')
random.seed(1234)
rows = random.choice(arange(0, train_x.shape[0]), train_x.shape[0])
t_x = train_x[rows,:,:,:]
t_y = train_y[rows,:]
t_id = train_id[rows]

printQ('finished')

Step 4: Train the Model

The network architecture used deep convolutional layers to find features in the image, then fully connected layers to convert these features into the centroid location:

20160308-image05


# fit the models

# set up the model
printQ ('setting up the model structure')
layers0 = [
 # layer dealing with the input data
 (InputLayer, {'shape': (None, 2, 96, 96)}),

# first stage of our convolutional layers
 (Conv2DLayer, {'num_filters': 32, 'filter_size': 5}),
 #(DropoutLayer, {'p': 0.2}),
 (Conv2DLayer, {'num_filters': 32, 'filter_size': 3}),
 #(DropoutLayer, {'p': 0.2}),
 (Conv2DLayer, {'num_filters': 32, 'filter_size': 3}),
 #(DropoutLayer, {'p': 0.2}),
 (Conv2DLayer, {'num_filters': 32, 'filter_size': 3}),
 #(DropoutLayer, {'p': 0.2}),
 (Conv2DLayer, {'num_filters': 32, 'filter_size': 3}),
 (MaxPool2DLayer, {'pool_size': 2}),
 (DropoutLayer, {'p': 0.2}),

# second stage of our convolutional layers
 (Conv2DLayer, {'num_filters': 64, 'filter_size': 3}),
 #(DropoutLayer, {'p': 0.3}),
 (Conv2DLayer, {'num_filters': 64, 'filter_size': 3}),
 #(DropoutLayer, {'p': 0.3}),
 (Conv2DLayer, {'num_filters': 64, 'filter_size': 3}),
 (MaxPool2DLayer, {'pool_size': 2}),
 (DropoutLayer, {'p': 0.3}),

# two dense layers with dropout
 (DenseLayer, {'num_units': 128}),
 (DropoutLayer, {'p': 0.5}),
 (DenseLayer, {'num_units': 128}),

# the output layer
 (DenseLayer, {'num_units': 2, 'nonlinearity': sigmoid}),
]

printQ ('creating and training the networks architectures')
numNets = 1
NNs = list()
for iNet in arange(numNets):
 nn = NeuralNet(
 layers = layers0,
 max_epochs = 2000,
 update=adam,
 update_learning_rate=0.0002,
 regression=True, # flag to indicate we're dealing with regression problem
 batch_iterator_train=BatchIterator(batch_size=100),
 on_epoch_finished=[EarlyStopping(patience=10),],
 train_split=TrainSplit(eval_size=0.25),
 verbose=1,
 )
 result = nn.fit(t_x, t_y)
 NNs.append(nn)

printQ('finished')

Based upon how quickly the training converged, the network could possibly have been simplified, reducing the number of layers, or using fewer neurons in the fully connected layers. But I didn’t have time to experiment with different architectures.

20160308-image06

The GPU quickly ran out of RAM unless I used the batch iterator. I found the batch size via trial and error. Large batch sizes caused the GPU to run out of RAM. Small batch sizes ran much slower.

Step 5: Review the Training Errors

Just like humans, all models make mistakes. The heart chamber segmentation algorithms I used later in this project were sensitive to how well the heart chamber was centred in the image. But as long as the model output was a centroid that was inside the heart chamber, things usually went OK. Early versions of my model made mistakes that placed the centroid outside the heart chamber, sometimes even far away from the heart.Tweaks to the training data (especially enhancing the data with rotation and reflection) and the architecture (especially dropout layers) improved the performance.


def getHeartLocation(trainX):
  # get the heart locations from each network
  heartLocs = zeros(numNets * trainX.shape[0] * 2).reshape((numNets, trainX.shape[0], 2))
  for j in arange(numNets):
    nn = NNs[j]
    heartLocs[j, :, :] = nn.predict(trainX)

    # use median as an ensembler
    heartLocsMedian = zeros(trainX.shape[0] * 2).reshape((trainX.shape[0], 2))
    heartLocsMedian[:,0] = median(heartLocs[:,:,0], axis = 0)
    heartLocsMedian[:,1] = median(heartLocs[:,:,1], axis = 0)

    # use a 'max distance from centre' ensembler
    heartLocsDist = zeros(trainX.shape[0] * 2).reshape((trainX.shape[0], 2))
    distance = abs(heartLocs - 0.5)
    am0 = distance[:,:,0].argmax(0)
    am1 = distance[:,:,1].argmax(0)
    heartLocsDist[:,0] = heartLocs[am0, arange(trainX.shape[0]), 0]
    heartLocsDist[:,1] = heartLocs[am1, arange(trainX.shape[0]), 1]

    # combine the two using an arithmetic average
    heartLocations = 0.5 * heartLocsMedian + 0.5 * heartLocsDist

    return heartLocations

heartLocations = getHeartLocation(train_x)

# review the training errors to check for model improvements
def plot_sample(x, y, predicted, axis):
  img = x[0, :, :].reshape(96, 96)
  axis.imshow(img, cmap='gray')
  axis.scatter(y[0::2] * 96, y[1::2] * 96, marker='x', s=10)
  axis.scatter(predicted[0::2] * 96, predicted[1::2] * 96, marker='x', s=10, color='red')

nTrain = train_x.shape[0]
errors = np.zeros(nTrain)

for i in arange(0, nTrain):
  errors[i] = sqrt( square(heartLocations[i, 0] - train_y[i, 0]) + square(heartLocations[i, 1] - train_y[i, 1]) )

print('Prob(error > 0.05)' + str(mean(errors > 0.05)))
print('Mean: ' + str(mean(errors)))
print('Percentiles: ' + str(percentile(errors, [50, 75, 90, 95, 99, 100])))

for i in arange(0, nTrain):
  error = sqrt( square(heartLocations[i, 0] - train_y[i, 0]) + square(heartLocations[i, 1] - train_y[i, 1]) )
  if (error > 0.04):
    if train_id[i] != train_id[i-1]:
      #print(i)
      print(train_id[i]) # only errors on the original images - not the altered images
      fig = pyplot.figure(figsize=(6, 3))
      ax = fig.add_subplot(1, 2, 1, xticks=[], yticks=[])
      plot_sample(train_x[i,:,:,:], train_y[i, :], heartLocations[i, :], ax)
      pyplot.show()

print('error review completed')

20160308-image07

After many failed models, I was excited when the two worst training errors were still close to the centre of the heart chamber 🙂

Step 6: Find the Left Ventricle Locations for the Submission Data

The main point of building a heart finder machine learning model is to automate the process of finding the left ventricle in the test images that will be used as part of the competition submission. These are images that the model has never seen before.


def load_submission_set():
  numTimes = 2 # how many time periods to use
  plusTime = 5
  xs = []
  ys = []
  ids = []
  paths = []
  times = []
  saxes = []
  all_y = '/home/colin/data/Second-Annual-Data-Science-Bowl/working/centroids-submission-R.csv'
  with open(all_y) as f:
    rows = csv.reader(f, delimiter=',', quotechar='"')
    iRow = 1
    for line in rows:
      # first line is column headers
      if iRow > 1:
        # parse the line
        patient = line[0]
        x = 0
        y = 0
        sax = int(line[1])
        firstTime = int(line[2])
        # save the targets
        xy = np.asarray([x, y])
        ys.append(xy.astype('float32').reshape((1, 2)))
        # read the images
        folder = '/home/colin/data/Second-Annual-Data-Science-Bowl/validate-cleaned/' + patient + '/study/sax_0' + str(sax) + '/'
        xm = np.zeros([1, 2, 96, 96])
        #
        #
        # current frame
        path0 = folder + 'image-' + ('%0*d' % (2, firstTime)) + '.png'
        img = load_image(path0)
        # transform the image - rotation, reflection etc
        #img = transform_image(img, row['equalise'], row['gammaAdjust'], row['reflection'], row['rotation'])
        # get the pixels into the range [-1, 1]
        img = (img / 128.0 - 1.0).astype('float32').reshape((96, 96))
        xm[0, 0, :, :] = img
        #
        #
        # find movement of current frame to future frame
        path5 = folder + 'image-' + ('%0*d' % (2, firstTime + plusTime)) + '.png'
        img = load_image(path5)
        # transform the image - rotation, reflection etc
        #img = transform_image(img, row['equalise'], row['gammaAdjust'], row['reflection'], row['rotation'])
        # get the pixels into the range [-1, 1]
        img = (img / 128.0 - 1.0).astype('float32').reshape((96, 96))
        # first time is the complete image at time 1
        # subsequent frames are the differences between frames
        xm[0, 1, :, :] = img - xm[0, 0, :, :]
        xs.append(xm.astype('float32').reshape((1, numTimes, 96, 96)))
        ids.append(patient)
        paths.append(path0)
        times.append(firstTime)
        saxes.append(sax)
      iRow = iRow + 1
  return np.vstack(xs), np.vstack(ids), np.vstack(paths), np.vstack(times), np.vstack(saxes)

printQ('reading the submission data')
test_x, test_ids, test_paths, test_times, test_sax = load_submission_set()

printQ('quot;getting the predictions')
predicted_y = getHeartLocation(test_x)

printQ('creating the output table')
fullIDs = []
fullPaths = []
fullX = []
fullY = []
fullSax = []
fullTime = []
nTest = test_x.shape[0]
iTime = 1
for i in arange(0, nTest):
  patient = (test_ids[i])[0]
  sax = int((test_sax[i])[0])
  path = (test_paths[i])[0]
  iTime = int((test_times[i])[0])
  fullIDs.append(patient)
  fullPaths.append(path)
  fullX.append(predicted_y[i, 0])
  fullY.append(predicted_y[i, 1])
  fullSax.append(sax)
  fullTime.append(iTime)
  outPath = '/home/colin/data/Second-Annual-Data-Science-Bowl/predicted-heart-location-submission/'
  outPath = outPath + patient
  outPath = outPath + '-' + str(sax)
  outPath = outPath + '-' + str(iTime) + '.png'
  img = load_image256x192(path)
  x = int(round(predicted_y[i, 0] * 192))
  y = int(round(predicted_y[i, 1] * 256))
  img[y, x] = 255
  img[y-1, x-1] = 255
  img[y-1, x+1] = 255
  img[y+1, x-1] = 255
  img[y+1, x+1] = 255
  write_image(img, outPath)

fullIDs = array(fullIDs)
fullPaths = array(fullPaths)
fullX = array(fullX)
fullY = array(fullY)
fullSax = array(fullSax)
fullTime = array(fullTime)

printQ('saving results table')
d = { 'patient': fullIDs, 'path' : fullPaths, 'x' : fullX, 'y' : fullY, 'iTime' : fullTime, 'sax' : fullSax}
import pandas as pd
d = pd.DataFrame(d)
d.to_csv('/home/colin/data/Second-Annual-Data-Science-Bowl/working/heartfinderV4b-centroids-submission.csv', index = False)

In the animated gif below, you can see the left ventricle centroid location that has been automatically fitted, displayed as a dark rectangle moving around near the centre of the heart chamber. The machine learning algorithm was not trained on this patient’s images – so what you see here is artificial intelligence in action!

20160308-submission-images

Share this:

  • Twitter
  • Facebook

Like this:

Like Loading...

Second Annual Data Science Bowl – Part 2

07 Monday Mar 2016

Posted by Colin Priest in Automation, Image Processing, Kaggle, Medical Imaging, R

≈ 3 Comments

Tags

Automation, Image Processing, Kaggle, Medical Imaging, R

In Part 1 of this blog series, I described how to fix the brightness and contrast of the MRI images. In this blog we finish cleaning up the input data.

Other than brightness and contrast, we need to fix up the following problems:

  • different image sizes
  • different image rotations – portrait versus landscape
  • different pixel spacing
  • different short axis slice spacing
  • image sets from duplicate locations
  • sax sets aren’t in the same order as their locations
  • some sax image sets have multiple locations

Different Image Sizes and Rotations

There are approximately a dozen different image sizes, include rotated images. Not all of the image sizes scale to the same 4:3 aspect ratio that is the most common across the training set. Some of the machine learning algorithms I used later need fixed dimension images, so I compromised and decided to use a standard sizing of 256 x 192 pixels portrait aspect ratio. This meant that I wasn’t always scaling the x-axis and the y-axis by the same amount, and occasionally I was even upscaling an image.


library(pacman)
pacman::p_load(EBImage)
rescaleImage = function(img)
{
 imgOut = img
 # check for landscape aspect ratio and correct
 if (nrow(img) > ncol(img)) imgOut = t(img)
 imgOut = resize(imgOut, 256, 192) # standardise image size to 256 x 192
 return (imgOut)
}

20160307image0120160307image02

Note that I used matrix transpose to rotate the image 90 degrees. I could also have used the rotate function in EBImage. Either way could work, but I felt that matrix transpose would be a faster operation.

Image Locations and Spacing

The DICOM images contain information about image slice location, pixel spacing (how far apart are adjacent pixel centroids), and slice spacing (how far apart are adjacent sax slices).


library(pacman)
pacman::p_load(oro.dicom)

# function to extract the dicom header info
getDicomHeaderInfo = function(path)
{
 img = readDICOMFile(path)
 width = ncol(img$img)
 height = nrow(img$img)
 headers = img$hdr
 patientID = extractHeader(img$hdr, 'PatientID', numeric = TRUE)
 patientAge = as.integer(substr(extractHeader(img$hdr, 'PatientsAge', numeric = FALSE), 1, 3))
 patientGender = as.character(extractHeader(img$hdr, 'PatientsSex', numeric = FALSE))
 ps = extractHeader(img$hdr, 'PixelSpacin', numeric = FALSE)
 pixelSpacingX = as.numeric(unlist(strsplit(ps, ' '))[1])
 pixelSpacingY = as.numeric(unlist(strsplit(ps, ' '))[2])
 seriesNum = extractHeader(img$hdr, 'SeriesNumber', numeric = TRUE)
 location = round(extractHeader(img$hdr, 'SliceLocation', numeric = TRUE), digits = 1)
 pathSplit = unlist(strsplit(path, '/&amp;amp;', fixed = TRUE))
 filename = pathSplit[length(pathSplit)]
 prefix = unlist(strsplit(filename, '.', fixed = TRUE))[1]
 frame = as.integer(unlist(strsplit(prefix, '-'))[3])
 sliceNum = 0
 if (length(unlist(strsplit(prefix, '-'))) &amp;gt; 3) sliceNum = as.integer(unlist(strsplit(prefix, '-'))[4])
 time = extractHeader(img$hdr, 'InstanceCreationTime', numeric = FALSE)
 #
 return (list(
 id = patientID,
 age = patientAge,
 gender = patientGender,
 pixelSpacingX = pixelSpacingX,
 pixelSpacingY = pixelSpacingY,
 width = width,
 height = height,
 series = seriesNum,
 location = location,
 frame = frame,
 sliceNum = sliceNum,
 time = time,
 path = path
 ))
}

I’m not a medical expert. So when there are images from duplicate locations, I don’t know which image sets are the best to use. Therefore I just assumed that the medical specialist repeated the MRI scans until the image quality was adequate. This meant that I chose the image set with the latest time stamps (from the header information inside the DICOM files, not the file system date).

Then I just

  1. searched for DICOM images in all of the subfolders for each patient
  2. filtered to use only sax slices, ignoring those images that did not come from a folder that contained the substring “sax”
  3. read the DICOM header information from each image to find its location and time
  4. searched for the high time stamp for each location and kept that image
  5. wrote out a new folder structure where sax_01, sax_02, … were the sax slice image sets for that patient, order by their location along the long axis of the heart

The R script to do this isn’t too difficult. I’ve pasted it below:


# this script creates a cleaner version of the image data
# plus an extract of the file headers
# 1) removes duplicate images
# 2) reorders images by location

# read a poor image and translate the pixel brightnesses
rebalanceImage = function(badImage)
{
v = matrix(badImage, nrow(badImage) * ncol(badImage))
o2 = order(v)
vIn = sample(vAll, nrow(badImage) * ncol(badImage))
vIn = vIn[order(vIn)]
v2 = v
v2[o2] = vIn
cleanImage = matrix(v2, nrow(badImage), ncol(badImage))
return (cleanImage)
}

# check whether a folder exists and make it if it doesn't exist
checkFolder = function(mainDir, subDir)
{
if(!dir.exists(file.path(mainDir, subDir)))
dir.create(file.path(mainDir, subDir))
}

# function to extract the dicom header info
getDicomHeaderInfo = function(path)
{
img = readDICOMFile(path)
width = ncol(img$img)
height = nrow(img$img)
headers = img$hdr
patientID = extractHeader(img$hdr, 'PatientID', numeric = TRUE)
patientAge = as.integer(substr(extractHeader(img$hdr, 'PatientsAge', numeric = FALSE), 1, 3))
patientGender = as.character(extractHeader(img$hdr, 'PatientsSex', numeric = FALSE))
ps = extractHeader(img$hdr, 'PixelSpacing', numeric = FALSE)
pixelSpacingX = as.numeric(unlist(strsplit(ps, ' '))[1])
pixelSpacingY = as.numeric(unlist(strsplit(ps, ' '))[2])
seriesNum = extractHeader(img$hdr, 'SeriesNumber', numeric = TRUE)
location = round(extractHeader(img$hdr, 'SliceLocation', numeric = TRUE), digits = 1)
pathSplit = unlist(strsplit(path, '/', fixed = TRUE))
filename = pathSplit[length(pathSplit)]
prefix = unlist(strsplit(filename, '.', fixed = TRUE))[1]
frame = as.integer(unlist(strsplit(prefix, '-'))[3])
sliceNum = 0
if (length(unlist(strsplit(prefix, '-'))) &amp;gt; 3) sliceNum = as.integer(unlist(strsplit(prefix, '-'))[4])
time = extractHeader(img$hdr, 'InstanceCreationTime', numeric = FALSE)
#
return (list(
id = patientID,
age = patientAge,
gender = patientGender,
pixelSpacingX = pixelSpacingX,
pixelSpacingY = pixelSpacingY,
width = width,
height = height,
series = seriesNum,
location = location,
frame = frame,
sliceNum = sliceNum,
time = time,
path = path
))
}
############################################################################################################################

library(pacman)
pacman::p_load(oro.dicom, raster, data.table, png, flexclust, foreach, doParallel, snowfall)

# whether this run is for the training set (FALSE) or the validation set (TRUE)
useValidation = FALSE
#useValidation = TRUE

# create a benchmark histogram from an exemplar image
dicomBenchmark = readDICOM('/home/colin/data/Second-Annual-Data-Science-Bowl/train/1/study/sax_13')
images = dicomBenchmark[[2]]
img = images[[1]]
vAll = unname(unlist(images))
vAll = vAll[order(vAll)]

# loop through all of the patients
rootFolder = '/home/colin/data/Second-Annual-Data-Science-Bowl/train'
outFolder = '/home/colin/data/Second-Annual-Data-Science-Bowl/train-cleaned'
if (useValidation)
{
rootFolder = '/home/colin/data/Second-Annual-Data-Science-Bowl/validate'
outFolder = '/home/colin/data/Second-Annual-Data-Science-Bowl/validate-cleaned'
}
cases = list.dirs(rootFolder, recursive=FALSE)
#cases = cases[grep('123', cases)]
simpleData = matrix(0, 500, 3)
simpleFeatures = data.frame(id = rep(0, 500), age = rep(0, 500), gender = rep('U', 500),
pixelSpacingX = integer(500), pixelSpacingY = integer(500),
width = integer(500), height = integer(500),
numSlices = rep(0, 500), stringsAsFactors = FALSE)
allImages = NULL

sfInit(parallel=TRUE, cpus=14)
sfLibrary(oro.dicom)

for (patient in cases)
{
patientFolder = paste0(patient, '/study')
imgSequences = list.dirs(patientFolder, recursive=FALSE)
# filter for 'sax' folders
imgSequences = imgSequences[grep('sax', imgSequences, fixed = TRUE)]
nMax = 10000
imgTable = data.table(id = integer(nMax), age = integer(nMax), gender = character(nMax),
pixelSpacingX = integer(nMax), pixelSpacingY = integer(nMax),
width = integer(nMax), height = integer(nMax),
location = numeric(nMax), frame = integer(nMax), series = integer(nMax),
sliceNum = integer(nMax), time = numeric(nMax),
path = character(nMax))
for (imgFolder in imgSequences)
{
# find the first file in imgs
imgFiles = list.files(imgFolder)
if (length(imgFiles) &lt; 30)
{
print(paste0('length = ', length(imgFiles), '! in ', imgFolder))
} else {
if (length(imgFiles) %% 30 != 0)
{
print(paste0('length = ', length(imgFiles), '! in ', imgFolder))
}
}
# do the next part regardless of the number of images
paths = unlist(lapply(imgFiles, function(x) return (paste0(imgFolder, '/', x))))
result <- sfLapply(paths, getDicomHeaderInfo)
for (row in result)
{
n = sum(imgTable$id > 0) + 1
imgTable$id[n] = row$id
imgTable$age[n] = row$age
imgTable$gender[n] = row$gender
imgTable$pixelSpacingX[n] = row$pixelSpacingX
imgTable$pixelSpacingY[n] = row$pixelSpacingY
imgTable$width[n] = row$width
imgTable$height[n] = row$height
imgTable$series[n] = row$series
imgTable$location[n] = row$location
imgTable$frame[n] = row$frame
imgTable$sliceNum[n] = row$sliceNum
imgTable$time[n] = row$time
imgTable$path[n] = row$path
}
}
# remove surplus records from table
imgTable= imgTable[imgTable$id > 0]
# grab the latest image for each location and frame
latestImages = imgTable[order(id, age, gender, pixelSpacingX, pixelSpacingY, width, height, location, frame, time, sliceNum, series), .SD[c(.N)], by=c(&amp;amp;amp;quot;id&amp;amp;amp;quot;, &amp;amp;amp;quot;age&amp;amp;amp;quot;, &amp;amp;amp;quot;gender&amp;amp;amp;quot;, &amp;amp;amp;quot;pixelSpacingX&amp;amp;amp;quot;, &amp;amp;amp;quot;pixelSpacingY&amp;amp;amp;quot;, &amp;amp;amp;quot;width&amp;amp;amp;quot;, &amp;amp;amp;quot;height&amp;amp;amp;quot;, &amp;amp;amp;quot;location&amp;amp;amp;quot;, &amp;amp;amp;quot;frame&amp;amp;amp;quot;)]
#
print(paste0('Patient: ', latestImages$id[1]))
uniqueLocs = sort(unique(latestImages$location[abs(latestImages$location) > 0.000001]))
simpleFeatures[latestImages$id[1], ] = list(latestImages$id[1], latestImages$age[1], latestImages$gender[1],
latestImages$pixelSpacingX[1], latestImages$pixelSpacingY[1],
latestImages$width[1], latestImages$height[1],
length(uniqueLocs))
# create the cleaned-up images for CNN training
imgTable$cleanPath = '---'
iSax = 1
patientID = latestImages$id[1]
for (loc in uniqueLocs)
{
imgset = latestImages$path[latestImages$location == loc]
iImage = 1
for (imgPath in imgset)
{
dicomImage = readDICOMFile(imgPath)
# fix the contrast and brightness
fixedImage = rebalanceImage(dicomImage$img)
# get the details of this image
###patientID = extractHeader(dicomImage$hdr, 'PatientID', numeric = TRUE)
sliceID = iSax
imageID = iImage
outPath = paste0(outFolder, '/', patientID, '/study/sax_', formatC(sliceID, width=2, flag='0'), '/image-', formatC(imageID, width=2, flag='0'), '.png')
#print(outPath)
#plot(raster(fixedImage))
checkFolder(outFolder, as.character(patientID))
checkFolder(paste0(outFolder, '/', patientID), 'study')
checkFolder(paste0(outFolder, '/', patientID, '/study'), paste0('sax_', formatC(sliceID, width=2, flag='0')))
writePNG(fixedImage / max(fixedImage), outPath)
imgTable$cleanPath[imgTable$path == imgPath] = outPath
#
iImage = iImage + 1
}
iSax = iSax + 1
}

if (patient == cases[1])
{
allImages = imgTable
} else
{
allImages = data.frame(rbind(allImages, imgTable))
}
}

sfStop()

I stored all of the DICOM header information in a table. Some of this will be required later, when calculating the volume of the left ventricle chamber.

The next step, which will be described in my next blog, is to design a convolutional neural network that will automatically find the left ventricle in an image.

Share this:

  • Twitter
  • Facebook

Like this:

Like Loading...

Second Annual Data Science Bowl – Part 1

06 Sunday Mar 2016

Posted by Colin Priest in Automation, Image Processing, Kaggle, Medical Imaging, R

≈ 3 Comments

Tags

Automation, Image Processing, Kaggle, Medical Imaging, R

I’m currently competing in the Second Annual Data Science Bowl at Kaggle. This is by far the most difficult competition that I have entered to date. At the time of writing I am placed 62nd out of 755 entries, with only a day remaining to lock down my methodology. There’s a lot more I’d like to do to improve my model, but alas, I don’t have the time!

Here’s the problem that we are solving:

  1. We are given a set of medical images taken by MRI, across 30 time periods, and a variable number of location slices through the body.
  2. We are also given the volume of the left ventricle of the heart at times of diastole and systole.
  3. Our task is to design an automatic algorithm that inputs DICOM images and outputs a cumulative density function of the likelihood of different volumes at both diastole and systole.

The medical image files are in DICOM format, containing information about the patient (e.g. age and gender) and a set of monochrome images for each patient giving a 4 dimensional view of that patient’s chest. The key images are the “sax” (short axis) images, a set of slices perpendicular to the line that passes through the length of the heart (a heart isn’t circular, but more ovoid in shape), and there are typically 30 images for each sax, each being 1/30th the time period of a heartbeat, showing one complete cycle of the heart. There are a varying number of sax images for each patient, depending upon the length of the patient’s heart, and sometimes there are also repeated sax sets, where the scanning was repeated in an attempt to improve the image quality.

The image quality varies greatly between patients, with differing image resolutions, brightness, contrast and aspect ratio / rotation.

sax_8
sax_8

As you can see in the animated gifs above, some of the images are such poor quality that it is difficult for the human eye to discern the details. So my first challenge was to improve the brightness and contrast. One way to do this is to do a linear transformation on each image so that its pixels have a preset mean and standard deviation.

20160306image01
20160306image02
20160306image03
20160306image04

As you can see, while this approach helped, it did not work well enough for the problem images. I also tried non-linear transformations without much more success. No transformation function was flexible enough for the wide range of image qualities. Frequently part of the problem image gets washed out.

At the National Heart Centre Singapore, I spoke with Assistant Professor Calvin Chin about how doctors use imaging to assess heart volume. He explained that the very bright, washed out sections in some of the images are the result of fat deposits within the patient’s body. He also explained how to find the left ventricle chamber in an image (it is round with a thick lining surrounding it) and what to do about the dark patches inside the chamber (include them in the area of the chamber because they are blood vessels). This was really helpful. It pays to bring in some domain knowledge to a machine learning problem.

20060306image05
20060306image06

What I wanted was for all images to have similar brightness histograms. After much experimentation, I couldn’t find a transformation function that achieved this for me. But then I realised that I didn’t need to use a function – I could just use empirical histograms as my target function, mapped to my original image via the brightness ranking of each pixel. All I needed to do was select an exemplar image (or multiple exemplar images) and then order the pixel brightnesses, then map across. Here’s how I did it using R:


library(pacman)
pacman::p_load(oro.dicom)

# a function to turn the image into a vector
img2vec = function(img)
{
return (matrix(img, nrow(img) * ncol(img), 1))
}

# read a poor image and translate the pixel brightnesses
rebalanceImage = function(badImage)
{
# get the pixel brightnesses and get an index that sorts them
v = img2vec(badImage)
o2 = order(v)

# get a target histogram, allowing for the size of the bad image
vIn = sample(vAll, nrow(badImage) * ncol(badImage))
vIn = vIn[order(vIn)]

#
v2 = v
v2[o2] = vIn

# turn the piuxel vector back into an image
cleanImage = matrix(v2, nrow(badImage), ncol(badImage))

return (cleanImage)
}
# create a benchmark histogram from an exemplar image
dicomBenchmark = readDICOM('C:/Users/Colin/Dropbox/blogging/20160306 Second Annual Data Science Bowl Part 1/SADSB/1/study/sax_13')
images = dicomBenchmark[[2]]
img = images[[1]]
vAll = unname(unlist(images))
vAll = vAll[order(vAll)]

# read the raw image
dicomImage = readDICOMFile('C:/Users/Colin/Dropbox/blogging/20160306 Second Annual Data Science Bowl Part 1/SADSB/1/study/sax_8/IM-4560-0001.dcm')
# fix the contrast and brightness
fixedImage = rebalanceImage(dicomImage$img)
# read the raw image
dicomImage2 = readDICOMFile('C:/Users/Colin/Dropbox/blogging/20160306 Second Annual Data Science Bowl Part 1/SADSB/6/study/sax_8/IM-9548-0001.dcm')
# fix the contrast and brightness
fixedImage2 = rebalanceImage(dicomImage2$img)

This gave me fairly consistent brightnesses and contrast, regardless of the quality of the original images, and also prevented washed out regions. You can see the results below:

20160306image07
20160306image08
20160306image09
20160306image10

Standardising the input data helps machine learning algorithms perform better because they don’t have to waste resources figuring out how to adjust for varying inputs where that variation is not a predictive feature.

In my next blog I will describe how I used the DICOM header information to further improve the model inputs, and to create extra features for my final model.

Share this:

  • Twitter
  • Facebook

Like this:

Like Loading...

Blogroll

  • Discover New Voices
  • Discuss
  • Get Inspired
  • Get Mobile
  • Get Polling
  • Get Support
  • Great Reads
  • Learn WordPress.com
  • Theme Showcase
  • WordPress.com News
  • www.r-bloggers.com

Enter your email address to follow this blog and receive notifications of new posts by email.

Join 277 other subscribers

Blog at WordPress.com.

  • Follow Following
    • Keeping Up With The Latest Techniques
    • Join 86 other followers
    • Already have a WordPress.com account? Log in now.
    • Keeping Up With The Latest Techniques
    • Customize
    • Follow Following
    • Sign up
    • Log in
    • Report this content
    • View site in Reader
    • Manage subscriptions
    • Collapse this bar
 

Loading Comments...
 

    %d bloggers like this: