未验证 提交 28aa0c61 编写于 作者: Y Yuang Liu 提交者: GitHub

[DCU] Fix NAN problem when training BERT on DUC platform (#44643)

上级 e7c7280f
......@@ -166,7 +166,11 @@ static void MultiTensorL2Norm(const platform::CUDAPlace &place,
constexpr int kNumTensor = MaxTensorNumPerLaunch;
constexpr int kNumChunk = MaxChunkNumPerLaunch;
#ifdef PADDLE_WITH_HIP
constexpr int kBlockDim = 256;
#else
constexpr int kBlockDim = 512;
#endif
int max_chunk_num = -1;
int vec_size = 8;
......@@ -805,7 +809,11 @@ static void MultiTensorUpdateLambParamAndBetaPows(
platform::errors::InvalidArgument("Beta2Pow should be nullptr."));
}
#ifdef PADDLE_WITH_HIP
const int block_dim = 256;
#else
const int block_dim = 512;
#endif
int vec_size = 8;
for (int i = 0; i < n; ++i) {
......
......@@ -134,7 +134,11 @@ __device__ T reduceSum(T val, int tid, int len) {
// I use Warp-Level Parallelism and assume the Warp size
// is 32 which may be different for different GPU,
// but most card's warp size is 32.
#ifdef PADDLE_WITH_HIP
const int warpSize = 64;
#else
const int warpSize = 32;
#endif
__shared__ T shm[warpSize];
unsigned mask = 0u;
CREATE_SHFL_MASK(mask, tid < len);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册