diff --git a/.travis.yml b/.travis.yml index c68c4baa8..1651b6b91 100644 --- a/.travis.yml +++ b/.travis.yml @@ -7,6 +7,8 @@ matrix: env: KERAS_BACKEND=tensorflow TEST_MODE=PEP8 - python: 2.7 env: KERAS_BACKEND=tensorflow TEST_MODE=INTEGRATION_TESTS + - python: 3.5 + env: KERAS_BACKEND=tensorflow TEST_MODE=DOC - python: 2.7 env: KERAS_BACKEND=tensorflow - python: 3.5 @@ -61,6 +63,8 @@ script: PYTHONPATH=$PWD:$PYTHONPATH py.test tests/integration_tests; elif [[ "$TEST_MODE" == "PEP8" ]]; then PYTHONPATH=$PWD:$PYTHONPATH py.test --pep8 -m pep8 -n0; + elif [[ "$TEST_MODE" == "DOC" ]]; then + PYTHONPATH=$PWD:$PYTHONPATH py.test tests/test_documentation.py; else - PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --ignore=tests/integration_tests --cov=keras tests/ --cov-fail-under 78 --cov-report term-missing; + PYTHONPATH=$PWD:$PYTHONPATH py.test tests/ --ignore=tests/integration_tests --ignore=tests/test_documentation.py --cov=keras tests/ --cov-fail-under 78 --cov-report term-missing; fi diff --git a/keras/backend/common.py b/keras/backend/common.py index 4cbcd96d1..7fc4bce23 100644 --- a/keras/backend/common.py +++ b/keras/backend/common.py @@ -44,7 +44,7 @@ def set_epsilon(e): def floatx(): - """Returns the default float type, as a string + """Returns the default float type, as a string. (e.g. 'float16', 'float32', 'float64'). # Returns @@ -109,8 +109,7 @@ def cast_to_floatx(x): def image_data_format(): - """Returns the default image data format - convention ('channels_first' or 'channels_last'). + """Returns the default image data format convention ('channels_first' or 'channels_last'). # Returns A string, either `'channels_first'` or `'channels_last'` @@ -181,7 +180,7 @@ def set_image_dim_ordering(dim_ordering): """Legacy setter for `image_data_format`. # Arguments - dim_ordering: string. `'tf'` or `'th'`. + dim_ordering: string. `tf` or `th`. # Example ```python @@ -192,6 +191,9 @@ def set_image_dim_ordering(dim_ordering): >>> K.image_data_format() 'channels_last' ``` + + # Raises + ValueError if invalid `dim_ordering` """ global _IMAGE_DATA_FORMAT if dim_ordering not in {'tf', 'th'}: @@ -205,6 +207,9 @@ def set_image_dim_ordering(dim_ordering): def image_dim_ordering(): """Legacy getter for `image_data_format`. + + # Returns + string, one of `'th'`, `'tf'` """ if _IMAGE_DATA_FORMAT == 'channels_first': return 'th' diff --git a/keras/backend/tensorflow_backend.py b/keras/backend/tensorflow_backend.py index d1f187d81..8e45b2245 100644 --- a/keras/backend/tensorflow_backend.py +++ b/keras/backend/tensorflow_backend.py @@ -43,6 +43,14 @@ _MANUAL_VAR_INIT = False def get_uid(prefix=''): + """Get the uid for the default graph. + + # Arguments + prefix: An optional prefix of the graph. + + # Returns + A unique identifier for the graph. + """ global _GRAPH_UID_DICTS graph = tf.get_default_graph() if graph not in _GRAPH_UID_DICTS: @@ -52,6 +60,7 @@ def get_uid(prefix=''): def reset_uids(): + """Reset graph identifiers.""" global _GRAPH_UID_DICTS _GRAPH_UID_DICTS = {} @@ -169,6 +178,17 @@ def set_session(session): # VARIABLE MANIPULATION def _convert_string_dtype(dtype): + """Get the type from a string. + + # Arguments + dtype: A string representation of a type. + + # Returns: + The type requested. + + # Raises + ValueError if `dtype` is not supported + """ if dtype == 'float16': return tf.float16 if dtype == 'float32': @@ -190,6 +210,15 @@ def _convert_string_dtype(dtype): def _to_tensor(x, dtype): + """Convert the input `x` to a tensor of type `dtype`. + + # Arguments + x: An object to be converted (numpy array, list, tensors). + dtype: The destination type. + + # Returns + A tensor. + """ x = tf.convert_to_tensor(x) if x.dtype != dtype: x = tf.cast(x, dtype) @@ -309,6 +338,17 @@ def _initialize_variables(): def constant(value, dtype=None, shape=None, name=None): + """Creates a constant tensor. + + # Arguments + value: A constant value (or list) + dtype: The type of the elements of the resulting tensor. + shape: Optional dimensions of resulting tensor. + name: Optional name for the tensor. + + # Returns + A Constant Tensor. + """ if dtype is None: dtype = floatx() return tf.constant(value, dtype=dtype, shape=shape, name=name) @@ -773,18 +813,54 @@ def cast(x, dtype): def update(x, new_x): + """Update the value of `x` to `new_x`. + + # Arguments + x: A Variable. + new_x: A tensor of same shape as `x`. + + # Returns + The variable `x` updated. + """ return tf.assign(x, new_x) def update_add(x, increment): + """Update the value of `x` by adding `increment`. + + # Arguments + x: A Variable. + increment: A tensor of same shape as `x`. + + # Returns + The variable `x` updated. + """ return tf.assign_add(x, increment) def update_sub(x, decrement): + """Update the value of `x` by subtracting `decrement`. + + # Arguments + x: A Variable. + decrement: A tensor of same shape as `x`. + + # Returns + The variable `x` updated. + """ return tf.assign_sub(x, decrement) def moving_average_update(x, value, momentum): + """Compute the moving average of a variable. + + # Arguments + x: A Variable. + value: A tensor with the same shape as `variable`. + momentum: The moving average momentum. + + # Returns + An Operation to update the variable.""" return moving_averages.assign_moving_average( x, value, momentum, zero_debias=False) @@ -2795,6 +2871,16 @@ def in_top_k(predictions, targets, k): # CONVOLUTIONS def _preprocess_deconv_output_shape(x, shape, data_format): + """Get the output_shape for the deconvolution. + + # Arguments + x: input tensor. + shape: output shape. + data_format: string, one of 'channels_last', 'channels_first'. + + # Returns + The output shape. + """ if data_format == 'channels_first': shape = (shape[0], shape[2], shape[3], shape[1]) @@ -2805,6 +2891,15 @@ def _preprocess_deconv_output_shape(x, shape, data_format): def _preprocess_conv2d_input(x, data_format): + """Transpose and cast the input before the conv2d. + + # Arguments + x: input tensor. + data_format: string, one of 'channels_last', 'channels_first'. + + # Returns + A tensor. + """ if dtype(x) == 'float64': x = tf.cast(x, 'float32') if data_format == 'channels_first': @@ -2817,6 +2912,15 @@ def _preprocess_conv2d_input(x, data_format): def _preprocess_conv3d_input(x, data_format): + """Transpose and cast the input before the conv3d. + + # Arguments + x: input tensor. + data_format: string, one of 'channels_last', 'channels_first'. + + # Returns + A tensor. + """ if dtype(x) == 'float64': x = tf.cast(x, 'float32') if data_format == 'channels_first': @@ -2825,6 +2929,15 @@ def _preprocess_conv3d_input(x, data_format): def _preprocess_conv2d_kernel(kernel, data_format): + """Transpose and cast the kernel before the conv2d. + + # Arguments + kernel: kernel tensor. + data_format: string, one of 'channels_last', 'channels_first'. + + # Returns + A tensor. + """ if dtype(kernel) == 'float64': kernel = tf.cast(kernel, 'float32') if data_format == 'channels_first': @@ -2833,6 +2946,15 @@ def _preprocess_conv2d_kernel(kernel, data_format): def _preprocess_conv3d_kernel(kernel, data_format): + """Transpose and cast the kernel before the conv3d. + + # Arguments + kernel: kernel tensor. + data_format: string, one of 'channels_last', 'channels_first'. + + # Returns + A tensor. + """ if dtype(kernel) == 'float64': kernel = tf.cast(kernel, 'float32') if data_format == 'channels_first': @@ -2841,16 +2963,37 @@ def _preprocess_conv3d_kernel(kernel, data_format): def _preprocess_padding(padding): + """Convert keras' padding to tensorflow's padding. + + # Arguments + padding: string, one of 'same' , 'valid' + + # Returns + a string, one of 'SAME', 'VALID'. + + # Raises + ValueError if invalid `padding'` + """ if padding == 'same': padding = 'SAME' elif padding == 'valid': padding = 'VALID' else: - raise ValueError('Invalid border mode:', padding) + raise ValueError('Invalid padding:', padding) return padding def _postprocess_conv2d_output(x, data_format): + """Transpose and cast the output from conv2d if needed. + + # Arguments + x: A tensor. + data_format: string, one of "channels_last", "channels_first". + + # Returns + A tensor. + """ + if data_format == 'channels_first': x = tf.transpose(x, (0, 3, 1, 2)) @@ -2860,6 +3003,15 @@ def _postprocess_conv2d_output(x, data_format): def _postprocess_conv3d_output(x, data_format): + """Transpose and cast the output from conv3d if needed. + + # Arguments + x: A tensor. + data_format: string, one of "channels_last", "channels_first". + + # Returns + A tensor. + """ if data_format == 'channels_first': x = tf.transpose(x, (0, 4, 1, 2, 3)) diff --git a/keras/layers/__init__.py b/keras/layers/__init__.py index 010a95341..0ce64c1c7 100644 --- a/keras/layers/__init__.py +++ b/keras/layers/__init__.py @@ -21,6 +21,14 @@ from ..legacy.layers import * def serialize(layer): + """Serialize a layer. + + # Arguments + layer: a Layer object. + + # Returns + dictionary with config. + """ return {'class_name': layer.__class__.__name__, 'config': layer.get_config()} diff --git a/keras/models.py b/keras/models.py index 0fcf593d0..ea723262b 100644 --- a/keras/models.py +++ b/keras/models.py @@ -293,6 +293,9 @@ def model_from_config(config, custom_objects=None): # Returns A Keras model instance (uncompiled). + + # Raises + TypeError if `config` is not a dictionary """ if isinstance(config, list): raise TypeError('`model_from_config` expects a dictionary, not a list. ' @@ -1227,6 +1230,15 @@ class Sequential(Model): @classmethod def legacy_from_config(cls, config, layer_cache=None): + """Load a model from a legacy configuration. + + # Arguments + config: dictionary with configuration. + layer_cache: cache to draw pre-existing layer. + + # Returns + The loaded Model. + """ if not layer_cache: layer_cache = {} diff --git a/tests/test_documentation.py b/tests/test_documentation.py new file mode 100644 index 000000000..d5894a568 --- /dev/null +++ b/tests/test_documentation.py @@ -0,0 +1,151 @@ +import importlib +import inspect +import re +import sys +from itertools import compress + +import pytest + +modules = ['keras.layers', 'keras.models', 'keras', 'keras.backend.tensorflow_backend'] +accepted_name = ['from_config'] +accepted_module = ['keras.legacy.layers', 'keras.utils.generic_utils'] + +# Functions or classes with less than 'MIN_CODE_SIZE' lines can be ignored +MIN_CODE_SIZE = 10 + + +def handle_class(name, member): + if is_accepted(name, member): + return + + if member.__doc__ is None and not member_too_small(member): + raise ValueError("{} class doesn't have any documentation".format(name), + member.__module__, inspect.getmodule(member).__file__) + for n, met in inspect.getmembers(member): + if inspect.ismethod(met): + handle_method(n, met) + + +def handle_function(name, member): + if is_accepted(name, member): + return + doc = member.__doc__ + if doc is None and not member_too_small(member): + raise ValueError("{} function doesn't have any documentation".format(name), + member.__module__, inspect.getmodule(member).__file__) + args = list(inspect.signature(member).parameters.keys()) + assert_args_presence(args, doc, member, name) + assert_function_style(name, member, doc, args) + assert_doc_style(name, member, doc) + + +def assert_doc_style(name, member, doc): + lines = doc.split("\n") + first_line = lines[0] + if len(first_line.strip()) == 0: + raise ValueError("{} the documentation should be on the first line.".format(name), + member.__module__) + if first_line.strip()[-1] != '.': + raise ValueError("{} first line should end with a '.'".format(name), + member.__module__) + + +def assert_function_style(name, member, doc, args): + code = inspect.getsource(member) + has_return = re.findall(r"\s*return \S+", code, re.MULTILINE) + if has_return and "# Returns" not in doc: + innerfunction = [inspect.getsource(x) for x in member.__code__.co_consts if + inspect.iscode(x)] + return_in_sub = [ret for code_inner in innerfunction for ret in + re.findall(r"\s*return \S+", code_inner, re.MULTILINE)] + if len(return_in_sub) < len(has_return): + raise ValueError("{} needs a '# Returns' section".format(name), + member.__module__) + + has_raise = re.findall(r"^\s*raise \S+", code, re.MULTILINE) + if has_raise and "# Raises" not in doc: + innerfunction = [inspect.getsource(x) for x in member.__code__.co_consts if + inspect.iscode(x)] + raise_in_sub = [ret for code_inner in innerfunction for ret in + re.findall(r"\s*raise \S+", code_inner, re.MULTILINE)] + if len(raise_in_sub) < len(has_raise): + raise ValueError("{} needs a '# Raises' section".format(name), + member.__module__) + + if len(args) > 0 and "# Arguments" not in doc: + raise ValueError("{} needs a '# Arguments' section".format(name), + member.__module__) + + assert_blank_before(name, member, doc, ['# Arguments', '# Raises', '# Returns']) + + +def assert_blank_before(name, member, doc, keywords): + doc_lines = [x.strip() for x in doc.split('\n')] + for keyword in keywords: + if keyword in doc_lines: + index = doc_lines.index(keyword) + if doc_lines[index - 1] != '': + raise ValueError( + "{} '{}' should have a blank line above.".format(name, keyword), + member.__module__) + + +def is_accepted(name, member): + if 'keras' not in str(member.__module__): + return True + return name in accepted_name or member.__module__ in accepted_module + + +def member_too_small(member): + code = inspect.getsource(member).split('\n') + return len(code) < MIN_CODE_SIZE + + +def assert_args_presence(args, doc, member, name): + args_not_in_doc = [arg not in doc for arg in args] + if any(args_not_in_doc): + raise ValueError( + "{} {} arguments are not present in documentation ".format(name, list( + compress(args, args_not_in_doc))), member.__module__) + words = doc.replace('*', '').split() + # Check arguments styling + styles = [arg + ":" not in words for arg in args] + if any(styles): + raise ValueError( + "{} {} are not style properly 'argument': documentation".format(name, list( + compress(args, styles))), member.__module__) + + # Check arguments order + indexes = [words.index(arg + ":") for arg in args] + if indexes != sorted(indexes): + raise ValueError( + "{} arguments order is different from the documentation".format(name), + member.__module__) + + +def handle_method(name, member): + if name in accepted_name or member.__module__ in accepted_module: + return + handle_function(name, member) + + +def handle_module(mod): + for name, mem in inspect.getmembers(mod): + if inspect.isclass(mem): + handle_class(name, mem) + elif inspect.isfunction(mem): + handle_function(name, mem) + elif 'keras' in name and inspect.ismodule(mem): + # Only test keras' modules + handle_module(mem) + + +@pytest.mark.skipif(sys.version_info < (3, 3), reason="requires python3.3") +def test_doc(): + for module in modules: + mod = importlib.import_module(module) + handle_module(mod) + + +if __name__ == '__main__': + pytest.main([__file__])