Added support for CTC in both Theano and Tensorflow along with image OCR example. (#3436)

* Added CTC to Theano and Tensorflow backend along with image OCR example

* Fixed python style issues, made data files remote, and made code more idiomatic to Keras

* Fixed a couple more style issues brought up in the original PR

* Reverted wrappers.py

* Fixed potential training-on-validation issue and removed unused imports

* Fixed PEP8 issue

* Remaining PEP8 issues fixed
This commit is contained in:
Mike Henry 2016-08-16 13:25:26 -07:00 committed by François Chollet
parent 4e155139ca
commit e8190a8d8d
4 changed files with 695 additions and 1 deletions

442
examples/image_ocr.py Normal file

@ -0,0 +1,442 @@
'''This example uses a convolutional stack followed by a recurrent stack
and a CTC logloss function to perform optical character recognition
of generated text images. I have no evidence of whether it actually
learns general shapes of text, or just is able to recognize all
the different fonts thrown at it...the purpose is more to demonstrate CTC
inside of Keras. Note that the font list may need to be updated
for the particular OS in use.
This starts off with 4 letter words. After 10 or so epochs, CTC
learns translational invariance, so longer words and groups of words
with spaces are gradually fed in. This gradual increase in difficulty
is handled using the TextImageGenerator class which is both a generator
class for test/train data and a Keras callback class. Every 10 epochs
the wordlist that the generator draws from increases in difficulty.
The table below shows normalized edit distance values. Theano uses
a slightly different CTC implementation, so some Theano-specific
hyperparameter tuning would be needed to get it to match Tensorflow.
Norm. ED
Epoch | TF | TH
------------------------
10 0.072 0.272
20 0.032 0.115
30 0.024 0.098
40 0.023 0.108
This requires cairo and editdistance packages:
pip install cairocffi
pip install editdistance
Due to the use of a dummy loss function, Theano requires the following flags:
on_unused_input='ignore'
Created by Mike Henry
https://github.com/mbhenry/
'''
import os
import itertools
import re
import datetime
import cairocffi as cairo
import editdistance
import numpy as np
from scipy import ndimage
import pylab
from keras import backend as K
from keras.layers.convolutional import Convolution2D, MaxPooling2D
from keras.layers import Input, Layer, Dense, Activation, Flatten
from keras.layers import Reshape, Lambda, merge, Permute, TimeDistributed
from keras.models import Model
from keras.layers.recurrent import GRU
from keras.optimizers import SGD
from keras.utils import np_utils
from keras.utils.data_utils import get_file
from keras.preprocessing import image
import keras.callbacks
OUTPUT_DIR = "image_ocr"
np.random.seed(55)
# this creates larger "blotches" of noise which look
# more realistic than just adding gaussian noise
# assumes greyscale with pixels ranging from 0 to 1
def speckle(img):
severity = np.random.uniform(0, 0.6)
blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
img_speck = (img + blur)
img_speck[img_speck > 1] = 1
img_speck[img_speck <= 0] = 0
return img_speck
# paints the string in a random location the bounding box
# also uses a random font, a slight random rotation,
# and a random amount of speckle noise
def paint_text(text, w, h):
surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
with cairo.Context(surface) as context:
context.set_source_rgb(1, 1, 1) # White
context.paint()
# this font list works in Centos 7
fonts = ['Century Schoolbook', 'Courier', 'STIX', 'URW Chancery L', 'FreeMono']
context.select_font_face(np.random.choice(fonts), cairo.FONT_SLANT_NORMAL,
np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL]))
context.set_font_size(40)
box = context.text_extents(text)
if box[2] > w or box[3] > h:
raise IOError('Could not fit string into image. Max char count is too large for given image width.')
# teach the RNN translational invariance by
# fitting text box randomly on canvas, with some room to rotate
border_w_h = (10, 16)
max_shift_x = w - box[2] - border_w_h[0]
max_shift_y = h - box[3] - border_w_h[1]
top_left_x = np.random.randint(0, int(max_shift_x))
top_left_y = np.random.randint(0, int(max_shift_y))
context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
context.set_source_rgb(0, 0, 0)
context.show_text(text)
buf = surface.get_data()
a = np.frombuffer(buf, np.uint8)
a.shape = (h, w, 4)
a = a[:, :, 0] # grab single channel
a /= 255
a = np.expand_dims(a, 0)
a = speckle(a)
a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
return a
def shuffle_mats_or_lists(matrix_list, stop_ind=None):
ret = []
assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
len_val = len(matrix_list[0])
if stop_ind is None:
stop_ind = len_val
assert stop_ind <= len_val
a = range(stop_ind)
np.random.shuffle(a)
a += range(stop_ind, len_val)
for mat in matrix_list:
if isinstance(mat, np.ndarray):
ret.append(mat[a])
elif isinstance(mat, list):
ret.append([mat[i] for i in a])
else:
raise TypeError('shuffle_mats_or_lists only supports numpy.array and list objects')
return ret
def text_to_labels(text, num_classes):
ret = []
for char in text:
if char >= 'a' and char <= 'z':
ret.append(ord(char) - ord('a'))
elif char == ' ':
ret.append(26)
return ret
# only a-z and space..probably not to difficult
# to expand to uppercase and symbols
def is_valid_str(in_str):
search = re.compile(r'[^a-z\ ]').search
return not bool(search(in_str))
# Uses generator functions to supply train/test with
# data. Image renderings are text are created on the fly
# each time with random perturbations
class TextImageGenerator(keras.callbacks.Callback):
def __init__(self, monogram_file, bigram_file, minibatch_size, img_w,
img_h, downsample_width, val_split,
absolute_max_string_len=16):
self.minibatch_size = minibatch_size
self.img_w = img_w
self.img_h = img_h
self.monogram_file = monogram_file
self.bigram_file = bigram_file
self.downsample_width = downsample_width
self.val_split = val_split
self.blank_label = self.get_output_size() - 1
self.absolute_max_string_len = absolute_max_string_len
def get_output_size(self):
return 28
# num_words can be independent of the epoch size due to the use of generators
# as max_string_len grows, num_words can grow
def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
assert max_string_len <= self.absolute_max_string_len
assert num_words % self.minibatch_size == 0
assert (self.val_split * num_words) % self.minibatch_size == 0
self.num_words = num_words
self.string_list = []
self.max_string_len = max_string_len
self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
self.X_text = []
self.Y_len = [0] * self.num_words
# monogram file is sorted by frequency in english speech
with open(self.monogram_file, 'rt') as f:
for line in f:
if len(self.string_list) == int(self.num_words * mono_fraction):
break
word = line.rstrip()
if max_string_len == -1 or max_string_len is None or len(word) <= max_string_len:
self.string_list.append(word)
# bigram file contains common word pairings in english speech
with open(self.bigram_file, 'rt') as f:
lines = f.readlines()
for line in lines:
if len(self.string_list) == self.num_words:
break
columns = line.lower().split()
word = columns[0] + ' ' + columns[1]
if is_valid_str(word) and \
(max_string_len == -1 or max_string_len is None or len(word) <= max_string_len):
self.string_list.append(word)
if len(self.string_list) != self.num_words:
raise IOError('Could not pull enough words from supplied monogram and bigram files. ')
for i, word in enumerate(self.string_list):
self.Y_len[i] = len(word)
self.Y_data[i, 0:len(word)] = text_to_labels(word, self.get_output_size())
self.X_text.append(word)
self.Y_len = np.expand_dims(np.array(self.Y_len), 1)
self.cur_val_index = self.val_split
self.cur_train_index = 0
# each time an image is requested from train/val/test, a new random
# painting of the text is performed
def get_batch(self, index, size, train):
X_data = np.ones([size, 1, self.img_h, self.img_w])
labels = np.ones([size, self.absolute_max_string_len])
input_length = np.zeros([size, 1])
label_length = np.zeros([size, 1])
source_str = []
for i in range(0, size):
# Mix in some blank inputs. This seems to be important for
# achieving translational invariance
if train and i > size - 4:
X_data[i, 0, :, :] = paint_text('', self.img_w, self.img_h)
labels[i, 0] = self.blank_label
input_length[i] = self.downsample_width
label_length[i] = 1
source_str.append('')
else:
X_data[i, 0, :, :] = paint_text(self.X_text[index + i], self.img_w, self.img_h)
labels[i, :] = self.Y_data[index + i]
input_length[i] = self.downsample_width
label_length[i] = self.Y_len[index + i]
source_str.append(self.X_text[index + i])
inputs = {'the_input': X_data,
'the_labels': labels,
'input_length': input_length,
'label_length': label_length,
'source_str': source_str # used for visualization only
}
outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function
return (inputs, outputs)
def next_train(self):
while 1:
ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
self.cur_train_index += self.minibatch_size
if self.cur_train_index >= self.val_split:
self.cur_train_index = self.cur_train_index % 32
(self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
[self.X_text, self.Y_data, self.Y_len], self.val_split)
yield ret
def next_val(self):
while 1:
ret = self.get_batch(self.cur_val_index, self.minibatch_size, train=False)
self.cur_val_index += self.minibatch_size
if self.cur_val_index >= self.num_words:
self.cur_val_index = self.val_split + self.cur_val_index % 32
yield ret
def on_train_begin(self, logs={}):
# translational invariance seems to be the hardest thing
# for the RNN to learn, so start with <= 4 letter words.
self.build_word_list(16000, 4, 1)
def on_epoch_begin(self, epoch, logs={}):
# After 10 epochs, translational invariance should be learned
# so start feeding longer words and eventually multiple words with spaces
if epoch == 10:
self.build_word_list(32000, 8, 1)
if epoch == 20:
self.build_word_list(32000, 8, 0.6)
if epoch == 30:
self.build_word_list(64000, 12, 0.5)
# the actual loss calc occurs here despite it not being
# an internal Keras loss function
def ctc_lambda_func(args):
y_pred, labels, input_length, label_length = args
# the 2 is critical here since the first couple outputs of the RNN
# tend to be garbage:
y_pred = y_pred[:, 2:, :]
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
# For a real OCR application, this should be beam search with a dictionary
# and language model. For this example, best path is sufficient.
def decode_batch(test_func, word_batch):
out = test_func([word_batch])[0]
ret = []
for j in range(out.shape[0]):
out_best = list(np.argmax(out[j, 2:], 1))
out_best = [k for k, g in itertools.groupby(out_best)]
# 26 is space, 27 is CTC blank char
outstr = ''
for c in out_best:
if c >= 0 and c < 26:
outstr += chr(c + ord('a'))
elif c == 26:
outstr += ' '
ret.append(outstr)
return ret
class VizCallback(keras.callbacks.Callback):
def __init__(self, test_func, text_img_gen, num_display_words = 6):
self.test_func = test_func
self.output_dir = os.path.join(
OUTPUT_DIR, datetime.datetime.now().strftime('%A, %d. %B %Y %I.%M%p'))
self.text_img_gen = text_img_gen
self.num_display_words = num_display_words
os.makedirs(self.output_dir)
def show_edit_distance(self, num):
num_left = num
mean_norm_ed = 0.0
mean_ed = 0.0
while num_left > 0:
word_batch = next(self.text_img_gen)[0]
num_proc = min(word_batch['the_input'].shape[0], num_left)
decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc])
for j in range(0, num_proc):
edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j])
mean_ed += float(edit_dist)
mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
num_left -= num_proc
mean_norm_ed = mean_norm_ed / num
mean_ed = mean_ed / num
print('\nOut of %d samples: Mean edit distance: %.3f Mean normalized edit distance: %0.3f'
% (num, mean_ed, mean_norm_ed))
def on_epoch_end(self, epoch, logs={}):
self.model.save_weights(os.path.join(self.output_dir, 'weights%02d.h5' % epoch))
self.show_edit_distance(256)
word_batch = next(self.text_img_gen)[0]
res = decode_batch(self.test_func, word_batch['the_input'][0:self.num_display_words])
for i in range(self.num_display_words):
pylab.subplot(self.num_display_words, 1, i + 1)
pylab.imshow(word_batch['the_input'][i, 0, :, :], cmap='Greys_r')
pylab.xlabel('Truth = \'%s\' Decoded = \'%s\'' % (word_batch['source_str'][i], res[i]))
fig = pylab.gcf()
fig.set_size_inches(10, 12)
pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % epoch))
pylab.close()
# Input Parameters
img_h = 64
img_w = 512
nb_epoch = 50
minibatch_size = 32
words_per_epoch = 16000
val_split = 0.2
val_words = int(words_per_epoch * (val_split))
# Network parameters
conv_num_filters = 16
filter_size = 3
pool_size_1 = 4
pool_size_2 = 2
time_dense_size = 32
rnn_size = 512
time_steps = img_w / (pool_size_1 * pool_size_2)
fdir = os.path.dirname(get_file('wordlists.tgz',
origin='http://www.isosemi.com/datasets/wordlists.tgz', untar=True))
img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
minibatch_size=32,
img_w=img_w,
img_h=img_h,
downsample_width=img_w / (pool_size_1 * pool_size_2) - 2,
val_split=words_per_epoch - val_words)
act = 'relu'
input_data = Input(name='the_input', shape=(1, img_h, img_w), dtype='float32')
inner = Convolution2D(conv_num_filters, filter_size, filter_size, border_mode='same',
activation=act, input_shape=(1, img_h, img_w), name='conv1')(input_data)
inner = MaxPooling2D(pool_size=(pool_size_1, pool_size_1), name='max1')(inner)
inner = Convolution2D(conv_num_filters, filter_size, filter_size, border_mode='same',
activation=act, name='conv2')(inner)
inner = MaxPooling2D(pool_size=(pool_size_2, pool_size_2), name='max2')(inner)
conv_to_rnn_dims = ((img_h / (pool_size_1 * pool_size_2)) * conv_num_filters, img_w / (pool_size_1 * pool_size_2))
inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
inner = Permute(dims=(2, 1), name='permute')(inner)
# cuts down input size going into RNN:
inner = TimeDistributed(Dense(time_dense_size, activation=act, name='dense1'))(inner)
# Two layers of bidirecitonal GRUs
# GRU seems to work as well, if not better than LSTM:
gru_1 = GRU(rnn_size, return_sequences=True, name='gru1')(inner)
gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, name='gru1_b')(inner)
gru1_merged = merge([gru_1, gru_1b], mode='sum')
gru_2 = GRU(rnn_size, return_sequences=True, name='gru2')(gru1_merged)
gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True)(gru1_merged)
# transforms RNN output to character activations:
inner = TimeDistributed(Dense(img_gen.get_output_size(), name='dense2'))(merge([gru_2, gru_2b], mode='concat'))
y_pred = Activation('softmax', name='softmax')(inner)
Model(input=[input_data], output=y_pred).summary()
labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32')
input_length = Input(name='input_length', shape=[1], dtype='int64')
label_length = Input(name='label_length', shape=[1], dtype='int64')
# Keras doesn't currently support loss funcs with extra parameters
# so CTC loss is implemented in a lambda layer
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name="ctc")([y_pred, labels, input_length, label_length])
lr = 0.03
# clipnorm seems to speeds up convergence
clipnorm = 5
sgd = SGD(lr=lr, decay=3e-7, momentum=0.9, nesterov=True, clipnorm=clipnorm)
model = Model(input=[input_data, labels, input_length, label_length], output=[loss_out])
# the loss calc occurs elsewhere, so use a dummy lambda func for the loss
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
# captures output of softmax so we can decode the output during visualization
test_func = K.function([input_data], [y_pred])
viz_cb = VizCallback(test_func, img_gen.next_val())
model.fit_generator(generator=img_gen.next_train(), samples_per_epoch=(words_per_epoch - val_words),
nb_epoch=nb_epoch, validation_data=img_gen.next_val(), nb_val_samples=val_words,
callbacks=[viz_cb, img_gen])

