1937d487b0
* API Generator for Keras * API Generator for Keras * Generates API Gen via api_gen.sh * Remove recursive import of _tf_keras * Generate API Files via api_gen.sh
35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
import os
|
|
|
|
# When using jax.experimental.enable_x64 in unit test, we want to keep the
|
|
# default dtype with 32 bits, aligning it with Keras's default.
|
|
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
|
|
|
|
try:
|
|
# When using torch and tensorflow, torch needs to be imported first,
|
|
# otherwise it will segfault upon import. This should force the torch
|
|
# import to happen first for all tests.
|
|
import torch # noqa: F401
|
|
except ImportError:
|
|
pass
|
|
|
|
import pytest # noqa: E402
|
|
|
|
from keras.src.backend import backend # noqa: E402
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"requires_trainable_backend: mark test for trainable backend only",
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
requires_trainable_backend = pytest.mark.skipif(
|
|
backend() == "numpy",
|
|
reason="Trainer not implemented for NumPy backend.",
|
|
)
|
|
for item in items:
|
|
if "requires_trainable_backend" in item.keywords:
|
|
item.add_marker(requires_trainable_backend)
|