Add missing docstrings

This commit is contained in:
Francois Chollet 2023-05-26 14:11:03 -07:00
parent 93f5c6cd7d
commit 083c817788
7 changed files with 232 additions and 25 deletions

@ -10,9 +10,8 @@ class Callback:
`predict()` in order to hook into the various stages of the model training,
evaluation, and inference lifecycle.
To create a custom callback, subclass `keras.callbacks.Callback` and
override the method associated with the stage of interest. See
https://www.tensorflow.org/guide/keras/custom_callback for more information.
To create a custom callback, subclass `keras_core.callbacks.Callback` and
override the method associated with the stage of interest.
Example:

@ -7,7 +7,7 @@ from keras_core.api_export import keras_core_export
class Initializer:
"""Initializer base class: all Keras initializers inherit from this class.
Initializers should implement a `__call__` method with the following
Initializers should implement a `__call__()` method with the following
signature:
```python
@ -29,7 +29,7 @@ class Initializer:
self.stddev = stddev
def __call__(self, shape, dtype=None, **kwargs):
return knp.random.normal(
return keras_core.random.normal(
shape, mean=self.mean, stddev=self.stddev, dtype=dtype
)

@ -9,6 +9,22 @@ from keras_core.utils.naming import auto_name
@keras_core_export(["keras_core.Loss", "keras_core.losses.Loss"])
class Loss:
"""Loss base class.
To be implemented by subclasses:
* `call()`: Contains the logic for loss calculation using `y_true`,
`y_pred`.
Example subclass implementation:
```python
class MeanSquaredError(Loss):
def call(self, y_true, y_pred):
return ops.mean(ops.square(y_pred - y_true), axis=-1)
```
"""
def __init__(self, name=None, reduction="sum_over_batch_size"):
self.name = name or auto_name(self.__class__.__name__)
self.reduction = standardize_reduction(reduction)

@ -59,7 +59,7 @@ class Metric:
self.true_positives = self.add_variable(
shape=(),
initializer='zeros',
name='tp'
name='true_positives'
)
def update_state(self, y_true, y_pred, sample_weight=None):

@ -15,17 +15,75 @@ from keras_core.utils import tracking
class Functional(Function, Model):
"""
Add support for extra call arguments compared to Function:
training, masks
"""A `Functional` model is a `Model` defined as a directed graph of layers.
Add support for arg standardization:
- list/dict duality
- upranking
Three types of `Model` exist: subclassed `Model`, `Functional` model,
and `Sequential` (a special case of `Functional`).
Override .layers
A `Functional` model can be instantiated by passing two arguments to
`__init__()`. The first argument is the `keras_core.Input` objects
that represent the inputs to the model.
The second argument specifies the output tensors that represent
the outputs of this model. Both arguments can be a nested structure
of tensors.
Symbolic add_loss
Example:
```
inputs = {'x1': keras_core.Input(shape=(10,)),
'x2': keras_core.Input(shape=(1,))}
t = keras_core.layers.Dense(1, activation='relu')(inputs['x1'])
outputs = keras_core.layers.Add()([t, inputs['x2'])
model = keras_core.Model(inputs, outputs)
```
A `Functional` model constructed using the Functional API can also
include raw Keras Core ops.
Example:
```python
inputs = keras_core.Input(shape=(10,))
x = keras_core.layers.Dense(1)(inputs)
outputs = ops.nn.relu(x)
model = keras_core.Model(inputs, outputs)
```
A new `Functional` model can also be created by using the
intermediate tensors. This enables you to quickly extract sub-components
of the model.
Example:
```python
inputs = keras_core.Input(shape=(None, None, 3))
processed = keras_core.layers.RandomCrop(width=32, height=32)(inputs)
conv = keras_core.layers.Conv2D(filters=2, kernel_size=3)(processed)
pooling = keras_core.layers.GlobalAveragePooling2D()(conv)
feature = keras_core.layers.Dense(10)(pooling)
full_model = keras_core.Model(inputs, feature)
backbone = keras_core.Model(processed, conv)
activations = keras_core.Model(conv, feature)
```
Note that the `backbone` and `activations` models are not
created with `keras_core.Input` objects, but with the tensors
that are originated from `keras_core.Input` objects.
Under the hood, the layers and weights will
be shared across these models, so that user can train the `full_model`, and
use `backbone` or `activations` to do feature extraction.
The inputs and outputs of the model can be nested structures of tensors as
well, and the created models are standard `Functional` model that support
all the existing API.
Args:
inputs: List of input tensors (must be created via `keras_core.Input()`
or originated from `keras_core.Input()`).
outputs: List of output tensors.
name: String, optional. Name of the model.
trainable: Boolean, optional. If the model's variables should be
trainable.
"""
@tracking.no_automatic_dependency_tracking

