From 5b86b9996001a7365511c09a6cba9cd5003df167 Mon Sep 17 00:00:00 2001 From: Qi Li Date: Fri, 13 Aug 2021 14:17:25 +0800 Subject: [PATCH] [NPU] fix bce_loss_npu, test=develop (#34876) --- paddle/fluid/operators/bce_loss_op_npu.cc | 7 ++++--- .../paddle/fluid/tests/unittests/npu/test_bce_loss_npu.py | 4 ++-- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/fluid/operators/bce_loss_op_npu.cc b/paddle/fluid/operators/bce_loss_op_npu.cc index f6b0f7b3fbb..3136c02af41 100644 --- a/paddle/fluid/operators/bce_loss_op_npu.cc +++ b/paddle/fluid/operators/bce_loss_op_npu.cc @@ -34,8 +34,9 @@ class BCELossNPUKernel : public framework::OpKernel { ctx.template device_context() .stream(); - const auto& runner = NpuOpRunner("BinaryCrossEntropy", {*x, *labels}, - {*out}, {{"reduction", "none"}}); + const auto& runner = + NpuOpRunner("BinaryCrossEntropy", {*x, *labels}, {*out}, + {{"reduction", static_cast("none")}}); runner.Run(stream); } }; @@ -57,7 +58,7 @@ class BCELossGradNPUKernel : public framework::OpKernel { const auto& runner = NpuOpRunner("BinaryCrossEntropyGrad", {*x, *labels, *dout}, {*dx}, - {{"reduction", "none"}}); + {{"reduction", static_cast("none")}}); runner.Run(stream); } }; diff --git a/python/paddle/fluid/tests/unittests/npu/test_bce_loss_npu.py b/python/paddle/fluid/tests/unittests/npu/test_bce_loss_npu.py index 16db9525334..7c3d32647ae 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_bce_loss_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_bce_loss_npu.py @@ -96,7 +96,7 @@ def test_dygraph_layer(place, label_np, reduction='mean', weight_np=None): - paddle.disable_static() + paddle.disable_static(place) if weight_np is not None: weight = paddle.to_tensor(weight_np) bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction) @@ -113,7 +113,7 @@ def test_dygraph_functional(place, label_np, reduction='mean', weight_np=None): - paddle.disable_static() + paddle.disable_static(place) input = paddle.to_tensor(input_np) label = paddle.to_tensor(label_np) -- GitLab