@ -1586,3 +1586,112 @@ def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
tf.ones(shape, dtype=dtype),
tf.zeros(shape, dtype=dtype))
# CTC
# tensorflow has a native implemenation, but it uses sparse tensors
# and therefore requires a wrapper for Keras. The functions below convert
# dense to sparse tensors and also wraps up the beam search code that is
# in tensorflow's CTC implementation
def ctc_label_dense_to_sparse(labels, label_lengths):
# undocumented feature soon to be made public
from tensorflow.python.ops import functional_ops
label_shape = tf.shape(labels)
num_batches_tns = tf.pack([label_shape[0]])
max_num_labels_tns = tf.pack([label_shape[1]])
def range_less_than(previous_state, current_input):
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
dense_mask = functional_ops.scan(range_less_than, label_lengths,
initializer=init, parallel_iterations=1)
dense_mask = dense_mask[:, 0, :]
label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns),
label_shape)
label_ind = tf.boolean_mask(label_array, dense_mask)
batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]),
max_num_labels_tns), tf.reverse(label_shape, [True])))
batch_ind = tf.boolean_mask(batch_array, dense_mask)
indices = tf.transpose(tf.reshape(tf.concat(0, [batch_ind, label_ind]), [2,-1]))
vals_sparse = tf.gather_nd(labels, indices)
return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
'''Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor (samples, max_string_length) containing the truth labels
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
or output of the softmax
input_length: tensor (samples,1) containing the sequence length for
each batch item in y_pred
label_length: tensor (samples,1) containing the sequence length for
each batch item in y_true
# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element
'''
label_length = tf.to_int32(tf.squeeze(label_length))
input_length = tf.to_int32(tf.squeeze(input_length))
sparse_labels = tf.to_int32(ctc_label_dense_to_sparse(y_true, label_length))
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
return tf.expand_dims(tf.contrib.ctc.ctc_loss(inputs = y_pred,
labels = sparse_labels,
sequence_length = input_length), 1)
def ctc_decode(y_pred, input_length, greedy = True, beam_width = None,
dict_seq_lens = None, dict_values = None):
'''Decodes the output of a softmax using either
greedy (also known as best path) or a constrained dictionary
search.
# Arguments
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
or output of the softmax
input_length: tensor (samples,1) containing the sequence length for
each batch item in y_pred
greedy: perform much faster best-path search if true. This does
not use a dictionary
beam_width: if greedy is false and this value is not none, then
the constrained dictionary search uses a beam of this width
dict_seq_lens: the length of each element in the dict_values list
dict_values: list of lists representing the dictionary.
# Returns
Tensor with shape (samples,time_steps,num_categories) containing the
path probabilities (in softmax output format). Note that a function that
pulls out the argmax and collapses blank labels is still needed.
'''
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
input_length = tf.to_int32(tf.squeeze(input_length))
if greedy:
(decoded, log_prob) = tf.contrib.ctc.ctc_greedy_decoder(
inputs = y_pred,
sequence_length = input_length)
else:
if beam_width is not None:
(decoded, log_prob) = tf.contrib.ctc.ctc_beam_search_decoder(
inputs = y_pred,
sequence_length = input_length,
dict_seq_lens = dict_seq_lens, dict_values = dict_values)
else:
(decoded, log_prob) = tf.contrib.ctc.ctc_beam_search_decoder(
inputs = y_pred,
sequence_length = input_length, beam_width = beam_width,
dict_seq_lens = dict_seq_lens, dict_values = dict_values)
decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value = -1)
for st in decoded]
return (decoded_dense, log_prob)

