未验证 提交 5b86b999 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix bce_loss_npu, test=develop (#34876)

上级 17a99760
......@@ -34,8 +34,9 @@ class BCELossNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
const auto& runner = NpuOpRunner("BinaryCrossEntropy", {*x, *labels},
{*out}, {{"reduction", "none"}});
const auto& runner =
NpuOpRunner("BinaryCrossEntropy", {*x, *labels}, {*out},
{{"reduction", static_cast<std::string>("none")}});
runner.Run(stream);
}
};
......@@ -57,7 +58,7 @@ class BCELossGradNPUKernel : public framework::OpKernel<T> {
const auto& runner =
NpuOpRunner("BinaryCrossEntropyGrad", {*x, *labels, *dout}, {*dx},
{{"reduction", "none"}});
{{"reduction", static_cast<std::string>("none")}});
runner.Run(stream);
}
};
......
......@@ -96,7 +96,7 @@ def test_dygraph_layer(place,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
paddle.disable_static(place)
if weight_np is not None:
weight = paddle.to_tensor(weight_np)
bce_loss = paddle.nn.loss.BCELoss(weight=weight, reduction=reduction)
......@@ -113,7 +113,7 @@ def test_dygraph_functional(place,
label_np,
reduction='mean',
weight_np=None):
paddle.disable_static()
paddle.disable_static(place)
input = paddle.to_tensor(input_np)
label = paddle.to_tensor(label_np)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册