keras/examples/demo_custom_layer_backend_agnostic.py

91 lines
2.4 KiB
Python
Raw Normal View History

2023-04-18 22:46:57 +00:00
import numpy as np
2023-05-17 23:06:18 +00:00
import keras_core
2023-04-18 22:46:57 +00:00
from keras_core import Model
from keras_core import backend
from keras_core import initializers
from keras_core import layers
from keras_core import losses
from keras_core import metrics
from keras_core import operations as ops
from keras_core import optimizers
class MyDense(layers.Layer):
def __init__(self, units, name=None):
super().__init__(name=name)
self.units = units
def build(self, input_shape):
input_dim = input_shape[-1]
self.w = self.add_weight(
shape=(input_dim, self.units),
initializer=initializers.GlorotNormal(),
name="kernel",
trainable=True,
)
2023-05-03 05:44:46 +00:00
self.b = self.add_weight(
shape=(self.units,),
2023-05-18 22:07:14 +00:00
initializer=initializers.Zeros(),
2023-05-03 05:44:46 +00:00
name="bias",
trainable=True,
)
2023-04-18 22:46:57 +00:00
def call(self, inputs):
# Use Keras ops to create backend-agnostic layers/metrics/etc.
return ops.matmul(inputs, self.w) + self.b
class MyDropout(layers.Layer):
def __init__(self, rate, name=None):
super().__init__(name=name)
self.rate = rate
# Use seed_generator for managing RNG state.
# It is a state element and its seed variable is
# tracked as part of `layer.variables`.
2023-05-17 23:06:18 +00:00
self.seed_generator = keras_core.random.SeedGenerator(1337)
def call(self, inputs):
2023-05-17 23:06:18 +00:00
# Use `keras_core.random` for random ops.
return keras_core.random.dropout(
inputs, self.rate, seed=self.seed_generator
)
2023-04-18 22:46:57 +00:00
class MyModel(Model):
def __init__(self, hidden_dim, output_dim):
super().__init__()
self.dense1 = MyDense(hidden_dim)
self.dense2 = MyDense(hidden_dim)
self.dense3 = MyDense(output_dim)
self.dp = MyDropout(0.5)
2023-04-18 22:46:57 +00:00
def call(self, x):
x1 = self.dense1(x)
x2 = self.dense2(x)
# Why not use some ops here as well
x = ops.concatenate([x1, x2], axis=-1)
x = self.dp(x)
2023-04-18 22:46:57 +00:00
return self.dense3(x)
model = MyModel(hidden_dim=256, output_dim=16)
x = np.random.random((50000, 128))
y = np.random.random((50000, 16))
batch_size = 32
2023-05-03 05:44:46 +00:00
epochs = 5
2023-04-18 22:46:57 +00:00
model.compile(
optimizer=optimizers.SGD(learning_rate=0.001),
loss=losses.MeanSquaredError(),
metrics=[metrics.MeanSquaredError()],
)
history = model.fit(x, y, batch_size=batch_size, epochs=epochs)
model.summary()
2023-04-18 22:46:57 +00:00
print("History:")
print(history.history)