@ -1319,3 +1319,105 @@ def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
seed = np.random.randint(1, 10e6)
rng = RandomStreams(seed=seed)
return rng.binomial(shape, p=p, dtype=dtype)
# Theano implementation of CTC
# Used with permission from Shawn Tan
# https://github.com/shawntan/
# Note that tensorflow's native CTC code is significantly
# faster than this
def ctc_interleave_blanks(Y):
Y_ = T.alloc(-1, Y.shape[0] * 2 + 1)
Y_ = T.set_subtensor(Y_[T.arange(Y.shape[0]) * 2 + 1], Y)
return Y_
def ctc_create_skip_idxs(Y):
skip_idxs = T.arange((Y.shape[0] - 3) // 2) * 2 + 1
non_repeats = T.neq(Y[skip_idxs], Y[skip_idxs + 2])
return skip_idxs[non_repeats.nonzero()]
def ctc_update_log_p(skip_idxs, zeros, active, log_p_curr, log_p_prev):
active_skip_idxs = skip_idxs[(skip_idxs < active).nonzero()]
active_next = T.cast(T.minimum(
T.maximum(
active + 1,
T.max(T.concatenate([active_skip_idxs, [-1]])) + 2 + 1
), log_p_curr.shape[0]), 'int32')
common_factor = T.max(log_p_prev[:active])
p_prev = T.exp(log_p_prev[:active] - common_factor)
_p_prev = zeros[:active_next]
# copy over
_p_prev = T.set_subtensor(_p_prev[:active], p_prev)
# previous transitions
_p_prev = T.inc_subtensor(_p_prev[1:], _p_prev[:-1])
# skip transitions
_p_prev = T.inc_subtensor(_p_prev[active_skip_idxs + 2], p_prev[active_skip_idxs])
updated_log_p_prev = T.log(_p_prev) + common_factor
log_p_next = T.set_subtensor(
zeros[:active_next],
log_p_curr[:active_next] + updated_log_p_prev
)
return active_next, log_p_next
def ctc_path_probs(predict, Y, alpha=1e-4):
smoothed_predict = (1 - alpha) * predict[:, Y] + alpha * np.float32(1.) / Y.shape[0]
L = T.log(smoothed_predict)
zeros = T.zeros_like(L[0])
base = T.set_subtensor(zeros[:1], np.float32(1))
log_first = zeros
f_skip_idxs = ctc_create_skip_idxs(Y)
b_skip_idxs = ctc_create_skip_idxs(Y[::-1]) # there should be a shortcut to calculating this
def step(log_f_curr, log_b_curr, f_active, log_f_prev, b_active, log_b_prev):
f_active_next, log_f_next = ctc_update_log_p(f_skip_idxs, zeros, f_active, log_f_curr, log_f_prev)
b_active_next, log_b_next = ctc_update_log_p(b_skip_idxs, zeros, b_active, log_b_curr, log_b_prev)
return f_active_next, log_f_next, b_active_next, log_b_next
[f_active, log_f_probs, b_active, log_b_probs], _ = theano.scan(
step, sequences=[L, L[::-1, ::-1]], outputs_info=[np.int32(1), log_first, np.int32(1), log_first])
idxs = T.arange(L.shape[1]).dimshuffle('x', 0)
mask = (idxs < f_active.dimshuffle(0, 'x')) & (idxs < b_active.dimshuffle(0, 'x'))[::-1, ::-1]
log_probs = log_f_probs + log_b_probs[::-1, ::-1] - L
return log_probs, mask
def ctc_cost(predict, Y):
log_probs, mask = ctc_path_probs(predict, ctc_interleave_blanks(Y))
common_factor = T.max(log_probs)
total_log_prob = T.log(T.sum(T.exp(log_probs - common_factor)[mask.nonzero()])) + common_factor
return -total_log_prob
# batchifies original CTC code
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
'''Runs CTC loss algorithm on each batch element.
# Arguments
y_true: tensor (samples, max_string_length) containing the truth labels
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
or output of the softmax
input_length: tensor (samples,1) containing the sequence length for
each batch item in y_pred
label_length: tensor (samples,1) containing the sequence length for
each batch item in y_true
# Returns
Tensor with shape (samples,1) containing the
CTC loss of each element
'''
def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step):
y_pred_step = y_pred_step[0: input_length_step[0]]
y_true_step = y_true_step[0:label_length_step[0]]
return ctc_cost(y_pred_step, y_true_step)
ret, _ = theano.scan(
fn = ctc_step,
outputs_info=None,
sequences=[y_true, y_pred, input_length, label_length]
)
ret = ret.dimshuffle('x', 0)
return ret

