未验证 提交 56692f66 编写于 作者: L Leo Chen 提交者: GitHub

fix bug when the cuda kernel config exceeds dims max (#33748)

上级 6801b6e2
...@@ -400,9 +400,9 @@ __global__ void LayerNormBackwardComputeGradInput( ...@@ -400,9 +400,9 @@ __global__ void LayerNormBackwardComputeGradInput(
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon, const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) { const U *gamma, T *grad_input) {
#ifdef __HIPCC__ #ifdef __HIPCC__
for (auto i1 = hipBlockIdx_y; i1 < n1; i1 += hipGridDim_y) { for (auto i1 = hipBlockIdx_x; i1 < n1; i1 += hipGridDim_x) {
#else #else
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) { for (auto i1 = blockIdx.x; i1 < n1; i1 += gridDim.x) {
#endif #endif
U sum_loss1 = U(0); U sum_loss1 = U(0);
U sum_loss2 = U(0); U sum_loss2 = U(0);
...@@ -869,9 +869,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale, ...@@ -869,9 +869,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
constexpr int BDIMX1 = 32; constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4; constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1); dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput< LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>( T, U, BDIMX1, BDIMY1><<<batch_size, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x); d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break; break;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册