@ -26,19 +26,108 @@ else:
@keras_core_export(["keras_core.Model", "keras_core.models.Model"])
class Model(Trainer, Layer):
"""
"""A model grouping layers into an object with training/inference features.
Combination of a Layer and Trainer. Adds:
There are three ways to instantiate a `Model`:
- layer surfacing
- saving
- export
- summary
## With the "Functional API"
Limitations:
You start from `Input`,
you chain layer calls to specify the model's forward pass,
and finally you create your model from inputs and outputs:
- call must have a single inputs argument
- no masking support
```python
inputs = keras_core.Input(shape=(37,))
x = keras_core.layers.Dense(32, activation="relu")(inputs)
outputs = keras_core.layers.Dense(5, activation="softmax")(x)
model = keras_core.Model(inputs=inputs, outputs=outputs)
```
Note: Only dicts, lists, and tuples of input tensors are supported. Nested
inputs are not supported (e.g. lists of list or dicts of dict).
A new Functional API model can also be created by using the
intermediate tensors. This enables you to quickly extract sub-components
of the model.
Example:
```python
inputs = keras_core.Input(shape=(None, None, 3))
processed = keras_core.layers.RandomCrop(width=128, height=128)(inputs)
conv = keras_core.layers.Conv2D(filters=32, kernel_size=3)(processed)
pooling = keras_core.layers.GlobalAveragePooling2D()(conv)
feature = keras_core.layers.Dense(10)(pooling)
full_model = keras_core.Model(inputs, feature)
backbone = keras_core.Model(processed, conv)
activations = keras_core.Model(conv, feature)
```
Note that the `backbone` and `activations` models are not
created with `keras_core.Input` objects, but with the tensors that originate
from `keras_core.Input` objects. Under the hood, the layers and weights will
be shared across these models, so that user can train the `full_model`, and
use `backbone` or `activations` to do feature extraction.
The inputs and outputs of the model can be nested structures of tensors as
well, and the created models are standard Functional API models that support
all the existing APIs.
## By subclassing the `Model` class
In that case, you should define your
layers in `__init__()` and you should implement the model's forward pass
in `call()`.
```python
class MyModel(keras_core.Model):
def __init__(self):
super().__init__()
self.dense1 = keras_core.layers.Dense(32, activation="relu")
self.dense2 = keras_core.layers.Dense(5, activation="softmax")
def call(self, inputs):
x = self.dense1(inputs)
return self.dense2(x)
model = MyModel()
```
If you subclass `Model`, you can optionally have
a `training` argument (boolean) in `call()`, which you can use to specify
a different behavior in training and inference:
```python
class MyModel(keras_core.Model):
def __init__(self):
super().__init__()
self.dense1 = keras_core.layers.Dense(32, activation="relu")
self.dense2 = keras_core.layers.Dense(5, activation="softmax")
self.dropout = keras_core.layers.Dropout(0.5)
def call(self, inputs, training=False):
x = self.dense1(inputs)
x = self.dropout(x, training=training)
return self.dense2(x)
model = MyModel()
```
Once the model is created, you can config the model with losses and metrics
with `model.compile()`, train the model with `model.fit()`, or use the model
to do prediction with `model.predict()`.
## With the `Sequential` class
In addition, `keras_core.Sequential` is a special case of model where
the model is purely a stack of single-input, single-output layers.
```python
model = keras_core.Sequential([
keras_core.Input(shape=(None, None, 3)),
keras_core.layers.Conv2D(filters=32, kernel_size=3),
])
```
"""
def __new__(cls, *args, **kwargs):

@ -12,6 +12,51 @@ from keras_core.utils import tracking
@keras_core_export(["keras_core.Sequential", "keras_core.models.Sequential"])
class Sequential(Model):
"""`Sequential` groups a linear stack of layers into a `Model`.
Examples:
```python
model = keras_core.Sequential()
model.add(keras_core.Input(shape=(16,)))
model.add(keras_core.layers.Dense(8))
# Note that you can also omit the initial `Input`.
# In that case the model doesn't have any weights until the first call
# to a training/evaluation method (since it isn't yet built):
model = keras_core.Sequential()
model.add(keras_core.layers.Dense(8))
model.add(keras_core.layers.Dense(4))
# model.weights not created yet
# Whereas if you specify an `Input`, the model gets built
# continuously as you are adding layers:
model = keras_core.Sequential()
model.add(keras_core.Input(shape=(16,)))
model.add(keras_core.layers.Dense(8))
len(model.weights) # Returns "2"
# When using the delayed-build pattern (no input shape specified), you can
# choose to manually build your model by calling
# `build(batch_input_shape)`:
model = keras_core.Sequential()
model.add(keras_core.layers.Dense(8))
model.add(keras_core.layers.Dense(4))
model.build((None, 16))
len(model.weights) # Returns "4"
# Note that when using the delayed-build pattern (no input shape specified),
# the model gets built the first time you call `fit`, `eval`, or `predict`,
# or the first time you call the model on some input data.
model = keras_core.Sequential()
model.add(keras_core.layers.Dense(8))
model.add(keras_core.layers.Dense(1))
model.compile(optimizer='sgd', loss='mse')
# This builds the model for the first time:
model.fit(x, y, batch_size=32, epochs=10)
```
"""
@tracking.no_automatic_dependency_tracking
def __init__(self, layers=None, trainable=True, name=None):
super().__init__(trainable=trainable, name=name)