From 5720537e84ef38c2c5c94786839491700edd64db Mon Sep 17 00:00:00 2001 From: crystal <62974595+Zjq9409@users.noreply.github.com> Date: Mon, 14 Mar 2022 19:39:24 +0800 Subject: [PATCH] optimize group_norm op backward (#39944) * optimize backwad * optimize group_norm backward * Add vectorized code * move assignment code * merge function * move code * optimize code * Modify function name --- paddle/fluid/operators/group_norm_op.cc | 4 + paddle/fluid/operators/group_norm_op.cu | 367 +++++++++++++++++++----- 2 files changed, 299 insertions(+), 72 deletions(-) diff --git a/paddle/fluid/operators/group_norm_op.cc b/paddle/fluid/operators/group_norm_op.cc index 2d284fb516e..4331523d26e 100644 --- a/paddle/fluid/operators/group_norm_op.cc +++ b/paddle/fluid/operators/group_norm_op.cc @@ -167,9 +167,11 @@ class GroupNormGradOp : public framework::OperatorWithKernel { void InferShape(framework::InferShapeContext *ctx) const override { // check input + OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNormGrad"); OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "GroupNormGrad"); OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance", "GroupNormGrad"); + OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "GroupNormGrad"); OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input", framework::GradVarName("Y"), "GroupNormGrad"); @@ -216,10 +218,12 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker { void Apply(GradOpPtr op) const override { op->SetType("group_norm_grad"); + op->SetInput("X", this->Input("X")); op->SetInput("Scale", this->Input("Scale")); op->SetInput("Bias", this->Input("Bias")); op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y")); op->SetInput("Y", this->Output("Y")); + op->SetInput("Mean", this->Output("Mean")); op->SetInput("Variance", this->Output("Variance")); op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); diff --git a/paddle/fluid/operators/group_norm_op.cu b/paddle/fluid/operators/group_norm_op.cu index b376334f1e9..ab8c50d90b8 100644 --- a/paddle/fluid/operators/group_norm_op.cu +++ b/paddle/fluid/operators/group_norm_op.cu @@ -81,46 +81,74 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W, CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var); } -template -__device__ __forceinline__ void ThreadReduce(const T* input, int size, - const int offset, AccT* mean, - AccT* var) { +template +__device__ __forceinline__ void ThreadReduce(phi::Array arrs, + int size, const int offset, + AccT* out_mean, AccT* out_var) { + const T* x = arrs[0]; + const T* y; + if (Num == 2) { + y = arrs[1]; + } using VecT = kps::details::VectorType; int tid = threadIdx.x; if (offset > 0) { - input -= offset; + x -= offset; + if (Num == 2) { + y -= offset; + } size += offset; if (tid >= offset) { - AccT temp = input[tid]; - *mean += temp; - *var += temp * temp; + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } } size -= blockDim.x; - input += blockDim.x; + x += blockDim.x; + if (Num == 2) { + y += blockDim.x; + } } int remain = size % (VecSize * blockDim.x); - T ins[VecSize]; - VecT* ins_vec = reinterpret_cast(&ins); + T ins_x[VecSize]; + T ins_y[VecSize]; + VecT* ins_vec_x = reinterpret_cast(&ins_x); + VecT* ins_vec_y = reinterpret_cast(&ins_y); // vector part for (; VecSize * tid < (size - remain); tid += blockDim.x) { - *ins_vec = reinterpret_cast(input)[tid]; + *ins_vec_x = reinterpret_cast(x)[tid]; + if (Num == 2) { + *ins_vec_y = reinterpret_cast(y)[tid]; + } #pragma unroll for (int i = 0; i < VecSize; ++i) { - AccT temp = ins[i]; - *mean += temp; - *var += temp * temp; + if (Num == 1) { + *out_mean += ins_x[i]; + *out_var += ins_x[i] * ins_x[i]; + } else if (Num == 2) { + *out_mean += ins_y[i]; + *out_var += ins_y[i] * ins_x[i]; + } } } // scalar part tid = size - remain + threadIdx.x; for (; tid < size; tid += blockDim.x) { - AccT temp = input[tid]; - *mean += temp; - *var += temp * temp; + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } } } @@ -148,7 +176,10 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var, AccT x_var = static_cast(0); const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); x += i * size; - ThreadReduce(x, size, input_offset, &x_mean, &x_var); + phi::Array ins; + ins[0] = x; + ThreadReduce(ins, size, input_offset, &x_mean, &x_var); + x_mean = kps::details::BlockXReduce>( x_mean, kps::AddFunctor()); x_var = kps::details::BlockXReduce>( @@ -310,10 +341,12 @@ class GroupNormKernel }; template -__global__ void GroupNormBackwardGetMeanAndVar( - const T* x, const T* scale, const T* bias, const T* d_y, int N, int C, - int W, int imsize, int groups, int group_size, T epsilon, T* d_mean, - T* d_var, T* d_scale, T* d_bias, const DataLayout data_layout) { +__global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale, + const T* bias, const T* d_y, + int N, int C, int W, int imsize, + int groups, int group_size, + T epsilon, T* d_mean, T* d_var, + T* d_scale, T* d_bias) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; @@ -329,15 +362,11 @@ __global__ void GroupNormBackwardGetMeanAndVar( for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { T val, dval; - if (data_layout == DataLayout::kNCHW) { - val = x[(bid * C + ccid) * imsize + imid] - x_bias; - dval = d_y[(bid * C + ccid) * imsize + imid]; - } else { - int hid = imid / W; - int wid = imid % W; - val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias; - dval = d_y[(bid * H + hid) * W * C + wid * C + ccid]; - } + + int hid = imid / W; + int wid = imid % W; + val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias; + dval = d_y[(bid * H + hid) * W * C + wid * C + ccid]; d_var_data += val * dval; d_mean_data += dval * x_scale; @@ -357,8 +386,7 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, const T* bias, const T* var, const T* d_mean, const T* d_var, int N, int C, int W, int imsize, int groups, int group_size, - T epsilon, T* d_x, - const DataLayout data_layout) { + T epsilon, T* d_x) { int gid = blockIdx.y; int cid = blockIdx.x; int bid = blockIdx.z; @@ -379,26 +407,142 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale, if (x_scale != 0) x_scale_inv = 1.0 / x_scale; for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { - if (data_layout == DataLayout::kNCHW) { - T tmp = x[(bid * C + ccid) * imsize + imid]; - T v_y = (tmp - x_bias) * x_scale_inv; - T dly = d_y[(bid * C + ccid) * imsize + imid]; - d_x[(bid * C + ccid) * imsize + imid] = - x_var_inv * - (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); - } else { - int hid = imid / W; - int wid = imid % W; - T tmp = x[(bid * H + hid) * W * C + wid * C + ccid]; - T v_y = (tmp - x_bias) * x_scale_inv; - T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid]; - d_x[(bid * H + hid) * W * C + wid * C + ccid] = - x_var_inv * - (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + int hid = imid / W; + int wid = imid % W; + T tmp = x[(bid * H + hid) * W * C + wid * C + ccid]; + T v_y = (tmp - x_bias) * x_scale_inv; + T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid]; + d_x[(bid * H + hid) * W * C + wid * C + ccid] = + x_var_inv * + (dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean); + } +} + +template +__global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, + T* ds, T* db) { + int i = blockIdx.x; + AccT ds_sum = static_cast(0); + AccT db_sum = static_cast(0); + const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T); + x += i * imsize; + + phi::Array ins; + ins[0] = x; + ins[1] = dy; + ThreadReduce(ins, imsize, input_offset, &db_sum, + &ds_sum); + + ds_sum = kps::details::BlockXReduce>( + ds_sum, kps::AddFunctor()); + db_sum = kps::details::BlockXReduce>( + db_sum, kps::AddFunctor()); + __syncthreads(); + if (threadIdx.x == 0) { + ds[i] = ds_sum; + db[i] = db_sum; + } +} + +template +__global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy, + T* ds, T* db) { + const int nc = blockIdx.x; + T ds_sum = 0; + T db_sum = 0; + for (int i = threadIdx.x; i < imsize; i += blockDim.x) { + const int index = nc * imsize + i; + ds_sum += dy[index] * x[index]; + db_sum += dy[index]; + } + CudaAtomicAddWithWarp(&ds[nc], ds_sum); + CudaAtomicAddWithWarp(&db[nc], db_sum); +} + +template +__global__ void GetScaleBiasGradientCUDAKernel(int N, int C, int group, + T epsilon, const T* mean, + const T* var, const T* ds, + const T* db, T* d_scale, + T* d_bias) { + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < C) { + const int G = group; + const int D = C / G; + T sum1 = 0; + T sum2 = 0; + for (int n = 0; n < N; ++n) { + const int nc = n * C + c; + const int ng = n * G + c / D; + sum1 += (d_scale == nullptr) + ? T(0) + : ((ds[nc] - db[nc] * static_cast(mean[ng])) * + static_cast(rsqrt(var[ng] + epsilon))); + sum2 += (d_bias == nullptr) ? T(0) : db[nc]; + } + if (d_scale != nullptr) { + d_scale[c] = sum1; + } + if (d_bias != nullptr) { + d_bias[c] = sum2; } } } +template +__global__ void GetBackwardParamsCUDAKernel(int imsize, int groups, + int group_size, T epsilon, + const T* mean, const T* var, + const T* scale, const T* ds, + const T* db, T* p1, T* p2, T* p3) { + const int n = blockIdx.x; + const int g = blockIdx.y; + const int ng = n * groups + g; + T sum1 = 0; + T sum2 = 0; + T var_inv = rsqrt(var[ng] + epsilon); + for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { + const int64_t index = ng * group_size + i; + const int64_t c = g * group_size + i; + const T scale_v = scale == nullptr ? T(1) : static_cast(scale[c]); + sum1 += ds[index] * scale_v; + sum2 += db[index] * scale_v; + const T scale_c = scale == nullptr ? T(0) : static_cast(scale[c]); + p1[index] = scale_c * var_inv; + } + + typedef cub::BlockReduce BlockReduce; + __shared__ typename BlockReduce::TempStorage ds_storage; + __shared__ typename BlockReduce::TempStorage db_storage; + sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum()); + sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum()); + + if (threadIdx.x == 0) { + const T s = T(1) / static_cast(group_size * imsize); + const T x = (sum2 * static_cast(mean[ng]) - sum1) * + static_cast(var_inv) * static_cast(var_inv) * + static_cast(var_inv) * s; + p2[ng] = x; + p3[ng] = -x * static_cast(mean[ng]) - sum2 * static_cast(var_inv) * s; + } +} + +template +__global__ void GetXGradientCUDAKernel(int imsize, int C, int group_size, + int groups, T* p1, T* p2, T* p3, + const T* x, const T* dy, T* dx) { + int cid = blockIdx.x; + int gid = blockIdx.y; + int bid = blockIdx.z; + int ccid = bid * C + gid * group_size + cid; + int ng = bid * groups + gid; + int nc = gid * group_size + cid; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + int index = (bid * C + nc) * imsize + imid; + dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; + } +} + template class GroupNormGradKernel : public framework::OpKernel { @@ -408,7 +552,9 @@ class GroupNormGradKernel const DataLayout data_layout = framework::StringToDataLayout(data_layout_str); const float epsilon = ctx.Attr("epsilon"); - auto* x = ctx.Input("Y"); + auto* x = ctx.Input("X"); + auto* y = ctx.Input("Y"); + auto* mean = ctx.Input("Mean"); auto* var = ctx.Input("Variance"); auto* scale = ctx.Input("Scale"); auto* bias = ctx.Input("Bias"); @@ -433,31 +579,27 @@ class GroupNormGradKernel phi::funcs::SetConstant set_zero; auto& dev_ctx = ctx.template device_context(); - Tensor temp_var; - temp_var.mutable_data(var->dims(), ctx.GetPlace()); - set_zero(dev_ctx, &temp_var, static_cast(0)); - T* temp_var_data = temp_var.data(); - - Tensor temp_mean; - temp_mean.mutable_data(var->dims(), ctx.GetPlace()); - set_zero(dev_ctx, &temp_mean, static_cast(0)); - T* temp_mean_data = temp_mean.data(); + Tensor ds, db; + ds.mutable_data({x_dims[0], C}, ctx.GetPlace()); + db.mutable_data({x_dims[0], C}, ctx.GetPlace()); + T* ds_data = ds.data(); + T* db_data = db.data(); + auto* y_data = y->data(); auto* x_data = x->data(); T* d_x_data = nullptr; if (d_x) d_x_data = d_x->data(); - auto* y_data = d_y->data(); + auto* dy_data = d_y->data(); auto* var_data = var->data(); + auto* mean_data = mean->data(); T* d_scale_data = nullptr; if (d_scale) { d_scale->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, d_scale, static_cast(0)); d_scale_data = d_scale->data(); } T* d_bias_data = nullptr; if (d_bias) { d_bias->mutable_data(ctx.GetPlace()); - set_zero(dev_ctx, d_bias, static_cast(0)); d_bias_data = d_bias->data(); } @@ -479,22 +621,103 @@ class GroupNormGradKernel #ifdef __HIPCC__ int block_size = std::max(std::min(256, imsize), 64); + const int block_dims = 256; #else int block_size = std::min(1024, imsize); + const int block_dims = 1024; #endif dim3 grid(group_size, groups, x_dims[0]); dim3 threads(block_size, 1, 1); int flags = (scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias; - UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data, - bias_data, y_data, x_dims[0], C, W, imsize, groups, - group_size, epsilon, temp_mean_data, temp_var_data, - d_scale_data, d_bias_data, data_layout); - if (d_x_data != nullptr) { - UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data, - bias_data, var_data, temp_mean_data, temp_var_data, - x_dims[0], C, W, imsize, groups, group_size, epsilon, - d_x_data, data_layout); + if (data_layout == DataLayout::kNCHW) { + using AccT = typename details::MPTypeTrait::Type; + constexpr int vec_size = sizeof(float4) / sizeof(T); + const int max_num_threads = 1024; + int max_block_size = std::min(imsize / vec_size, max_num_threads); + int block_size_nchw = 1; + while (block_size_nchw < max_block_size) { + block_size_nchw *= 2; + } + block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize); + dim3 blocks(block_size_nchw); + if (imsize < vec_size) { + if (d_scale) { + set_zero(dev_ctx, d_scale, static_cast(0)); + } + if (d_bias) { + set_zero(dev_ctx, d_bias, static_cast(0)); + } + ScalarGetDsDbCUDAKernel< + T><<>>( + imsize, x_data, dy_data, ds_data, db_data); + } else { + VectorizedGetDsDbCUDAKernel< + T, AccT, vec_size><<>>( + imsize, x_data, dy_data, ds_data, db_data); + } + + if (d_scale || d_bias) { + const int block = 256; + GetScaleBiasGradientCUDAKernel< + T><<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>( + x_dims[0], C, groups, epsilon, mean_data, var_data, ds_data, + db_data, d_scale_data, d_bias_data); + } + + if (d_x_data != nullptr) { + // p1 * dy + p2 * x + p3, + // p1, p2, p3 represent the reverse calculation of temporary variables + // p1 = scale * var_inv + // p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n) + // p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n); + Tensor p1, p2, p3; + p1.mutable_data({x_dims[0] * C}, ctx.GetPlace()); + p2.mutable_data({x_dims[0], groups}, ctx.GetPlace()); + p3.mutable_data({x_dims[0], groups}, ctx.GetPlace()); + T* p1_data = p1.data(); + T* p2_data = p2.data(); + T* p3_data = p3.data(); + + GetBackwardParamsCUDAKernel<<< + dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>( + imsize, groups, group_size, epsilon, mean_data, var_data, + scale_data, ds_data, db_data, p1_data, p2_data, p3_data); + GetXGradientCUDAKernel<<>>( + imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data, + dy_data, d_x_data); + } + + } else { + if (d_scale) { + set_zero(dev_ctx, d_scale, static_cast(0)); + } + if (d_bias) { + set_zero(dev_ctx, d_bias, static_cast(0)); + } + + Tensor temp_var; + temp_var.mutable_data(var->dims(), ctx.GetPlace()); + set_zero(dev_ctx, &temp_var, static_cast(0)); + T* temp_var_data = temp_var.data(); + + Tensor temp_mean; + temp_mean.mutable_data(var->dims(), ctx.GetPlace()); + set_zero(dev_ctx, &temp_mean, static_cast(0)); + T* temp_mean_data = temp_mean.data(); + + int flags = (scale_data != nullptr) * kHasScale + + (bias_data != nullptr) * kHasBias; + UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, y_data, + scale_data, bias_data, dy_data, x_dims[0], C, W, imsize, + groups, group_size, epsilon, temp_mean_data, + temp_var_data, d_scale_data, d_bias_data); + if (d_x_data != nullptr) { + UNROLL_ALL_CASES(flags, GroupNormBackward, y_data, dy_data, scale_data, + bias_data, var_data, temp_mean_data, temp_var_data, + x_dims[0], C, W, imsize, groups, group_size, epsilon, + d_x_data); + } } } }; -- GitLab