keras/keras_core/utils/backend_utils.py
Aritra Roy Gosthipaty 2481069ed4 Adding: Numpy Backend (#483)
* chore: adding numpy backend

* creview comments

* review comments

* chore: adding math

* chore: adding random module

* chore: adding ranndom in init

* review comments

* chore: adding numpy and nn for numpy backend

* chore: adding generic pool, max, and average pool

* chore: adding the conv ops

* chore: reformat code and using jax for conv and pool

* chore:  added self value

* chore: activation tests pass

* chore: adding post build method

* chore: adding necessaity methods to the numpy trainer

* chore: fixing utils test

* chore: fixing losses test suite

* chore: fix backend tests

* chore: fixing initializers test

* chore: fixing accuracy metrics test

* chore: fixing ops test

* chore: review comments

* chore: init with image and fixing random tests

* chore: skipping random seed set for numpy backend

* chore: adding single resize image method

* chore: skipping tests for applications and layers

* chore: skipping tests for models

* chore: skipping testsor saving

* chore: skipping tests for trainers

* chore:ixing one hot

* chore: fixing vmap in numpy and metrics test

* chore: adding a wrapper to numpy sum, started fixing layer tests

* fix: is_tensor now accepts numpy scalars

* chore: adding draw seed

* fix: warn message for numpy masking

* fix: checking whether kernel are tensors

* chore: adding rnn

* chore: adding dynamic backend for numpy

* fix: axis cannot be None for normalize

* chore: adding jax resize for numpy image

* chore: adding rnn implementation in numpy

* chore: using pytest fixtures

* change: numpy import string

* chore: review comments

* chore: adding numpy to backend list of github actions

* chore: remove debug print statements
2023-07-19 01:08:48 +05:30

60 lines
1.7 KiB
Python

import sys
from keras_core import backend as backend_module
def in_tf_graph():
if "tensorflow" in sys.modules:
from keras_core.utils.module_utils import tensorflow as tf
return not tf.executing_eagerly()
return False
class DynamicBackend:
"""A class that can be used to switch from one backend to another.
Usage:
```python
backend = DynamicBackend("tensorflow")
y = backend.square(tf.constant(...))
backend.set_backend("jax")
y = backend.square(jax.numpy.array(...))
```
Args:
backend: Initial backend to use (string).
"""
def __init__(self, backend=None):
self._backend = backend or backend_module.backend()
def set_backend(self, backend):
self._backend = backend
def reset(self):
self._backend = backend_module.backend()
def __getattr__(self, name):
if self._backend == "tensorflow":
from keras_core.backend import tensorflow as tf_backend
return getattr(tf_backend, name)
if self._backend == "jax":
from keras_core.backend import jax as jax_backend
return getattr(jax_backend, name)
if self._backend == "torch":
from keras_core.backend import torch as torch_backend
return getattr(torch_backend, name)
if self._backend == "numpy":
# TODO (ariG23498):
# The import `from keras_core.backend import numpy as numpy_backend`
# is not working. This is a temporary fix.
# The import is redirected to `keras_core.backend.numpy.numpy.py`
from keras_core import backend as numpy_backend
return getattr(numpy_backend, name)