Improve error messages in data validation checks.

This commit is contained in:
Francois Chollet 2017-04-11 13:42:18 -07:00
parent 17ef113ed7
commit b2f0dd4cb2

@ -50,6 +50,8 @@ def _standardize_input_data(data, names, shapes=None,
# Raises
ValueError: in case of improperly formatted user-provided data.
"""
if not names:
return []
if data is None:
return [None for _ in range(len(names))]
if isinstance(data, dict):
@ -63,7 +65,8 @@ def _standardize_input_data(data, names, shapes=None,
elif isinstance(data, list):
if len(data) != len(names):
if data and hasattr(data[0], 'shape'):
raise ValueError('Error when checking ' + exception_prefix +
raise ValueError('Error when checking model ' +
exception_prefix +
': the list of Numpy arrays '
'that you are passing to your model '
'is not the size the model expected. '
@ -77,7 +80,8 @@ def _standardize_input_data(data, names, shapes=None,
data = [np.asarray(data)]
else:
raise ValueError(
'Error when checking ' + exception_prefix +
'Error when checking model ' +
exception_prefix +
': you are passing a list as '
'input to your model, '
'but the model expects '
@ -88,15 +92,17 @@ def _standardize_input_data(data, names, shapes=None,
arrays = data
else:
if not hasattr(data, 'shape'):
raise TypeError('Error when checking ' + exception_prefix +
raise TypeError('Error when checking model ' +
exception_prefix +
': data should be a Numpy array, '
'or list/dict of Numpy arrays. '
'Found: ' + str(data)[:200] + '...')
if len(names) != 1:
if len(names) > 1:
# Case: model expects multiple inputs but only received
# a single Numpy array.
raise ValueError('The model expects ' + str(len(names)) +
' input arrays, but only received one array. '
exception_prefix +
' arrays, but only received one array. '
'Found: array with shape ' + str(data.shape))
arrays = [data]
@ -1291,11 +1297,11 @@ class Model(Container):
x = _standardize_input_data(x, self._feed_input_names,
self._feed_input_shapes,
check_batch_axis=False,
exception_prefix='model input')
exception_prefix='input')
y = _standardize_input_data(y, self._feed_output_names,
output_shapes,
check_batch_axis=False,
exception_prefix='model target')
exception_prefix='target')
sample_weights = _standardize_sample_weights(sample_weight,
self._feed_output_names)
class_weights = _standardize_class_weights(class_weight,