Bux fixes

This commit is contained in:
Francois Chollet 2023-05-30 15:55:13 -07:00
parent f7e73c8b9a
commit 429c2d67a4
7 changed files with 18 additions and 147 deletions

@ -205,35 +205,3 @@ def block_update(inputs, start_indices, updates):
]
inputs[tuple(slices)] = updates
return inputs
def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
loop_vars = tuple(loop_vars)
if maximum_iterations is not None:
current_iter = 0
loop_vars = loop_vars + (current_iter,)
# Unpack list/tuple args. The last argument is `current_iter`.
def _cond(args):
return cond(*args[:-1]) & (args[-1] < maximum_iterations)
def _body(args):
return tuple(body(*args[:-1])) + (args[-1] + 1,)
else:
def _cond(args):
return cond(*args)
def _body(args):
return tuple(body(*args))
outputs = jax.lax.while_loop(_cond, _body, loop_vars)
if maximum_iterations is not None:
outputs = outputs[:-1]
return outputs

@ -111,17 +111,3 @@ def scatter_update(inputs, indices, updates):
def block_update(inputs, start_indices, updates):
return dynamic_update_slice(inputs, updates, start_indices)
def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
return tf.while_loop(
cond,
body,
loop_vars,
maximum_iterations=maximum_iterations,
)

@ -130,23 +130,3 @@ def block_update(inputs, start_indices, updates):
]
inputs[slices] = updates
return inputs
def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
current_iter = 0
iteration_check = (
lambda iter: maximum_iterations is None or iter < maximum_iterations
)
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
while cond(*loop_vars) and iteration_check(current_iter):
loop_vars = body(*loop_vars)
if not isinstance(loop_vars, (list, tuple)):
loop_vars = (loop_vars,)
loop_vars = tuple(loop_vars)
current_iter += 1
return loop_vars

@ -881,6 +881,14 @@ class BinaryCrossentropyTest(testing.TestCase):
self.assertAlmostEqual(loss, 0.0)
def test_unweighted(self):
y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32")
y_pred = np.array(
[[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype="float32"
)
bce_obj = losses.BinaryCrossentropy()
loss = bce_obj(y_true, y_pred)
self.assertAllClose(loss, 0.20046903)
y_true = np.array([1, 0, 1, 0]).reshape([2, 2])
y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2])
bce_obj = losses.BinaryCrossentropy()

@ -57,12 +57,15 @@ class Accuracy(reduction_metrics.MeanMetricWrapper):
return {"name": self.name, "dtype": self.dtype}
@keras_core_export("keras_core.metrics.binary_accuracy")
def binary_accuracy(y_true, y_pred, threshold=0.5):
y_true = ops.convert_to_tensor(y_true)
y_pred = ops.convert_to_tensor(y_pred)
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
threshold = ops.cast(threshold, y_pred.dtype)
y_pred = ops.cast(y_pred > threshold, y_pred.dtype)
y_pred = ops.cast(y_pred > threshold, y_true.dtype)
return ops.mean(
ops.cast(ops.equal(y_true, y_pred), backend.floatx()),
ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()),
axis=-1,
)
@ -114,12 +117,14 @@ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper):
return {"name": self.name, "dtype": self.dtype}
@keras_core_export("keras_core.metrics.categorical_accuracy")
def categorical_accuracy(y_true, y_pred):
y_true = ops.argmax(y_true, axis=-1)
reshape_matches = False
y_pred = ops.convert_to_tensor(y_pred)
y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype)
y_true_org_shape = ops.shape(y_true)
y_pred_rank = len(y_pred.shape)
y_true_rank = len(y_true.shape)
@ -198,6 +203,7 @@ class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
return {"name": self.name, "dtype": self.dtype}
@keras_core_export("keras_core.metrics.sparse_categorical_accuracy")
def sparse_categorical_accuracy(y_true, y_pred):
reshape_matches = False
y_pred = ops.convert_to_tensor(y_pred)
@ -281,6 +287,7 @@ class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
return {"name": self.name, "dtype": self.dtype}
@keras_core_export("keras_core.metrics.top_k_categorical_accuracy")
def top_k_categorical_accuracy(y_true, y_pred, k=5):
reshape_matches = False
y_pred = ops.convert_to_tensor(y_pred)
@ -357,6 +364,7 @@ class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
return {"name": self.name, "dtype": self.dtype, "k": self.k}
@keras_core_export("keras_core.metrics.sparse_top_k_categorical_accuracy")
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
reshape_matches = False
y_pred = ops.convert_to_tensor(y_pred)

@ -126,63 +126,3 @@ def block_update(inputs, start_indices, updates):
if any_symbolic_tensors((inputs, start_indices, updates)):
return BlockUpdate().symbolic_call(inputs, start_indices, updates)
return backend.core.block_update(inputs, start_indices, updates)
class WhileLoop(Operation):
def __init__(self, cond, body, maximum_iterations):
super().__init__()
self.cond = cond
self.body = body
self.maximum_iterations = maximum_iterations
def call(self, loop_vars):
return backend.core.while_loop(
self.cond,
self.body,
loop_vars,
maximum_iterations=self.maximum_iterations,
)
def compute_output_spec(self, loop_vars):
return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars]
@keras_core_export("keras_core.operations.while_loop")
def while_loop(
cond,
body,
loop_vars,
maximum_iterations=None,
):
"""While loop implemetation.
Args:
cond: A callable that represents the termination condition of the loop.
Must have the same number of args as `loop_vars`, and return a bool.
body: A callable that represents the loop body. Must have the same
number of args as `loop_vars`, and return a list/tuple of the same
length, shape and dtype as `loop_vars`.
loop_vars: A list/tuple of tensors, the loop variables.
maximum_iterations: Optional maximum number of iterations of the while
loop to run. If provided, the `cond` output is AND-ed with an
additional condition ensuring the number of iterations executed is
no greater than `maximum_iterations`.
Returns:
A list/tuple of tensors, has the same shape and dtype as `inputs`.
Examples:
>>> i = 0
>>> cond = lambda i: i < 10
>>> body = lambda i: i + 1
>>> keras_core.operations.while_loop(cond, body, [i])[0]
10
"""
return backend.core.while_loop(
cond,
body,
loop_vars,
maximum_iterations=maximum_iterations,
)

@ -139,22 +139,3 @@ class CoreOpsCorrectnessTest(testing.TestCase):
updates = np.zeros([2, 2, 2, 2])
outputs = core.block_update(inputs, start_indices, updates)
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))
def test_while_loop(self):
def cond(x, y):
return x[0, 0] < 10
def body(x, y):
return x + 1, y + 1
x = np.ones((2, 3))
y = np.ones((3, 2))
x, y = core.while_loop(cond, body, (x, y))
self.assertAllClose(x, np.ones((2, 3)) * 10)
self.assertAllClose(y, np.ones((3, 2)) * 10)
x = np.ones((2, 3))
y = np.ones((3, 2))
x, y = core.while_loop(cond, body, (x, y), maximum_iterations=5)
self.assertAllClose(x, np.ones((2, 3)) * 6)
self.assertAllClose(y, np.ones((3, 2)) * 6)