keras/conftest.py
2023-09-22 09:29:36 -07:00

29 lines
807 B
Python

try:
# When using torch and tensorflow, torch needs to be imported first,
# otherwise it will segfault upon import. This should force the torch
# import to happen first for all tests.
import torch # noqa: F401
except ImportError:
pass
import pytest
from keras.backend import backend
def pytest_configure(config):
config.addinivalue_line(
"markers",
"requires_trainable_backend: mark test for trainable backend only",
)
def pytest_collection_modifyitems(config, items):
requires_trainable_backend = pytest.mark.skipif(
backend() == "numpy",
reason="Trainer not implemented for NumPy backend.",
)
for item in items:
if "requires_trainable_backend" in item.keywords:
item.add_marker(requires_trainable_backend)