未验证 提交 cdeffff4 编写于 作者: Z zhiboniu 提交者: GitHub

fix gpt2 train loss Nan problem by add a line __syncthreads in BlockReduceSum (#33659)

上级 18043ab5
...@@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) { ...@@ -42,6 +42,7 @@ __forceinline__ __device__ T blockReduceSum(T val) {
int wid = threadIdx.x / warpSize; int wid = threadIdx.x / warpSize;
val = warpReduceSum(val); val = warpReduceSum(val);
__syncthreads();
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
__syncthreads(); __syncthreads();
......
...@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) { ...@@ -64,17 +64,16 @@ static __forceinline__ __device__ U WarpReduceSum(U val) {
} }
template <typename U> template <typename U>
__forceinline__ __device__ U BlockReduceSum(U val) { __forceinline__ __device__ U BlockReduceSum(U val, U *shared) {
static __shared__ U shared[32];
int lane = threadIdx.x % warpSize; int lane = threadIdx.x % warpSize;
int wid = threadIdx.x / warpSize; int wid = threadIdx.x / warpSize;
val = WarpReduceSum(val); // Each warp performs partial reduction val = WarpReduceSum(val); // Each warp performs partial reduction
__syncthreads();
if (lane == 0) shared[wid] = val; // Write reduced value to shared memory if (lane == 0) shared[wid] = val; // Write reduced value to shared memory
__syncthreads(); // Wait for all partial reductions __syncthreads(); // Wait for all partial reductions
// read from shared memory only if that warp existed // read from shared memory only if that warp existed
val = val =
(threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0); (threadIdx.x < blockDim.x / warpSize) ? shared[lane] : static_cast<U>(0);
...@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -183,6 +182,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
int64_t feature_size) { int64_t feature_size) {
__shared__ U mean_share; __shared__ U mean_share;
__shared__ U var_share; __shared__ U var_share;
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x; int64_t beg_idx = blockIdx.x * feature_size + threadIdx.x;
int64_t end_idx = (blockIdx.x + 1) * feature_size; int64_t end_idx = (blockIdx.x + 1) * feature_size;
...@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias, ...@@ -196,8 +197,8 @@ __global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
var_val += (tmp * tmp); var_val += (tmp * tmp);
} }
mean_val = BlockReduceSum<U>(mean_val); mean_val = BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val); var_val = BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto scale = static_cast<float>(1.) / static_cast<float>(feature_size); auto scale = static_cast<float>(1.) / static_cast<float>(feature_size);
...@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll( ...@@ -541,8 +542,10 @@ __global__ void LayerNormBackwardGradientAll(
} }
} }
d_scale_partial = BlockReduceSum<U>(d_scale_partial); __shared__ U shared_scale[32];
d_bias_partial = BlockReduceSum<U>(d_bias_partial); __shared__ U shared_bias[32];
d_scale_partial = BlockReduceSum<U>(d_scale_partial, shared_scale);
d_bias_partial = BlockReduceSum<U>(d_bias_partial, shared_bias);
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = d_scale_partial; d_scale[blockIdx.x + col_offset] = d_scale_partial;
......
...@@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) { ...@@ -188,6 +188,7 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
val = warpReduceSum<T>(val, mask); val = warpReduceSum<T>(val, mask);
__syncthreads();
if (lane == 0) shared[wid] = val; if (lane == 0) shared[wid] = val;
__syncthreads(); __syncthreads();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册