未验证 提交 fb7590d4 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] refine lookup_table_v2_grad npu_kernel (#32497)

* use ZerosLike instead of NPUMemsetAsync

* fix compile
上级 136ef09d
......@@ -55,19 +55,19 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *ids_t = ctx.Input<framework::LoDTensor>("Ids");
auto *output_grad_t =
ctx.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *table_grad_t =
ctx.Output<framework::LoDTensor>(framework::GradVarName("W"));
auto *p = table_grad_t->mutable_data<T>(ctx.GetPlace());
table_grad_t->mutable_data<T>(ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
platform::NPUMemsetAsync(static_cast<void *>(p), 0,
table_grad_t->numel() * sizeof(T), stream);
auto runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);
// NOTE(zhiqiu): It seems in cann 20.1, the first input and output
// can be different tensor, but in cann 20.2+, it does inplace operation.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册