未验证 提交 5b86b999 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix bce_loss_npu, test=develop (#34876)

上级 17a99760
...@@ -34,8 +34,9 @@ class BCELossNPUKernel : public framework::OpKernel<T> { ...@@ -34,8 +34,9 @@ class BCELossNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
const auto& runner = NpuOpRunner("BinaryCrossEntropy", {*x, *labels}, const auto& runner =
{*out}, {{"reduction", "none"}}); NpuOpRunner("BinaryCrossEntropy", {*x, *labels}, {*out},
{{"reduction", static_cast<std::string>("none")}});
runner.Run(stream); runner.Run(stream);
} }
}; };
...@@ -57,7 +58,7 @@ class BCELossGradNPUKernel : public framework::OpKernel<T> { ...@@ -57,7 +58,7 @@ class BCELossGradNPUKernel : public framework::OpKernel<T> {
const auto& runner = const auto& runner =
NpuOpRunner("BinaryCrossEntropyGrad", {*x, *labels, *dout}, {*dx}, NpuOpRunner("BinaryCrossEntropyGrad", {*x, *labels, *dout}, {*dx},
{{"reduction", "none"}}); {{"reduction", static_cast<std::string>("none")}});
runner.Run(stream); runner.Run(stream);
} }
}; };
......
...@@ -96,7 +96,7 @@ def test_dygraph_layer(place, ...@@ -96,7 +96,7 @@ def test_dygraph_layer(place,
label_np, label_np,
reduction='mean', reduction='mean',
weight_np=None): weight_np=None):
paddle.disable_static() paddle.disable_static(place)
if weight_np is not None: if weight_np is not None:
weight = paddle.to_tensor(weight_np) weight = paddle.to_tensor(weight_np)
bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction) bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction)
...@@ -113,7 +113,7 @@ def test_dygraph_functional(place, ...@@ -113,7 +113,7 @@ def test_dygraph_functional(place,
label_np, label_np,
reduction='mean', reduction='mean',
weight_np=None): weight_np=None):
paddle.disable_static() paddle.disable_static(place)
input = paddle.to_tensor(input_np) input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np) label = paddle.to_tensor(label_np)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册