未验证 提交 5329187d 编写于 作者: X xiongkun 提交者: GitHub

fix the thread number to ensure deterministic of embedding kernel (#48073)

上级 04dcb9d7
...@@ -107,6 +107,7 @@ struct EmbeddingGradCUDAFunctor { ...@@ -107,6 +107,7 @@ struct EmbeddingGradCUDAFunctor {
if (FLAGS_cudnn_deterministic) { if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread."; VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1; grids.x = 1;
threads.y = 1;
} }
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>( EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册