未验证 提交 5720537e 编写于 作者: C crystal 提交者: GitHub

optimize group_norm op backward (#39944)

* optimize backwad

* optimize group_norm backward

* Add vectorized code

* move assignment code

* merge function

* move code

* optimize code

* Modify function name
上级 9e1f762c
......@@ -167,9 +167,11 @@ class GroupNormGradOp : public framework::OperatorWithKernel {
void InferShape(framework::InferShapeContext *ctx) const override {
// check input
OP_INOUT_CHECK(ctx->HasInput("X"), "Input", "X", "GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Y"), "Input", "Y", "GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Variance"), "Input", "Variance",
"GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput("Mean"), "Input", "Mean", "GroupNormGrad");
OP_INOUT_CHECK(ctx->HasInput(framework::GradVarName("Y")), "Input",
framework::GradVarName("Y"), "GroupNormGrad");
......@@ -216,10 +218,12 @@ class GroupNormGradMaker : public framework::SingleGradOpMaker<T> {
void Apply(GradOpPtr<T> op) const override {
op->SetType("group_norm_grad");
op->SetInput("X", this->Input("X"));
op->SetInput("Scale", this->Input("Scale"));
op->SetInput("Bias", this->Input("Bias"));
op->SetInput(framework::GradVarName("Y"), this->OutputGrad("Y"));
op->SetInput("Y", this->Output("Y"));
op->SetInput("Mean", this->Output("Mean"));
op->SetInput("Variance", this->Output("Variance"));
op->SetOutput(framework::GradVarName("X"), this->InputGrad("X"));
......
......@@ -81,46 +81,74 @@ __global__ void GroupNormForwardGetMeanAndVar(const T* x, int N, int C, int W,
CudaAtomicAddWithWarp(&var[bid * groups + gid], x_var);
}
template <typename T, typename AccT, int VecSize>
__device__ __forceinline__ void ThreadReduce(const T* input, int size,
const int offset, AccT* mean,
AccT* var) {
template <typename T, typename AccT, int VecSize, int Num>
__device__ __forceinline__ void ThreadReduce(phi::Array<const T*, Num> arrs,
int size, const int offset,
AccT* out_mean, AccT* out_var) {
const T* x = arrs[0];
const T* y;
if (Num == 2) {
y = arrs[1];
}
using VecT = kps::details::VectorType<T, VecSize>;
int tid = threadIdx.x;
if (offset > 0) {
input -= offset;
x -= offset;
if (Num == 2) {
y -= offset;
}
size += offset;
if (tid >= offset) {
AccT temp = input[tid];
*mean += temp;
*var += temp * temp;
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
size -= blockDim.x;
input += blockDim.x;
x += blockDim.x;
if (Num == 2) {
y += blockDim.x;
}
}
int remain = size % (VecSize * blockDim.x);
T ins[VecSize];
VecT* ins_vec = reinterpret_cast<VecT*>(&ins);
T ins_x[VecSize];
T ins_y[VecSize];
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x);
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec = reinterpret_cast<const VecT*>(input)[tid];
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid];
if (Num == 2) {
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid];
}
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
AccT temp = ins[i];
*mean += temp;
*var += temp * temp;
if (Num == 1) {
*out_mean += ins_x[i];
*out_var += ins_x[i] * ins_x[i];
} else if (Num == 2) {
*out_mean += ins_y[i];
*out_var += ins_y[i] * ins_x[i];
}
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
AccT temp = input[tid];
*mean += temp;
*var += temp * temp;
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
}
......@@ -148,7 +176,10 @@ __global__ void VectorizedGetMeanAndVarNCHW(const T* x, T* mean, T* var,
AccT x_var = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * size;
ThreadReduce<T, AccT, VecSize>(x, size, input_offset, &x_mean, &x_var);
phi::Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, AccT, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
x_mean = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
x_mean, kps::AddFunctor<AccT>());
x_var = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
......@@ -310,10 +341,12 @@ class GroupNormKernel<platform::CUDADeviceContext, T>
};
template <typename T, int flags>
__global__ void GroupNormBackwardGetMeanAndVar(
const T* x, const T* scale, const T* bias, const T* d_y, int N, int C,
int W, int imsize, int groups, int group_size, T epsilon, T* d_mean,
T* d_var, T* d_scale, T* d_bias, const DataLayout data_layout) {
__global__ void GroupNormBackwardGetMeanAndVar(const T* x, const T* scale,
const T* bias, const T* d_y,
int N, int C, int W, int imsize,
int groups, int group_size,
T epsilon, T* d_mean, T* d_var,
T* d_scale, T* d_bias) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
......@@ -329,15 +362,11 @@ __global__ void GroupNormBackwardGetMeanAndVar(
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val, dval;
if (data_layout == DataLayout::kNCHW) {
val = x[(bid * C + ccid) * imsize + imid] - x_bias;
dval = d_y[(bid * C + ccid) * imsize + imid];
} else {
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias;
dval = d_y[(bid * H + hid) * W * C + wid * C + ccid];
}
int hid = imid / W;
int wid = imid % W;
val = x[(bid * H + hid) * W * C + wid * C + ccid] - x_bias;
dval = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_var_data += val * dval;
d_mean_data += dval * x_scale;
......@@ -357,8 +386,7 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
const T* bias, const T* var, const T* d_mean,
const T* d_var, int N, int C, int W,
int imsize, int groups, int group_size,
T epsilon, T* d_x,
const DataLayout data_layout) {
T epsilon, T* d_x) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
......@@ -379,26 +407,142 @@ __global__ void GroupNormBackward(const T* x, const T* d_y, const T* scale,
if (x_scale != 0) x_scale_inv = 1.0 / x_scale;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
if (data_layout == DataLayout::kNCHW) {
T tmp = x[(bid * C + ccid) * imsize + imid];
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * C + ccid) * imsize + imid];
d_x[(bid * C + ccid) * imsize + imid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
} else {
int hid = imid / W;
int wid = imid % W;
T tmp = x[(bid * H + hid) * W * C + wid * C + ccid];
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_x[(bid * H + hid) * W * C + wid * C + ccid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
int hid = imid / W;
int wid = imid % W;
T tmp = x[(bid * H + hid) * W * C + wid * C + ccid];
T v_y = (tmp - x_bias) * x_scale_inv;
T dly = d_y[(bid * H + hid) * W * C + wid * C + ccid];
d_x[(bid * H + hid) * W * C + wid * C + ccid] =
x_var_inv *
(dly * x_scale - number_inv * d_x_var * v_y - number_inv * d_x_mean);
}
}
template <typename T, typename AccT, int VecSize>
__global__ void VectorizedGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
T* ds, T* db) {
int i = blockIdx.x;
AccT ds_sum = static_cast<AccT>(0);
AccT db_sum = static_cast<AccT>(0);
const int input_offset = ((uint64_t)x) % ALIGN_BYTES / sizeof(T);
x += i * imsize;
phi::Array<const T*, 2> ins;
ins[0] = x;
ins[1] = dy;
ThreadReduce<T, AccT, VecSize, 2>(ins, imsize, input_offset, &db_sum,
&ds_sum);
ds_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
ds_sum, kps::AddFunctor<AccT>());
db_sum = kps::details::BlockXReduce<AccT, kps::AddFunctor<AccT>>(
db_sum, kps::AddFunctor<AccT>());
__syncthreads();
if (threadIdx.x == 0) {
ds[i] = ds_sum;
db[i] = db_sum;
}
}
template <typename T>
__global__ void ScalarGetDsDbCUDAKernel(int imsize, const T* x, const T* dy,
T* ds, T* db) {
const int nc = blockIdx.x;
T ds_sum = 0;
T db_sum = 0;
for (int i = threadIdx.x; i < imsize; i += blockDim.x) {
const int index = nc * imsize + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
CudaAtomicAddWithWarp(&ds[nc], ds_sum);
CudaAtomicAddWithWarp(&db[nc], db_sum);
}
template <typename T>
__global__ void GetScaleBiasGradientCUDAKernel(int N, int C, int group,
T epsilon, const T* mean,
const T* var, const T* ds,
const T* db, T* d_scale,
T* d_bias) {
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) {
const int G = group;
const int D = C / G;
T sum1 = 0;
T sum2 = 0;
for (int n = 0; n < N; ++n) {
const int nc = n * C + c;
const int ng = n * G + c / D;
sum1 += (d_scale == nullptr)
? T(0)
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) *
static_cast<T>(rsqrt(var[ng] + epsilon)));
sum2 += (d_bias == nullptr) ? T(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = sum1;
}
if (d_bias != nullptr) {
d_bias[c] = sum2;
}
}
}
template <typename T, int BlockDim>
__global__ void GetBackwardParamsCUDAKernel(int imsize, int groups,
int group_size, T epsilon,
const T* mean, const T* var,
const T* scale, const T* ds,
const T* db, T* p1, T* p2, T* p3) {
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * groups + g;
T sum1 = 0;
T sum2 = 0;
T var_inv = rsqrt(var[ng] + epsilon);
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) {
const int64_t index = ng * group_size + i;
const int64_t c = g * group_size + i;
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
sum1 += ds[index] * scale_v;
sum2 += db[index] * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]);
p1[index] = scale_c * var_inv;
}
typedef cub::BlockReduce<T, BlockDim> BlockReduce;
__shared__ typename BlockReduce::TempStorage ds_storage;
__shared__ typename BlockReduce::TempStorage db_storage;
sum1 = BlockReduce(ds_storage).Reduce(sum1, cub::Sum());
sum2 = BlockReduce(db_storage).Reduce(sum2, cub::Sum());
if (threadIdx.x == 0) {
const T s = T(1) / static_cast<T>(group_size * imsize);
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) *
static_cast<T>(var_inv) * static_cast<T>(var_inv) *
static_cast<T>(var_inv) * s;
p2[ng] = x;
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s;
}
}
template <typename T>
__global__ void GetXGradientCUDAKernel(int imsize, int C, int group_size,
int groups, T* p1, T* p2, T* p3,
const T* x, const T* dy, T* dx) {
int cid = blockIdx.x;
int gid = blockIdx.y;
int bid = blockIdx.z;
int ccid = bid * C + gid * group_size + cid;
int ng = bid * groups + gid;
int nc = gid * group_size + cid;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int index = (bid * C + nc) * imsize + imid;
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng];
}
}
template <typename T>
class GroupNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
......@@ -408,7 +552,9 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
const DataLayout data_layout =
framework::StringToDataLayout(data_layout_str);
const float epsilon = ctx.Attr<float>("epsilon");
auto* x = ctx.Input<Tensor>("Y");
auto* x = ctx.Input<Tensor>("X");
auto* y = ctx.Input<Tensor>("Y");
auto* mean = ctx.Input<Tensor>("Mean");
auto* var = ctx.Input<Tensor>("Variance");
auto* scale = ctx.Input<Tensor>("Scale");
auto* bias = ctx.Input<Tensor>("Bias");
......@@ -433,31 +579,27 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
phi::funcs::SetConstant<platform::CUDADeviceContext, T> set_zero;
auto& dev_ctx = ctx.template device_context<platform::CUDADeviceContext>();
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
T* temp_var_data = temp_var.data<T>();
Tensor temp_mean;
temp_mean.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
T* temp_mean_data = temp_mean.data<T>();
Tensor ds, db;
ds.mutable_data<T>({x_dims[0], C}, ctx.GetPlace());
db.mutable_data<T>({x_dims[0], C}, ctx.GetPlace());
T* ds_data = ds.data<T>();
T* db_data = db.data<T>();
auto* y_data = y->data<T>();
auto* x_data = x->data<T>();
T* d_x_data = nullptr;
if (d_x) d_x_data = d_x->data<T>();
auto* y_data = d_y->data<T>();
auto* dy_data = d_y->data<T>();
auto* var_data = var->data<T>();
auto* mean_data = mean->data<T>();
T* d_scale_data = nullptr;
if (d_scale) {
d_scale->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_scale, static_cast<T>(0));
d_scale_data = d_scale->data<T>();
}
T* d_bias_data = nullptr;
if (d_bias) {
d_bias->mutable_data<T>(ctx.GetPlace());
set_zero(dev_ctx, d_bias, static_cast<T>(0));
d_bias_data = d_bias->data<T>();
}
......@@ -479,22 +621,103 @@ class GroupNormGradKernel<platform::CUDADeviceContext, T>
#ifdef __HIPCC__
int block_size = std::max(std::min(256, imsize), 64);
const int block_dims = 256;
#else
int block_size = std::min(1024, imsize);
const int block_dims = 1024;
#endif
dim3 grid(group_size, groups, x_dims[0]);
dim3 threads(block_size, 1, 1);
int flags =
(scale_data != nullptr) * kHasScale + (bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, x_data, scale_data,
bias_data, y_data, x_dims[0], C, W, imsize, groups,
group_size, epsilon, temp_mean_data, temp_var_data,
d_scale_data, d_bias_data, data_layout);
if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags, GroupNormBackward, x_data, y_data, scale_data,
bias_data, var_data, temp_mean_data, temp_var_data,
x_dims[0], C, W, imsize, groups, group_size, epsilon,
d_x_data, data_layout);
if (data_layout == DataLayout::kNCHW) {
using AccT = typename details::MPTypeTrait<T>::Type;
constexpr int vec_size = sizeof(float4) / sizeof(T);
const int max_num_threads = 1024;
int max_block_size = std::min(imsize / vec_size, max_num_threads);
int block_size_nchw = 1;
while (block_size_nchw < max_block_size) {
block_size_nchw *= 2;
}
block_size_nchw = std::max(block_size_nchw, kps::details::kWarpSize);
dim3 blocks(block_size_nchw);
if (imsize < vec_size) {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
ScalarGetDsDbCUDAKernel<
T><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
} else {
VectorizedGetDsDbCUDAKernel<
T, AccT, vec_size><<<x_dims[0] * C, blocks, 0, dev_ctx.stream()>>>(
imsize, x_data, dy_data, ds_data, db_data);
}
if (d_scale || d_bias) {
const int block = 256;
GetScaleBiasGradientCUDAKernel<
T><<<(C + block - 1) / block, block, 0, dev_ctx.stream()>>>(
x_dims[0], C, groups, epsilon, mean_data, var_data, ds_data,
db_data, d_scale_data, d_bias_data);
}
if (d_x_data != nullptr) {
// p1 * dy + p2 * x + p3,
// p1, p2, p3 represent the reverse calculation of temporary variables
// p1 = scale * var_inv
// p2 = (db * scale * mean - ds * scale) * pow(var_inv, 3) * (1/n)
// p3 = -p2 * mean[ng] - db * scale * var_inv * (1/n);
Tensor p1, p2, p3;
p1.mutable_data<T>({x_dims[0] * C}, ctx.GetPlace());
p2.mutable_data<T>({x_dims[0], groups}, ctx.GetPlace());
p3.mutable_data<T>({x_dims[0], groups}, ctx.GetPlace());
T* p1_data = p1.data<T>();
T* p2_data = p2.data<T>();
T* p3_data = p3.data<T>();
GetBackwardParamsCUDAKernel<T, block_dims><<<
dim3(x_dims[0], groups), block_dims, 0, dev_ctx.stream()>>>(
imsize, groups, group_size, epsilon, mean_data, var_data,
scale_data, ds_data, db_data, p1_data, p2_data, p3_data);
GetXGradientCUDAKernel<T><<<grid, threads, 0, dev_ctx.stream()>>>(
imsize, C, group_size, groups, p1_data, p2_data, p3_data, x_data,
dy_data, d_x_data);
}
} else {
if (d_scale) {
set_zero(dev_ctx, d_scale, static_cast<T>(0));
}
if (d_bias) {
set_zero(dev_ctx, d_bias, static_cast<T>(0));
}
Tensor temp_var;
temp_var.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_var, static_cast<T>(0));
T* temp_var_data = temp_var.data<T>();
Tensor temp_mean;
temp_mean.mutable_data<T>(var->dims(), ctx.GetPlace());
set_zero(dev_ctx, &temp_mean, static_cast<T>(0));
T* temp_mean_data = temp_mean.data<T>();
int flags = (scale_data != nullptr) * kHasScale +
(bias_data != nullptr) * kHasBias;
UNROLL_ALL_CASES(flags, GroupNormBackwardGetMeanAndVar, y_data,
scale_data, bias_data, dy_data, x_dims[0], C, W, imsize,
groups, group_size, epsilon, temp_mean_data,
temp_var_data, d_scale_data, d_bias_data);
if (d_x_data != nullptr) {
UNROLL_ALL_CASES(flags, GroupNormBackward, y_data, dy_data, scale_data,
bias_data, var_data, temp_mean_data, temp_var_data,
x_dims[0], C, W, imsize, groups, group_size, epsilon,
d_x_data);
}
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册