未验证 提交 32fe5a49 编写于 作者: H HydrogenSulfate 提交者: GitHub

cherry pick CrossEntropy's bug fix (#36647)

上级 d2be870a
......@@ -1175,6 +1175,56 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_2d_with_weight_axis_change_mean(self):
input_np = np.random.random(size=(2, 3, 2, 2)).astype(self.dtype) #NCHW
label_np = np.random.randint(
0, 3, size=(2, 2, 2)).astype(np.int64) #NHW
weight_np = np.random.random(size=(3, )).astype(self.dtype) #C
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(
name='input', shape=[2, 3, 2, 2], dtype=self.dtype)
label = fluid.data(name='label', shape=[2, 2, 2], dtype='int64')
weight = fluid.data(name='weight', shape=[3], dtype=self.dtype)
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=weight, reduction='mean', axis=1)
# specify the class channels to axis 1
ret = cross_entropy_loss(input, label)
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
"weight": weight_np
},
fetch_list=[ret])
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
cross_entropy_loss = paddle.nn.loss.CrossEntropyLoss(
weight=fluid.dygraph.to_variable(weight_np),
reduction='mean',
axis=1)
dy_ret = cross_entropy_loss(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np))
dy_ret_value = dy_ret.numpy()
self.assertIsNotNone(dy_ret_value)
expected = cross_entropy_loss_2d(
np.transpose(input_np, [0, 2, 3, 1]),
label_np,
weight=weight_np,
reduction='mean')[0]
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_2d_with_weight_mean_ignore_exceedlabel(self):
N = 4
C = 3
......
......@@ -1668,12 +1668,13 @@ def cross_entropy(input,
format(invalid_label[0], 0))
# TODO: Temporarily use paddle.nonzero instead of paddle.max
# to detect and find out possible illegal label values
if len(paddle.nonzero(valid_label >= input.shape[-1])) > 0:
if len(paddle.nonzero(valid_label >= input.shape[axis])) > 0:
invalid_label = paddle.gather_nd(
valid_label, paddle.nonzero(valid_label >= input.shape[-1]))
valid_label,
paddle.nonzero(valid_label >= input.shape[axis]))
raise ValueError(
"Target({}) is out of class_dimension's upper bound({})".
format(invalid_label[0], input.shape[-1] - 1))
format(invalid_label[0], input.shape[axis] - 1))
_, out = _C_ops.softmax_with_cross_entropy(
input, label, 'soft_label', soft_label, 'ignore_index',
......@@ -1700,19 +1701,28 @@ def cross_entropy(input,
out = _C_ops.elementwise_mul(out, weight_gather_reshape)
else:
if input.shape[-1] != weight.shape[-1]:
if input.shape[axis] != weight.shape[-1]:
raise ValueError(
"input's class_dimension({}) must equal to \
weight's class_dimension({}) \
when weight is provided"
.format(input.shape[-1], weight.shape[-1]))
"input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided"\
.format(input.shape[axis], weight.shape[-1]))
ignore_weight_mask = paddle.cast((label != ignore_index),
out.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1:
ignore_weight_mask.squeeze_(-1)
weight_gather = _C_ops.gather_nd(weight, valid_label)
axis] == 1:
# TODO: Temporarily use squeeze instead of squeeze_
ignore_weight_mask = paddle.squeeze(ignore_weight_mask,
axis)
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1) , valid_label.ndim)) \
+ [axis % valid_label.ndim]
weight_gather = _C_ops.gather_nd(
weight, valid_label.transpose(temp_perm))
else:
weight_gather = _C_ops.gather_nd(weight, valid_label)
weight_gather = _C_ops.elementwise_mul(weight_gather,
ignore_weight_mask)
input_shape = list(label.shape)
......@@ -1807,20 +1817,27 @@ def cross_entropy(input,
weight_gather_reshape = reshape(weight_gather, shape=out_shape)
out = paddle.cast(out, weight_gather_reshape.dtype)
else:
if input.shape[-1] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to "\
"weight's class_dimension({}) "\
"when weight is provided"
.format(input.shape[-1], weight.shape[-1]))
if input.shape[axis] != weight.shape[-1]:
raise ValueError("input's class_dimension({}) must equal to "
"weight's class_dimension({}) "
"when weight is provided"\
.format(input.shape[axis], weight.shape[-1]))
valid_label = paddle.where(label == ignore_index,
paddle.zeros_like(label), label)
ignore_weight_mask = paddle.cast((label != ignore_index),
input.dtype)
if ignore_weight_mask.ndim > 1 and ignore_weight_mask.shape[
-1] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, -1)
weight_gather = paddle.gather_nd(weight, valid_label)
axis] == 1:
ignore_weight_mask = paddle.squeeze(ignore_weight_mask, axis)
if axis != -1 and axis != valid_label.ndim - 1:
temp_perm = list(range(axis % valid_label.ndim)) \
+ list(range((axis % valid_label.ndim + 1), valid_label.ndim)) \
+ [axis % valid_label.ndim]
weight_gather = paddle.gather_nd(
weight, paddle.transpose(valid_label, temp_perm))
else:
weight_gather = paddle.gather_nd(weight, valid_label)
weight_gather = paddle.multiply(weight_gather, ignore_weight_mask)
input_shape = list(label.shape)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册