keras/conftest.py
HongYu f73df5a4eb
Introduce backend.result_type (#18482)
* Add `result_dtype` and some refactor of `ops.numpy`

* Fix keras_export

* Refactor `result_dtype`

* Update `result_type`

* Revert `ops.numpy.*` changes

* ensure consistent dtype inference

* add dtype test

* fix dropout rnn test

* Fix symbolic test

* Fix torch test

* keep `"int64"` when using tensorflow

* Fix test

* Simplify `result_type` for tensorflow

* Add `pre_canonicalize` option, rename to `result_type`

* Align the behavior of `ops.add`

* Fix test

* Match `backend.result_type` to JAX with `JAX_ENABLE_X64=true` and `JAX_DEFAULT_DTYPE_BITS=32`

* Use `dtype or config.floatx()`

* Fix symbolic ops

* Remove `result_type` in jax and torch

* Address comments

* Apply `result_type` to `ops.numpy.arange`

* Apply `backend.result_type` to `ops.numpy.sqrt`

* Skip float16 test for torch's sqrt
2023-09-28 10:00:24 -07:00

35 lines
1.0 KiB
Python

import os
# When using jax.experimental.enable_x64 in unit test, we want to keep the
# default dtype with 32 bits, aligning it with Keras's default.
os.environ["JAX_DEFAULT_DTYPE_BITS"] = "32"
try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
# import to happen first for all tests.
import torch # noqa: F401
except ImportError:
pass
import pytest # noqa: E402
from keras.backend import backend # noqa: E402
def pytest_configure(config):
config.addinivalue_line(
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)
def pytest_collection_modifyitems(config, items):
requires_trainable_backend = pytest.mark.skipif(
backend() == "numpy",
reason="Trainer not implemented for NumPy backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)