Initial Sparse Matrix Support (#3695)

* Minimal SparseTensor support for TensorFlow

* Basic Theano support for Sparse dot product

* Sparse Input for Both + Sparse Concat for TF

* Fixed issue with _keras_shape for sparse Inputs

* pep8

* Cleanup + Theano concat (untested)

* Bug fix & pep8

* Fix Theano concat

* Bugfix & simplification

* Next step: Unit tests

* Basic unit test for sparse dot; TF works, TH fails

* Fix KTH is_sparse

* pep8

* more tests, sparse KTH.eval, pep8

* sparse model test

* address code review comments

* make sparse boolean in K.placeholder

* skip sparse tests when TH.sparse import fails

* pep8

* pep8

* fixed flakey test, auto-dense in KTH.eval

* fixed some more len/shape issues for fit_generator

* fixed some more len/shape issues for prediction

* Added better exceptions when theano.sparse fails to import

* betterer

* pep8
This commit is contained in:
kuza55 2016-09-09 19:26:37 -04:00 committed by François Chollet
parent 6675776640
commit 79edae58d5
6 changed files with 209 additions and 26 deletions

@ -9,6 +9,7 @@ import os
import copy
import warnings
from .common import _FLOATX, _EPSILON, _IMAGE_DIM_ORDERING, reset_uids
py_all = all
# INTERNAL UTILS
@ -117,6 +118,17 @@ def _to_tensor(x, dtype):
return x
def is_sparse(tensor):
return isinstance(tensor, tf.SparseTensor)
def to_dense(tensor):
if is_sparse(tensor):
return tf.sparse_tensor_to_dense(tensor)
else:
return tensor
def variable(value, dtype=_FLOATX, name=None):
'''Instantiates a tensor.
@ -128,6 +140,12 @@ def variable(value, dtype=_FLOATX, name=None):
# Returns
Tensor variable instance.
'''
if hasattr(value, 'tocoo'):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(sparse_coo.col, 1)), 1)
# SparseTensor doesn't need initialization
return tf.SparseTensor(indices=indices, values=value.data, shape=value.shape)
v = tf.Variable(value, dtype=_convert_string_dtype(dtype), name=name)
if _MANUAL_VAR_INIT:
return v
@ -148,7 +166,7 @@ def variable(value, dtype=_FLOATX, name=None):
return v
def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None):
'''Instantiates a placeholder.
# Arguments
@ -166,7 +184,11 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
if not shape:
if ndim:
shape = tuple([None for _ in range(ndim)])
x = tf.placeholder(dtype, shape=shape, name=name)
if sparse:
tf_shape = tf.constant(np.array(list([0 for _ in range(len(shape))]), dtype=np.int64))
x = tf.sparse_placeholder(dtype, shape=tf_shape, name=name)
else:
x = tf.placeholder(dtype, shape=shape, name=name)
x._keras_shape = shape
x._uses_learning_phase = False
return x
@ -190,6 +212,9 @@ def int_shape(x):
def ndim(x):
'''Returns the number of axes in a tensor, as an integer.
'''
if is_sparse(x):
return int(x.shape.get_shape()[0])
dims = x.get_shape()._dims
if dims is not None:
return len(dims)
@ -206,7 +231,7 @@ def eval(x):
'''Evaluates the value of a tensor.
Returns a Numpy array.
'''
return x.eval(session=get_session())
return to_dense(x).eval(session=get_session())
def zeros(shape, dtype=_FLOATX, name=None):
@ -318,7 +343,10 @@ def dot(x, y):
xt = tf.reshape(x, [-1, x_shape[-1]])
yt = tf.reshape(tf.transpose(y, perm=y_permute_dim), [y_shape[-2], -1])
return tf.reshape(tf.matmul(xt, yt), x_shape[:-1] + y_shape[:-2] + y_shape[-1:])
out = tf.matmul(x, y)
if is_sparse(x):
out = tf.sparse_tensor_dense_matmul(x, y)
else:
out = tf.matmul(x, y)
return out
@ -676,11 +704,16 @@ def concatenate(tensors, axis=-1):
'''Concantes a list of tensors alongside the specified axis.
'''
if axis < 0:
if len(tensors[0].get_shape()):
axis = axis % len(tensors[0].get_shape())
dims = ndim(tensors[0])
if dims:
axis = axis % dims
else:
axis = 0
return tf.concat(axis, tensors)
if py_all([is_sparse(x) for x in tensors]):
return tf.sparse_concat(axis, tensors)
else:
return tf.concat(axis, [to_dense(x) for x in tensors])
def reshape(x, shape):
@ -969,8 +1002,13 @@ class Function(object):
def __call__(self, inputs):
assert type(inputs) in {list, tuple}
names = [getattr(v, 'name', None) for v in self.inputs]
feed_dict = dict(zip(names, inputs))
feed_dict = {}
for tensor, value in zip(self.inputs, inputs):
if is_sparse(tensor):
sparse_coo = value.tocoo()
indices = np.concatenate((np.expand_dims(sparse_coo.row, 1), np.expand_dims(sparse_coo.col, 1)), 1)
value = (indices, value.data, value.shape)
feed_dict[tensor] = value
session = get_session()
updated = session.run(self.outputs + [self.updates_op], feed_dict=feed_dict)
return updated[:len(self.outputs)]

@ -4,6 +4,10 @@ from theano.sandbox.rng_mrg import MRG_RandomStreams as RandomStreams
from theano.tensor.signal import pool
from theano.tensor.nnet import conv3d2d
from theano.printing import Print
try:
import theano.sparse as th_sparse_module
except ImportError:
th_sparse_module = None
try:
from theano.tensor.nnet.nnet import softsign as T_softsign
except ImportError:
@ -11,6 +15,7 @@ except ImportError:
import inspect
import numpy as np
from .common import _FLOATX, _EPSILON, _IMAGE_DIM_ORDERING
py_all = all
# INTERNAL UTILS
@ -30,17 +35,38 @@ def set_learning_phase(value):
'0 or 1.')
_LEARNING_PHASE = value
# VARIABLE MANIPULATION
def _assert_sparse_module():
if not th_sparse_module:
raise ImportError("Failed to import theano.sparse\n"
"You probably need to pip install nose-parameterized")
def is_sparse(tensor):
return th_sparse_module and isinstance(tensor.type, th_sparse_module.SparseType)
def to_dense(tensor):
if is_sparse(tensor):
return th_sparse_module.dense_from_sparse(tensor)
else:
return tensor
def variable(value, dtype=_FLOATX, name=None):
'''Instantiate a tensor variable.
'''
value = np.asarray(value, dtype=dtype)
return theano.shared(value=value, name=name, strict=False)
if hasattr(value, 'tocoo'):
_assert_sparse_module()
return th_sparse_module.as_sparse_variable(value)
else:
value = np.asarray(value, dtype=dtype)
return theano.shared(value=value, name=name, strict=False)
def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
def placeholder(shape=None, ndim=None, dtype=_FLOATX, sparse=False, name=None):
'''Instantiate an input data placeholder variable.
'''
if shape is None and ndim is None:
@ -51,7 +77,11 @@ def placeholder(shape=None, ndim=None, dtype=_FLOATX, name=None):
shape = tuple([None for _ in range(ndim)])
broadcast = (False,) * ndim
x = T.TensorType(dtype, broadcast)(name)
if sparse:
_assert_sparse_module()
x = th_sparse_module.csr_matrix(name=name, dtype=dtype)
else:
x = T.TensorType(dtype, broadcast)(name)
x._keras_shape = shape
x._uses_learning_phase = False
return x
@ -77,7 +107,7 @@ def dtype(x):
def eval(x):
'''Run a graph.
'''
return x.eval()
return to_dense(x).eval()
def zeros(shape, dtype=_FLOATX, name=None):
@ -156,7 +186,10 @@ Assumed overridden:
def dot(x, y):
return T.dot(x, y)
if is_sparse(x):
return th_sparse_module.basic.structured_dot(x, y)
else:
return T.dot(x, y)
def batch_dot(x, y, axes=None):
@ -402,7 +435,16 @@ def batch_normalization(x, mean, var, beta, gamma, epsilon=0.0001):
# SHAPE OPERATIONS
def concatenate(tensors, axis=-1):
return T.concatenate(tensors, axis=axis)
if py_all([is_sparse(x) for x in tensors]):
axis = axis % ndim(tensors[0])
if axis == 0:
return th_sparse_module.basic.vstack(tensors, format='csr')
elif axis == 1:
return th_sparse_module.basic.hstack(tensors, format='csr')
else:
raise Exception('Invalid concat axis for sparse matrix: ' + axis)
else:
return T.concatenate([to_dense(x) for x in tensors], axis=axis)
def reshape(x, shape):

@ -947,7 +947,7 @@ class InputLayer(Layer):
'''TODO: dosctring
'''
def __init__(self, input_shape=None, batch_input_shape=None,
input_dtype=None, input_tensor=None, name=None):
input_dtype=None, input_tensor=None, sparse=False, name=None):
self.input_spec = None
self.supports_masking = False
self.uses_learning_phase = False
@ -964,6 +964,8 @@ class InputLayer(Layer):
self.regularizers = []
self.constraints = {}
self.sparse = sparse
if not name:
prefix = 'input'
name = prefix + '_' + str(K.get_uid(prefix))
@ -1004,6 +1006,7 @@ class InputLayer(Layer):
if input_tensor is None:
input_tensor = K.placeholder(shape=batch_input_shape,
dtype=input_dtype,
sparse=self.sparse,
name=self.name)
else:
input_tensor._keras_shape = batch_input_shape
@ -1025,12 +1028,13 @@ class InputLayer(Layer):
def get_config(self):
config = {'batch_input_shape': self.batch_input_shape,
'input_dtype': self.input_dtype,
'sparse': self.sparse,
'name': self.name}
return config
def Input(shape=None, batch_shape=None,
name=None, dtype=K.floatx(),
name=None, dtype=K.floatx(), sparse=False,
tensor=None):
'''`Input()` is used to instantiate a Keras tensor.
A Keras tensor is a tensor object from the underlying backend
@ -1063,6 +1067,7 @@ def Input(shape=None, batch_shape=None,
It will be autogenerated if it isn't provided.
dtype: The data type expected by the input, as a string
(`float32`, `float64`, `int32`...)
sparse: a boolean specifying whether this will be a sparse tensor
# Example usage
@ -1082,6 +1087,7 @@ def Input(shape=None, batch_shape=None,
batch_shape = (None,) + tuple(shape)
input_layer = InputLayer(batch_input_shape=batch_shape,
name=name, input_dtype=dtype,
sparse=sparse,
input_tensor=tensor)
# return tensor including _keras_shape and _keras_history
# note that in this case train_output and test_output are the same pointer.

@ -763,9 +763,9 @@ class Model(Container):
do_validation = True
if verbose:
print('Train on %d samples, validate on %d samples' %
(len(ins[0]), len(val_ins[0])))
(ins[0].shape[0], val_ins[0].shape[0]))
nb_train_sample = len(ins[0])
nb_train_sample = ins[0].shape[0]
index_array = np.arange(nb_train_sample)
self.history = cbks.History()
@ -859,7 +859,7 @@ class Model(Container):
or list of arrays of predictions
(if the model has multiple outputs).
'''
nb_sample = len(ins[0])
nb_sample = ins[0].shape[0]
outs = []
if verbose == 1:
progbar = Progbar(target=nb_sample)
@ -904,7 +904,7 @@ class Model(Container):
and/or metrics). The attribute `model.metrics_names` will give you
the display labels for the scalar outputs.
'''
nb_sample = len(ins[0])
nb_sample = ins[0].shape[0]
outs = []
if verbose == 1:
progbar = Progbar(target=nb_sample)
@ -1426,11 +1426,11 @@ class Model(Container):
# build batch logs
batch_logs = {}
if type(x) is list:
batch_size = len(x[0])
batch_size = x[0].shape[0]
elif type(x) is dict:
batch_size = len(list(x.values())[0])
batch_size = list(x.values())[0].shape[0]
else:
batch_size = len(x)
batch_size = x.shape[0]
batch_logs['batch'] = batch_index
batch_logs['size'] = batch_size
callbacks.on_batch_begin(batch_index, batch_logs)

@ -2,6 +2,7 @@ import sys
import pytest
from numpy.testing import assert_allclose
import numpy as np
import scipy.sparse as sparse
from keras.backend import theano_backend as KTH
from keras.backend import tensorflow_backend as KTF
@ -780,6 +781,61 @@ class TestBackend(object):
koh = K.eval(K.one_hot(K.variable(indices, dtype='int32'), nb_classes))
assert np.all(koh == oh)
def test_sparse_dot(self):
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)
x_sparse = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))
x_dense = x_sparse.toarray()
W = np.random.random((5, 4))
backends = [KTF]
if KTH.th_sparse_module:
# Theano has some dependency issues for sparse
backends.append(KTH)
for K in backends:
t_W = K.variable(W)
k_s = K.eval(K.dot(K.variable(x_sparse), t_W))
k_d = K.eval(K.dot(K.variable(x_dense), t_W))
assert k_s.shape == k_d.shape
assert_allclose(k_s, k_d, atol=1e-05)
def test_sparse_concat(self):
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)
x_sparse_1 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))
x_d = np.array([0, 7, 2, 3], dtype=np.float32)
x_r = np.array([0, 2, 2, 3], dtype=np.int64)
x_c = np.array([4, 3, 2, 3], dtype=np.int64)
x_sparse_2 = sparse.csr_matrix((x_d, (x_r, x_c)), shape=(4, 5))
x_dense_1 = x_sparse_1.toarray()
x_dense_2 = x_sparse_2.toarray()
backends = [KTF]
if KTH.th_sparse_module:
# Theano has some dependency issues for sparse
backends.append(KTH)
for K in backends:
k_s = K.concatenate([K.variable(x_sparse_1), K.variable(x_sparse_2)])
assert K.is_sparse(k_s)
k_s_d = K.eval(k_s)
k_d = K.eval(K.concatenate([K.variable(x_dense_1), K.variable(x_dense_2)]))
assert k_s_d.shape == k_d.shape
assert_allclose(k_s_d, k_d, atol=1e-05)
if __name__ == '__main__':
pytest.main([__file__])

@ -0,0 +1,41 @@
from __future__ import absolute_import
from __future__ import print_function
import pytest
from keras.models import Model
from keras.layers import Dense, Input
from keras.utils.test_utils import keras_test
from keras import backend as K
from keras.backend import theano_backend as KTH
from keras.backend import tensorflow_backend as KTF
import scipy.sparse as sparse
import numpy as np
np.random.seed(1337)
input_dim = 16
nb_hidden = 8
nb_class = 4
batch_size = 32
nb_epoch = 1
def do_sparse():
return K == KTF or KTH.th_sparse_module
@keras_test
def test_sparse_mlp():
if not do_sparse():
return
input = Input(batch_shape=(None, input_dim), sparse=True)
hidden = Dense(nb_hidden, activation='relu')(input)
hidden = Dense(nb_hidden, activation='relu')(hidden)
predictions = Dense(nb_class, activation='sigmoid')(hidden)
model = Model(input=[input], output=predictions)
model.compile(loss='mse', optimizer='sgd')
x = sparse.rand(batch_size, input_dim, density=0.1, format='csr')
y = np.random.random((batch_size, nb_class))
model.fit(x, y, nb_epoch=1)