From 1d660eb6767b990f8a5760e7b766a880f88d2d03 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Mon, 11 Oct 2021 17:42:25 +0800 Subject: [PATCH] Fix the bug when axis is specified and weight is provided --- .../unittests/test_cross_entropy_loss.py | 48 +++++++++++++++++++ python/paddle/nn/functional/loss.py | 46 +++++++++++------- 2 files changed, 78 insertions(+), 16 deletions(-) diff --git a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py index d2eae1cce5..6a0d955040 100644 --- a/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py +++ b/python/paddle/fluid/tests/unittests/test_cross_entropy_loss.py @@ -1175,6 +1175,54 @@ class CrossEntropyLoss(unittest.TestCase): self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self): + input_np = np.random.random(size=(2, 3, 2, 2)).astype(self.dtype) #NCHW + label_np = np.random.randint( + 0, 3, size=(2, 2, 2)).astype(np.int64) #NHW + weight_np = np.random.random(size=(3, )).astype(self.dtype) #C + + paddle.enable_static() + prog = fluid.Program() + startup_prog = fluid.Program() + place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda( + ) else fluid.CPUPlace() + with fluid.program_guard(prog, startup_prog): + input = fluid.data( + name='input', shape=[2, 3, 2, 2], dtype=self.dtype) + label = fluid.data(name='label', shape=[2, 2, 2], dtype='int64') + weight = fluid.data(name='weight', shape=[3], dtype=self.dtype) + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=weight, reduction='mean', axis=1) + # specify the class channels to axis 1 + ret = cross_entropy_loss(input, label) + + exe = fluid.Executor(place) + static_ret = exe.run(prog, + feed={ + 'input': input_np, + 'label': label_np, + "weight": weight_np + }, + fetch_list=[ret]) + + self.assertIsNotNone(static_ret) + with fluid.dygraph.guard(): + cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss( + weight=fluid.dygraph.to_variable(weight_np), reduction='mean') + dy_ret = cross_entropy_loss( + fluid.dygraph.to_variable(input_np), + fluid.dygraph.to_variable(label_np)) + dy_ret_value = dy_ret.numpy() + self.assertIsNotNone(dy_ret_value) + expected = cross_entropy_loss_2d( + np.transpose(input_np, [0, 2, 3, 1]), + label_np, + weight=weight_np, + reduction='mean')[0] + self.assertTrue(np.allclose(static_ret, dy_ret_value)) + self.assertTrue(np.allclose(static_ret, expected)) + self.assertTrue(np.allclose(dy_ret_value, expected)) + def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self): N = 4 C = 3 diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index da2d010c32..f4e8711a23 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1700,19 +1700,26 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: - if input.shape[-1] != weight.shape[-1]: + if input.shape[axis] != weight.shape[-1]: raise ValueError( - "input's class_dimension({}) must equal to \ - weight's class_dimension({}) \ - when weight is provided" - .format(input.shape[-1], weight.shape[-1])) + "input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided"\ + .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ - -1] == 1: - ignore_weight_mask.squeeze_(-1) - weight_gather = _C_ops.gather_nd(weight, valid_label) + axis] == 1: + ignore_weight_mask.squeeze_(axis) + if axis != -1: + temp_perm = list(range(axis % valid_label.ndim)) \ + + list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \ + + [axis%valid_label.ndim] + weight_gather = _C_ops.gather_nd( + weight, valid_label.transpose(temp_perm)) + else: + weight_gather = _C_ops.gather_nd(weight, valid_label) weight_gather = _C_ops.elementwise_mul(weight_gather, ignore_weight_mask) input_shape = list(label.shape) @@ -1807,20 +1814,27 @@ def cross_entropy(input, weight_gather_reshape = reshape(weight_gather, shape=out_shape) out = paddle.cast(out, weight_gather_reshape.dtype) else: - if input.shape[-1] != weight.shape[-1]: - raise ValueError("input's class_dimension({}) must equal to "\ - "weight's class_dimension({}) "\ - "when weight is provided" - .format(input.shape[-1], weight.shape[-1])) + if input.shape[axis] != weight.shape[-1]: + raise ValueError("input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided"\ + .format(input.shape[axis], weight.shape[-1])) valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) ignore_weight_mask = paddle.cast((label != ignore_index), input.dtype) if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ - -1] == 1: - ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) - weight_gather = paddle.gather_nd(weight, valid_label) + axis] == 1: + ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis) + if axis != -1: + temp_perm = list(range(axis % valid_label.ndim)) \ + + list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \ + + [axis % valid_label.ndim] + weight_gather = paddle.gather_nd( + weight, paddle.transpose(valid_label, temp_perm)) + else: + weight_gather = paddle.gather_nd(weight, valid_label) weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) input_shape = list(label.shape) -- GitLab