提交 1d660eb6 编写于 作者: H HydrogenSulfate 提交者: chajchaj

Fix the bug when axis is specified and weight is provided

上级 8cc7146d
...@@ -1175,6 +1175,54 @@ class CrossEntropyLoss(unittest.TestCase): ...@@ -1175,6 +1175,54 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(static_ret, expected)) self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected)) self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self):
input_np = np.random.random(size=(2, 3, 2, 2)).astype(self.dtype) #NCHW
label_np = np.random.randint(
0, 3, size=(2, 2, 2)).astype(np.int64) #NHW
weight_np = np.random.random(size=(3, )).astype(self.dtype) #C
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[2, 3, 2, 2], dtype=self.dtype)
label = fluid.data(name='label', shape=[2, 2, 2], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype=self.dtype)
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=weight, reduction='mean', axis=1)
# specify the class channels to axis 1
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
"weight": weight_np
},
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=fluid.dygraph.to_variable(weight_np), reduction='mean')
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)
expected = cross_entropy_loss_2d(
np.transpose(input_np, [0, 2, 3, 1]),
label_np,
weight=weight_np,
reduction='mean')[0]
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self): def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self):
N = 4 N = 4
C = 3 C = 3
......
...@@ -1700,18 +1700,25 @@ def cross_entropy(input, ...@@ -1700,18 +1700,25 @@ def cross_entropy(input,
out = _C_ops.elementwise_mul(out, weight_gather_reshape) out = _C_ops.elementwise_mul(out, weight_gather_reshape)
else: else:
if input.shape[-1] != weight.shape[-1]: if input.shape[axis] != weight.shape[-1]:
raise ValueError( raise ValueError(
"input's class_dimension({}) must equal to \ "input's class_dimension({}) must equal to "
weight's class_dimension({}) \ "weight's class_dimension({}) "
when weight is provided" "when weight is provided"\
.format(input.shape[-1], weight.shape[-1])) .format(input.shape[axis], weight.shape[-1]))
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype) out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1: axis] == 1:
ignore_weight_mask.squeeze_(-1) ignore_weight_mask.squeeze_(axis)
if axis != -1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \
+ [axis%valid_label.ndim]
weight_gather = _C_ops.gather_nd(
weight, valid_label.transpose(temp_perm))
else:
weight_gather = _C_ops.gather_nd(weight, valid_label) weight_gather = _C_ops.gather_nd(weight, valid_label)
weight_gather = _C_ops.elementwise_mul(weight_gather, weight_gather = _C_ops.elementwise_mul(weight_gather,
ignore_weight_mask) ignore_weight_mask)
...@@ -1807,19 +1814,26 @@ def cross_entropy(input, ...@@ -1807,19 +1814,26 @@ def cross_entropy(input,
weight_gather_reshape = reshape(weight_gather, shape=out_shape) weight_gather_reshape = reshape(weight_gather, shape=out_shape)
out = paddle.cast(out, weight_gather_reshape.dtype) out = paddle.cast(out, weight_gather_reshape.dtype)
else: else:
if input.shape[-1] != weight.shape[-1]: if input.shape[axis] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to "\ raise ValueError("input's class_dimension({}) must equal to "
"weight's class_dimension({}) "\ "weight's class_dimension({}) "
"when weight is provided" "when weight is provided"\
.format(input.shape[-1], weight.shape[-1])) .format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index, valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label) paddle.zeros_like(label), label)
ignore_weight_mask = paddle.cast((label != ignore_index), ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype) input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[ if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1: axis] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1) ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
if axis != -1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis + 1) % valid_label.ndim, valid_label.ndim)) \
+ [axis % valid_label.ndim]
weight_gather = paddle.gather_nd(
weight, paddle.transpose(valid_label, temp_perm))
else:
weight_gather = paddle.gather_nd(weight, valid_label) weight_gather = paddle.gather_nd(weight, valid_label)
weight_gather = paddle.multiply(weight_gather, ignore_weight_mask) weight_gather = paddle.multiply(weight_gather, ignore_weight_mask)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册