From 1e3e17df33654a2f291bbf3daf9e2d40b07dd967 Mon Sep 17 00:00:00 2001 From: HydrogenSulfate <490868991@qq.com> Date: Sun, 26 Dec 2021 13:46:08 +0800 Subject: [PATCH] Remove the labels range check under the dynamic graph --- python/paddle/nn/functional/loss.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/python/paddle/nn/functional/loss.py b/python/paddle/nn/functional/loss.py index 8eb6e05fc04..f13f14cdde1 100755 --- a/python/paddle/nn/functional/loss.py +++ b/python/paddle/nn/functional/loss.py @@ -1696,6 +1696,13 @@ def cross_entropy(input, out = _C_ops.elementwise_mul(out, weight_gather_reshape) else: + if input.shape[axis] != weight.shape[-1]: + raise ValueError( + "input's class_dimension({}) must equal to " + "weight's class_dimension({}) " + "when weight is provided" \ + .format(input.shape[axis], weight.shape[-1])) + valid_label = paddle.where(label == ignore_index, paddle.zeros_like(label), label) # TODO: Temporarily use paddle.nonzero instead of paddle.max @@ -1715,12 +1722,6 @@ def cross_entropy(input, raise ValueError( "Target({}) is out of class_dimension's upper bound({})". format(invalid_label[0], input.shape[axis] - 1)) - if input.shape[axis] != weight.shape[-1]: - raise ValueError( - "input's class_dimension({}) must equal to " - "weight's class_dimension({}) " - "when weight is provided" \ - .format(input.shape[axis], weight.shape[-1])) ignore_weight_mask = paddle.cast((label != ignore_index), out.dtype) -- GitLab