未验证 提交 28aa0c61 编写于 作者: Y Yuang Liu 提交者: GitHub

[DCU] Fix NAN problem when training BERT on DUC platform (#44643)

上级 e7c7280f
...@@ -166,7 +166,11 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place, ...@@ -166,7 +166,11 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
constexpr int kNumTensor = MaxTensorNumPerLaunch; constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch; constexpr int kNumChunk = MaxChunkNumPerLaunch;
#ifdef PADDLE_WITH_HIP
constexpr int kBlockDim = 256;
#else
constexpr int kBlockDim = 512; constexpr int kBlockDim = 512;
#endif
int max_chunk_num = -1; int max_chunk_num = -1;
int vec_size = 8; int vec_size = 8;
...@@ -805,7 +809,11 @@ static void MultiTensorUpdateLambParamAndBetaPows( ...@@ -805,7 +809,11 @@ static void MultiTensorUpdateLambParamAndBetaPows(
platform::errors::InvalidArgument("Beta2Pow should be nullptr.")); platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
} }
#ifdef PADDLE_WITH_HIP
const int block_dim = 256;
#else
const int block_dim = 512; const int block_dim = 512;
#endif
int vec_size = 8; int vec_size = 8;
for (int i = 0; i < n; ++i) { for (int i = 0; i < n; ++i) {
......
...@@ -134,7 +134,11 @@ __device__ T reduceSum(T val, int tid, int len) { ...@@ -134,7 +134,11 @@ __device__ T reduceSum(T val, int tid, int len) {
// I use Warp-Level Parallelism and assume the Warp size // I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU, // is 32 which may be different for different GPU,
// but most card's warp size is 32. // but most card's warp size is 32.
#ifdef PADDLE_WITH_HIP
const int warpSize = 64;
#else
const int warpSize = 32; const int warpSize = 32;
#endif
__shared__ T shm[warpSize]; __shared__ T shm[warpSize];
unsigned mask = 0u; unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len); CREATE_SHFL_MASK(mask, tid < len);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册