未验证 提交 4a790cba 编写于 作者: C Chitsing KUI 提交者: GitHub

fix c_embedding bug (#52742)

上级 94a8177f
......@@ -19,6 +19,8 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
DECLARE_bool(cudnn_deterministic);
namespace paddle {
namespace operators {
......@@ -164,6 +166,10 @@ class CEmbeddingGradCUDAKernel : public framework::OpKernel<T> {
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));
const auto &index_type = framework::TransToProtoVarType(ids_t->dtype());
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
blocks = 1;
}
if (index_type == framework::proto::VarType::INT32) {
CEmbeddingGrad<T, int32_t>
<<<blocks, threads, 0, dev_ctx.stream()>>>(d_table,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册