Improve jit autoresolution
This commit is contained in:
parent
1a9e850af2
commit
f235ef6598
@ -1,3 +1,4 @@
|
||||
import platform
|
||||
import warnings
|
||||
|
||||
from keras_core import backend
|
||||
@ -357,3 +358,30 @@ class Trainer:
|
||||
else:
|
||||
msg += f"calling `{method_name}()`."
|
||||
raise ValueError(msg)
|
||||
|
||||
@property
|
||||
def supports_jit(self):
|
||||
if (
|
||||
platform.system() == "Darwin"
|
||||
and "arm" in platform.processor().lower()
|
||||
):
|
||||
if backend.backend() == "tensorflow":
|
||||
import tensorflow as tf
|
||||
|
||||
if tf.config.list_physical_devices("GPU"):
|
||||
warnings.warn(
|
||||
"XLA (`jit_compile`) is not yet supported "
|
||||
"on GPU on Apple M1/M2 ARM processors with "
|
||||
"TensorFlow-Metal. "
|
||||
"Falling back to `jit_compile=False`.",
|
||||
stacklevel=1,
|
||||
)
|
||||
return False
|
||||
if all(x.supports_jit for x in self._flatten_layers()):
|
||||
return True
|
||||
return False
|
||||
|
||||
@supports_jit.setter
|
||||
def supports_jit(self, _):
|
||||
# The property is computed, rather than set.
|
||||
pass
|
||||
|
Loading…
Reference in New Issue
Block a user