提交 11b33344 编写于 作者: T TensorFlower Gardener

Merge pull request #30053 from yongtang:30040-BinaryCrossEntropy

PiperOrigin-RevId: 258598105
......@@ -4328,6 +4328,7 @@ def binary_crossentropy(target, output, from_logits=False):
if not from_logits:
if (isinstance(output, (ops.EagerTensor, variables_module.Variable)) or
output.op.type != 'Sigmoid'):
target.get_shape().assert_is_compatible_with(output.get_shape())
epsilon_ = _constant_to_tensor(epsilon(), output.dtype.base_dtype)
output = clip_ops.clip_by_value(output, epsilon_, 1. - epsilon_)
......
......@@ -846,7 +846,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
with self.cached_session():
with distribution.scope():
input_img = keras.layers.Input([64, 64, 3], name='img')
input_lbl = keras.layers.Input([64, 64, 1], name='lbl')
input_lbl = keras.layers.Input([64, 64, 2], name='lbl')
input_weight = keras.layers.Input([64, 64], name='weight')
predict = keras.layers.Conv2D(2, [1, 1], padding='same')(input_img)
loss_lambda = keras.layers.Lambda(
......@@ -864,7 +864,7 @@ class TestDistributionStrategyWithDatasets(test.TestCase,
return inputs, targets
fake_imgs = np.ones([50, 64, 64, 3], dtype=np.float32)
fake_lbls = np.ones([50, 64, 64, 1], dtype=np.float32)
fake_lbls = np.ones([50, 64, 64, 2], dtype=np.float32)
fake_weights = np.ones([50, 64, 64], dtype=np.float32)
data = dataset_ops.Dataset.from_tensor_slices(
......
......@@ -825,6 +825,18 @@ class BinaryCrossentropyTest(test.TestCase):
expected_value = (100.0 + 50.0 * label_smoothing) / 3.0
self.assertAlmostEqual(self.evaluate(loss), expected_value, 3)
def test_shape_mismatch(self):
y_true = np.array([[1.], [1.], [1.], [0.], [1.], [0.], [0.], [1.], [1.],
[0.]]).astype(np.float32)
y_pred = np.array([[0.], [0.], [0.], [1.], [1.], [0.], [0.], [1.], [0.],
[1.]]).astype(np.float32)
bce_obj = keras.losses.BinaryCrossentropy()
loss = bce_obj(y_true, y_pred)
self.assertAlmostEqual(self.evaluate(loss), 9.23662, 3)
with self.assertRaisesRegexp(ValueError, 'Shapes .+ are incompatible'):
loss = bce_obj(np.squeeze(y_true), y_pred)
@test_util.run_all_in_graph_and_eager_modes
class CategoricalCrossentropyTest(test.TestCase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册