Add traceback filtering.
This commit is contained in:
parent
43e33ab9ab
commit
aaa13fd9c6
@ -1,6 +1,8 @@
|
||||
import json
|
||||
import os
|
||||
|
||||
from keras_core.api_export import keras_core_export
|
||||
|
||||
# The type of float to use throughout a session.
|
||||
_FLOATX = "float32"
|
||||
|
||||
@ -14,6 +16,7 @@ _IMAGE_DATA_FORMAT = "channels_last"
|
||||
_BACKEND = "tensorflow"
|
||||
|
||||
|
||||
@keras_core_export(["keras_core.config.floatx", "keras_core.backend.floatx"])
|
||||
def floatx():
|
||||
"""Return the default float type, as a string.
|
||||
|
||||
@ -23,32 +26,34 @@ def floatx():
|
||||
String, the current default float type.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.floatx()
|
||||
>>> keras_core.config.floatx()
|
||||
'float32'
|
||||
"""
|
||||
return _FLOATX
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.config.set_floatx", "keras_core.backend.set_floatx"]
|
||||
)
|
||||
def set_floatx(value):
|
||||
"""Set the default float type.
|
||||
"""Set the default float dtype.
|
||||
|
||||
Note: It is not recommended to set this to float16 for training, as this
|
||||
will likely cause numeric stability issues. Instead, mixed precision, which
|
||||
is using a mix of float16 and float32, can be used by calling
|
||||
`keras_core.mixed_precision.set_global_policy('mixed_float16')`. See the
|
||||
[mixed precision guide](
|
||||
https://www.tensorflow.org/guide/keras/mixed_precision) for details.
|
||||
Note: It is not recommended to set this to `"float16"` for training,
|
||||
as this will likely cause numeric stability issues.
|
||||
Instead, mixed precision, which leverages
|
||||
a mix of `float16` and `float32`. It can be configured by calling
|
||||
`keras_core.mixed_precision.set_global_policy('mixed_float16')`.
|
||||
|
||||
Args:
|
||||
value: String; `'float16'`, `'float32'`, or `'float64'`.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.floatx()
|
||||
>>> keras_core.config.floatx()
|
||||
'float32'
|
||||
>>> keras_core.backend.set_floatx('float64')
|
||||
>>> keras_core.backend.floatx()
|
||||
>>> keras_core.config.set_floatx('float64')
|
||||
>>> keras_core.config.floatx()
|
||||
'float64'
|
||||
>>> keras_core.backend.set_floatx('float32')
|
||||
>>> keras_core.config.set_floatx('float32')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid value.
|
||||
@ -63,6 +68,7 @@ def set_floatx(value):
|
||||
_FLOATX = str(value)
|
||||
|
||||
|
||||
@keras_core_export(["keras_core.config.epsilon", "keras_core.backend.epsilon"])
|
||||
def epsilon():
|
||||
"""Return the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
@ -70,12 +76,15 @@ def epsilon():
|
||||
A float.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.epsilon()
|
||||
>>> keras_core.config.epsilon()
|
||||
1e-07
|
||||
"""
|
||||
return _EPSILON
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
["keras_core.config.set_epsilon", "keras_core.backend.set_epsilon"]
|
||||
)
|
||||
def set_epsilon(value):
|
||||
"""Set the value of the fuzz factor used in numeric expressions.
|
||||
|
||||
@ -83,30 +92,42 @@ def set_epsilon(value):
|
||||
value: float. New value of epsilon.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.epsilon()
|
||||
>>> keras_core.config.epsilon()
|
||||
1e-07
|
||||
>>> keras_core.backend.set_epsilon(1e-5)
|
||||
>>> keras_core.backend.epsilon()
|
||||
>>> keras_core.config.set_epsilon(1e-5)
|
||||
>>> keras_core.config.epsilon()
|
||||
1e-05
|
||||
>>> keras_core.backend.set_epsilon(1e-7)
|
||||
>>> keras_core.config.set_epsilon(1e-7)
|
||||
"""
|
||||
global _EPSILON
|
||||
_EPSILON = value
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
[
|
||||
"keras_core.config.image_data_format",
|
||||
"keras_core.backend.image_data_format",
|
||||
]
|
||||
)
|
||||
def image_data_format():
|
||||
"""Return the default image data format convention.
|
||||
|
||||
Returns:
|
||||
A string, either `'channels_first'` or `'channels_last'`
|
||||
A string, either `'channels_first'` or `'channels_last'`.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.image_data_format()
|
||||
>>> keras_core.config.image_data_format()
|
||||
'channels_last'
|
||||
"""
|
||||
return _IMAGE_DATA_FORMAT
|
||||
|
||||
|
||||
@keras_core_export(
|
||||
[
|
||||
"keras_core.config.set_image_data_format",
|
||||
"keras_core.backend.set_image_data_format",
|
||||
]
|
||||
)
|
||||
def set_image_data_format(data_format):
|
||||
"""Set the value of the image data format convention.
|
||||
|
||||
@ -114,15 +135,12 @@ def set_image_data_format(data_format):
|
||||
data_format: string. `'channels_first'` or `'channels_last'`.
|
||||
|
||||
Example:
|
||||
>>> keras_core.backend.image_data_format()
|
||||
>>> keras_core.config.image_data_format()
|
||||
'channels_last'
|
||||
>>> keras_core.backend.set_image_data_format('channels_first')
|
||||
>>> keras_core.backend.image_data_format()
|
||||
>>> keras_core.config.set_image_data_format('channels_first')
|
||||
>>> keras_core.config.image_data_format()
|
||||
'channels_first'
|
||||
>>> keras_core.backend.set_image_data_format('channels_last')
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `data_format` value.
|
||||
>>> keras_core.config.set_image_data_format('channels_last')
|
||||
"""
|
||||
global _IMAGE_DATA_FORMAT
|
||||
accepted_formats = {"channels_last", "channels_first"}
|
||||
@ -196,6 +214,7 @@ if "KERAS_BACKEND" in os.environ:
|
||||
_BACKEND = _backend
|
||||
|
||||
|
||||
@keras_core_export("keras_core.backend.backend")
|
||||
def backend():
|
||||
"""Publicly accessible method for determining the current backend.
|
||||
|
||||
|
@ -34,6 +34,7 @@ from keras_core.layers import input_spec
|
||||
from keras_core.metrics.metric import Metric
|
||||
from keras_core.operations.operation import Operation
|
||||
from keras_core.utils import summary_utils
|
||||
from keras_core.utils import traceback_utils
|
||||
from keras_core.utils.tracking import Tracker
|
||||
|
||||
|
||||
@ -277,6 +278,7 @@ class Layer(Operation):
|
||||
def compute_mask(self, inputs, previous_mask):
|
||||
return previous_mask
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def __call__(self, *args, **kwargs):
|
||||
self._check_super_called()
|
||||
|
||||
@ -414,6 +416,7 @@ class Layer(Operation):
|
||||
def call(self, *args, **kwargs):
|
||||
raise NotImplementedError
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def stateless_call(
|
||||
self,
|
||||
trainable_variables,
|
||||
@ -659,8 +662,8 @@ class Layer(Operation):
|
||||
if not self.built:
|
||||
raise ValueError(
|
||||
"You tried to call `count_params` "
|
||||
f"on layer '{self.name}'"
|
||||
", but the layer isn't built. "
|
||||
f"on layer '{self.name}', "
|
||||
"but the layer isn't built. "
|
||||
"You can build it manually via: "
|
||||
f"`layer.build(input_shape)`."
|
||||
)
|
||||
@ -714,7 +717,9 @@ class Layer(Operation):
|
||||
"then the build signature should be "
|
||||
"`def build(self, x1_shape, x2_shape)`. "
|
||||
"Keras will not build this layer automatically "
|
||||
"since it does not conform to this."
|
||||
"since it does not conform to this. "
|
||||
"Expected the following build keys: "
|
||||
f"{list(shapes_dict.keys())}"
|
||||
)
|
||||
if failure:
|
||||
raise ValueError(
|
||||
|
@ -8,6 +8,7 @@ from keras_core.backend.common.keras_tensor import any_symbolic_tensors
|
||||
from keras_core.operations.node import Node
|
||||
from keras_core.saving import serialization_lib
|
||||
from keras_core.utils import python_utils
|
||||
from keras_core.utils import traceback_utils
|
||||
from keras_core.utils.naming import auto_name
|
||||
|
||||
|
||||
@ -24,7 +25,21 @@ class Operation:
|
||||
self._inbound_nodes = []
|
||||
self._outbound_nodes = []
|
||||
|
||||
@traceback_utils.filter_traceback
|
||||
def __call__(self, *args, **kwargs):
|
||||
if traceback_utils.is_traceback_filtering_enabled():
|
||||
# Wrap self.call to provide helpful info in case of exception
|
||||
if any_symbolic_tensors(args, kwargs):
|
||||
call_fn = self.symbolic_call
|
||||
else:
|
||||
call_fn = self.call
|
||||
call_fn = traceback_utils.inject_argument_info_in_traceback(
|
||||
call_fn,
|
||||
object_name=(f"{self.__class__.__name__}.call()"),
|
||||
)
|
||||
return call_fn(*args, **kwargs)
|
||||
|
||||
# Plain flow.
|
||||
if any_symbolic_tensors(args, kwargs):
|
||||
return self.symbolic_call(*args, **kwargs)
|
||||
return self.call(*args, **kwargs)
|
||||
@ -50,13 +65,15 @@ class Operation:
|
||||
try:
|
||||
return backend.compute_output_spec(self.call, *args, **kwargs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
new_e = RuntimeError(
|
||||
"Could not automatically infer the output shape / dtype of "
|
||||
f"operation '{self.name}'. "
|
||||
"Please implement the `compute_output_spec()` method "
|
||||
f"on your object ({self.__class__.__name__}). "
|
||||
f"Error encountered: {e}"
|
||||
f"'{self.name}' (of type {self.__class__.__name__}). "
|
||||
f"Either the `{self.__class__.__name__}.call()` method "
|
||||
f"is incorrect, or you need to implement the "
|
||||
f"`{self.__class__.__name__}.compute_output_spec()` method. "
|
||||
f"Error encountered:\n\n{e}"
|
||||
)
|
||||
raise new_e.with_traceback(e.__traceback__) from None
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
"""We override __new__ to saving serializable constructor arguments.
|
||||
|
@ -6,7 +6,7 @@ from keras_core.api_export import keras_core_export
|
||||
from keras_core.backend.common import global_state
|
||||
|
||||
|
||||
@keras_core_export("keras_core.utils.enable_interactive_logging")
|
||||
@keras_core_export("keras_core.config.enable_interactive_logging")
|
||||
def enable_interactive_logging():
|
||||
"""Turn on interactive logging.
|
||||
|
||||
@ -17,7 +17,7 @@ def enable_interactive_logging():
|
||||
global_state.set_global_setting("interactive_logging", True)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.utils.disable_interactive_logging")
|
||||
@keras_core_export("keras_core.config.disable_interactive_logging")
|
||||
def disable_interactive_logging():
|
||||
"""Turn off interactive logging.
|
||||
|
||||
@ -28,16 +28,17 @@ def disable_interactive_logging():
|
||||
global_state.set_global_setting("interactive_logging", False)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.utils.is_interactive_logging_enabled")
|
||||
@keras_core_export("keras_core.config.is_interactive_logging_enabled")
|
||||
def is_interactive_logging_enabled():
|
||||
"""Check if interactive logging is enabled.
|
||||
|
||||
To switch between writing logs to stdout and `absl.logging`, you may use
|
||||
`keras.utils.enable_interactive_logging()` and
|
||||
`keras.utils.disable_interactie_logging()`.
|
||||
`keras.config.enable_interactive_logging()` and
|
||||
`keras.config.disable_interactie_logging()`.
|
||||
|
||||
Returns:
|
||||
Boolean (True if interactive logging is enabled and False otherwise).
|
||||
Boolean, `True` if interactive logging is enabled,
|
||||
and `False` otherwise.
|
||||
"""
|
||||
return global_state.get_global_setting("interactive_logging", True)
|
||||
|
||||
|
228
keras_core/utils/traceback_utils.py
Normal file
228
keras_core/utils/traceback_utils.py
Normal file
@ -0,0 +1,228 @@
|
||||
import inspect
|
||||
import os
|
||||
import traceback
|
||||
import types
|
||||
|
||||
from tensorflow import errors as tf_errors
|
||||
from tensorflow import nest
|
||||
|
||||
from keras_core import backend
|
||||
from keras_core.api_export import keras_core_export
|
||||
from keras_core.backend.common import global_state
|
||||
|
||||
_EXCLUDED_PATHS = (
|
||||
os.path.abspath(os.path.join(__file__, "..", "..")),
|
||||
os.path.join("tensorflow", "python"),
|
||||
)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.config.enable_traceback_filtering")
|
||||
def enable_traceback_filtering():
|
||||
"""Turn on traceback filtering.
|
||||
|
||||
Raw Keras tracebacks (also known as stack traces)
|
||||
involve many internal frames, which can be
|
||||
challenging to read through, while not being actionable for end users.
|
||||
By default, Keras filters internal frames in most exceptions that it
|
||||
raises, to keep traceback short, readable, and focused on what's
|
||||
actionable for you (your own code).
|
||||
|
||||
See also `keras_core.config.disable_traceback_filtering()` and
|
||||
`keras_core.config.is_traceback_filtering_enabled()`.
|
||||
|
||||
If you have previously disabled traceback filtering via
|
||||
`keras_core.config.disable_traceback_filtering()`, you can re-enable it via
|
||||
`keras_core.config.enable_traceback_filtering()`.
|
||||
"""
|
||||
global_state.set_global_setting("traceback_filtering", True)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.config.disable_traceback_filtering")
|
||||
def disable_traceback_filtering():
|
||||
"""Turn off traceback filtering.
|
||||
|
||||
Raw Keras tracebacks (also known as stack traces)
|
||||
involve many internal frames, which can be
|
||||
challenging to read through, while not being actionable for end users.
|
||||
By default, Keras filters internal frames in most exceptions that it
|
||||
raises, to keep traceback short, readable, and focused on what's
|
||||
actionable for you (your own code).
|
||||
|
||||
See also `keras_core.config.enable_traceback_filtering()` and
|
||||
`keras_core.config.is_traceback_filtering_enabled()`.
|
||||
|
||||
If you have previously disabled traceback filtering via
|
||||
`keras_core.config.disable_traceback_filtering()`, you can re-enable it via
|
||||
`keras_core.config.enable_traceback_filtering()`.
|
||||
"""
|
||||
global_state.set_global_setting("traceback_filtering", False)
|
||||
|
||||
|
||||
@keras_core_export("keras_core.config.is_traceback_filtering_enabled")
|
||||
def is_traceback_filtering_enabled():
|
||||
"""Check if traceback filtering is enabled.
|
||||
|
||||
Raw Keras tracebacks (also known as stack traces)
|
||||
involve many internal frames, which can be
|
||||
challenging to read through, while not being actionable for end users.
|
||||
By default, Keras filters internal frames in most exceptions that it
|
||||
raises, to keep traceback short, readable, and focused on what's
|
||||
actionable for you (your own code).
|
||||
|
||||
See also `keras_core.config.enable_traceback_filtering()` and
|
||||
`keras_core.config.disable_traceback_filtering()`.
|
||||
|
||||
If you have previously disabled traceback filtering via
|
||||
`keras_core.config.disable_traceback_filtering()`, you can re-enable it via
|
||||
`keras_core.config.enable_traceback_filtering()`.
|
||||
|
||||
Returns:
|
||||
Boolean, `True` if traceback filtering is enabled,
|
||||
and `False` otherwise.
|
||||
"""
|
||||
return global_state.get_global_setting("traceback_filtering", True)
|
||||
|
||||
|
||||
def include_frame(fname):
|
||||
for exclusion in _EXCLUDED_PATHS:
|
||||
if exclusion in fname:
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def _process_traceback_frames(tb):
|
||||
"""Iterate through traceback frames and return a new, filtered traceback."""
|
||||
last_tb = None
|
||||
tb_list = list(traceback.walk_tb(tb))
|
||||
for f, line_no in reversed(tb_list):
|
||||
if include_frame(f.f_code.co_filename):
|
||||
last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)
|
||||
if last_tb is None and tb_list:
|
||||
# If no frames were kept during filtering, create a new traceback
|
||||
# from the outermost function.
|
||||
f, line_no = tb_list[-1]
|
||||
last_tb = types.TracebackType(last_tb, f, f.f_lasti, line_no)
|
||||
return last_tb
|
||||
|
||||
|
||||
def filter_traceback(fn):
|
||||
"""Filter out Keras-internal traceback frames in exceptions raised by fn."""
|
||||
|
||||
def error_handler(*args, **kwargs):
|
||||
if not is_traceback_filtering_enabled():
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
filtered_tb = None
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
filtered_tb = _process_traceback_frames(e.__traceback__)
|
||||
# To get the full stack trace, call:
|
||||
# `keras_core.config.disable_traceback_filtering()`
|
||||
raise e.with_traceback(filtered_tb) from None
|
||||
finally:
|
||||
del filtered_tb
|
||||
|
||||
return error_handler
|
||||
|
||||
|
||||
def inject_argument_info_in_traceback(fn, object_name=None):
|
||||
"""Add information about call argument values to an error message.
|
||||
|
||||
Arguments:
|
||||
fn: Function to wrap. Exceptions raised by the this function will be
|
||||
re-raised with additional information added to the error message,
|
||||
displaying the values of the different arguments that the function
|
||||
was called with.
|
||||
object_name: String, display name of the class/function being called,
|
||||
e.g. `'layer "layer_name" (LayerClass)'`.
|
||||
|
||||
Returns:
|
||||
A wrapped version of `fn`.
|
||||
"""
|
||||
|
||||
def error_handler(*args, **kwargs):
|
||||
signature = None
|
||||
bound_signature = None
|
||||
try:
|
||||
return fn(*args, **kwargs)
|
||||
except Exception as e:
|
||||
if hasattr(e, "_keras_call_info_injected"):
|
||||
# Only inject info for the innermost failing call
|
||||
raise e
|
||||
signature = inspect.signature(fn)
|
||||
try:
|
||||
# The first argument is `self`, so filter it out
|
||||
bound_signature = signature.bind(*args, **kwargs)
|
||||
except TypeError:
|
||||
# Likely unbindable arguments
|
||||
raise e
|
||||
|
||||
# Add argument context
|
||||
arguments_context = []
|
||||
for arg in list(signature.parameters.values()):
|
||||
if arg.name in bound_signature.arguments:
|
||||
value = nest.map_structure(
|
||||
format_argument_value,
|
||||
bound_signature.arguments[arg.name],
|
||||
)
|
||||
else:
|
||||
value = arg.default
|
||||
arguments_context.append(f" • {arg.name}={value}")
|
||||
|
||||
if arguments_context:
|
||||
arguments_context = "\n".join(arguments_context)
|
||||
# Get original error message and append information to it.
|
||||
if isinstance(e, tf_errors.OpError):
|
||||
message = e.message
|
||||
elif e.args:
|
||||
# Canonically, the 1st argument in an exception is the error
|
||||
# message. This works for all built-in Python exceptions.
|
||||
message = e.args[0]
|
||||
else:
|
||||
message = ""
|
||||
display_name = f"{object_name if object_name else fn.__name__}"
|
||||
message = (
|
||||
f"Exception encountered when calling {display_name}.\n\n"
|
||||
f"\x1b[1m{message}\x1b[0m\n\n"
|
||||
f"Arguments received by {display_name}:\n"
|
||||
f"{arguments_context}"
|
||||
)
|
||||
|
||||
# Reraise exception, with added context
|
||||
if isinstance(e, tf_errors.OpError):
|
||||
new_e = e.__class__(e.node_def, e.op, message, e.error_code)
|
||||
else:
|
||||
try:
|
||||
# For standard exceptions such as ValueError, TypeError,
|
||||
# etc.
|
||||
new_e = e.__class__(message)
|
||||
except TypeError:
|
||||
# For any custom error that doesn't have a standard
|
||||
# signature.
|
||||
new_e = RuntimeError(message)
|
||||
new_e._keras_call_info_injected = True
|
||||
else:
|
||||
new_e = e
|
||||
raise new_e.with_traceback(e.__traceback__) from None
|
||||
finally:
|
||||
del signature
|
||||
del bound_signature
|
||||
|
||||
return error_handler
|
||||
|
||||
|
||||
def format_argument_value(value):
|
||||
if backend.is_tensor(value):
|
||||
# Simplified representation for eager / graph tensors
|
||||
# to keep messages readable
|
||||
if backend.backend() == "tensorflow":
|
||||
tensor_cls = "tf.Tensor"
|
||||
elif backend.backend() == "jax":
|
||||
tensor_cls = "jnp.ndarray"
|
||||
elif backend.backend() == "pytorch":
|
||||
tensor_cls = "torch.Tensor"
|
||||
else:
|
||||
tensor_cls = "array"
|
||||
return f"{tensor_cls}(shape={value.shape}, dtype={value.dtype.name})"
|
||||
return repr(value)
|
Loading…
Reference in New Issue
Block a user