未验证 提交 d094cd02 编写于 作者: C chajchaj 提交者: GitHub

change shape of output in cross_entropy (#29414)

上级 401cc1e0
......@@ -219,6 +219,47 @@ class CrossEntropyLoss(unittest.TestCase):
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_1d_with_weight_none_func(self):
input_np = np.random.random([100, 200]).astype(np.float64) #N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N
weight_np = np.random.random([200]).astype(np.float64) #C
paddle.enable_static()
prog = fluid.Program()
startup_prog = fluid.Program()
place = fluid.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.CPUPlace()
with fluid.program_guard(prog, startup_prog):
input = fluid.data(name='input', shape=[100, 200], dtype='float64')
label = fluid.data(name='label', shape=[100], dtype='int64')
weight = fluid.data(name='weight', shape=[200], dtype='float64')
ret = paddle.nn.functional.cross_entropy(
input, label, weight=weight, reduction='none')
exe = fluid.Executor(place)
static_ret = exe.run(prog,
feed={
'input': input_np,
'label': label_np,
"weight": weight_np
},
fetch_list=[ret])
static_ret = np.squeeze(static_ret)
self.assertIsNotNone(static_ret)
with fluid.dygraph.guard():
dy_ret = paddle.nn.functional.cross_entropy(
fluid.dygraph.to_variable(input_np),
fluid.dygraph.to_variable(label_np),
weight=fluid.dygraph.to_variable(weight_np),
reduction='none')
dy_ret_value = dy_ret.numpy()
dy_ret_value = np.squeeze(dy_ret_value)
self.assertIsNotNone(dy_ret_value)
expected = cross_entropy_loss_1d(
input_np, label_np, weight=weight_np, reduction='none')
self.assertTrue(np.allclose(static_ret, dy_ret_value))
self.assertTrue(np.allclose(static_ret, expected))
self.assertTrue(np.allclose(dy_ret_value, expected))
def test_cross_entropy_loss_1d_mean(self):
input_np = np.random.random([100, 200]).astype(np.float64) #N,C
label_np = np.random.randint(0, 100, size=(100)).astype(np.int64) #N,1
......
......@@ -1236,6 +1236,8 @@ def cross_entropy(input,
else:
return core.ops.mean(out)
else:
if input_dims - 1 == label_dims:
out = paddle.squeeze(out, axis=axis)
return out
fluid.data_feeder.check_variable_and_dtype(
......@@ -1267,6 +1269,8 @@ def cross_entropy(input,
else:
return paddle.mean(out, name=name)
else:
if input_dims - 1 == label_dims:
out = paddle.squeeze(out, axis=axis)
return out
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册