提交 e5cf807e 编写于 作者: P Pavithra Vijay 提交者: TensorFlower Gardener

Automated rollback of commit 11b33344. Revert #30053.

PiperOrigin-RevId: 258813889
上级 4f16da51
......@@ -4328,7 +4328,6 @@ def binary_crossentropy(target, output, from_logits=False):
if not from_logits:
if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
output.op.type != 'Sigmoid'):
target.get_shape().assert_is_compatible_with(output.get_shape())
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
......
......@@ -793,7 +793,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
with self.cached_session():
with distribution.scope():
input_img = keras.layers.Input([64, 64, 3], name='img')
input_lbl = keras.layers.Input([64, 64, 2], name='lbl')
input_lbl = keras.layers.Input([64, 64, 1], name='lbl')
input_weight = keras.layers.Input([64, 64], name='weight')
predict = keras.layers.Conv2D(2, [1, 1], padding='same')(input_img)
loss_lambda = keras.layers.Lambda(
......@@ -811,7 +811,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
return inputs, targets
fake_imgs = np.ones([50, 64, 64, 3], dtype=np.float32)
fake_lbls = np.ones([50, 64, 64, 2], dtype=np.float32)
fake_lbls = np.ones([50, 64, 64, 1], dtype=np.float32)
fake_weights = np.ones([50, 64, 64], dtype=np.float32)
data = dataset_ops.Dataset.from_tensor_slices(
......
......@@ -825,18 +825,6 @@ class BinaryCrossentropyTest(test.TestCase):
expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
def test_shape_mismatch(self):
y_true = np.array([[1.], [1.], [1.], [0.], [1.], [0.], [0.], [1.], [1.],
[0.]]).astype(np.float32)
y_pred = np.array([[0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [0.],
[1.]]).astype(np.float32)
bce_obj = keras.losses.BinaryCrossentropy()
loss = bce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), 9.23662, 3)
with self.assertRaisesRegexp(ValueError, 'Shapes .+ are incompatible'):
loss = bce_obj(np.squeeze(y_true), y_pred)
@test_util.run_all_in_graph_and_eager_modes
class CategoricalCrossentropyTest(test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册