diff --git a/keras_core/backend/config.py b/keras_core/backend/config.py index 232c01cba..17479129d 100644 --- a/keras_core/backend/config.py +++ b/keras_core/backend/config.py @@ -1,6 +1,8 @@ import json import os +from keras_core.api_export import keras_core_export + # The type of float to use throughout a session. _FLOATX = "float32" @@ -14,6 +16,7 @@ _IMAGE_DATA_FORMAT = "channels_last" _BACKEND = "tensorflow" +@keras_core_export(["keras_core.config.floatx", "keras_core.backend.floatx"]) def floatx(): """Return the default float type, as a string. @@ -23,32 +26,34 @@ def floatx(): String, the current default float type. Example: - >>> keras_core.backend.floatx() + >>> keras_core.config.floatx() 'float32' """ return _FLOATX +@keras_core_export( + ["keras_core.config.set_floatx", "keras_core.backend.set_floatx"] +) def set_floatx(value): - """Set the default float type. + """Set the default float dtype. - Note: It is not recommended to set this to float16 for training, as this - will likely cause numeric stability issues. Instead, mixed precision, which - is using a mix of float16 and float32, can be used by calling - `keras_core.mixed_precision.set_global_policy('mixed_float16')`. See the - [mixed precision guide]( - https://www.tensorflow.org/guide/keras/mixed_precision) for details. + Note: It is not recommended to set this to `"float16"` for training, + as this will likely cause numeric stability issues. + Instead, mixed precision, which leverages + a mix of `float16` and `float32`. It can be configured by calling + `keras_core.mixed_precision.set_global_policy('mixed_float16')`. Args: value: String; `'float16'`, `'float32'`, or `'float64'`. Example: - >>> keras_core.backend.floatx() + >>> keras_core.config.floatx() 'float32' - >>> keras_core.backend.set_floatx('float64') - >>> keras_core.backend.floatx() + >>> keras_core.config.set_floatx('float64') + >>> keras_core.config.floatx() 'float64' - >>> keras_core.backend.set_floatx('float32') + >>> keras_core.config.set_floatx('float32') Raises: ValueError: In case of invalid value. @@ -63,6 +68,7 @@ def set_floatx(value): _FLOATX = str(value) +@keras_core_export(["keras_core.config.epsilon", "keras_core.backend.epsilon"]) def epsilon(): """Return the value of the fuzz factor used in numeric expressions. @@ -70,12 +76,15 @@ def epsilon(): A float. Example: - >>> keras_core.backend.epsilon() + >>> keras_core.config.epsilon() 1e-07 """ return _EPSILON +@keras_core_export( + ["keras_core.config.set_epsilon", "keras_core.backend.set_epsilon"] +) def set_epsilon(value): """Set the value of the fuzz factor used in numeric expressions. @@ -83,30 +92,42 @@ def set_epsilon(value): value: float. New value of epsilon. Example: - >>> keras_core.backend.epsilon() + >>> keras_core.config.epsilon() 1e-07 - >>> keras_core.backend.set_epsilon(1e-5) - >>> keras_core.backend.epsilon() + >>> keras_core.config.set_epsilon(1e-5) + >>> keras_core.config.epsilon() 1e-05 - >>> keras_core.backend.set_epsilon(1e-7) + >>> keras_core.config.set_epsilon(1e-7) """ global _EPSILON _EPSILON = value +@keras_core_export( + [ + "keras_core.config.image_data_format", + "keras_core.backend.image_data_format", + ] +) def image_data_format(): """Return the default image data format convention. Returns: - A string, either `'channels_first'` or `'channels_last'` + A string, either `'channels_first'` or `'channels_last'`. Example: - >>> keras_core.backend.image_data_format() + >>> keras_core.config.image_data_format() 'channels_last' """ return _IMAGE_DATA_FORMAT +@keras_core_export( + [ + "keras_core.config.set_image_data_format", + "keras_core.backend.set_image_data_format", + ] +) def set_image_data_format(data_format): """Set the value of the image data format convention. @@ -114,15 +135,12 @@ def set_image_data_format(data_format): data_format: string. `'channels_first'` or `'channels_last'`. Example: - >>> keras_core.backend.image_data_format() + >>> keras_core.config.image_data_format() 'channels_last' - >>> keras_core.backend.set_image_data_format('channels_first') - >>> keras_core.backend.image_data_format() + >>> keras_core.config.set_image_data_format('channels_first') + >>> keras_core.config.image_data_format() 'channels_first' - >>> keras_core.backend.set_image_data_format('channels_last') - - Raises: - ValueError: In case of invalid `data_format` value. + >>> keras_core.config.set_image_data_format('channels_last') """ global _IMAGE_DATA_FORMAT accepted_formats = {"channels_last", "channels_first"} @@ -196,6 +214,7 @@ if "KERAS_BACKEND" in os.environ: _BACKEND = _backend +@keras_core_export("keras_core.backend.backend") def backend(): """Publicly accessible method for determining the current backend. diff --git a/keras_core/layers/layer.py b/keras_core/layers/layer.py index 79db92050..5b8cf4189 100644 --- a/keras_core/layers/layer.py +++ b/keras_core/layers/layer.py @@ -34,6 +34,7 @@ from keras_core.layers import input_spec from keras_core.metrics.metric import Metric from keras_core.operations.operation import Operation from keras_core.utils import summary_utils +from keras_core.utils import traceback_utils from keras_core.utils.tracking import Tracker @@ -277,6 +278,7 @@ class Layer(Operation): def compute_mask(self, inputs, previous_mask): return previous_mask + @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): self._check_super_called() @@ -414,6 +416,7 @@ class Layer(Operation): def call(self, *args, **kwargs): raise NotImplementedError + @traceback_utils.filter_traceback def stateless_call( self, trainable_variables, @@ -659,8 +662,8 @@ class Layer(Operation): if not self.built: raise ValueError( "You tried to call `count_params` " - f"on layer '{self.name}'" - ", but the layer isn't built. " + f"on layer '{self.name}', " + "but the layer isn't built. " "You can build it manually via: " f"`layer.build(input_shape)`." ) @@ -714,7 +717,9 @@ class Layer(Operation): "then the build signature should be " "`def build(self, x1_shape, x2_shape)`. " "Keras will not build this layer automatically " - "since it does not conform to this." + "since it does not conform to this. " + "Expected the following build keys: " + f"{list(shapes_dict.keys())}" ) if failure: raise ValueError( diff --git a/keras_core/operations/operation.py b/keras_core/operations/operation.py index 76c5406de..f26a3c285 100644 --- a/keras_core/operations/operation.py +++ b/keras_core/operations/operation.py @@ -8,6 +8,7 @@ from keras_core.backend.common.keras_tensor import any_symbolic_tensors from keras_core.operations.node import Node from keras_core.saving import serialization_lib from keras_core.utils import python_utils +from keras_core.utils import traceback_utils from keras_core.utils.naming import auto_name @@ -24,7 +25,21 @@ class Operation: self._inbound_nodes = [] self._outbound_nodes = [] + @traceback_utils.filter_traceback def __call__(self, *args, **kwargs): + if traceback_utils.is_traceback_filtering_enabled(): + # Wrap self.call to provide helpful info in case of exception + if any_symbolic_tensors(args, kwargs): + call_fn = self.symbolic_call + else: + call_fn = self.call + call_fn = traceback_utils.inject_argument_info_in_traceback( + call_fn, + object_name=(f"{self.__class__.__name__}.call()"), + ) + return call_fn(*args, **kwargs) + + # Plain flow. if any_symbolic_tensors(args, kwargs): return self.symbolic_call(*args, **kwargs) return self.call(*args, **kwargs) @@ -50,13 +65,15 @@ class Operation: try: return backend.compute_output_spec(self.call, *args, **kwargs) except Exception as e: - raise RuntimeError( + new_e = RuntimeError( "Could not automatically infer the output shape / dtype of " - f"operation '{self.name}'. " - "Please implement the `compute_output_spec()` method " - f"on your object ({self.__class__.__name__}). " - f"Error encountered: {e}" + f"'{self.name}' (of type {self.__class__.__name__}). " + f"Either the `{self.__class__.__name__}.call()` method " + f"is incorrect, or you need to implement the " + f"`{self.__class__.__name__}.compute_output_spec()` method. " + f"Error encountered:\n\n{e}" ) + raise new_e.with_traceback(e.__traceback__) from None def __new__(cls, *args, **kwargs): """We override __new__ to saving serializable constructor arguments. diff --git a/keras_core/utils/io_utils.py b/keras_core/utils/io_utils.py index 0941d3c2a..f47b8c29f 100644 --- a/keras_core/utils/io_utils.py +++ b/keras_core/utils/io_utils.py @@ -6,7 +6,7 @@ from keras_core.api_export import keras_core_export from keras_core.backend.common import global_state -@keras_core_export("keras_core.utils.enable_interactive_logging") +@keras_core_export("keras_core.config.enable_interactive_logging") def enable_interactive_logging(): """Turn on interactive logging. @@ -17,7 +17,7 @@ def enable_interactive_logging(): global_state.set_global_setting("interactive_logging", True) -@keras_core_export("keras_core.utils.disable_interactive_logging") +@keras_core_export("keras_core.config.disable_interactive_logging") def disable_interactive_logging(): """Turn off interactive logging. @@ -28,16 +28,17 @@ def disable_interactive_logging(): global_state.set_global_setting("interactive_logging", False) -@keras_core_export("keras_core.utils.is_interactive_logging_enabled") +@keras_core_export("keras_core.config.is_interactive_logging_enabled") def is_interactive_logging_enabled(): """Check if interactive logging is enabled. To switch between writing logs to stdout and `absl.logging`, you may use - `keras.utils.enable_interactive_logging()` and - `keras.utils.disable_interactie_logging()`. + `keras.config.enable_interactive_logging()` and + `keras.config.disable_interactie_logging()`. Returns: - Boolean (True if interactive logging is enabled and False otherwise). + Boolean, `True` if interactive logging is enabled, + and `False` otherwise. """ return global_state.get_global_setting("interactive_logging", True) diff --git a/keras_core/utils/traceback_utils.py b/keras_core/utils/traceback_utils.py new file mode 100644 index 000000000..7b05dd240 --- /dev/null +++ b/keras_core/utils/traceback_utils.py @@ -0,0 +1,228 @@ +import inspect +import os +import traceback +import types + +from tensorflow import errors as tf_errors +from tensorflow import nest + +from keras_core import backend +from keras_core.api_export import keras_core_export +from keras_core.backend.common import global_state + +_EXCLUDED_PATHS = ( + os.path.abspath(os.path.join(__file__, "..", "..")), + os.path.join("tensorflow", "python"), +) + + +@keras_core_export("keras_core.config.enable_traceback_filtering") +def enable_traceback_filtering(): + """Turn on traceback filtering. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras_core.config.disable_traceback_filtering()` and + `keras_core.config.is_traceback_filtering_enabled()`. + + If you have previously disabled traceback filtering via + `keras_core.config.disable_traceback_filtering()`, you can re-enable it via + `keras_core.config.enable_traceback_filtering()`. + """ + global_state.set_global_setting("traceback_filtering", True) + + +@keras_core_export("keras_core.config.disable_traceback_filtering") +def disable_traceback_filtering(): + """Turn off traceback filtering. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras_core.config.enable_traceback_filtering()` and + `keras_core.config.is_traceback_filtering_enabled()`. + + If you have previously disabled traceback filtering via + `keras_core.config.disable_traceback_filtering()`, you can re-enable it via + `keras_core.config.enable_traceback_filtering()`. + """ + global_state.set_global_setting("traceback_filtering", False) + + +@keras_core_export("keras_core.config.is_traceback_filtering_enabled") +def is_traceback_filtering_enabled(): + """Check if traceback filtering is enabled. + + Raw Keras tracebacks (also known as stack traces) + involve many internal frames, which can be + challenging to read through, while not being actionable for end users. + By default, Keras filters internal frames in most exceptions that it + raises, to keep traceback short, readable, and focused on what's + actionable for you (your own code). + + See also `keras_core.config.enable_traceback_filtering()` and + `keras_core.config.disable_traceback_filtering()`. + + If you have previously disabled traceback filtering via + `keras_core.config.disable_traceback_filtering()`, you can re-enable it via + `keras_core.config.enable_traceback_filtering()`. + + Returns: + Boolean, `True` if traceback filtering is enabled, + and `False` otherwise. + """ + return global_state.get_global_setting("traceback_filtering", True) + + +def include_frame(fname): + for exclusion in _EXCLUDED_PATHS: + if exclusion in fname: + return False + return True + + +def _process_traceback_frames(tb): + """Iterate through traceback frames and return a new, filtered traceback.""" + last_tb = None + tb_list = list(traceback.walk_tb(tb)) + for f, line_no in reversed(tb_list): + if include_frame(f.f_code.co_filename): + last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) + if last_tb is None and tb_list: + # If no frames were kept during filtering, create a new traceback + # from the outermost function. + f, line_no = tb_list[-1] + last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no) + return last_tb + + +def filter_traceback(fn): + """Filter out Keras-internal traceback frames in exceptions raised by fn.""" + + def error_handler(*args, **kwargs): + if not is_traceback_filtering_enabled(): + return fn(*args, **kwargs) + + filtered_tb = None + try: + return fn(*args, **kwargs) + except Exception as e: + filtered_tb = _process_traceback_frames(e.__traceback__) + # To get the full stack trace, call: + # `keras_core.config.disable_traceback_filtering()` + raise e.with_traceback(filtered_tb) from None + finally: + del filtered_tb + + return error_handler + + +def inject_argument_info_in_traceback(fn, object_name=None): + """Add information about call argument values to an error message. + + Arguments: + fn: Function to wrap. Exceptions raised by the this function will be + re-raised with additional information added to the error message, + displaying the values of the different arguments that the function + was called with. + object_name: String, display name of the class/function being called, + e.g. `'layer "layer_name" (LayerClass)'`. + + Returns: + A wrapped version of `fn`. + """ + + def error_handler(*args, **kwargs): + signature = None + bound_signature = None + try: + return fn(*args, **kwargs) + except Exception as e: + if hasattr(e, "_keras_call_info_injected"): + # Only inject info for the innermost failing call + raise e + signature = inspect.signature(fn) + try: + # The first argument is `self`, so filter it out + bound_signature = signature.bind(*args, **kwargs) + except TypeError: + # Likely unbindable arguments + raise e + + # Add argument context + arguments_context = [] + for arg in list(signature.parameters.values()): + if arg.name in bound_signature.arguments: + value = nest.map_structure( + format_argument_value, + bound_signature.arguments[arg.name], + ) + else: + value = arg.default + arguments_context.append(f" • {arg.name}={value}") + + if arguments_context: + arguments_context = "\n".join(arguments_context) + # Get original error message and append information to it. + if isinstance(e, tf_errors.OpError): + message = e.message + elif e.args: + # Canonically, the 1st argument in an exception is the error + # message. This works for all built-in Python exceptions. + message = e.args[0] + else: + message = "" + display_name = f"{object_name if object_name else fn.__name__}" + message = ( + f"Exception encountered when calling {display_name}.\n\n" + f"\x1b[1m{message}\x1b[0m\n\n" + f"Arguments received by {display_name}:\n" + f"{arguments_context}" + ) + + # Reraise exception, with added context + if isinstance(e, tf_errors.OpError): + new_e = e.__class__(e.node_def, e.op, message, e.error_code) + else: + try: + # For standard exceptions such as ValueError, TypeError, + # etc. + new_e = e.__class__(message) + except TypeError: + # For any custom error that doesn't have a standard + # signature. + new_e = RuntimeError(message) + new_e._keras_call_info_injected = True + else: + new_e = e + raise new_e.with_traceback(e.__traceback__) from None + finally: + del signature + del bound_signature + + return error_handler + + +def format_argument_value(value): + if backend.is_tensor(value): + # Simplified representation for eager / graph tensors + # to keep messages readable + if backend.backend() == "tensorflow": + tensor_cls = "tf.Tensor" + elif backend.backend() == "jax": + tensor_cls = "jnp.ndarray" + elif backend.backend() == "pytorch": + tensor_cls = "torch.Tensor" + else: + tensor_cls = "array" + return f"{tensor_cls}(shape={value.shape}, dtype={value.dtype.name})" + return repr(value)