未验证 提交 0c40d889 编写于 作者: L Li Min 提交者: GitHub

add determine action for embed_grad and index_add. (#46040)

上级 54a43981
...@@ -23,6 +23,8 @@ ...@@ -23,6 +23,8 @@
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/embedding_util.h" #include "paddle/phi/kernels/funcs/embedding_util.h"
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
template <typename InT, typename OutT> template <typename InT, typename OutT>
...@@ -101,6 +103,11 @@ struct EmbeddingGradCUDAFunctor { ...@@ -101,6 +103,11 @@ struct EmbeddingGradCUDAFunctor {
const int gridx = 2 * dev_ctx_.GetSMCount(); const int gridx = 2 * dev_ctx_.GetSMCount();
dim3 threads(128, 8); dim3 threads(128, 8);
dim3 grids(gridx, 1); dim3 grids(gridx, 1);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of embedding with single thread.";
grids.x = 1;
}
EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>( EmbeddingGrad<T, IdT><<<grids, threads, 0, dev_ctx_.stream()>>>(
d_table, d_output, ids, N, K, D); d_table, d_output, ids, N, K, D);
} }
......
...@@ -20,6 +20,8 @@ ...@@ -20,6 +20,8 @@
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h" #include "paddle/phi/core/utils/data_type.h"
DECLARE_bool(cudnn_deterministic);
namespace phi { namespace phi {
using paddle::platform::PADDLE_CUDA_NUM_THREADS; using paddle::platform::PADDLE_CUDA_NUM_THREADS;
...@@ -79,6 +81,12 @@ void IndexAddKernel(const Context& ctx, ...@@ -79,6 +81,12 @@ void IndexAddKernel(const Context& ctx,
// todo(@limin29): inplace do not need copy. // todo(@limin29): inplace do not need copy.
phi::Copy(ctx, x, ctx.GetPlace(), false, output); phi::Copy(ctx, x, ctx.GetPlace(), false, output);
if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_add with single thread.";
block_dim = 1;
grid_dim.x = 1;
}
if (index_type == phi::DataType::INT64) { if (index_type == phi::DataType::INT64) {
const int64_t* index_data = index.data<int64_t>(); const int64_t* index_data = index.data<int64_t>();
index_add_cuda_kernel<T, int64_t> index_add_cuda_kernel<T, int64_t>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册