@ -581,6 +581,48 @@ class TestBackend(object):
assert(np.max(rand) == 1)
assert(np.min(rand) == 0)
def test_ctc(self):
# simplified version of TensorFlow's test
label_lens = np.expand_dims(np.asarray([5, 4]), 1)
input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps
# the Theano and Tensorflow CTC code use different methods to ensure
# numerical stability. The Theano code subtracts out the max
# before the final log, so the results are different but scale
# identically and still train properly
loss_log_probs_tf = [3.34211, 5.42262]
loss_log_probs_th = [1.73308, 3.81351]
# dimensions are batch x time x categories
labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]])
inputs = np.asarray(
[[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
[[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
[0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
[0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
[0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
[0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]],
dtype=np.float32)
labels_tf = KTF.variable(labels, dtype="int32")
inputs_tf = KTF.variable(inputs, dtype="float32")
input_lens_tf = KTF.variable(input_lens, dtype="int32")
label_lens_tf = KTF.variable(label_lens, dtype="int32")
res = KTF.eval(KTF.ctc_batch_cost(labels_tf, inputs_tf, input_lens_tf, label_lens_tf))
assert_allclose(res[:, 0], loss_log_probs_tf, atol=1e-05)
labels_th = KTH.variable(labels, dtype="int32")
inputs_th = KTH.variable(inputs, dtype="float32")
input_lens_th = KTH.variable(input_lens, dtype="int32")
label_lens_th = KTH.variable(label_lens, dtype="int32")
res = KTH.eval(KTH.ctc_batch_cost(labels_th, inputs_th, input_lens_th, label_lens_th))
assert_allclose(res[0, :], loss_log_probs_th, atol=1e-05)
def test_one_hot(self):
input_length = 10
nb_classes = 20
@ -591,6 +633,5 @@ class TestBackend(object):
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
assert np.all(koh == oh)
if __name__ == '__main__':
pytest.main([__file__])