Added support for CTC in both Theano and Tensorflow along with image OCR example. (#3436)
* Added CTC to Theano and Tensorflow backend along with image OCR example * Fixed python style issues, made data files remote, and made code more idiomatic to Keras * Fixed a couple more style issues brought up in the original PR * Reverted wrappers.py * Fixed potential training-on-validation issue and removed unused imports * Fixed PEP8 issue * Remaining PEP8 issues fixed
This commit is contained in:
parent
4e155139ca
commit
e8190a8d8d
442
examples/image_ocr.py
Normal file
442
examples/image_ocr.py
Normal file
@ -0,0 +1,442 @@
|
|||||||
|
'''This example uses a convolutional stack followed by a recurrent stack
|
||||||
|
and a CTC logloss function to perform optical character recognition
|
||||||
|
of generated text images. I have no evidence of whether it actually
|
||||||
|
learns general shapes of text, or just is able to recognize all
|
||||||
|
the different fonts thrown at it...the purpose is more to demonstrate CTC
|
||||||
|
inside of Keras. Note that the font list may need to be updated
|
||||||
|
for the particular OS in use.
|
||||||
|
|
||||||
|
This starts off with 4 letter words. After 10 or so epochs, CTC
|
||||||
|
learns translational invariance, so longer words and groups of words
|
||||||
|
with spaces are gradually fed in. This gradual increase in difficulty
|
||||||
|
is handled using the TextImageGenerator class which is both a generator
|
||||||
|
class for test/train data and a Keras callback class. Every 10 epochs
|
||||||
|
the wordlist that the generator draws from increases in difficulty.
|
||||||
|
|
||||||
|
The table below shows normalized edit distance values. Theano uses
|
||||||
|
a slightly different CTC implementation, so some Theano-specific
|
||||||
|
hyperparameter tuning would be needed to get it to match Tensorflow.
|
||||||
|
|
||||||
|
Norm. ED
|
||||||
|
Epoch | TF | TH
|
||||||
|
------------------------
|
||||||
|
10 0.072 0.272
|
||||||
|
20 0.032 0.115
|
||||||
|
30 0.024 0.098
|
||||||
|
40 0.023 0.108
|
||||||
|
|
||||||
|
This requires cairo and editdistance packages:
|
||||||
|
pip install cairocffi
|
||||||
|
pip install editdistance
|
||||||
|
|
||||||
|
Due to the use of a dummy loss function, Theano requires the following flags:
|
||||||
|
on_unused_input='ignore'
|
||||||
|
|
||||||
|
Created by Mike Henry
|
||||||
|
https://github.com/mbhenry/
|
||||||
|
'''
|
||||||
|
|
||||||
|
import os
|
||||||
|
import itertools
|
||||||
|
import re
|
||||||
|
import datetime
|
||||||
|
import cairocffi as cairo
|
||||||
|
import editdistance
|
||||||
|
import numpy as np
|
||||||
|
from scipy import ndimage
|
||||||
|
import pylab
|
||||||
|
from keras import backend as K
|
||||||
|
from keras.layers.convolutional import Convolution2D, MaxPooling2D
|
||||||
|
from keras.layers import Input, Layer, Dense, Activation, Flatten
|
||||||
|
from keras.layers import Reshape, Lambda, merge, Permute, TimeDistributed
|
||||||
|
from keras.models import Model
|
||||||
|
from keras.layers.recurrent import GRU
|
||||||
|
from keras.optimizers import SGD
|
||||||
|
from keras.utils import np_utils
|
||||||
|
from keras.utils.data_utils import get_file
|
||||||
|
from keras.preprocessing import image
|
||||||
|
import keras.callbacks
|
||||||
|
|
||||||
|
OUTPUT_DIR = "image_ocr"
|
||||||
|
|
||||||
|
np.random.seed(55)
|
||||||
|
|
||||||
|
# this creates larger "blotches" of noise which look
|
||||||
|
# more realistic than just adding gaussian noise
|
||||||
|
# assumes greyscale with pixels ranging from 0 to 1
|
||||||
|
|
||||||
|
def speckle(img):
|
||||||
|
severity = np.random.uniform(0, 0.6)
|
||||||
|
blur = ndimage.gaussian_filter(np.random.randn(*img.shape) * severity, 1)
|
||||||
|
img_speck = (img + blur)
|
||||||
|
img_speck[img_speck > 1] = 1
|
||||||
|
img_speck[img_speck <= 0] = 0
|
||||||
|
return img_speck
|
||||||
|
|
||||||
|
# paints the string in a random location the bounding box
|
||||||
|
# also uses a random font, a slight random rotation,
|
||||||
|
# and a random amount of speckle noise
|
||||||
|
|
||||||
|
def paint_text(text, w, h):
|
||||||
|
surface = cairo.ImageSurface(cairo.FORMAT_RGB24, w, h)
|
||||||
|
with cairo.Context(surface) as context:
|
||||||
|
context.set_source_rgb(1, 1, 1) # White
|
||||||
|
context.paint()
|
||||||
|
# this font list works in Centos 7
|
||||||
|
fonts = ['Century Schoolbook', 'Courier', 'STIX', 'URW Chancery L', 'FreeMono']
|
||||||
|
context.select_font_face(np.random.choice(fonts), cairo.FONT_SLANT_NORMAL,
|
||||||
|
np.random.choice([cairo.FONT_WEIGHT_BOLD, cairo.FONT_WEIGHT_NORMAL]))
|
||||||
|
context.set_font_size(40)
|
||||||
|
box = context.text_extents(text)
|
||||||
|
if box[2] > w or box[3] > h:
|
||||||
|
raise IOError('Could not fit string into image. Max char count is too large for given image width.')
|
||||||
|
|
||||||
|
# teach the RNN translational invariance by
|
||||||
|
# fitting text box randomly on canvas, with some room to rotate
|
||||||
|
border_w_h = (10, 16)
|
||||||
|
max_shift_x = w - box[2] - border_w_h[0]
|
||||||
|
max_shift_y = h - box[3] - border_w_h[1]
|
||||||
|
top_left_x = np.random.randint(0, int(max_shift_x))
|
||||||
|
top_left_y = np.random.randint(0, int(max_shift_y))
|
||||||
|
|
||||||
|
context.move_to(top_left_x - int(box[0]), top_left_y - int(box[1]))
|
||||||
|
context.set_source_rgb(0, 0, 0)
|
||||||
|
context.show_text(text)
|
||||||
|
|
||||||
|
buf = surface.get_data()
|
||||||
|
a = np.frombuffer(buf, np.uint8)
|
||||||
|
a.shape = (h, w, 4)
|
||||||
|
a = a[:, :, 0] # grab single channel
|
||||||
|
a /= 255
|
||||||
|
a = np.expand_dims(a, 0)
|
||||||
|
a = speckle(a)
|
||||||
|
a = image.random_rotation(a, 3 * (w - top_left_x) / w + 1)
|
||||||
|
|
||||||
|
return a
|
||||||
|
|
||||||
|
def shuffle_mats_or_lists(matrix_list, stop_ind=None):
|
||||||
|
ret = []
|
||||||
|
assert all([len(i) == len(matrix_list[0]) for i in matrix_list])
|
||||||
|
len_val = len(matrix_list[0])
|
||||||
|
if stop_ind is None:
|
||||||
|
stop_ind = len_val
|
||||||
|
assert stop_ind <= len_val
|
||||||
|
|
||||||
|
a = range(stop_ind)
|
||||||
|
np.random.shuffle(a)
|
||||||
|
a += range(stop_ind, len_val)
|
||||||
|
for mat in matrix_list:
|
||||||
|
if isinstance(mat, np.ndarray):
|
||||||
|
ret.append(mat[a])
|
||||||
|
elif isinstance(mat, list):
|
||||||
|
ret.append([mat[i] for i in a])
|
||||||
|
else:
|
||||||
|
raise TypeError('shuffle_mats_or_lists only supports numpy.array and list objects')
|
||||||
|
return ret
|
||||||
|
|
||||||
|
def text_to_labels(text, num_classes):
|
||||||
|
ret = []
|
||||||
|
for char in text:
|
||||||
|
if char >= 'a' and char <= 'z':
|
||||||
|
ret.append(ord(char) - ord('a'))
|
||||||
|
elif char == ' ':
|
||||||
|
ret.append(26)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
# only a-z and space..probably not to difficult
|
||||||
|
# to expand to uppercase and symbols
|
||||||
|
|
||||||
|
def is_valid_str(in_str):
|
||||||
|
search = re.compile(r'[^a-z\ ]').search
|
||||||
|
return not bool(search(in_str))
|
||||||
|
|
||||||
|
# Uses generator functions to supply train/test with
|
||||||
|
# data. Image renderings are text are created on the fly
|
||||||
|
# each time with random perturbations
|
||||||
|
|
||||||
|
class TextImageGenerator(keras.callbacks.Callback):
|
||||||
|
|
||||||
|
def __init__(self, monogram_file, bigram_file, minibatch_size, img_w,
|
||||||
|
img_h, downsample_width, val_split,
|
||||||
|
absolute_max_string_len=16):
|
||||||
|
|
||||||
|
self.minibatch_size = minibatch_size
|
||||||
|
self.img_w = img_w
|
||||||
|
self.img_h = img_h
|
||||||
|
self.monogram_file = monogram_file
|
||||||
|
self.bigram_file = bigram_file
|
||||||
|
self.downsample_width = downsample_width
|
||||||
|
self.val_split = val_split
|
||||||
|
self.blank_label = self.get_output_size() - 1
|
||||||
|
self.absolute_max_string_len = absolute_max_string_len
|
||||||
|
|
||||||
|
def get_output_size(self):
|
||||||
|
return 28
|
||||||
|
|
||||||
|
# num_words can be independent of the epoch size due to the use of generators
|
||||||
|
# as max_string_len grows, num_words can grow
|
||||||
|
def build_word_list(self, num_words, max_string_len=None, mono_fraction=0.5):
|
||||||
|
assert max_string_len <= self.absolute_max_string_len
|
||||||
|
assert num_words % self.minibatch_size == 0
|
||||||
|
assert (self.val_split * num_words) % self.minibatch_size == 0
|
||||||
|
self.num_words = num_words
|
||||||
|
self.string_list = []
|
||||||
|
self.max_string_len = max_string_len
|
||||||
|
self.Y_data = np.ones([self.num_words, self.absolute_max_string_len]) * -1
|
||||||
|
self.X_text = []
|
||||||
|
self.Y_len = [0] * self.num_words
|
||||||
|
|
||||||
|
# monogram file is sorted by frequency in english speech
|
||||||
|
with open(self.monogram_file, 'rt') as f:
|
||||||
|
for line in f:
|
||||||
|
if len(self.string_list) == int(self.num_words * mono_fraction):
|
||||||
|
break
|
||||||
|
word = line.rstrip()
|
||||||
|
if max_string_len == -1 or max_string_len is None or len(word) <= max_string_len:
|
||||||
|
self.string_list.append(word)
|
||||||
|
|
||||||
|
# bigram file contains common word pairings in english speech
|
||||||
|
with open(self.bigram_file, 'rt') as f:
|
||||||
|
lines = f.readlines()
|
||||||
|
for line in lines:
|
||||||
|
if len(self.string_list) == self.num_words:
|
||||||
|
break
|
||||||
|
columns = line.lower().split()
|
||||||
|
word = columns[0] + ' ' + columns[1]
|
||||||
|
if is_valid_str(word) and \
|
||||||
|
(max_string_len == -1 or max_string_len is None or len(word) <= max_string_len):
|
||||||
|
self.string_list.append(word)
|
||||||
|
if len(self.string_list) != self.num_words:
|
||||||
|
raise IOError('Could not pull enough words from supplied monogram and bigram files. ')
|
||||||
|
|
||||||
|
for i, word in enumerate(self.string_list):
|
||||||
|
self.Y_len[i] = len(word)
|
||||||
|
self.Y_data[i, 0:len(word)] = text_to_labels(word, self.get_output_size())
|
||||||
|
self.X_text.append(word)
|
||||||
|
self.Y_len = np.expand_dims(np.array(self.Y_len), 1)
|
||||||
|
|
||||||
|
self.cur_val_index = self.val_split
|
||||||
|
self.cur_train_index = 0
|
||||||
|
|
||||||
|
# each time an image is requested from train/val/test, a new random
|
||||||
|
# painting of the text is performed
|
||||||
|
def get_batch(self, index, size, train):
|
||||||
|
X_data = np.ones([size, 1, self.img_h, self.img_w])
|
||||||
|
labels = np.ones([size, self.absolute_max_string_len])
|
||||||
|
input_length = np.zeros([size, 1])
|
||||||
|
label_length = np.zeros([size, 1])
|
||||||
|
source_str = []
|
||||||
|
|
||||||
|
for i in range(0, size):
|
||||||
|
# Mix in some blank inputs. This seems to be important for
|
||||||
|
# achieving translational invariance
|
||||||
|
if train and i > size - 4:
|
||||||
|
X_data[i, 0, :, :] = paint_text('', self.img_w, self.img_h)
|
||||||
|
labels[i, 0] = self.blank_label
|
||||||
|
input_length[i] = self.downsample_width
|
||||||
|
label_length[i] = 1
|
||||||
|
source_str.append('')
|
||||||
|
else:
|
||||||
|
X_data[i, 0, :, :] = paint_text(self.X_text[index + i], self.img_w, self.img_h)
|
||||||
|
labels[i, :] = self.Y_data[index + i]
|
||||||
|
input_length[i] = self.downsample_width
|
||||||
|
label_length[i] = self.Y_len[index + i]
|
||||||
|
source_str.append(self.X_text[index + i])
|
||||||
|
|
||||||
|
inputs = {'the_input': X_data,
|
||||||
|
'the_labels': labels,
|
||||||
|
'input_length': input_length,
|
||||||
|
'label_length': label_length,
|
||||||
|
'source_str': source_str # used for visualization only
|
||||||
|
}
|
||||||
|
outputs = {'ctc': np.zeros([size])} # dummy data for dummy loss function
|
||||||
|
return (inputs, outputs)
|
||||||
|
|
||||||
|
def next_train(self):
|
||||||
|
while 1:
|
||||||
|
ret = self.get_batch(self.cur_train_index, self.minibatch_size, train=True)
|
||||||
|
self.cur_train_index += self.minibatch_size
|
||||||
|
if self.cur_train_index >= self.val_split:
|
||||||
|
self.cur_train_index = self.cur_train_index % 32
|
||||||
|
(self.X_text, self.Y_data, self.Y_len) = shuffle_mats_or_lists(
|
||||||
|
[self.X_text, self.Y_data, self.Y_len], self.val_split)
|
||||||
|
yield ret
|
||||||
|
|
||||||
|
def next_val(self):
|
||||||
|
while 1:
|
||||||
|
ret = self.get_batch(self.cur_val_index, self.minibatch_size, train=False)
|
||||||
|
self.cur_val_index += self.minibatch_size
|
||||||
|
if self.cur_val_index >= self.num_words:
|
||||||
|
self.cur_val_index = self.val_split + self.cur_val_index % 32
|
||||||
|
yield ret
|
||||||
|
|
||||||
|
def on_train_begin(self, logs={}):
|
||||||
|
# translational invariance seems to be the hardest thing
|
||||||
|
# for the RNN to learn, so start with <= 4 letter words.
|
||||||
|
self.build_word_list(16000, 4, 1)
|
||||||
|
|
||||||
|
def on_epoch_begin(self, epoch, logs={}):
|
||||||
|
# After 10 epochs, translational invariance should be learned
|
||||||
|
# so start feeding longer words and eventually multiple words with spaces
|
||||||
|
if epoch == 10:
|
||||||
|
self.build_word_list(32000, 8, 1)
|
||||||
|
if epoch == 20:
|
||||||
|
self.build_word_list(32000, 8, 0.6)
|
||||||
|
if epoch == 30:
|
||||||
|
self.build_word_list(64000, 12, 0.5)
|
||||||
|
|
||||||
|
# the actual loss calc occurs here despite it not being
|
||||||
|
# an internal Keras loss function
|
||||||
|
|
||||||
|
def ctc_lambda_func(args):
|
||||||
|
y_pred, labels, input_length, label_length = args
|
||||||
|
# the 2 is critical here since the first couple outputs of the RNN
|
||||||
|
# tend to be garbage:
|
||||||
|
y_pred = y_pred[:, 2:, :]
|
||||||
|
return K.ctc_batch_cost(labels, y_pred, input_length, label_length)
|
||||||
|
|
||||||
|
# For a real OCR application, this should be beam search with a dictionary
|
||||||
|
# and language model. For this example, best path is sufficient.
|
||||||
|
|
||||||
|
def decode_batch(test_func, word_batch):
|
||||||
|
out = test_func([word_batch])[0]
|
||||||
|
ret = []
|
||||||
|
for j in range(out.shape[0]):
|
||||||
|
out_best = list(np.argmax(out[j, 2:], 1))
|
||||||
|
out_best = [k for k, g in itertools.groupby(out_best)]
|
||||||
|
# 26 is space, 27 is CTC blank char
|
||||||
|
outstr = ''
|
||||||
|
for c in out_best:
|
||||||
|
if c >= 0 and c < 26:
|
||||||
|
outstr += chr(c + ord('a'))
|
||||||
|
elif c == 26:
|
||||||
|
outstr += ' '
|
||||||
|
ret.append(outstr)
|
||||||
|
return ret
|
||||||
|
|
||||||
|
class VizCallback(keras.callbacks.Callback):
|
||||||
|
|
||||||
|
def __init__(self, test_func, text_img_gen, num_display_words = 6):
|
||||||
|
self.test_func = test_func
|
||||||
|
self.output_dir = os.path.join(
|
||||||
|
OUTPUT_DIR, datetime.datetime.now().strftime('%A, %d. %B %Y %I.%M%p'))
|
||||||
|
self.text_img_gen = text_img_gen
|
||||||
|
self.num_display_words = num_display_words
|
||||||
|
os.makedirs(self.output_dir)
|
||||||
|
|
||||||
|
def show_edit_distance(self, num):
|
||||||
|
num_left = num
|
||||||
|
mean_norm_ed = 0.0
|
||||||
|
mean_ed = 0.0
|
||||||
|
while num_left > 0:
|
||||||
|
word_batch = next(self.text_img_gen)[0]
|
||||||
|
num_proc = min(word_batch['the_input'].shape[0], num_left)
|
||||||
|
decoded_res = decode_batch(self.test_func, word_batch['the_input'][0:num_proc])
|
||||||
|
for j in range(0, num_proc):
|
||||||
|
edit_dist = editdistance.eval(decoded_res[j], word_batch['source_str'][j])
|
||||||
|
mean_ed += float(edit_dist)
|
||||||
|
mean_norm_ed += float(edit_dist) / len(word_batch['source_str'][j])
|
||||||
|
num_left -= num_proc
|
||||||
|
mean_norm_ed = mean_norm_ed / num
|
||||||
|
mean_ed = mean_ed / num
|
||||||
|
print('\nOut of %d samples: Mean edit distance: %.3f Mean normalized edit distance: %0.3f'
|
||||||
|
% (num, mean_ed, mean_norm_ed))
|
||||||
|
|
||||||
|
def on_epoch_end(self, epoch, logs={}):
|
||||||
|
self.model.save_weights(os.path.join(self.output_dir, 'weights%02d.h5' % epoch))
|
||||||
|
self.show_edit_distance(256)
|
||||||
|
word_batch = next(self.text_img_gen)[0]
|
||||||
|
res = decode_batch(self.test_func, word_batch['the_input'][0:self.num_display_words])
|
||||||
|
|
||||||
|
for i in range(self.num_display_words):
|
||||||
|
pylab.subplot(self.num_display_words, 1, i + 1)
|
||||||
|
pylab.imshow(word_batch['the_input'][i, 0, :, :], cmap='Greys_r')
|
||||||
|
pylab.xlabel('Truth = \'%s\' Decoded = \'%s\'' % (word_batch['source_str'][i], res[i]))
|
||||||
|
fig = pylab.gcf()
|
||||||
|
fig.set_size_inches(10, 12)
|
||||||
|
pylab.savefig(os.path.join(self.output_dir, 'e%02d.png' % epoch))
|
||||||
|
pylab.close()
|
||||||
|
|
||||||
|
# Input Parameters
|
||||||
|
img_h = 64
|
||||||
|
img_w = 512
|
||||||
|
nb_epoch = 50
|
||||||
|
minibatch_size = 32
|
||||||
|
words_per_epoch = 16000
|
||||||
|
val_split = 0.2
|
||||||
|
val_words = int(words_per_epoch * (val_split))
|
||||||
|
|
||||||
|
# Network parameters
|
||||||
|
conv_num_filters = 16
|
||||||
|
filter_size = 3
|
||||||
|
pool_size_1 = 4
|
||||||
|
pool_size_2 = 2
|
||||||
|
time_dense_size = 32
|
||||||
|
rnn_size = 512
|
||||||
|
time_steps = img_w / (pool_size_1 * pool_size_2)
|
||||||
|
|
||||||
|
fdir = os.path.dirname(get_file('wordlists.tgz',
|
||||||
|
origin='http://www.isosemi.com/datasets/wordlists.tgz', untar=True))
|
||||||
|
|
||||||
|
img_gen = TextImageGenerator(monogram_file=os.path.join(fdir, 'wordlist_mono_clean.txt'),
|
||||||
|
bigram_file=os.path.join(fdir, 'wordlist_bi_clean.txt'),
|
||||||
|
minibatch_size=32,
|
||||||
|
img_w=img_w,
|
||||||
|
img_h=img_h,
|
||||||
|
downsample_width=img_w / (pool_size_1 * pool_size_2) - 2,
|
||||||
|
val_split=words_per_epoch - val_words)
|
||||||
|
|
||||||
|
act = 'relu'
|
||||||
|
input_data = Input(name='the_input', shape=(1, img_h, img_w), dtype='float32')
|
||||||
|
inner = Convolution2D(conv_num_filters, filter_size, filter_size, border_mode='same',
|
||||||
|
activation=act, input_shape=(1, img_h, img_w), name='conv1')(input_data)
|
||||||
|
inner = MaxPooling2D(pool_size=(pool_size_1, pool_size_1), name='max1')(inner)
|
||||||
|
inner = Convolution2D(conv_num_filters, filter_size, filter_size, border_mode='same',
|
||||||
|
activation=act, name='conv2')(inner)
|
||||||
|
inner = MaxPooling2D(pool_size=(pool_size_2, pool_size_2), name='max2')(inner)
|
||||||
|
|
||||||
|
conv_to_rnn_dims = ((img_h / (pool_size_1 * pool_size_2)) * conv_num_filters, img_w / (pool_size_1 * pool_size_2))
|
||||||
|
inner = Reshape(target_shape=conv_to_rnn_dims, name='reshape')(inner)
|
||||||
|
inner = Permute(dims=(2, 1), name='permute')(inner)
|
||||||
|
|
||||||
|
# cuts down input size going into RNN:
|
||||||
|
inner = TimeDistributed(Dense(time_dense_size, activation=act, name='dense1'))(inner)
|
||||||
|
|
||||||
|
# Two layers of bidirecitonal GRUs
|
||||||
|
# GRU seems to work as well, if not better than LSTM:
|
||||||
|
gru_1 = GRU(rnn_size, return_sequences=True, name='gru1')(inner)
|
||||||
|
gru_1b = GRU(rnn_size, return_sequences=True, go_backwards=True, name='gru1_b')(inner)
|
||||||
|
gru1_merged = merge([gru_1, gru_1b], mode='sum')
|
||||||
|
gru_2 = GRU(rnn_size, return_sequences=True, name='gru2')(gru1_merged)
|
||||||
|
gru_2b = GRU(rnn_size, return_sequences=True, go_backwards=True)(gru1_merged)
|
||||||
|
|
||||||
|
# transforms RNN output to character activations:
|
||||||
|
inner = TimeDistributed(Dense(img_gen.get_output_size(), name='dense2'))(merge([gru_2, gru_2b], mode='concat'))
|
||||||
|
y_pred = Activation('softmax', name='softmax')(inner)
|
||||||
|
Model(input=[input_data], output=y_pred).summary()
|
||||||
|
|
||||||
|
labels = Input(name='the_labels', shape=[img_gen.absolute_max_string_len], dtype='float32')
|
||||||
|
input_length = Input(name='input_length', shape=[1], dtype='int64')
|
||||||
|
label_length = Input(name='label_length', shape=[1], dtype='int64')
|
||||||
|
# Keras doesn't currently support loss funcs with extra parameters
|
||||||
|
# so CTC loss is implemented in a lambda layer
|
||||||
|
loss_out = Lambda(ctc_lambda_func, output_shape=(1,), name="ctc")([y_pred, labels, input_length, label_length])
|
||||||
|
|
||||||
|
lr = 0.03
|
||||||
|
# clipnorm seems to speeds up convergence
|
||||||
|
clipnorm = 5
|
||||||
|
sgd = SGD(lr=lr, decay=3e-7, momentum=0.9, nesterov=True, clipnorm=clipnorm)
|
||||||
|
|
||||||
|
model = Model(input=[input_data, labels, input_length, label_length], output=[loss_out])
|
||||||
|
|
||||||
|
# the loss calc occurs elsewhere, so use a dummy lambda func for the loss
|
||||||
|
model.compile(loss={'ctc': lambda y_true, y_pred: y_pred}, optimizer=sgd)
|
||||||
|
|
||||||
|
# captures output of softmax so we can decode the output during visualization
|
||||||
|
test_func = K.function([input_data], [y_pred])
|
||||||
|
|
||||||
|
viz_cb = VizCallback(test_func, img_gen.next_val())
|
||||||
|
|
||||||
|
model.fit_generator(generator=img_gen.next_train(), samples_per_epoch=(words_per_epoch - val_words),
|
||||||
|
nb_epoch=nb_epoch, validation_data=img_gen.next_val(), nb_val_samples=val_words,
|
||||||
|
callbacks=[viz_cb, img_gen])
|
@ -1586,3 +1586,112 @@ def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
|
|||||||
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
|
return tf.select(tf.random_uniform(shape, dtype=dtype, seed=seed) <= p,
|
||||||
tf.ones(shape, dtype=dtype),
|
tf.ones(shape, dtype=dtype),
|
||||||
tf.zeros(shape, dtype=dtype))
|
tf.zeros(shape, dtype=dtype))
|
||||||
|
|
||||||
|
# CTC
|
||||||
|
# tensorflow has a native implemenation, but it uses sparse tensors
|
||||||
|
# and therefore requires a wrapper for Keras. The functions below convert
|
||||||
|
# dense to sparse tensors and also wraps up the beam search code that is
|
||||||
|
# in tensorflow's CTC implementation
|
||||||
|
|
||||||
|
def ctc_label_dense_to_sparse(labels, label_lengths):
|
||||||
|
# undocumented feature soon to be made public
|
||||||
|
from tensorflow.python.ops import functional_ops
|
||||||
|
label_shape = tf.shape(labels)
|
||||||
|
num_batches_tns = tf.pack([label_shape[0]])
|
||||||
|
max_num_labels_tns = tf.pack([label_shape[1]])
|
||||||
|
|
||||||
|
def range_less_than(previous_state, current_input):
|
||||||
|
return tf.expand_dims(tf.range(label_shape[1]), 0) < current_input
|
||||||
|
|
||||||
|
init = tf.cast(tf.fill(max_num_labels_tns, 0), tf.bool)
|
||||||
|
dense_mask = functional_ops.scan(range_less_than, label_lengths,
|
||||||
|
initializer=init, parallel_iterations=1)
|
||||||
|
dense_mask = dense_mask[:, 0, :]
|
||||||
|
|
||||||
|
label_array = tf.reshape(tf.tile(tf.range(0, label_shape[1]), num_batches_tns),
|
||||||
|
label_shape)
|
||||||
|
label_ind = tf.boolean_mask(label_array, dense_mask)
|
||||||
|
|
||||||
|
batch_array = tf.transpose(tf.reshape(tf.tile(tf.range(0, label_shape[0]),
|
||||||
|
max_num_labels_tns), tf.reverse(label_shape, [True])))
|
||||||
|
batch_ind = tf.boolean_mask(batch_array, dense_mask)
|
||||||
|
indices = tf.transpose(tf.reshape(tf.concat(0, [batch_ind, label_ind]), [2,-1]))
|
||||||
|
|
||||||
|
vals_sparse = tf.gather_nd(labels, indices)
|
||||||
|
|
||||||
|
return tf.SparseTensor(tf.to_int64(indices), vals_sparse, tf.to_int64(label_shape))
|
||||||
|
|
||||||
|
|
||||||
|
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
|
||||||
|
|
||||||
|
'''Runs CTC loss algorithm on each batch element.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
y_true: tensor (samples, max_string_length) containing the truth labels
|
||||||
|
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
|
||||||
|
or output of the softmax
|
||||||
|
input_length: tensor (samples,1) containing the sequence length for
|
||||||
|
each batch item in y_pred
|
||||||
|
label_length: tensor (samples,1) containing the sequence length for
|
||||||
|
each batch item in y_true
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
Tensor with shape (samples,1) containing the
|
||||||
|
CTC loss of each element
|
||||||
|
'''
|
||||||
|
label_length = tf.to_int32(tf.squeeze(label_length))
|
||||||
|
input_length = tf.to_int32(tf.squeeze(input_length))
|
||||||
|
sparse_labels = tf.to_int32(ctc_label_dense_to_sparse(y_true, label_length))
|
||||||
|
|
||||||
|
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
|
||||||
|
|
||||||
|
return tf.expand_dims(tf.contrib.ctc.ctc_loss(inputs = y_pred,
|
||||||
|
labels = sparse_labels,
|
||||||
|
sequence_length = input_length), 1)
|
||||||
|
|
||||||
|
def ctc_decode(y_pred, input_length, greedy = True, beam_width = None,
|
||||||
|
dict_seq_lens = None, dict_values = None):
|
||||||
|
'''Decodes the output of a softmax using either
|
||||||
|
greedy (also known as best path) or a constrained dictionary
|
||||||
|
search.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
|
||||||
|
or output of the softmax
|
||||||
|
input_length: tensor (samples,1) containing the sequence length for
|
||||||
|
each batch item in y_pred
|
||||||
|
greedy: perform much faster best-path search if true. This does
|
||||||
|
not use a dictionary
|
||||||
|
beam_width: if greedy is false and this value is not none, then
|
||||||
|
the constrained dictionary search uses a beam of this width
|
||||||
|
dict_seq_lens: the length of each element in the dict_values list
|
||||||
|
dict_values: list of lists representing the dictionary.
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
Tensor with shape (samples,time_steps,num_categories) containing the
|
||||||
|
path probabilities (in softmax output format). Note that a function that
|
||||||
|
pulls out the argmax and collapses blank labels is still needed.
|
||||||
|
'''
|
||||||
|
y_pred = tf.log(tf.transpose(y_pred, perm=[1, 0, 2]) + 1e-8)
|
||||||
|
input_length = tf.to_int32(tf.squeeze(input_length))
|
||||||
|
|
||||||
|
if greedy:
|
||||||
|
(decoded, log_prob) = tf.contrib.ctc.ctc_greedy_decoder(
|
||||||
|
inputs = y_pred,
|
||||||
|
sequence_length = input_length)
|
||||||
|
else:
|
||||||
|
if beam_width is not None:
|
||||||
|
(decoded, log_prob) = tf.contrib.ctc.ctc_beam_search_decoder(
|
||||||
|
inputs = y_pred,
|
||||||
|
sequence_length = input_length,
|
||||||
|
dict_seq_lens = dict_seq_lens, dict_values = dict_values)
|
||||||
|
else:
|
||||||
|
(decoded, log_prob) = tf.contrib.ctc.ctc_beam_search_decoder(
|
||||||
|
inputs = y_pred,
|
||||||
|
sequence_length = input_length, beam_width = beam_width,
|
||||||
|
dict_seq_lens = dict_seq_lens, dict_values = dict_values)
|
||||||
|
|
||||||
|
decoded_dense = [tf.sparse_to_dense(st.indices, st.shape, st.values, default_value = -1)
|
||||||
|
for st in decoded]
|
||||||
|
|
||||||
|
return (decoded_dense, log_prob)
|
||||||
|
@ -1319,3 +1319,105 @@ def random_binomial(shape, p=0.0, dtype=_FLOATX, seed=None):
|
|||||||
seed = np.random.randint(1, 10e6)
|
seed = np.random.randint(1, 10e6)
|
||||||
rng = RandomStreams(seed=seed)
|
rng = RandomStreams(seed=seed)
|
||||||
return rng.binomial(shape, p=p, dtype=dtype)
|
return rng.binomial(shape, p=p, dtype=dtype)
|
||||||
|
|
||||||
|
# Theano implementation of CTC
|
||||||
|
# Used with permission from Shawn Tan
|
||||||
|
# https://github.com/shawntan/
|
||||||
|
# Note that tensorflow's native CTC code is significantly
|
||||||
|
# faster than this
|
||||||
|
|
||||||
|
def ctc_interleave_blanks(Y):
|
||||||
|
Y_ = T.alloc(-1, Y.shape[0] * 2 + 1)
|
||||||
|
Y_ = T.set_subtensor(Y_[T.arange(Y.shape[0]) * 2 + 1], Y)
|
||||||
|
return Y_
|
||||||
|
|
||||||
|
def ctc_create_skip_idxs(Y):
|
||||||
|
skip_idxs = T.arange((Y.shape[0] - 3) // 2) * 2 + 1
|
||||||
|
non_repeats = T.neq(Y[skip_idxs], Y[skip_idxs + 2])
|
||||||
|
return skip_idxs[non_repeats.nonzero()]
|
||||||
|
|
||||||
|
def ctc_update_log_p(skip_idxs, zeros, active, log_p_curr, log_p_prev):
|
||||||
|
active_skip_idxs = skip_idxs[(skip_idxs < active).nonzero()]
|
||||||
|
active_next = T.cast(T.minimum(
|
||||||
|
T.maximum(
|
||||||
|
active + 1,
|
||||||
|
T.max(T.concatenate([active_skip_idxs, [-1]])) + 2 + 1
|
||||||
|
), log_p_curr.shape[0]), 'int32')
|
||||||
|
|
||||||
|
common_factor = T.max(log_p_prev[:active])
|
||||||
|
p_prev = T.exp(log_p_prev[:active] - common_factor)
|
||||||
|
_p_prev = zeros[:active_next]
|
||||||
|
# copy over
|
||||||
|
_p_prev = T.set_subtensor(_p_prev[:active], p_prev)
|
||||||
|
# previous transitions
|
||||||
|
_p_prev = T.inc_subtensor(_p_prev[1:], _p_prev[:-1])
|
||||||
|
# skip transitions
|
||||||
|
_p_prev = T.inc_subtensor(_p_prev[active_skip_idxs + 2], p_prev[active_skip_idxs])
|
||||||
|
updated_log_p_prev = T.log(_p_prev) + common_factor
|
||||||
|
|
||||||
|
log_p_next = T.set_subtensor(
|
||||||
|
zeros[:active_next],
|
||||||
|
log_p_curr[:active_next] + updated_log_p_prev
|
||||||
|
)
|
||||||
|
return active_next, log_p_next
|
||||||
|
|
||||||
|
def ctc_path_probs(predict, Y, alpha=1e-4):
|
||||||
|
smoothed_predict = (1 - alpha) * predict[:, Y] + alpha * np.float32(1.) / Y.shape[0]
|
||||||
|
L = T.log(smoothed_predict)
|
||||||
|
zeros = T.zeros_like(L[0])
|
||||||
|
base = T.set_subtensor(zeros[:1], np.float32(1))
|
||||||
|
log_first = zeros
|
||||||
|
|
||||||
|
f_skip_idxs = ctc_create_skip_idxs(Y)
|
||||||
|
b_skip_idxs = ctc_create_skip_idxs(Y[::-1]) # there should be a shortcut to calculating this
|
||||||
|
|
||||||
|
def step(log_f_curr, log_b_curr, f_active, log_f_prev, b_active, log_b_prev):
|
||||||
|
f_active_next, log_f_next = ctc_update_log_p(f_skip_idxs, zeros, f_active, log_f_curr, log_f_prev)
|
||||||
|
b_active_next, log_b_next = ctc_update_log_p(b_skip_idxs, zeros, b_active, log_b_curr, log_b_prev)
|
||||||
|
return f_active_next, log_f_next, b_active_next, log_b_next
|
||||||
|
|
||||||
|
[f_active, log_f_probs, b_active, log_b_probs], _ = theano.scan(
|
||||||
|
step, sequences=[L, L[::-1, ::-1]], outputs_info=[np.int32(1), log_first, np.int32(1), log_first])
|
||||||
|
|
||||||
|
idxs = T.arange(L.shape[1]).dimshuffle('x', 0)
|
||||||
|
mask = (idxs < f_active.dimshuffle(0, 'x')) & (idxs < b_active.dimshuffle(0, 'x'))[::-1, ::-1]
|
||||||
|
log_probs = log_f_probs + log_b_probs[::-1, ::-1] - L
|
||||||
|
return log_probs, mask
|
||||||
|
|
||||||
|
def ctc_cost(predict, Y):
|
||||||
|
log_probs, mask = ctc_path_probs(predict, ctc_interleave_blanks(Y))
|
||||||
|
common_factor = T.max(log_probs)
|
||||||
|
total_log_prob = T.log(T.sum(T.exp(log_probs - common_factor)[mask.nonzero()])) + common_factor
|
||||||
|
return -total_log_prob
|
||||||
|
|
||||||
|
# batchifies original CTC code
|
||||||
|
def ctc_batch_cost(y_true, y_pred, input_length, label_length):
|
||||||
|
'''Runs CTC loss algorithm on each batch element.
|
||||||
|
|
||||||
|
# Arguments
|
||||||
|
y_true: tensor (samples, max_string_length) containing the truth labels
|
||||||
|
y_pred: tensor (samples, time_steps, num_categories) containing the prediction,
|
||||||
|
or output of the softmax
|
||||||
|
input_length: tensor (samples,1) containing the sequence length for
|
||||||
|
each batch item in y_pred
|
||||||
|
label_length: tensor (samples,1) containing the sequence length for
|
||||||
|
each batch item in y_true
|
||||||
|
|
||||||
|
# Returns
|
||||||
|
Tensor with shape (samples,1) containing the
|
||||||
|
CTC loss of each element
|
||||||
|
'''
|
||||||
|
|
||||||
|
def ctc_step(y_true_step, y_pred_step, input_length_step, label_length_step):
|
||||||
|
y_pred_step = y_pred_step[0: input_length_step[0]]
|
||||||
|
y_true_step = y_true_step[0:label_length_step[0]]
|
||||||
|
return ctc_cost(y_pred_step, y_true_step)
|
||||||
|
|
||||||
|
ret, _ = theano.scan(
|
||||||
|
fn = ctc_step,
|
||||||
|
outputs_info=None,
|
||||||
|
sequences=[y_true, y_pred, input_length, label_length]
|
||||||
|
)
|
||||||
|
|
||||||
|
ret = ret.dimshuffle('x', 0)
|
||||||
|
return ret
|
||||||
|
@ -581,6 +581,48 @@ class TestBackend(object):
|
|||||||
assert(np.max(rand) == 1)
|
assert(np.max(rand) == 1)
|
||||||
assert(np.min(rand) == 0)
|
assert(np.min(rand) == 0)
|
||||||
|
|
||||||
|
def test_ctc(self):
|
||||||
|
# simplified version of TensorFlow's test
|
||||||
|
|
||||||
|
label_lens = np.expand_dims(np.asarray([5, 4]), 1)
|
||||||
|
input_lens = np.expand_dims(np.asarray([5, 5]), 1) # number of timesteps
|
||||||
|
|
||||||
|
# the Theano and Tensorflow CTC code use different methods to ensure
|
||||||
|
# numerical stability. The Theano code subtracts out the max
|
||||||
|
# before the final log, so the results are different but scale
|
||||||
|
# identically and still train properly
|
||||||
|
loss_log_probs_tf = [3.34211, 5.42262]
|
||||||
|
loss_log_probs_th = [1.73308, 3.81351]
|
||||||
|
|
||||||
|
# dimensions are batch x time x categories
|
||||||
|
labels = np.asarray([[0, 1, 2, 1, 0], [0, 1, 1, 0, -1]])
|
||||||
|
inputs = np.asarray(
|
||||||
|
[[[0.633766, 0.221185, 0.0917319, 0.0129757, 0.0142857, 0.0260553],
|
||||||
|
[0.111121, 0.588392, 0.278779, 0.0055756, 0.00569609, 0.010436],
|
||||||
|
[0.0357786, 0.633813, 0.321418, 0.00249248, 0.00272882, 0.0037688],
|
||||||
|
[0.0663296, 0.643849, 0.280111, 0.00283995, 0.0035545, 0.00331533],
|
||||||
|
[0.458235, 0.396634, 0.123377, 0.00648837, 0.00903441, 0.00623107]],
|
||||||
|
[[0.30176, 0.28562, 0.0831517, 0.0862751, 0.0816851, 0.161508],
|
||||||
|
[0.24082, 0.397533, 0.0557226, 0.0546814, 0.0557528, 0.19549],
|
||||||
|
[0.230246, 0.450868, 0.0389607, 0.038309, 0.0391602, 0.202456],
|
||||||
|
[0.280884, 0.429522, 0.0326593, 0.0339046, 0.0326856, 0.190345],
|
||||||
|
[0.423286, 0.315517, 0.0338439, 0.0393744, 0.0339315, 0.154046]]],
|
||||||
|
dtype=np.float32)
|
||||||
|
|
||||||
|
labels_tf = KTF.variable(labels, dtype="int32")
|
||||||
|
inputs_tf = KTF.variable(inputs, dtype="float32")
|
||||||
|
input_lens_tf = KTF.variable(input_lens, dtype="int32")
|
||||||
|
label_lens_tf = KTF.variable(label_lens, dtype="int32")
|
||||||
|
res = KTF.eval(KTF.ctc_batch_cost(labels_tf, inputs_tf, input_lens_tf, label_lens_tf))
|
||||||
|
assert_allclose(res[:, 0], loss_log_probs_tf, atol=1e-05)
|
||||||
|
|
||||||
|
labels_th = KTH.variable(labels, dtype="int32")
|
||||||
|
inputs_th = KTH.variable(inputs, dtype="float32")
|
||||||
|
input_lens_th = KTH.variable(input_lens, dtype="int32")
|
||||||
|
label_lens_th = KTH.variable(label_lens, dtype="int32")
|
||||||
|
res = KTH.eval(KTH.ctc_batch_cost(labels_th, inputs_th, input_lens_th, label_lens_th))
|
||||||
|
assert_allclose(res[0, :], loss_log_probs_th, atol=1e-05)
|
||||||
|
|
||||||
def test_one_hot(self):
|
def test_one_hot(self):
|
||||||
input_length = 10
|
input_length = 10
|
||||||
nb_classes = 20
|
nb_classes = 20
|
||||||
@ -591,6 +633,5 @@ class TestBackend(object):
|
|||||||
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
|
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
|
||||||
assert np.all(koh == oh)
|
assert np.all(koh == oh)
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
pytest.main([__file__])
|
pytest.main([__file__])
|
||||||
|
Loading…
Reference in New Issue
Block a user