Improve jit autoresolution

This commit is contained in:
Francois Chollet 2023-06-09 12:37:31 -07:00
parent 1a9e850af2
commit f235ef6598

@ -1,3 +1,4 @@
import platform
import warnings import warnings
from keras_core import backend from keras_core import backend
@ -357,3 +358,30 @@ class Trainer:
else: else:
msg += f"calling `{method_name}()`." msg += f"calling `{method_name}()`."
raise ValueError(msg) raise ValueError(msg)
@property
def supports_jit(self):
if (
platform.system() == "Darwin"
and "arm" in platform.processor().lower()
):
if backend.backend() == "tensorflow":
import tensorflow as tf
if tf.config.list_physical_devices("GPU"):
warnings.warn(
"XLA (`jit_compile`) is not yet supported "
"on GPU on Apple M1/M2 ARM processors with "
"TensorFlow-Metal. "
"Falling back to `jit_compile=False`.",
stacklevel=1,
)
return False
if all(x.supports_jit for x in self._flatten_layers()):
return True
return False
@supports_jit.setter
def supports_jit(self, _):
# The property is computed, rather than set.
pass