Improve jit autoresolution
This commit is contained in:
parent
1a9e850af2
commit
f235ef6598
@ -1,3 +1,4 @@
|
|||||||
|
import platform
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from keras_core import backend
|
from keras_core import backend
|
||||||
@ -357,3 +358,30 @@ class Trainer:
|
|||||||
else:
|
else:
|
||||||
msg += f"calling `{method_name}()`."
|
msg += f"calling `{method_name}()`."
|
||||||
raise ValueError(msg)
|
raise ValueError(msg)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def supports_jit(self):
|
||||||
|
if (
|
||||||
|
platform.system() == "Darwin"
|
||||||
|
and "arm" in platform.processor().lower()
|
||||||
|
):
|
||||||
|
if backend.backend() == "tensorflow":
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
if tf.config.list_physical_devices("GPU"):
|
||||||
|
warnings.warn(
|
||||||
|
"XLA (`jit_compile`) is not yet supported "
|
||||||
|
"on GPU on Apple M1/M2 ARM processors with "
|
||||||
|
"TensorFlow-Metal. "
|
||||||
|
"Falling back to `jit_compile=False`.",
|
||||||
|
stacklevel=1,
|
||||||
|
)
|
||||||
|
return False
|
||||||
|
if all(x.supports_jit for x in self._flatten_layers()):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
@supports_jit.setter
|
||||||
|
def supports_jit(self, _):
|
||||||
|
# The property is computed, rather than set.
|
||||||
|
pass
|
||||||
|
Loading…
Reference in New Issue
Block a user