未验证 提交 2ad1e4c7 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] disable EmbeddingDenseGrad temporarily (#34498)

上级 87148a5c
......@@ -65,8 +65,9 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
int embedding_dim = table_grad_t->dims()[1];
/* EmbeddingDenseGrad has bug on large shape, temporarily disable it.
int embedding_dim = table_grad_t->dims()[1];
if (embedding_dim % 32 == 0) {
// NOTE(pangyoki): The embedding_dim of Tensor used in
// EmbeddingDenseGrad must be an integer multiple of 32.
......@@ -77,7 +78,10 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{"padding_idx", -1},
{"scale_grad_by_freq", false}});
runner.Run(stream);
} else {
return;
}
*/
const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream);
......@@ -90,7 +94,6 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream);
}
}
};
} // namespace operators
} // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册