Add missing docstrings
This commit is contained in:
parent
93f5c6cd7d
commit
083c817788
@ -10,9 +10,8 @@ class Callback:
|
||||
`predict()` in order to hook into the various stages of the model training,
|
||||
evaluation, and inference lifecycle.
|
||||
|
||||
To create a custom callback, subclass `keras.callbacks.Callback` and
|
||||
override the method associated with the stage of interest. See
|
||||
https://www.tensorflow.org/guide/keras/custom_callback for more information.
|
||||
To create a custom callback, subclass `keras_core.callbacks.Callback` and
|
||||
override the method associated with the stage of interest.
|
||||
|
||||
Example:
|
||||
|
||||
|
@ -7,7 +7,7 @@ from keras_core.api_export import keras_core_export
|
||||
class Initializer:
|
||||
"""Initializer base class: all Keras initializers inherit from this class.
|
||||
|
||||
Initializers should implement a `__call__` method with the following
|
||||
Initializers should implement a `__call__()` method with the following
|
||||
signature:
|
||||
|
||||
```python
|
||||
@ -29,7 +29,7 @@ class Initializer:
|
||||
self.stddev = stddev
|
||||
|
||||
def __call__(self, shape, dtype=None, **kwargs):
|
||||
return knp.random.normal(
|
||||
return keras_core.random.normal(
|
||||
shape, mean=self.mean, stddev=self.stddev, dtype=dtype
|
||||
)
|
||||
|
||||
|
@ -9,6 +9,22 @@ from keras_core.utils.naming import auto_name
|
||||
|
||||
@keras_core_export(["keras_core.Loss", "keras_core.losses.Loss"])
|
||||
class Loss:
|
||||
"""Loss base class.
|
||||
|
||||
To be implemented by subclasses:
|
||||
|
||||
* `call()`: Contains the logic for loss calculation using `y_true`,
|
||||
`y_pred`.
|
||||
|
||||
Example subclass implementation:
|
||||
|
||||
```python
|
||||
class MeanSquaredError(Loss):
|
||||
def call(self, y_true, y_pred):
|
||||
return ops.mean(ops.square(y_pred - y_true), axis=-1)
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self, name=None, reduction="sum_over_batch_size"):
|
||||
self.name = name or auto_name(self.__class__.__name__)
|
||||
self.reduction = standardize_reduction(reduction)
|
||||
|
@ -59,7 +59,7 @@ class Metric:
|
||||
self.true_positives = self.add_variable(
|
||||
shape=(),
|
||||
initializer='zeros',
|
||||
name='tp'
|
||||
name='true_positives'
|
||||
)
|
||||
|
||||
def update_state(self, y_true, y_pred, sample_weight=None):
|
||||
|
@ -15,17 +15,75 @@ from keras_core.utils import tracking
|
||||
|
||||
|
||||
class Functional(Function, Model):
|
||||
"""
|
||||
Add support for extra call arguments compared to Function:
|
||||
training, masks
|
||||
"""A `Functional` model is a `Model` defined as a directed graph of layers.
|
||||
|
||||
Add support for arg standardization:
|
||||
- list/dict duality
|
||||
- upranking
|
||||
Three types of `Model` exist: subclassed `Model`, `Functional` model,
|
||||
and `Sequential` (a special case of `Functional`).
|
||||
|
||||
Override .layers
|
||||
A `Functional` model can be instantiated by passing two arguments to
|
||||
`__init__()`. The first argument is the `keras_core.Input` objects
|
||||
that represent the inputs to the model.
|
||||
The second argument specifies the output tensors that represent
|
||||
the outputs of this model. Both arguments can be a nested structure
|
||||
of tensors.
|
||||
|
||||
Symbolic add_loss
|
||||
Example:
|
||||
|
||||
```
|
||||
inputs = {'x1': keras_core.Input(shape=(10,)),
|
||||
'x2': keras_core.Input(shape=(1,))}
|
||||
t = keras_core.layers.Dense(1, activation='relu')(inputs['x1'])
|
||||
outputs = keras_core.layers.Add()([t, inputs['x2'])
|
||||
model = keras_core.Model(inputs, outputs)
|
||||
```
|
||||
|
||||
A `Functional` model constructed using the Functional API can also
|
||||
include raw Keras Core ops.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
inputs = keras_core.Input(shape=(10,))
|
||||
x = keras_core.layers.Dense(1)(inputs)
|
||||
outputs = ops.nn.relu(x)
|
||||
model = keras_core.Model(inputs, outputs)
|
||||
```
|
||||
|
||||
A new `Functional` model can also be created by using the
|
||||
intermediate tensors. This enables you to quickly extract sub-components
|
||||
of the model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
inputs = keras_core.Input(shape=(None, None, 3))
|
||||
processed = keras_core.layers.RandomCrop(width=32, height=32)(inputs)
|
||||
conv = keras_core.layers.Conv2D(filters=2, kernel_size=3)(processed)
|
||||
pooling = keras_core.layers.GlobalAveragePooling2D()(conv)
|
||||
feature = keras_core.layers.Dense(10)(pooling)
|
||||
|
||||
full_model = keras_core.Model(inputs, feature)
|
||||
backbone = keras_core.Model(processed, conv)
|
||||
activations = keras_core.Model(conv, feature)
|
||||
```
|
||||
|
||||
Note that the `backbone` and `activations` models are not
|
||||
created with `keras_core.Input` objects, but with the tensors
|
||||
that are originated from `keras_core.Input` objects.
|
||||
Under the hood, the layers and weights will
|
||||
be shared across these models, so that user can train the `full_model`, and
|
||||
use `backbone` or `activations` to do feature extraction.
|
||||
The inputs and outputs of the model can be nested structures of tensors as
|
||||
well, and the created models are standard `Functional` model that support
|
||||
all the existing API.
|
||||
|
||||
Args:
|
||||
inputs: List of input tensors (must be created via `keras_core.Input()`
|
||||
or originated from `keras_core.Input()`).
|
||||
outputs: List of output tensors.
|
||||
name: String, optional. Name of the model.
|
||||
trainable: Boolean, optional. If the model's variables should be
|
||||
trainable.
|
||||
"""
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
|
@ -26,19 +26,108 @@ else:
|
||||
|
||||
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
|
||||
class Model(Trainer, Layer):
|
||||
"""
|
||||
"""A model grouping layers into an object with training/inference features.
|
||||
|
||||
Combination of a Layer and Trainer. Adds:
|
||||
There are three ways to instantiate a `Model`:
|
||||
|
||||
- layer surfacing
|
||||
- saving
|
||||
- export
|
||||
- summary
|
||||
## With the "Functional API"
|
||||
|
||||
Limitations:
|
||||
You start from `Input`,
|
||||
you chain layer calls to specify the model's forward pass,
|
||||
and finally you create your model from inputs and outputs:
|
||||
|
||||
- call must have a single inputs argument
|
||||
- no masking support
|
||||
```python
|
||||
inputs = keras_core.Input(shape=(37,))
|
||||
x = keras_core.layers.Dense(32, activation="relu")(inputs)
|
||||
outputs = keras_core.layers.Dense(5, activation="softmax")(x)
|
||||
model = keras_core.Model(inputs=inputs, outputs=outputs)
|
||||
```
|
||||
|
||||
Note: Only dicts, lists, and tuples of input tensors are supported. Nested
|
||||
inputs are not supported (e.g. lists of list or dicts of dict).
|
||||
|
||||
A new Functional API model can also be created by using the
|
||||
intermediate tensors. This enables you to quickly extract sub-components
|
||||
of the model.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
inputs = keras_core.Input(shape=(None, None, 3))
|
||||
processed = keras_core.layers.RandomCrop(width=128, height=128)(inputs)
|
||||
conv = keras_core.layers.Conv2D(filters=32, kernel_size=3)(processed)
|
||||
pooling = keras_core.layers.GlobalAveragePooling2D()(conv)
|
||||
feature = keras_core.layers.Dense(10)(pooling)
|
||||
|
||||
full_model = keras_core.Model(inputs, feature)
|
||||
backbone = keras_core.Model(processed, conv)
|
||||
activations = keras_core.Model(conv, feature)
|
||||
```
|
||||
|
||||
Note that the `backbone` and `activations` models are not
|
||||
created with `keras_core.Input` objects, but with the tensors that originate
|
||||
from `keras_core.Input` objects. Under the hood, the layers and weights will
|
||||
be shared across these models, so that user can train the `full_model`, and
|
||||
use `backbone` or `activations` to do feature extraction.
|
||||
The inputs and outputs of the model can be nested structures of tensors as
|
||||
well, and the created models are standard Functional API models that support
|
||||
all the existing APIs.
|
||||
|
||||
## By subclassing the `Model` class
|
||||
|
||||
In that case, you should define your
|
||||
layers in `__init__()` and you should implement the model's forward pass
|
||||
in `call()`.
|
||||
|
||||
```python
|
||||
class MyModel(keras_core.Model):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dense1 = keras_core.layers.Dense(32, activation="relu")
|
||||
self.dense2 = keras_core.layers.Dense(5, activation="softmax")
|
||||
|
||||
def call(self, inputs):
|
||||
x = self.dense1(inputs)
|
||||
return self.dense2(x)
|
||||
|
||||
model = MyModel()
|
||||
```
|
||||
|
||||
If you subclass `Model`, you can optionally have
|
||||
a `training` argument (boolean) in `call()`, which you can use to specify
|
||||
a different behavior in training and inference:
|
||||
|
||||
```python
|
||||
class MyModel(keras_core.Model):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.dense1 = keras_core.layers.Dense(32, activation="relu")
|
||||
self.dense2 = keras_core.layers.Dense(5, activation="softmax")
|
||||
self.dropout = keras_core.layers.Dropout(0.5)
|
||||
|
||||
def call(self, inputs, training=False):
|
||||
x = self.dense1(inputs)
|
||||
x = self.dropout(x, training=training)
|
||||
return self.dense2(x)
|
||||
|
||||
model = MyModel()
|
||||
```
|
||||
|
||||
Once the model is created, you can config the model with losses and metrics
|
||||
with `model.compile()`, train the model with `model.fit()`, or use the model
|
||||
to do prediction with `model.predict()`.
|
||||
|
||||
## With the `Sequential` class
|
||||
|
||||
In addition, `keras_core.Sequential` is a special case of model where
|
||||
the model is purely a stack of single-input, single-output layers.
|
||||
|
||||
```python
|
||||
model = keras_core.Sequential([
|
||||
keras_core.Input(shape=(None, None, 3)),
|
||||
keras_core.layers.Conv2D(filters=32, kernel_size=3),
|
||||
])
|
||||
```
|
||||
"""
|
||||
|
||||
def __new__(cls, *args, **kwargs):
|
||||
|
@ -12,6 +12,51 @@ from keras_core.utils import tracking
|
||||
|
||||
@keras_core_export(["keras_core.Sequential", "keras_core.models.Sequential"])
|
||||
class Sequential(Model):
|
||||
"""`Sequential` groups a linear stack of layers into a `Model`.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
model = keras_core.Sequential()
|
||||
model.add(keras_core.Input(shape=(16,)))
|
||||
model.add(keras_core.layers.Dense(8))
|
||||
|
||||
# Note that you can also omit the initial `Input`.
|
||||
# In that case the model doesn't have any weights until the first call
|
||||
# to a training/evaluation method (since it isn't yet built):
|
||||
model = keras_core.Sequential()
|
||||
model.add(keras_core.layers.Dense(8))
|
||||
model.add(keras_core.layers.Dense(4))
|
||||
# model.weights not created yet
|
||||
|
||||
# Whereas if you specify an `Input`, the model gets built
|
||||
# continuously as you are adding layers:
|
||||
model = keras_core.Sequential()
|
||||
model.add(keras_core.Input(shape=(16,)))
|
||||
model.add(keras_core.layers.Dense(8))
|
||||
len(model.weights) # Returns "2"
|
||||
|
||||
# When using the delayed-build pattern (no input shape specified), you can
|
||||
# choose to manually build your model by calling
|
||||
# `build(batch_input_shape)`:
|
||||
model = keras_core.Sequential()
|
||||
model.add(keras_core.layers.Dense(8))
|
||||
model.add(keras_core.layers.Dense(4))
|
||||
model.build((None, 16))
|
||||
len(model.weights) # Returns "4"
|
||||
|
||||
# Note that when using the delayed-build pattern (no input shape specified),
|
||||
# the model gets built the first time you call `fit`, `eval`, or `predict`,
|
||||
# or the first time you call the model on some input data.
|
||||
model = keras_core.Sequential()
|
||||
model.add(keras_core.layers.Dense(8))
|
||||
model.add(keras_core.layers.Dense(1))
|
||||
model.compile(optimizer='sgd', loss='mse')
|
||||
# This builds the model for the first time:
|
||||
model.fit(x, y, batch_size=32, epochs=10)
|
||||
```
|
||||
"""
|
||||
|
||||
@tracking.no_automatic_dependency_tracking
|
||||
def __init__(self, layers=None, trainable=True, name=None):
|
||||
super().__init__(trainable=trainable, name=name)
|
||||
|
Loading…
Reference in New Issue
Block a user