f73df5a4eb
* Add `result_dtype` and some refactor of `ops.numpy` * Fix keras_export * Refactor `result_dtype` * Update `result_type` * Revert `ops.numpy.*` changes * ensure consistent dtype inference * add dtype test * fix dropout rnn test * Fix symbolic test * Fix torch test * keep `"int64"` when using tensorflow * Fix test * Simplify `result_type` for tensorflow * Add `pre_canonicalize` option, rename to `result_type` * Align the behavior of `ops.add` * Fix test * Match `backend.result_type` to JAX with `JAX_ENABLE_X64=true` and `JAX_DEFAULT_DTYPE_BITS=32` * Use `dtype or config.floatx()` * Fix symbolic ops * Remove `result_type` in jax and torch * Address comments * Apply `result_type` to `ops.numpy.arange` * Apply `backend.result_type` to `ops.numpy.sqrt` * Skip float16 test for torch's sqrt
35 lines
1.0 KiB
Python
35 lines
1.0 KiB
Python
import os
|
|
|
|
# When using jax.experimental.enable_x64 in unit test, we want to keep the
|
|
# default dtype with 32 bits, aligning it with Keras's default.
|
|
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
|
|
|
|
try:
|
|
# When using torch and tensorflow, torch needs to be imported first,
|
|
# otherwise it will segfault upon import. This should force the torch
|
|
# import to happen first for all tests.
|
|
import torch # noqa: F401
|
|
except ImportError:
|
|
pass
|
|
|
|
import pytest # noqa: E402
|
|
|
|
from keras.backend import backend # noqa: E402
|
|
|
|
|
|
def pytest_configure(config):
|
|
config.addinivalue_line(
|
|
"markers",
|
|
"requires_trainable_backend: mark test for trainable backend only",
|
|
)
|
|
|
|
|
|
def pytest_collection_modifyitems(config, items):
|
|
requires_trainable_backend = pytest.mark.skipif(
|
|
backend() == "numpy",
|
|
reason="Trainer not implemented for NumPy backend.",
|
|
)
|
|
for item in items:
|
|
if "requires_trainable_backend" in item.keywords:
|
|
item.add_marker(requires_trainable_backend)
|