From 4a790cba53c037e6ffa6617fe86f68a79325f164 Mon Sep 17 00:00:00 2001 From: Chitsing KUI Date: Tue, 11 Apr 2023 10:01:52 +0800 Subject: [PATCH] fix c_embedding bug (#52742) --- paddle/fluid/operators/collective/c_embedding_op.cu | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/paddle/fluid/operators/collective/c_embedding_op.cu b/paddle/fluid/operators/collective/c_embedding_op.cu index b44aaf74e49..8b521580c5c 100644 --- a/paddle/fluid/operators/collective/c_embedding_op.cu +++ b/paddle/fluid/operators/collective/c_embedding_op.cu @@ -19,6 +19,8 @@ limitations under the License. */ #include "paddle/fluid/platform/float16.h" #include "paddle/phi/backends/gpu/gpu_primitives.h" +DECLARE_bool(cudnn_deterministic); + namespace paddle { namespace operators { @@ -164,6 +166,10 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel { t.device(*dev_ctx.eigen_device()) = t.constant(static_cast(0)); const auto &index_type = framework::TransToProtoVarType(ids_t->dtype()); + if (FLAGS_cudnn_deterministic) { + VLOG(2) << "Run grad kernel of embedding with single thread."; + blocks = 1; + } if (index_type == framework::proto::VarType::INT32) { CEmbeddingGrad <<>>(d_table, -- GitLab