Bux fixes
This commit is contained in:
parent
f7e73c8b9a
commit
429c2d67a4
@ -205,35 +205,3 @@ def block_update(inputs, start_indices, updates):
|
||||
]
|
||||
inputs[tuple(slices)] = updates
|
||||
return inputs
|
||||
|
||||
|
||||
def while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=None,
|
||||
):
|
||||
loop_vars = tuple(loop_vars)
|
||||
if maximum_iterations is not None:
|
||||
current_iter = 0
|
||||
loop_vars = loop_vars + (current_iter,)
|
||||
|
||||
# Unpack list/tuple args. The last argument is `current_iter`.
|
||||
def _cond(args):
|
||||
return cond(*args[:-1]) & (args[-1] < maximum_iterations)
|
||||
|
||||
def _body(args):
|
||||
return tuple(body(*args[:-1])) + (args[-1] + 1,)
|
||||
|
||||
else:
|
||||
|
||||
def _cond(args):
|
||||
return cond(*args)
|
||||
|
||||
def _body(args):
|
||||
return tuple(body(*args))
|
||||
|
||||
outputs = jax.lax.while_loop(_cond, _body, loop_vars)
|
||||
if maximum_iterations is not None:
|
||||
outputs = outputs[:-1]
|
||||
return outputs
|
||||
|
@ -111,17 +111,3 @@ def scatter_update(inputs, indices, updates):
|
||||
|
||||
def block_update(inputs, start_indices, updates):
|
||||
return dynamic_update_slice(inputs, updates, start_indices)
|
||||
|
||||
|
||||
def while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=None,
|
||||
):
|
||||
return tf.while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
@ -130,23 +130,3 @@ def block_update(inputs, start_indices, updates):
|
||||
]
|
||||
inputs[slices] = updates
|
||||
return inputs
|
||||
|
||||
|
||||
def while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=None,
|
||||
):
|
||||
current_iter = 0
|
||||
iteration_check = (
|
||||
lambda iter: maximum_iterations is None or iter < maximum_iterations
|
||||
)
|
||||
loop_vars = tuple([convert_to_tensor(v) for v in loop_vars])
|
||||
while cond(*loop_vars) and iteration_check(current_iter):
|
||||
loop_vars = body(*loop_vars)
|
||||
if not isinstance(loop_vars, (list, tuple)):
|
||||
loop_vars = (loop_vars,)
|
||||
loop_vars = tuple(loop_vars)
|
||||
current_iter += 1
|
||||
return loop_vars
|
||||
|
@ -881,6 +881,14 @@ class BinaryCrossentropyTest(testing.TestCase):
|
||||
self.assertAlmostEqual(loss, 0.0)
|
||||
|
||||
def test_unweighted(self):
|
||||
y_true = np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1]], dtype="float32")
|
||||
y_pred = np.array(
|
||||
[[0.9, 0.1, 0.2], [0.3, 0.8, 0.1], [0.1, 0.2, 0.7]], dtype="float32"
|
||||
)
|
||||
bce_obj = losses.BinaryCrossentropy()
|
||||
loss = bce_obj(y_true, y_pred)
|
||||
self.assertAllClose(loss, 0.20046903)
|
||||
|
||||
y_true = np.array([1, 0, 1, 0]).reshape([2, 2])
|
||||
y_pred = np.array([1, 1, 1, 0], dtype=np.float32).reshape([2, 2])
|
||||
bce_obj = losses.BinaryCrossentropy()
|
||||
|
@ -57,12 +57,15 @@ class Accuracy(reduction_metrics.MeanMetricWrapper):
|
||||
return {"name": self.name, "dtype": self.dtype}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.binary_accuracy")
|
||||
def binary_accuracy(y_true, y_pred, threshold=0.5):
|
||||
y_true = ops.convert_to_tensor(y_true)
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true, y_pred = squeeze_to_same_rank(y_true, y_pred)
|
||||
threshold = ops.cast(threshold, y_pred.dtype)
|
||||
y_pred = ops.cast(y_pred > threshold, y_pred.dtype)
|
||||
y_pred = ops.cast(y_pred > threshold, y_true.dtype)
|
||||
return ops.mean(
|
||||
ops.cast(ops.equal(y_true, y_pred), backend.floatx()),
|
||||
ops.cast(ops.equal(y_true, y_pred), dtype=backend.floatx()),
|
||||
axis=-1,
|
||||
)
|
||||
|
||||
@ -114,12 +117,14 @@ class BinaryAccuracy(reduction_metrics.MeanMetricWrapper):
|
||||
return {"name": self.name, "dtype": self.dtype}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.categorical_accuracy")
|
||||
def categorical_accuracy(y_true, y_pred):
|
||||
y_true = ops.argmax(y_true, axis=-1)
|
||||
|
||||
reshape_matches = False
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
y_true = ops.convert_to_tensor(y_true, dtype=y_true.dtype)
|
||||
|
||||
y_true_org_shape = ops.shape(y_true)
|
||||
y_pred_rank = len(y_pred.shape)
|
||||
y_true_rank = len(y_true.shape)
|
||||
@ -198,6 +203,7 @@ class CategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
||||
return {"name": self.name, "dtype": self.dtype}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.sparse_categorical_accuracy")
|
||||
def sparse_categorical_accuracy(y_true, y_pred):
|
||||
reshape_matches = False
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
@ -281,6 +287,7 @@ class SparseCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
||||
return {"name": self.name, "dtype": self.dtype}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.top_k_categorical_accuracy")
|
||||
def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
reshape_matches = False
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
@ -357,6 +364,7 @@ class TopKCategoricalAccuracy(reduction_metrics.MeanMetricWrapper):
|
||||
return {"name": self.name, "dtype": self.dtype, "k": self.k}
|
||||
|
||||
|
||||
@keras_core_export("keras_core.metrics.sparse_top_k_categorical_accuracy")
|
||||
def sparse_top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
reshape_matches = False
|
||||
y_pred = ops.convert_to_tensor(y_pred)
|
||||
|
@ -126,63 +126,3 @@ def block_update(inputs, start_indices, updates):
|
||||
if any_symbolic_tensors((inputs, start_indices, updates)):
|
||||
return BlockUpdate().symbolic_call(inputs, start_indices, updates)
|
||||
return backend.core.block_update(inputs, start_indices, updates)
|
||||
|
||||
|
||||
class WhileLoop(Operation):
|
||||
def __init__(self, cond, body, maximum_iterations):
|
||||
super().__init__()
|
||||
self.cond = cond
|
||||
self.body = body
|
||||
self.maximum_iterations = maximum_iterations
|
||||
|
||||
def call(self, loop_vars):
|
||||
return backend.core.while_loop(
|
||||
self.cond,
|
||||
self.body,
|
||||
loop_vars,
|
||||
maximum_iterations=self.maximum_iterations,
|
||||
)
|
||||
|
||||
def compute_output_spec(self, loop_vars):
|
||||
return [KerasTensor(v.shape, dtype=v.dtype) for v in loop_vars]
|
||||
|
||||
|
||||
@keras_core_export("keras_core.operations.while_loop")
|
||||
def while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=None,
|
||||
):
|
||||
"""While loop implemetation.
|
||||
|
||||
Args:
|
||||
cond: A callable that represents the termination condition of the loop.
|
||||
Must have the same number of args as `loop_vars`, and return a bool.
|
||||
body: A callable that represents the loop body. Must have the same
|
||||
number of args as `loop_vars`, and return a list/tuple of the same
|
||||
length, shape and dtype as `loop_vars`.
|
||||
loop_vars: A list/tuple of tensors, the loop variables.
|
||||
maximum_iterations: Optional maximum number of iterations of the while
|
||||
loop to run. If provided, the `cond` output is AND-ed with an
|
||||
additional condition ensuring the number of iterations executed is
|
||||
no greater than `maximum_iterations`.
|
||||
|
||||
Returns:
|
||||
A list/tuple of tensors, has the same shape and dtype as `inputs`.
|
||||
|
||||
Examples:
|
||||
|
||||
>>> i = 0
|
||||
>>> cond = lambda i: i < 10
|
||||
>>> body = lambda i: i + 1
|
||||
>>> keras_core.operations.while_loop(cond, body, [i])[0]
|
||||
10
|
||||
"""
|
||||
|
||||
return backend.core.while_loop(
|
||||
cond,
|
||||
body,
|
||||
loop_vars,
|
||||
maximum_iterations=maximum_iterations,
|
||||
)
|
||||
|
@ -139,22 +139,3 @@ class CoreOpsCorrectnessTest(testing.TestCase):
|
||||
updates = np.zeros([2, 2, 2, 2])
|
||||
outputs = core.block_update(inputs, start_indices, updates)
|
||||
self.assertAllClose(outputs[1:3, 1:3, 2:4, 2:4], np.zeros([2, 2, 2, 2]))
|
||||
|
||||
def test_while_loop(self):
|
||||
def cond(x, y):
|
||||
return x[0, 0] < 10
|
||||
|
||||
def body(x, y):
|
||||
return x + 1, y + 1
|
||||
|
||||
x = np.ones((2, 3))
|
||||
y = np.ones((3, 2))
|
||||
x, y = core.while_loop(cond, body, (x, y))
|
||||
self.assertAllClose(x, np.ones((2, 3)) * 10)
|
||||
self.assertAllClose(y, np.ones((3, 2)) * 10)
|
||||
|
||||
x = np.ones((2, 3))
|
||||
y = np.ones((3, 2))
|
||||
x, y = core.while_loop(cond, body, (x, y), maximum_iterations=5)
|
||||
self.assertAllClose(x, np.ones((2, 3)) * 6)
|
||||
self.assertAllClose(y, np.ones((3, 2)) * 6)
|
||||
|
Loading…
Reference in New Issue
Block a user