From 110613256898b2431654ab21cbd0ba869f99ec40 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Mon, 11 Oct 2021 12:17:21 +0800 Subject: [PATCH] [NPU] fix softmax_with_cross_entropy in dygraph, test=develop (#36297) --- .../operators/softmax_with_cross_entropy_op.cc | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc index 0c2d39e7519..78e813edda9 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op.cc @@ -13,10 +13,6 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/softmax_with_cross_entropy_op.h" -#include -#include -#include -#include #include "paddle/fluid/framework/op_version_registry.h" namespace paddle { @@ -54,8 +50,7 @@ class SoftmaxWithCrossEntropyOpMaker "exp(logits -max_logits) / sum(exp(logits - max_logits)) - labels, " "where labels is ont-hot." "Currently, the tensor is generated and used in npu kernel only. ") - .AsIntermediate() - .AsDispensable(); + .AsIntermediate(); #endif AddOutput("Loss", "(Tensor, default: Tensor), A tensor in same shape with " @@ -136,6 +131,11 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasOutput("Softmax"), true, platform::errors::InvalidArgument( "Output(Softmax) should be not null.")); +#ifdef PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_EQ(ctx->HasOutput("Backprop"), true, + platform::errors::InvalidArgument( + "Output(Backprop) should be not null.")); +#endif PADDLE_ENFORCE_EQ( ctx->HasOutput("Loss"), true, platform::errors::InvalidArgument("Output(Loss) should be not null.")); @@ -225,6 +225,11 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_EQ(ctx->HasInput("Softmax"), true, platform::errors::InvalidArgument( "Input(Softmax) should be not null.")); +#ifdef PADDLE_WITH_ASCEND_CL + PADDLE_ENFORCE_EQ(ctx->HasInput("Backprop"), true, + platform::errors::InvalidArgument( + "Input(Backprop) should be not null.")); +#endif PADDLE_ENFORCE_EQ( ctx->HasInput("Label"), true, platform::errors::InvalidArgument("Input(Label) should be not null.")); -- GitLab