未验证 提交 2ad1e4c7 编写于 作者: L Leo Chen 提交者: GitHub

[NPU] disable EmbeddingDenseGrad temporarily (#34498)

上级 87148a5c
...@@ -65,8 +65,9 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -65,8 +65,9 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
int embedding_dim = table_grad_t->dims()[1]; /* EmbeddingDenseGrad has bug on large shape, temporarily disable it.
int embedding_dim = table_grad_t->dims()[1];
if (embedding_dim % 32 == 0) { if (embedding_dim % 32 == 0) {
// NOTE(pangyoki): The embedding_dim of Tensor used in // NOTE(pangyoki): The embedding_dim of Tensor used in
// EmbeddingDenseGrad must be an integer multiple of 32. // EmbeddingDenseGrad must be an integer multiple of 32.
...@@ -77,7 +78,10 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -77,7 +78,10 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{"padding_idx", -1}, {"padding_idx", -1},
{"scale_grad_by_freq", false}}); {"scale_grad_by_freq", false}});
runner.Run(stream); runner.Run(stream);
} else { return;
}
*/
const auto &runner_zeros = const auto &runner_zeros =
NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t}); NpuOpRunner("ZerosLike", {*table_grad_t}, {*table_grad_t});
runner_zeros.Run(stream); runner_zeros.Run(stream);
...@@ -90,7 +94,6 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> { ...@@ -90,7 +94,6 @@ class LookupTableV2GradNPUKernel : public framework::OpKernel<T> {
{*table_grad_t}, {{"use_locking", true}}); {*table_grad_t}, {{"use_locking", true}});
runner_scatter.Run(stream); runner_scatter.Run(stream);
} }
}
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册