diff --git a/ci/compatibility/fbs/V2-backup/opr_param_defs.fbs b/ci/compatibility/fbs/V2-backup/opr_param_defs.fbs index 1ef3ac1a14034971f88bdec9d8d4dad5de5b5ff6..75ed1971585079ddea1920cdce3818aad04e1d5b 100644 --- a/ci/compatibility/fbs/V2-backup/opr_param_defs.fbs +++ b/ci/compatibility/fbs/V2-backup/opr_param_defs.fbs @@ -1869,6 +1869,13 @@ table LayerNorm { normalized_size:ulong = 1; } +table GroupNorm { + affine:bool = true; + eps:float = 1e-5; + group:uint = 1; + format:ConvolutionFormat = NCHW; +} + table Dropout { drop_prob:float = 0; seed:ulong = 0; diff --git a/ci/compatibility/fbs/V2-backup/schema_v2.fbs b/ci/compatibility/fbs/V2-backup/schema_v2.fbs index 7bbb847e8480590e5efbddd0c39dabb03acfecda..f6f0732c652b75e36f88fd2af0475dace0a21de0 100644 --- a/ci/compatibility/fbs/V2-backup/schema_v2.fbs +++ b/ci/compatibility/fbs/V2-backup/schema_v2.fbs @@ -140,6 +140,7 @@ union OperatorParam { param.LSTM = 89, param.Softmax = 90, param.Diag = 91, + param.GroupNorm = 92, } table Operator { diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index 4468a0cc2ae724d93ffa62452a1b37c0441ad534..5d7a6ac5998fcd8bed318255fbd5b3e3839f0a47 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -2430,6 +2430,76 @@ protected: const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw, size_t workspace_in_bytes); }; + +class GroupNormBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(GroupNormBase, OperatorBase); + DEF_OPR_PARAM(GroupNorm); + +protected: + void deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, + TensorLayout& rstd); + void check_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd); +}; + +class GroupNormForward : public GroupNormBase { + DEF_OPR_IMPL(GroupNormForward, GroupNormBase, 3, 3); + +public: + virtual void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) = 0; + MGE_WIN_DECLSPEC_FUC void deduce_layout( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, + TensorLayout& rstd); + virtual size_t get_workspace_in_bytes( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd) = 0; + +protected: + void check_exec( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd, size_t workspace_in_bytes); +}; +using GroupNorm = GroupNormForward; + +class GroupNormBackward : public GroupNormBase { + DEF_OPR_IMPL(GroupNormBackward, GroupNormBase, 5, 3); + +public: + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, + TensorLayout& dbias); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias, + size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 1c7dd1938f6cdec17e2c655e5456777a390a22d7..3ab046e458b41dc9a6deaf85147ebadc977e4f2a 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1247,6 +1247,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), .add_fields('uint64', 'normalized_size', '1') ) +(pdef('GroupNorm') + .add_fields('bool', 'affine', 'true') + .add_fields('float32', 'eps', '1e-5f') + .add_fields('uint32', 'group', '1') + .add_enum_alias('Format', 'Convolution') +) + (pdef('Dropout') .add_fields('float32', 'drop_prob', '0') .add_fields('uint64', 'seed', '0') diff --git a/dnn/src/common/group_norm.cpp b/dnn/src/common/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b4d391da85235b9aaa318231e6d20299655fb476 --- /dev/null +++ b/dnn/src/common/group_norm.cpp @@ -0,0 +1,121 @@ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +using Param = GroupNormBase::Param; + +void GroupNormBase::deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { + MEGDNN_MARK_USED_VAR(weight); + MEGDNN_MARK_USED_VAR(bias); + size_t N = data.shape[0]; + size_t group = param().group; + TensorLayout unnormalized_layout({N, group}, dtype::Float32()); + dst = data; + mean = unnormalized_layout; + rstd = unnormalized_layout; +} + +void GroupNormBase::check_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { + megdnn_assert_contiguous(data); + megdnn_assert_contiguous(weight); + megdnn_assert_contiguous(bias); + megdnn_assert_contiguous(dst); + megdnn_assert_contiguous(mean); + megdnn_assert_contiguous(rstd); + auto errmsg = [&]() { + return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + + megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + + megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + megdnn_assert(data.eq_layout(dst), "%s", errmsg().c_str()); + megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str()); + megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); + + auto p = param(); + size_t C = data.shape[1]; + size_t group = p.group; + megdnn_assert( + group > 0, "Expected num groups to be greater than 0, got %zu", group); + megdnn_assert( + C % group == 0, + "Expected number of channels in input to be divisible by num_groups, but " + "got Channel of shape %zu and num_groups= %zu", + C, group); +} + +void GroupNormForward::deduce_layout( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { + deduce_layout_fwd(data, weight, bias, dst, mean, rstd); +} + +void GroupNormForward::check_exec( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, + size_t workspace_in_bytes) { + check_layout_fwd(data, weight, bias, dst, mean, rstd); + auto required_workspace_in_bytes = + get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void GroupNormBackward::deduce_layout( + const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, + TensorLayout& dweight, TensorLayout& dbias) { + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(mean); + MEGDNN_MARK_USED_VAR(rstd); + ddata = data; + dweight = weight; + dbias = weight; +} + +void GroupNormBackward::check_exec( + const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias, + size_t workspace_in_bytes) { + auto p = param(); + auto required_workspace_in_bytes = get_workspace_in_bytes( + diff, data, weight, mean, rstd, ddata, dweight, dbias); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + + megdnn_assert_contiguous(diff); + megdnn_assert_contiguous(data); + megdnn_assert_contiguous(mean); + megdnn_assert_contiguous(rstd); + megdnn_assert_contiguous(ddata); + if (p.affine) { + megdnn_assert_contiguous(weight); + megdnn_assert_contiguous(dweight); + megdnn_assert_contiguous(dbias); + } + + auto errmsg = [&]() { + return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + + megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + + megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + + megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str()); + megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str()); + if (p.affine) { + megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str()); + megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str()); + } +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 375d36cafc2242676c50b1e0c0cec36873ace438..6c5a04da31039b0bba62a30f543990ed74529d21 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -216,7 +216,9 @@ private: cb(NormForward) \ cb(RegionRestrictedConvolutionForward) \ cb(RegionRestrictedConvolutionBackwardData) \ - cb(RegionRestrictedConvolutionBackwardFilter) + cb(RegionRestrictedConvolutionBackwardFilter) \ + cb(GroupNormForward) \ + cb(GroupNormBackward) // clang-format on /*! diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 87e7708dc9e26d8a5978237e268db9d4d8a2ec4f..39a28089f172c4d495a2d73f3e20942f6f315e22 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -142,6 +142,8 @@ DEF(SoftmaxBackward, 3, true, false); DEF(RegionRestrictedConvolutionForward, 5, true, true); DEF(RegionRestrictedConvolutionBackwardData, 5, true, false); DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false); +DEF(GroupNormForward, 6, true, true); +DEF(GroupNormBackward, 8, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_norm/group_norm_cuda.cu b/dnn/src/cuda/group_norm/group_norm_cuda.cu new file mode 100644 index 0000000000000000000000000000000000000000..8240218d6ad71b8c0508554e74a349166bbc1361 --- /dev/null +++ b/dnn/src/cuda/group_norm/group_norm_cuda.cu @@ -0,0 +1,529 @@ +#include +#include +#include +#include +#include "megdnn/arch.h" +#include "megdnn/basic_types.h" +#include "megdnn/dtype.h" +#include "src/cuda/cuda_shfl_compat.cuh" +#include "src/cuda/group_norm/group_norm_cuda.cuh" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace group_norm { + +// warp size may be used as array length, or used in host function, +// so we define WARP_SIZE rather than using warpSize +#define WARP_SIZE 32 + +template +struct Compare { + template + __host__ __device__ inline static bool Run(const T* d1, const T* d2) { + return d1[kStart] == d2[kStart] && + Compare::Run(d1, d2); + } +}; + +template +struct Compare { + template + __host__ __device__ inline constexpr static bool Run(const T* d1, const T* d2) { + return true; + } +}; + +template +using UnrollCompare = Compare<0, N, N == 0>; + +template +struct UnrollVarArgsAssignImpl { + template + __host__ __device__ inline static void Run(T* d, T val, Args... args) { + static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument"); + d[kStart] = val; + UnrollVarArgsAssignImpl::Run( + d, args...); + } +}; + +template +struct UnrollVarArgsAssignImpl { + __host__ __device__ inline static void Run(T* d) {} +}; + +template +struct UnrollVarArgsAssign { + template + __host__ __device__ inline static void Run(T* d, Args... args) { + UnrollVarArgsAssignImpl::Run( + d, args...); + } +}; + +template +class Array { +public: + static constexpr size_t kSize = N; + + __host__ __device__ inline Array() {} + + template + __host__ __device__ inline explicit Array(const T& val, Args... args) { + static_assert(N == sizeof...(Args) + 1, "Invalid argument"); + UnrollVarArgsAssign::Run(data_, val, args...); + } + + __host__ __device__ inline T& operator[](size_t i) { return *(data_ + i); } + + __host__ __device__ inline const T& operator[](size_t i) const { + return *(data_ + i); + } + +private: + template + __host__ __device__ static inline U* advance(U* ptr, size_t i) { + return ptr + i; + } + + T data_[N]; +}; + +// ================================ group_norm forward =========================== + +// implementation of groupnorm_forward from +// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_kernel.cu#L115 + +template +__forceinline__ __device__ T +CudaShuffleDownSync(T val, int delta, int width = warpSize) { + return __shfl_down(val, static_cast(delta), width); +} + +template <> +__forceinline__ __device__ dt_float16 +CudaShuffleDownSync(dt_float16 val, int delta, int width) { + return dt_float16(__shfl_down(val, static_cast(delta), width)); +} + +template <> +__forceinline__ __device__ dt_bfloat16 +CudaShuffleDownSync(dt_bfloat16 val, int delta, int width) { + return dt_bfloat16(__shfl_down(val, static_cast(delta), width)); +} + +template +struct alignas(sizeof(T) * VecSize) VectorType { + T val[VecSize]; +}; + +template +struct AddFunctor { + inline T initial() { return static_cast(0.0f); } + + __device__ __forceinline__ T operator()(const T a, const T b) const { + return b + a; + } +}; + +template +__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) { + for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) { + T temp = CudaShuffleDownSync(val, stride); + val = reducer(val, temp); + } + return val; +} + +template +__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) { + __syncthreads(); + __shared__ T shared[64]; + int block_dim_x = blockDim.x; + if (blockDim.x > WARP_SIZE) { + block_dim_x = blockDim.x / WARP_SIZE; + int lane = threadIdx.x % WARP_SIZE; + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int wid = tid / WARP_SIZE; + int bid = threadIdx.y; + val = WarpReduce(val, reducer); + if (lane == 0) { + shared[wid] = val; + } + __syncthreads(); + val = shared[bid * block_dim_x + lane]; + } + + for (int stride = 1; stride < block_dim_x; stride <<= 1) { + T temp = CudaShuffleDownSync(val, stride); + val = reducer(val, temp); + } + if (threadIdx.x == 0) { + shared[threadIdx.y] = val; + } + __syncthreads(); + return shared[threadIdx.y]; +} + +template +__device__ __forceinline__ void ReduceMeanAndVar( + T* mean, T* var, T x_mean, T x_var, int size) { + const int nc = blockIdx.x; + x_mean = BlockXReduce>(x_mean, AddFunctor()); + x_var = BlockXReduce>(x_var, AddFunctor()); + __syncthreads(); + if (threadIdx.x == 0) { + mean[nc] = static_cast(x_mean / size); + var[nc] = static_cast(x_var / size); + } +} + +template +__device__ __forceinline__ void ThreadReduce( + Array arrs, int size, const int offset, T_ACC* out_mean, + T_ACC* out_var) { + const T* x = arrs[0]; + const T* y; + if (Num == 2) { + y = arrs[1]; + } + using VecT = VectorType; + int tid = threadIdx.x; + if (offset > 0) { + x -= offset; + if (Num == 2) { + y -= offset; + } + size += offset; + if (tid >= offset) { + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } + } + size -= blockDim.x; + x += blockDim.x; + if (Num == 2) { + y += blockDim.x; + } + } + int remain = size % (VecSize * blockDim.x); + + T ins_x[VecSize]; + T ins_y[VecSize]; + VecT* ins_vec_x = reinterpret_cast(&ins_x); + VecT* ins_vec_y = reinterpret_cast(&ins_y); + + // vector part + for (; VecSize * tid < (size - remain); tid += blockDim.x) { + *ins_vec_x = reinterpret_cast(x)[tid]; + if (Num == 2) { + *ins_vec_y = reinterpret_cast(y)[tid]; + } + +#pragma unroll + for (int i = 0; i < VecSize; ++i) { + if (Num == 1) { + *out_mean += ins_x[i]; + *out_var += ins_x[i] * ins_x[i]; + } else if (Num == 2) { + *out_mean += ins_y[i]; + *out_var += ins_y[i] * ins_x[i]; + } + } + } + + // scalar part + tid = size - remain + threadIdx.x; + for (; tid < size; tid += blockDim.x) { + if (Num == 1) { + *out_mean += x[tid]; + *out_var += x[tid] * x[tid]; + } else if (Num == 2) { + *out_mean += y[tid]; + *out_var += y[tid] * x[tid]; + } + } +} + +template +__global__ void ScalarGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { + int i = blockIdx.x; + T_ACC x_mean = static_cast(0); + T_ACC x_var = static_cast(0); + for (int j = threadIdx.x; j < size; j += blockDim.x) { + T val; + val = x[i * size + j]; + x_mean += val; + x_var += val * val; + } + ReduceMeanAndVar(mean, var, x_mean, x_var, size); +} + +template +__global__ void VectorizedGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) { + int i = blockIdx.x; + T_ACC x_mean = static_cast(0); + T_ACC x_var = static_cast(0); + x += i * size; + const int input_offset = ((uint64_t)x) % 16 / sizeof(T); + Array ins; + ins[0] = x; + ThreadReduce(ins, size, input_offset, &x_mean, &x_var); + ReduceMeanAndVar(mean, var, x_mean, x_var, size); +} + +template +__global__ void GroupNormForward( + const T* x, const T_ACC* mean, const T_ACC* var, const T* scale, const T* bias, + int N, int C, int W, int imsize, int groups, int group_size, float epsilon, + T* y, T_ACC* real_var) { + int gid = blockIdx.y; + int cid = blockIdx.x; + int bid = blockIdx.z; + int ccid = gid * group_size + cid; + if (ccid >= C) + return; + auto ng = bid * groups + gid; + T_ACC x_mean = mean[ng]; + T_ACC x_var = var[ng]; + x_var = x_var - x_mean * x_mean; + T_ACC var_inv = rsqrt(x_var + epsilon); + if (cid == 0 && threadIdx.x == 0) { + real_var[ng] = x_var; + } + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + T val; + int index = (bid * C + ccid) * imsize + imid; + val = x[index]; + val = (val - x_mean) * var_inv; + if (scale != nullptr) { + val *= scale[ccid]; + } + if (bias != nullptr) { + val += bias[ccid]; + } + y[index] = val; + } +} + +template +void forward( + T* src, T* weight, T* bias, T* dst, T_ACC* mean, T_ACC* rstd, T_ACC* temp_rstd, + T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream) { + auto group_size = C / group; + int block_size = std::min(1024, imsize); + dim3 grid(group_size, group, N); + dim3 threads(block_size, 1, 1); + int size = group_size * imsize; + constexpr int vec_size = sizeof(float4) / sizeof(T); + int max_block_size = std::min(size / vec_size, 1024); + int block_size_temp = 1; + while (block_size_temp < max_block_size) { + block_size_temp *= 2; + } + block_size_temp = std::max(block_size_temp, WARP_SIZE); + dim3 grids(N * group); + dim3 blocks(block_size_temp); + if (size < vec_size * block_size_temp) { + ScalarGetMeanAndVar + <<>>(src, mean, temp_rstd, size); + after_kernel_launch(); + } else { + VectorizedGetMeanAndVar + <<>>(src, mean, temp_rstd, size); + after_kernel_launch(); + } + GroupNormForward<<>>( + src, mean, temp_rstd, weight, bias, N, C, W, imsize, group, group_size, eps, + dst, rstd); + after_kernel_launch(); +} + +// ================================ group_norm backward =========================== + +// implementation of groupnorm_backward from +// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu#L253 + +template +__global__ void GetDsDbCUDAKernel(int imsize, const T* x, const T* dy, T* ds, T* db) { + const int nc = blockIdx.x; + T ds_sum = static_cast(0); + T db_sum = static_cast(0); + for (int i = threadIdx.x; i < imsize; i += blockDim.x) { + const int index = nc * imsize + i; + ds_sum += dy[index] * x[index]; + db_sum += dy[index]; + } + ReduceMeanAndVar(db, ds, db_sum, ds_sum, 1); +} + +template +__global__ void GetBiasGradientCUDAKernel( + int N, int C, int group, T_ACC epsilon, const T_ACC* mean, const T_ACC* var, + const T* ds, const T* db, T* d_scale, T* d_bias) { + const int c = blockIdx.x * blockDim.x + threadIdx.x; + if (c < C) { + const int G = group; + const int D = C / G; + T sum1 = static_cast(0); + T sum2 = static_cast(0); + for (int n = 0; n < N; ++n) { + const int nc = n * C + c; + const int ng = n * G + c / D; + sum1 += (d_scale == nullptr) + ? T(0) + : ((ds[nc] - db[nc] * static_cast(mean[ng])) * + static_cast(rsqrt((float)(var[ng] + epsilon)))); + sum2 += (d_bias == nullptr) ? T(0) : db[nc]; + } + if (d_scale != nullptr) { + d_scale[c] = sum1; + } + if (d_bias != nullptr) { + d_bias[c] = sum2; + } + } +} + +template +__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { +#pragma unroll + for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { + val += __shfl_down(val, offset, warpSize); + } + return val; +} + +template +__inline__ MEGDNN_DEVICE T BlockReduceSum(T val, T* shared) { + const int lid = threadIdx.x % warpSize; + const int wid = threadIdx.x / warpSize; + val = warp_reduce_sum(val); + __syncthreads(); + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); + if (wid == 0) { + val = warp_reduce_sum(val); + } + return val; +} + +template +__global__ void GetBackwardParamsCUDAKernel( + int imsize, int groups, int group_size, T_ACC epsilon, const T_ACC* mean, + const T_ACC* var, const T* scale, const T* ds, const T* db, T* p1, T* p2, + T* p3) { + const int n = blockIdx.x; + const int g = blockIdx.y; + const int ng = n * groups + g; + T sum1 = static_cast(0); + T sum2 = static_cast(0); + T var_inv = static_cast(rsqrt(var[ng] + epsilon)); + for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) { + const int64_t index = ng * group_size + i; + const int64_t c = g * group_size + i; + const T scale_v = scale == nullptr ? T(1) : static_cast(scale[c]); + sum1 += ds[index] * scale_v; + sum2 += db[index] * scale_v; + const T scale_c = scale == nullptr ? T(0) : static_cast(scale[c]); + p1[index] = scale_c * var_inv; + } + + __shared__ T ds_shared[WARP_SIZE]; + __shared__ T db_shared[WARP_SIZE]; + sum1 = BlockReduceSum(sum1, ds_shared); + sum2 = BlockReduceSum(sum2, db_shared); + + if (threadIdx.x == 0) { + const T s = T(1) / static_cast(group_size * imsize); + const T x = (sum2 * static_cast(mean[ng]) - sum1) * static_cast(var_inv) * + static_cast(var_inv) * static_cast(var_inv) * s; + p2[ng] = x; + p3[ng] = -x * static_cast(mean[ng]) - sum2 * static_cast(var_inv) * s; + } +} + +template +__global__ void GetXGradientCUDAKernel( + int imsize, int C, int group_size, int groups, T* p1, T* p2, T* p3, const T* x, + const T* dy, T* dx) { + int cid = blockIdx.x; + int gid = blockIdx.y; + int bid = blockIdx.z; + int ccid = bid * C + gid * group_size + cid; + int ng = bid * groups + gid; + int nc = gid * group_size + cid; + for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) { + int index = (bid * C + nc) * imsize + imid; + dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng]; + } +} + +template +void backward( + const T* dY_data, const T* X_data, const T_ACC* mean_data, + const T_ACC* rstd_data, const T* weight_data, T* dX_data, T* dweight_data, + T* dbias_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, + T* p1, T* p2, T* p3, cudaStream_t stream) { + auto group_size = C / group; + int block_size = std::min(1024, imsize); + const int block_dims = 1024; + dim3 grid(group_size, group, N); + dim3 threads(block_size, 1, 1); + const int max_num_threads = 1024; + int max_block_size = std::min(imsize, max_num_threads); + int block_size_temp = 1; + while (block_size_temp < max_block_size) { + block_size_temp *= 2; + } + block_size_temp = std::max(block_size_temp, WARP_SIZE); + dim3 blocks(block_size_temp); + GetDsDbCUDAKernel + <<>>(imsize, X_data, dY_data, ds, db); + after_kernel_launch(); + bool flag = weight_data != nullptr ? true : false; + if (flag) { + const int block = 256; + GetBiasGradientCUDAKernel + <<<(C + block - 1) / block, block, 0, stream>>>( + N, C, group, eps, mean_data, rstd_data, ds, db, dweight_data, + dbias_data); + after_kernel_launch(); + } + + GetBackwardParamsCUDAKernel + <<>>( + imsize, group, group_size, eps, mean_data, rstd_data, weight_data, + ds, db, p1, p2, p3); + after_kernel_launch(); + GetXGradientCUDAKernel<<>>( + imsize, C, group_size, group, p1, p2, p3, X_data, dY_data, dX_data); + after_kernel_launch(); +} + +#define INST(T, T_ACC) \ + template void forward( \ + T*, T*, T*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC, int, int, int, int, int, \ + cudaStream_t); \ + template void backward( \ + const T*, const T*, const T_ACC*, const T_ACC*, const T*, T*, T*, T*, \ + T_ACC, int, int, int, int, T*, T*, T*, T*, T*, cudaStream_t); + +INST(dt_float32, dt_float32) +INST(dt_float16, dt_float32) +INST(dt_bfloat16, dt_float32) +#undef INST + +} // namespace group_norm +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_norm/group_norm_cuda.cuh b/dnn/src/cuda/group_norm/group_norm_cuda.cuh new file mode 100644 index 0000000000000000000000000000000000000000..ce4f999ac0d8a98b66893b82e52bde0d71d58680 --- /dev/null +++ b/dnn/src/cuda/group_norm/group_norm_cuda.cuh @@ -0,0 +1,24 @@ +#pragma once +#include + +namespace megdnn { +namespace cuda { +namespace group_norm { + +template +void forward( + T* X, T* gamma, T* beta, T* Y, T_ACC* mean, T_ACC* rstd, T_ACC* tesmp_rstd, + T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream); + +template +void backward( + const T* dY_data, const T* X_data, const T_ACC* mean_data, + const T_ACC* rstd_data, const T* gamma_data, T* dX_data, T* dgamma_data, + T* dbeta_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db, + T* p1, T* p2, T* p3, cudaStream_t stream); + +} // namespace group_norm +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_norm/opr_impl.cpp b/dnn/src/cuda/group_norm/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..c10a0e9bbd6762690d4ceceebe7ed4d3f935aa46 --- /dev/null +++ b/dnn/src/cuda/group_norm/opr_impl.cpp @@ -0,0 +1,143 @@ +#include "src/cuda/group_norm/opr_impl.h" +#include "src/cuda/group_norm/group_norm_cuda.cuh" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +size_t GroupNormForwardImpl::get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout& rstd) { + size_t N = rstd.shape[0]; + size_t G = rstd.shape[1]; + return get_workspace_bundle(N, G, rstd.dtype.size()).total_size_in_bytes(); +} + +WorkspaceBundle GroupNormForwardImpl::get_workspace_bundle( + size_t N, size_t G, size_t dtype_size, void* raw_ptr) { + return {raw_ptr, {N * G * dtype_size}, handle()->alignment_requirement()}; +} + +void GroupNormForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) { + check_exec( + data.layout, weight.layout, bias.layout, dst.layout, mean.layout, + rstd.layout, workspace.size); + + auto p = param(); + using Format = param::GroupNorm::Format; + float eps = p.eps; + int group = p.group; + bool affine = p.affine; + auto layout = data.layout; + auto format = p.format; + size_t N, C, H, W, imsize; + if (data.layout.ndim == 4 && format == Format::NCHW) { + N = layout.shape[0]; + C = layout.shape[1]; + H = layout.shape[2]; + W = layout.shape[3]; + imsize = H * W; + } else { + megdnn_throw(ssprintf("Unspport groupnorm input")); + } + + auto stream = cuda_stream(handle()); + using namespace ::megdnn::cuda::group_norm; + auto wbundle = + get_workspace_bundle(N, group, rstd.layout.dtype.size(), workspace.raw_ptr); + +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + T_ACC* temp_rstd = wbundle.get_workspace(0).ptr(); \ + forward( \ + data.ptr(), affine ? weight.ptr() : nullptr, \ + affine ? bias.ptr() : nullptr, dst.ptr(), mean.ptr(), \ + rstd.ptr(), temp_rstd, static_cast(eps), \ + static_cast(group), static_cast(N), static_cast(C), \ + static_cast(W), static_cast(imsize), stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +size_t GroupNormBackwardImpl::get_workspace_in_bytes( + const TensorLayout&, const TensorLayout& data, const TensorLayout&, + const TensorLayout& mean, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) { + size_t N = data.shape[0]; + size_t C = data.shape[1]; + size_t G = mean.shape[1]; + return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); +} + +WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( + size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { + return {raw_ptr, + {N * C * dtype_size, N * C * dtype_size, N * C * dtype_size, + N * G * dtype_size, N * G * dtype_size}, + handle()->alignment_requirement()}; +} + +void GroupNormBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) { + check_exec( + diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, + ddata.layout, dweight.layout, dbias.layout, workspace.size); + auto p = param(); + using Format = param::GroupNorm::Format; + bool affine = p.affine; + float eps = p.eps; + int group = p.group; + auto layout = data.layout; + auto format = p.format; + size_t N, C, H, W, imsize; + if (layout.ndim == 4 && format == Format::NCHW) { + N = layout.shape[0]; + C = layout.shape[1]; + H = layout.shape[2]; + W = layout.shape[3]; + imsize = H * W; + } else { + megdnn_throw(ssprintf("Unspport groupnorm input")); + } + + auto stream = cuda_stream(handle()); + using namespace ::megdnn::cuda::group_norm; + auto wbundle = get_workspace_bundle( + N, C, group, data.layout.dtype.size(), workspace.raw_ptr); +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + T* ds = wbundle.get_workspace(0).ptr(); \ + T* db = wbundle.get_workspace(1).ptr(); \ + T* p1 = wbundle.get_workspace(2).ptr(); \ + T* p2 = wbundle.get_workspace(3).ptr(); \ + T* p3 = wbundle.get_workspace(4).ptr(); \ + backward( \ + diff.ptr(), data.ptr(), mean.ptr(), rstd.ptr(), \ + affine ? weight.ptr() : nullptr, ddata.ptr(), \ + affine ? dweight.ptr() : nullptr, \ + affine ? dbias.ptr() : nullptr, static_cast(eps), \ + static_cast(group), static_cast(N), static_cast(C), \ + static_cast(imsize), ds, db, p1, p2, p3, stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/group_norm/opr_impl.h b/dnn/src/cuda/group_norm/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..41cd2c67849e77de919f930961b34d3e81634982 --- /dev/null +++ b/dnn/src/cuda/group_norm/opr_impl.h @@ -0,0 +1,47 @@ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" +#include "src/cuda/cudnn_wrapper.h" + +namespace megdnn { +namespace cuda { + +class GroupNormForwardImpl final : public GroupNormForward { +public: + using GroupNormForward::GroupNormForward; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, + const TensorLayout& rstd) override; + +private: + WorkspaceBundle get_workspace_bundle( + size_t N, size_t G, size_t dtype_size, void* raw_ptr = nullptr); +}; + +class GroupNormBackwardImpl final : public GroupNormBackward { +public: + using GroupNormBackward::GroupNormBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout& data, const TensorLayout&, + const TensorLayout& mean, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override; + +private: + WorkspaceBundle get_workspace_bundle( + size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 7b35f36c6b8007282157e8d6ba03fc2524c0d830..d49a3b9b73ef96ad8a281389791e7f3bc45cc9c0 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -32,6 +32,7 @@ #include "src/cuda/flip/opr_impl.h" #include "src/cuda/gaussian_blur/opr_impl.h" #include "src/cuda/group_local/opr_impl.h" +#include "src/cuda/group_norm/opr_impl.h" #include "src/cuda/images2neibs/opr_impl.h" #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" #include "src/cuda/indexing_one_hot/opr_impl.h" @@ -163,6 +164,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData); MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormForward); +MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormBackward); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip); MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate); MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy); diff --git a/dnn/src/naive/group_norm/opr_impl.cpp b/dnn/src/naive/group_norm/opr_impl.cpp new file mode 100644 index 0000000000000000000000000000000000000000..a0b1ba02b510cedea7aba8869d5d748aafb55216 --- /dev/null +++ b/dnn/src/naive/group_norm/opr_impl.cpp @@ -0,0 +1,206 @@ +#include "src/naive/group_norm/opr_impl.h" +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +namespace { + +using Param = megdnn::GroupNorm::Param; + +template +void forward( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + const Param& param) { + float eps = param.eps; + bool affine = param.affine; + size_t N = data.layout.shape[0]; + size_t C = data.layout.shape[1]; + size_t HxW = data.layout.shape[2] * data.layout.shape[3]; + const int64_t G = param.group; + size_t D = C / G; + size_t inner_size = D * HxW; + + for (size_t i = 0; i < N * G; i++) { + T_ACC slice_sum = static_cast(0.0f); + for (size_t j = 0; j < inner_size; j++) { + auto value = data.ptr()[i * inner_size + j]; + slice_sum += value; + } + T_ACC slice_mean = static_cast(slice_sum / inner_size); + + T_ACC slice_var = static_cast(0.0f); + for (size_t j = 0; j < inner_size; j++) { + slice_var += (data.ptr()[i * inner_size + j] - slice_mean) * + (data.ptr()[i * inner_size + j] - slice_mean); + } + slice_var = slice_var / inner_size; + + T_ACC slice_std = static_cast(1.0f) / static_cast(sqrt(slice_var + eps)); + if (affine) { + const int64_t g = i % G; + for (size_t j = 0; j < D; j++) { + const int64_t c = g * D + j; + T_ACC s = slice_std * weight.ptr()[c]; + T_ACC b = -s * slice_mean + bias.ptr()[c]; + for (size_t k = 0; k < HxW; k++) { + dst.ptr()[(i * D + j) * HxW + k] = + s * data.ptr()[(i * D + j) * HxW + k] + b; + } + } + } else { + for (size_t j = 0; j < inner_size; j++) { + dst.ptr()[i * inner_size + j] = + (data.ptr()[i * inner_size + j] - slice_mean) / slice_std; + } + } + mean.ptr()[i] = static_cast(slice_mean); + rstd.ptr()[i] = static_cast(slice_var); + } +} + +template +void backward( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param, + WorkspaceBundle wbundle) { + bool affine = param.affine; + size_t N = data.layout.shape[0]; + size_t C = data.layout.shape[1]; + size_t G = param.group; + float eps = param.eps; + size_t HxW = data.layout.shape[2] * data.layout.shape[3]; + T* ds = wbundle.get_workspace(0).ptr(); + T* db = wbundle.get_workspace(1).ptr(); + T* slice_std = wbundle.get_workspace(2).ptr(); + for (size_t i = 0; i < N * G; i++) { + slice_std[i] = + static_cast(1.0f) / static_cast(sqrt(rstd.ptr()[i] + eps)); + } + for (size_t i = 0; i < N * C; i++) { + T ds_data = static_cast(0.0f); + T db_data = static_cast(0.0f); + for (size_t j = 0; j < HxW; j++) { + db_data += diff.ptr()[i * HxW + j]; + ds_data += data.ptr()[i * HxW + j] * diff.ptr()[i * HxW + j]; + } + ds[i] = ds_data; + db[i] = db_data; + } + size_t D = C / G; + const T s = T(1) / static_cast(D * HxW); + for (size_t i = 0; i < N * G; i++) { + const int64_t g = i % G; + T ds_v = static_cast(0.0f); + T db_v = static_cast(0.0f); + for (size_t j = 0; j < D; j += 1) { + auto weight_v = affine ? weight.ptr()[g * D + j] : static_cast(1.0f); + ds_v += ds[i * D + j] * weight_v; + db_v += db[i * D + j] * weight_v; + } + auto c2 = (db_v * mean.ptr()[i] - ds_v) * slice_std[i] * slice_std[i] * + slice_std[i] * s; + auto c3 = -c2 * mean.ptr()[i] - db_v * slice_std[i] * s; + for (size_t j = 0; j < D; j++) { + const int64_t c = g * D + j; + auto weight_v = affine ? weight.ptr()[c] : static_cast(1.0f); + auto c1 = slice_std[i] * weight_v; + for (size_t k = 0; k < HxW; k++) { + ddata.ptr()[(i * D + j) * HxW + k] = + c1 * diff.ptr()[(i * D + j) * HxW + k] + + c2 * data.ptr()[(i * D + j) * HxW + k] + c3; + } + } + } + if (affine) { + for (size_t i = 0; i < C; ++i) { + dweight.ptr()[i] = 0; + dbias.ptr()[i] = 0; + } + for (size_t i = 0; i < N * G; i++) { + auto g = i % G; + for (size_t j = 0; j < D; j++) { + auto c = g * D + j; + dweight.ptr()[c] += + (ds[i * D + j] - db[i * D + j] * mean.ptr()[i]) * + slice_std[i]; + } + } + for (size_t i = 0; i < N; i++) { + for (size_t j = 0; j < C; j++) { + dbias.ptr()[j] += db[i * C + j]; + } + } + } +} + +} // namespace + +namespace megdnn { +namespace naive { + +size_t GroupNormBackwardImpl::get_workspace_in_bytes( + const TensorLayout&, const TensorLayout& data, const TensorLayout&, + const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, + const TensorLayout&, const TensorLayout&) { + size_t N = data.shape[0]; + size_t C = data.shape[1]; + size_t G = rstd.shape[1]; + return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes(); +} + +WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle( + size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) { + return {raw_ptr, + {N * C * dtype_size, N * C * dtype_size, N * G * dtype_size}, + handle()->alignment_requirement()}; +} + +void GroupNormForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) { + check_exec( + data.layout, weight.layout, bias.layout, dst.layout, mean.layout, + rstd.layout, workspace.size); +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(forward::ctype>( \ + data, weight, bias, dst, mean, rstd, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void GroupNormBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) { + check_exec( + diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, + ddata.layout, dweight.layout, dbias.layout, workspace.size); +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + auto wbundle = get_workspace_bundle( \ + data.layout.shape[0], data.layout.shape[1], rstd.layout.shape[1], \ + sizeof(DTypeTrait::ctype), workspace.raw_ptr); \ + MEGDNN_DISPATCH_CPU_KERN_OPR(backward::ctype>( \ + diff, data, weight, mean, rstd, ddata, dweight, dbias, param(), \ + wbundle)); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/group_norm/opr_impl.h b/dnn/src/naive/group_norm/opr_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..7691ce2cccbb33e4ede00bfc90d044b5ea90ba33 --- /dev/null +++ b/dnn/src/naive/group_norm/opr_impl.h @@ -0,0 +1,44 @@ +#pragma once +#include "megdnn/oprs.h" +#include "src/common/utils.h" + +namespace megdnn { +namespace naive { + +class GroupNormForwardImpl final : public GroupNormForward { +public: + using GroupNormForward::GroupNormForward; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +class GroupNormBackwardImpl final : public GroupNormBackward { +public: + using GroupNormBackward::GroupNormBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) override; + + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout& data, const TensorLayout&, + const TensorLayout&, const TensorLayout& rstd, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override; + +private: + WorkspaceBundle get_workspace_bundle( + size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr); +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen \ No newline at end of file diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index f916b9d05ca501e29de668246b9b20738b8a99c0..6c5e50c77672661d29e5a1425656c8e3a272febe 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -34,6 +34,7 @@ #include "src/naive/flip/opr_impl.h" #include "src/naive/gaussian_blur/opr_impl.h" #include "src/naive/group_local/opr_impl.h" +#include "src/naive/group_norm/opr_impl.h" #include "src/naive/images2neibs/opr_impl.h" #include "src/naive/indexing_multi_axis_vec/opr_impl.h" #include "src/naive/indexing_one_hot/opr_impl.h" diff --git a/dnn/test/cuda/group_norm.cpp b/dnn/test/cuda/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..7cc3ddeb45b271bb5428ab3a480dd5c8c0c3765f --- /dev/null +++ b/dnn/test/cuda/group_norm.cpp @@ -0,0 +1,44 @@ +#include "test/cuda/fixture.h" + +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, GROUPNORM_FORWARD) { + using Param = GroupNormForward::Param; + Param param; + param.affine = true; + param.eps = 1e-6; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + + auto run = [&](DType d) { + for (size_t group : {1, 3}) + for (size_t C : {6, 9}) { + param.group = group; + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, d) + .set_dtype(2, d) + .set_dtype(3, d) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, dtype::Float32()) + .execs({{2, C, 2, 1}, + {C}, + {C}, + {2, C, 2, 1}, + {2, group}, + {2, group}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + run(dtype::BFloat16()); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/naive/group_norm.cpp b/dnn/test/naive/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..6a341f3b6be3e6f1b985a190a7bd02bd5c271b0a --- /dev/null +++ b/dnn/test/naive/group_norm.cpp @@ -0,0 +1,70 @@ +#include "megdnn/dtype.h" +#include "megdnn/oprs.h" +#include "test/common/checker.h" +#include "test/naive/fixture.h" + +namespace megdnn { +namespace test { + +TEST_F(NAIVE, GROUPNORM_FORWARD) { + Checker checker(handle(), true); + + GroupNorm::Param param; + param.affine = true; + param.group = 3; + checker.set_param(param).exect( + Testcase{ + TensorValue( + {2, 3, 2, 1}, dtype::Float32(), + {3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587, + 0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input + TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx + TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx + {}, + {}, + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue( + {2, 3, 2, 1}, dtype::Float32(), + {1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999, + -0.9998, 0.9998}), // output + TensorValue( + {2, 3}, dtype::Float32(), + {1.7135, -0.1645, -0.0107, -0.9938, 0.067, + -0.0758}), // mean + TensorValue( + {2, 3}, dtype::Float32(), + {2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}), // var + }); + checker.set_param(param).exect( + Testcase{ + TensorValue( + {1, 3, 1, 2}, dtype::Float32(), + {-2.4348, -1.7948, 0.5223, 0.0932, -0.2955, + -0.0492}), // input + TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx + TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx + {}, + {}, + {}}, + Testcase{ + {}, + {}, + {}, + TensorValue( + {1, 3, 1, 2}, dtype::Float32(), + {-0.9999, 0.9999, 0.9999, -0.9999, -0.9997, + 0.9997}), // output + TensorValue( + {1, 3}, dtype::Float32(), + {-2.1148, 0.3077, -0.1724}), // mean + TensorValue( + {1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}), // var + }); +} + +} // namespace test +} // namespace megdnn \ No newline at end of file diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index a1f11f9f6304171d85ff60a5922ef667c1d29459..f5742d8b55896a434dbf8292c39525e6c489639c 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -60,6 +60,7 @@ __all__ = [ "dropout", "embedding", "gelu", + "group_norm", "hsigmoid", "hswish", "indexing_one_hot", @@ -1202,6 +1203,33 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return output +def group_norm( + inp: Tensor, + num_groups: int, + affine: bool, + weight: Optional[Tensor] = None, + bias: Optional[Tensor] = None, + eps: float = 1e-5, +): + r"""Applies Group Normalization over a mini-batch of inputs as described in + the paper `Group Normalization `__ + + Args: + inp: input tensor. + num_groups: number of groups to separate the channels into + affine: whether to use weight and bias + weight: must not be None when the affine is true + bias: must not be None when the affine is true + eps: a value added to the denominator for numerical stability. Default: 1e-5 + """ + op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,) + if affine: + assert weight is not None and bias is not None + return apply(op, inp, weight, bias)[0] + else: + return apply(op, inp)[0] + + def layer_norm( inp: Tensor, normalized_shape: tuple, diff --git a/imperative/python/megengine/module/normalization.py b/imperative/python/megengine/module/normalization.py index ae6aa72be37448007d619dd0cfb4e74068499b1c..e9a9bdb9ae735fc3db8c53a978e7bdb3c462a12b 100644 --- a/imperative/python/megengine/module/normalization.py +++ b/imperative/python/megengine/module/normalization.py @@ -34,21 +34,9 @@ class GroupNorm(Module): zeros_(self.bias) def forward(self, x): - N, C, H, W = x.shape - format = x.format - assert C == self.num_channels - - x = x.reshape(N, self.num_groups, -1) - mean = x.mean(axis=2, keepdims=True) - var = (x * x).mean(axis=2, keepdims=True) - mean * mean - - x = (x - mean) / F.sqrt(var + self.eps) - x = x.reshape(N, C, H, W) - if self.affine: - x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1) - # FIXME(czh): remove this after making it a builtin op. - if format == "nhwc": - x = mge.amp.convert_tensor_format(x, inplace=False) + x = F.nn.group_norm( + x, self.num_groups, self.affine, self.weight, self.bias, self.eps + ) return x def _module_info_string(self) -> str: diff --git a/imperative/python/test/unit/module/test_module.py b/imperative/python/test/unit/module/test_module.py index e7c6a553b73f6840ce4eca602b86b8b65803f8ff..6ba7171b82f5a0ebc2fb149f28a077540f1bb160 100644 --- a/imperative/python/test/unit/module/test_module.py +++ b/imperative/python/test/unit/module/test_module.py @@ -8,12 +8,14 @@ import pytest import megengine as mge import megengine.functional as F from megengine import Parameter, Tensor, tensor +from megengine.device import get_device_count from megengine.module import ( BatchNorm1d, BatchNorm2d, Conv1d, Conv2d, Dropout, + GroupNorm, Linear, MaxPool2d, Module, @@ -698,3 +700,67 @@ def test_module_compatible(): assert ( old_attributes == current_attributes ), "Add or delete attributes in Module class may break compatibility of pickle serialization" + + +def test_grou_norm(): + class OriginGroupNormFunc(Module): + def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs): + super().__init__(**kwargs) + assert num_channels % num_groups == 0 + self.num_groups = num_groups + self.num_channels = num_channels + self.eps = eps + self.affine = affine + if self.affine: + self.weight = Parameter(np.ones(num_channels, dtype=np.float32)) + self.bias = Parameter(np.zeros(num_channels, dtype=np.float32)) + else: + self.weight = None + self.bias = None + + def forward(self, x): + N, C, H, W = x.shape + x = x.reshape(N, self.num_groups, -1) + mean = x.mean(axis=2, keepdims=True) + var = (x * x).mean(axis=2, keepdims=True) - mean * mean + x = (x - mean) / F.sqrt(var + self.eps) + x = x.reshape(N, C, H, W) + if self.affine: + x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape( + 1, -1, 1, 1 + ) + return x + + inp = np.random.randn(2, 256, 10, 16).astype("float32") + mge_inp = Tensor(inp) + mge_m = GroupNorm(32, 256) + + ori_inp = Tensor(inp) + ori_m = OriginGroupNormFunc(32, 256) + + targets = np.array(2) + mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters()) + ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters()) + + for i in range(2): + with mge_gm: + mge_output = mge_m(mge_inp) + loss = F.loss.square_loss( + mge_output.sum(), mge.tensor(targets, dtype=np.float32) + ) + mge_gm.backward(loss) + + with ori_gm: + ori_output = ori_m(ori_inp) + loss = F.loss.square_loss( + ori_output.sum(), mge.tensor(targets, dtype=np.float32) + ) + ori_gm.backward(loss) + + np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05) + np.testing.assert_allclose( + mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03 + ) + np.testing.assert_allclose( + mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03 + ) diff --git a/imperative/src/impl/ops/group_norm.cpp b/imperative/src/impl/ops/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..299f23ac1f22c566277fa83833be55470f409694 --- /dev/null +++ b/imperative/src/impl/ops/group_norm.cpp @@ -0,0 +1,97 @@ +#include "megbrain/opr/dnn/group_norm.h" +#include "megbrain/imperative/ops/autogen.h" +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" + +#include "../blob_manager_impl.h" +#include "../dnn_op_helper.h" +#include "../op_trait.h" + +namespace mgb::imperative { +namespace group_norm { + +cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + size_t nr_inp = inputs.size(); + auto p = op.param(); + mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); + OperatorNodeConfig config{op.make_name()}; + if (nr_inp == 3) { + return opr::GroupNorm::make( + inputs[0], inputs[1], inputs[2], op.param(), config)[0] + .node() + ->owner_opr(); + } else { + return opr::GroupNorm::make(inputs[0], op.param(), config)[0] + .node() + ->owner_opr(); + } +} + +std::tuple, bool> infer_output_attrs_fallible( + const OpDef& def, const SmallVector& inputs) { + auto&& group_norm = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto affine = group_norm.affine; + mgb_assert( + (nr_inp == 3 && affine) || (nr_inp == 1 && !affine), + "num of inputs of pooling should be 1 or 3 but you give %zu", + inputs.size()); + + auto&& inp = inputs[0]; + auto& inp_cn = inp.comp_node; + + if (inp.layout.ndim == 0) { + return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}, + {TensorLayout{dtype::Float32()}, inp_cn, {}}}, + false}; + } + + DnnOprHelper dnn_opr(group_norm.param()); + auto&& [oup_layout, mean_layout, rstd_layout] = + dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{}); + return {{{oup_layout, inp_cn, {}}, + {mean_layout, inp_cn, {}}, + {rstd_layout, inp_cn, {}}}, + true}; +} + +SmallVector apply_on_physical_tensor( + const OpDef& def, const SmallVector& inputs, + SmallVector& output_descs, const bool& validated) { + auto&& op_def = def.cast_final_safe(); + size_t nr_inp = inputs.size(); + auto p = op_def.param(); + + mgb_assert( + (nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine), + "num of inputs of groupnorm should be 1 or 3 but you give %zu", + inputs.size()); + + auto cn = inputs[0]->comp_node(); + DnnOprCaller caller(cn, op_def.param()); + + auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>( + inputs[0]->layout(), TensorLayout{}, TensorLayout{}); + + auto out = Tensor::make(oup_layout, cn); + auto mean = Tensor::make(mean_layout, cn); + auto rstd = Tensor::make(rstd_layout, cn); + + if (p.affine) { + caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd); + } else { + megdnn::TensorND empty_dnn; + caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd); + } + return {out, mean, rstd}; +} + +OP_TRAIT_REG(GroupNorm, GroupNorm) + .apply_on_var_node(apply_on_var_node) + .infer_output_attrs_fallible(infer_output_attrs_fallible) + .apply_on_physical_tensor(apply_on_physical_tensor) + .fallback(); + +} // namespace group_norm +} // namespace mgb::imperative diff --git a/imperative/tablegen/generated/hash.txt b/imperative/tablegen/generated/hash.txt index a5c170f904aba8141cb3d27936dde30b3ee6cd29..ed0040f70084d78c72c17e1363d1f40252ceae09 100644 --- a/imperative/tablegen/generated/hash.txt +++ b/imperative/tablegen/generated/hash.txt @@ -1,7 +1,7 @@ -905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py -da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td -5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl -98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl -b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl -3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl +e38b68be4e2aaf3de2f22e3dddbeaac4 ../../dnn/scripts/opr_param_defs.py +cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td +9248d42a9b3e770693306992156f6015 generated/opdef.h.inl +5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl +30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl +d10455217f5f01e3d2668e5689068920 generated/opdef.cpy.inl 71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h diff --git a/imperative/tablegen/generated/opdef.cpp.inl b/imperative/tablegen/generated/opdef.cpp.inl index 5fdd308af2121abafcd430e333c6e7f8a5fb8e6b..b9c5878d94de7e36ba0570fac624e18b1bab10ed 100644 --- a/imperative/tablegen/generated/opdef.cpp.inl +++ b/imperative/tablegen/generated/opdef.cpp.inl @@ -3775,6 +3775,110 @@ OP_TRAIT_REG(GroupLocal, GroupLocal) .props(GroupLocal_props_impl) .make_name(GroupLocal_make_name_impl); +MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNorm); + +namespace { +size_t GroupNorm_hash_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + size_t val = mgb::hash(op_.dyn_typeinfo()); + val = mgb::hash_pair_combine(val, mgb::hash(op_.affine)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.eps)); + val = mgb::hash_pair_combine(val, mgb::hash(op_.group)); + val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format)); + return val; +} +bool GroupNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) { + auto &&a_ = lhs_.cast_final_safe(), + &&b_ = rhs_.cast_final_safe(); + static_cast(a_); + static_cast(b_); + if (a_.affine != b_.affine) return false; + if (a_.eps != b_.eps) return false; + if (a_.group != b_.group) return false; + if (a_.format != b_.format) return false; + return true; +} +std::vector> GroupNorm_props_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + std::vector> props_; + props_.emplace_back("affine", std::to_string(op_.affine)); + props_.emplace_back("eps", std::to_string(op_.eps)); + props_.emplace_back("group", std::to_string(op_.group)); + switch (op_.format){ + case GroupNorm::Format::NCHW: + props_.emplace_back("format", "NCHW"); + break; + case GroupNorm::Format::NHWC: + props_.emplace_back("format", "NHWC"); + break; + case GroupNorm::Format::NHWCD4: + props_.emplace_back("format", "NHWCD4"); + break; + case GroupNorm::Format::NCHW4: + props_.emplace_back("format", "NCHW4"); + break; + case GroupNorm::Format::NCHW8: + props_.emplace_back("format", "NCHW8"); + break; + case GroupNorm::Format::NCHW32: + props_.emplace_back("format", "NCHW32"); + break; + case GroupNorm::Format::NCHW88: + props_.emplace_back("format", "NCHW88"); + break; + case GroupNorm::Format::NCHW44: + props_.emplace_back("format", "NCHW44"); + break; + case GroupNorm::Format::NCHW44_DOT: + props_.emplace_back("format", "NCHW44_DOT"); + break; + case GroupNorm::Format::NCHW4_NCHW32: + props_.emplace_back("format", "NCHW4_NCHW32"); + break; + case GroupNorm::Format::NCHW32_NCHW4: + props_.emplace_back("format", "NCHW32_NCHW4"); + break; + case GroupNorm::Format::NCHW4_NCHW: + props_.emplace_back("format", "NCHW4_NCHW"); + break; + case GroupNorm::Format::NHWC_NCHW: + props_.emplace_back("format", "NHWC_NCHW"); + break; + case GroupNorm::Format::NHWC_NCHW4_IC_SMALL: + props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL"); + break; + case GroupNorm::Format::NCHW_NCHW4_IC_SMALL: + props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL"); + break; + case GroupNorm::Format::CHWN4: + props_.emplace_back("format", "CHWN4"); + break; + case GroupNorm::Format::NCHW64: + props_.emplace_back("format", "NCHW64"); + break; + case GroupNorm::Format::NCHW4_NHWC: + props_.emplace_back("format", "NCHW4_NHWC"); + break; + default: + props_.emplace_back("format", "INVALID"); + break; + } + return props_; +} +std::string GroupNorm_make_name_impl(const OpDef& def_) { + auto&& op_ = def_.cast_final_safe(); + static_cast(op_); + return "GroupNorm"; +} +} // anonymous namespace +OP_TRAIT_REG(GroupNorm, GroupNorm) + .hash(GroupNorm_hash_impl) + .is_same_st(GroupNorm_is_same_st_impl) + .props(GroupNorm_props_impl) + .make_name(GroupNorm_make_name_impl); + MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity); namespace { diff --git a/imperative/tablegen/generated/opdef.cpy.inl b/imperative/tablegen/generated/opdef.cpy.inl index 35f24bada091ff0b7ee7e3cf67c9b3e29b1440bc..8d0511a1dd47b6e1e23b31254aa088b2e4f7b5c0 100644 --- a/imperative/tablegen/generated/opdef.cpy.inl +++ b/imperative/tablegen/generated/opdef.cpy.inl @@ -10075,6 +10075,158 @@ void _init_py_GroupLocal(py::module m) { mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupLocal::typeinfo(), &py_type).second); } +void _init_py_GroupNorm_Format(PyTypeObject& py_type) { + auto& e_type = EnumWrapper::type; + + Py_INCREF(e_type); + mgb_assert(PyDict_SetItemString( + py_type.tp_dict, "Format", reinterpret_cast(e_type)) >= 0); +} + +PyOpDefBegin(GroupNorm) // { + static PyGetSetDef py_getsetters[]; + static PyMethodDef tp_methods[]; + + static PyObject* getstate(PyObject* self, PyObject*) { + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + std::unordered_map state { + + {"affine", serialization::dump(opdef.affine)}, + {"eps", serialization::dump(opdef.eps)}, + {"group", serialization::dump(opdef.group)}, + {"format", serialization::dump(opdef.format)} + }; + return py::cast(state).release().ptr(); + } + static PyObject* setstate(PyObject* self, PyObject* args) { + PyObject* dict = PyTuple_GetItem(args, 0); + if (!dict) return NULL; + auto state = py::cast>(dict); + auto& opdef = reinterpret_cast(self)->inst(); + static_cast(opdef); + + { + auto&& iter = state.find("affine"); + if (iter != state.end()) { + opdef.affine = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("eps"); + if (iter != state.end()) { + opdef.eps = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("group"); + if (iter != state.end()) { + opdef.group = serialization::load(iter->second); + } + } + + { + auto&& iter = state.find("format"); + if (iter != state.end()) { + opdef.format = serialization::load(iter->second); + } + } + Py_RETURN_NONE; + } + static int py_init(PyObject *self, PyObject *args, PyObject *kwds); +// }; +PyOpDefEnd(GroupNorm) + +int PyOp(GroupNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) { + static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL}; + PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL; + if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast(kwlist), &affine, &eps, &group, &format, &scope)) + return -1; + + if (affine) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().affine = + py::cast(py::handle(affine)); + } CATCH_ALL(-1) + } + + if (eps) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().eps = + py::cast(py::handle(eps)); + } CATCH_ALL(-1) + } + + if (group) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().group = + py::cast(py::handle(group)); + } CATCH_ALL(-1) + } + + if (format) { + try { + // TODO: remove this guard which is used for pybind11 implicit conversion + py::detail::loader_life_support guard{}; + reinterpret_cast(self)->inst().format = + py::cast(py::handle(format)); + } CATCH_ALL(-1) + } + + if (scope) { + try { + reinterpret_cast(self)->op + ->set_scope(py::cast(py::handle(scope))); + } CATCH_ALL(-1) + } + + return 0; +} + +PyGetSetDef PyOp(GroupNorm)::py_getsetters[] = { + {const_cast("affine"), py_get_generic(GroupNorm, affine), py_set_generic(GroupNorm, affine), const_cast("affine"), NULL}, + {const_cast("eps"), py_get_generic(GroupNorm, eps), py_set_generic(GroupNorm, eps), const_cast("eps"), NULL}, + {const_cast("group"), py_get_generic(GroupNorm, group), py_set_generic(GroupNorm, group), const_cast("group"), NULL}, + {const_cast("format"), py_get_generic(GroupNorm, format), py_set_generic(GroupNorm, format), const_cast("format"), NULL}, + {NULL} /* Sentinel */ +}; + + PyMethodDef PyOp(GroupNorm)::tp_methods[] = { + {const_cast("__getstate__"), PyOp(GroupNorm)::getstate, METH_NOARGS, "GroupNorm getstate"}, + {const_cast("__setstate__"), PyOp(GroupNorm)::setstate, METH_VARARGS, "GroupNorm setstate"}, + {NULL} /* Sentinel */ + }; + +void _init_py_GroupNorm(py::module m) { + using py_op = PyOp(GroupNorm); + auto& py_type = PyOpType(GroupNorm); + py_type = {PyVarObject_HEAD_INIT(NULL, 0)}; + py_type.tp_name = "megengine.core._imperative_rt.ops.GroupNorm"; + py_type.tp_basicsize = sizeof(PyOp(GroupNorm)); + py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE; + py_type.tp_doc = "GroupNorm"; + py_type.tp_base = &PyOpType(OpDef); + py_type.tp_dealloc = py_dealloc_generic; + py_type.tp_new = py_new_generic; + py_type.tp_init = py_op::py_init; + py_type.tp_methods = py_op::tp_methods; + py_type.tp_getset = py_op::py_getsetters; + mgb_assert(PyType_Ready(&py_type) >= 0); + _init_py_GroupNorm_Format(py_type); + + PyType_Modified(&py_type); + m.add_object("GroupNorm", reinterpret_cast(&py_type)); + mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupNorm::typeinfo(), &py_type).second); +} + PyOpDefBegin(Identity) // { static PyGetSetDef py_getsetters[]; static PyMethodDef tp_methods[]; @@ -19237,6 +19389,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) { _init_py_GaussianRNG(m); \ _init_py_GetVarShape(m); \ _init_py_GroupLocal(m); \ + _init_py_GroupNorm(m); \ _init_py_Identity(m); \ _init_py_Images2Neibs(m); \ _init_py_IncrMeshIndexing(m); \ diff --git a/imperative/tablegen/generated/opdef.h.inl b/imperative/tablegen/generated/opdef.h.inl index 32493046f9e4d19e6ce37171ea6eefd8ab076dd3..b674b127126f4d4d8343c57ea2f7a8275794aa3d 100644 --- a/imperative/tablegen/generated/opdef.h.inl +++ b/imperative/tablegen/generated/opdef.h.inl @@ -988,6 +988,23 @@ public: } }; +class GroupNorm : public OpDefImplBase { + MGB_DYN_TYPE_OBJ_FINAL_DECL; + +public: + using Format = ::megdnn::param::GroupNorm::Format; + bool affine = true; + float eps = 1e-5f; + uint32_t group = 1; + Format format = ::megdnn::param::GroupNorm::Format::NCHW; + GroupNorm() = default; + GroupNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); } + GroupNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {} + ::megdnn::param::GroupNorm param() const { + return {affine, eps, group, format}; + } +}; + class Identity : public OpDefImplBase { MGB_DYN_TYPE_OBJ_FINAL_DECL; diff --git a/imperative/tablegen/generated/opdef.py.inl b/imperative/tablegen/generated/opdef.py.inl index f93be1bd4c1c20d01031ec1d3a971c46abc9839d..1788ff1576bcc5fe2568e2f67bc9b38126135ec7 100644 --- a/imperative/tablegen/generated/opdef.py.inl +++ b/imperative/tablegen/generated/opdef.py.inl @@ -1193,6 +1193,17 @@ GroupLocalInst .def_readwrite("format", &GroupLocal::format) .def_readwrite("compute_mode", &GroupLocal::compute_mode); +py::class_, OpDef> GroupNormInst(m, "GroupNorm"); + +GroupNormInst.attr("Format") = AdaptivePoolingInst.attr("Format"); + +GroupNormInst + .def(py::init(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {}) + .def_readwrite("affine", &GroupNorm::affine) + .def_readwrite("eps", &GroupNorm::eps) + .def_readwrite("group", &GroupNorm::group) + .def_readwrite("format", &GroupNorm::format); + py::class_, OpDef> IdentityInst(m, "Identity"); IdentityInst diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index ecf403e0681245b0cd5ef7c9441873e47ebcd50b..221ee8769ecc13abee38ed9a661214bb587726c9 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -490,6 +490,8 @@ def LRN: MgbHashableOp<"LRN", [LRNParam]>; def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; +def GroupNorm: MgbHashableOp<"GroupNorm", [GroupNormParam]>; + def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>; def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>; diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 54a05d9ea4cdbf73c19be62931ca7a27add28a5f..9fddafa3b80de1fae1d5da6011b657e3c3f434a2 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -1,8 +1,10 @@ +#include "megbrain/opr/basic_arith.h" #include "megbrain/opr/dnn/adaptive_pooling.h" #include "megbrain/opr/dnn/batch_norm.h" #include "megbrain/opr/dnn/convolution.h" #include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" +#include "megbrain/opr/dnn/group_norm.h" #include "megbrain/opr/dnn/images2neibs.h" #include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" @@ -15,6 +17,9 @@ #include "megbrain/opr/dnn/sliding_window_transpose.h" #include "megbrain/opr/dnn/softmax.h" #include "megbrain/opr/dnn/tqt.h" +#include "megbrain/opr/io.h" +#include "megbrain/opr/tensor_manip.h" +#include "megbrain/serialization/oss_opr_load_dump.h" #include "megbrain/serialization/sereg.h" #include "megdnn/opr_param_defs.h" #include "megdnn/oprs/nn.h" @@ -524,6 +529,213 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Param = opr::GroupNorm::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 3) { + return opr::GroupNorm::make(i[0], i[1], i[2], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 1); + return opr::GroupNorm::make(i[0], param, config)[0].node()->owner_opr(); + } + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::GroupNorm; + using Param = opr::GroupNorm::Param; + using ElemwiseParam = opr::Elemwise::Param; + using ReduceParam = opr::Reduce::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto graph = inputs[0]->owner_graph(); + auto comp_node = inputs[0]->comp_node(); + // std::unique_ptr m_static_infer_manager; + auto opr_param = opr->cast_final_safe().param(); + float eps = opr_param.eps; + auto half = DTypeScalar(static_cast(0.5)); + auto param_eps = DTypeScalar(static_cast(eps)); + auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); + auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); + + auto origin_shape = opr::GetVarShape::make(inputs[0]).node(); + + TensorShape input_shape = + inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); + size_t N = input_shape[0]; + size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; + int group = opr_param.group; + int size = inner_size / group; + HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = N; + ptr[1] = group; + ptr[2] = size; + auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); + auto inp = opr::Reshape::make(inputs[0], target_shape); + + auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2}); + auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL}); + auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2}); + auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL}); + auto var = + opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB}); + auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD}); + auto sqrt = + opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW}); + auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB}); + auto temp_inp = + opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); + auto res = opr::Reshape::make(temp_inp, origin_shape); + + if (inputs.size() == 3) { + auto mul_temp = + opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL}); + auto res = opr::Elemwise::make( + {mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD}); + return res.node()->owner_opr(); + } else { + return res.node()->owner_opr(); + } + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + // auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx); + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + +// OprMaker in MGB_SEREG_OPR only support unique output opr +template <> +struct OprMaker { + using Param = opr::GroupNormBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 5) { + return opr::GroupNormBackward::make( + i[0], i[1], i[2], i[3], i[4], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 4); + return opr::GroupNormBackward::make( + i[0], i[1], i[2], i[3], param, config)[0] + .node() + ->owner_opr(); + } + } +}; + +template <> +struct OprLoadDumpImplV2 { + using Opr = opr::GroupNormBackward; + using Param = opr::GroupNormBackward::Param; + using ElemwiseParam = opr::Elemwise::Param; + using ReduceParam = opr::Reduce::Param; + static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) { + ctx.write_param(opr.cast_final_safe().param()); + } + + static cg::OperatorNodeBase* replace_opr( + cg::OperatorNodeBase* opr, const VarNodeArray& inputs) { + auto rstd = inputs[4]; + auto graph = inputs[1]->owner_graph(); + auto comp_node = inputs[1]->comp_node(); + auto opr_param = opr->cast_final_safe().param(); + float eps = opr_param.eps; + auto half = DTypeScalar(static_cast(0.5)); + auto param_eps = DTypeScalar(static_cast(eps)); + auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node}); + auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node}); + auto const_node = + opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node}); + + TensorShape input_shape = + inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]); + auto origin_shape = opr::GetVarShape::make(inputs[1]).node(); + size_t N = input_shape[0]; + size_t C = input_shape[1]; + size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3]; + int group = opr_param.group; + int size = inner_size / group; + HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32()); + auto* ptr = hv.ptr(); + ptr[0] = N; + ptr[1] = group; + ptr[2] = size; + auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node}); + auto inp = opr::Reshape::make(inputs[1], target_shape); + + auto temp_rstd = + opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD}); + auto sqrt = + opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW}); + auto slice_std = opr::Elemwise::make( + {const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV}); + auto sub_mean = + opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB}); + auto x_hat = + opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL}); + x_hat = opr::Reshape::make(x_hat, origin_shape); + auto size_node = + opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node}); + auto temp1 = opr::Elemwise::make( + {slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV}); + + auto dx_hat = + opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL}); + HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32()); + auto* ptr2 = tshape.ptr(); + ptr2[0] = N; + ptr2[1] = group; + ptr2[2] = C / group; + ptr2[3] = input_shape[2]; + ptr2[4] = input_shape[3]; + target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); + x_hat = opr::Reshape::make(x_hat, target_shape); + dx_hat = opr::Reshape::make(dx_hat, target_shape); + auto temp2 = + opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL}); + ptr2[2] = 1; + ptr2[3] = 1; + ptr2[4] = 1; + target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node}); + auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape); + auto sum_dx_hat = + opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape); + auto temp4 = + opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL}); + auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB}); + auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB}); + auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL}); + auto dx = opr::Reshape::make(dx_temp, origin_shape); + return dx.node()->owner_opr(); + } + + static cg::OperatorNodeBase* load( + OprLoadContext& ctx, const cg::VarNodeArray& inputs, + const OperatorNodeConfig& config) { + return OprMaker::make( + ctx.read_param(), inputs, ctx.graph(), config); + } +}; + template struct MakeLocalShareCaller2 { template @@ -747,6 +959,8 @@ MGB_SEREG_OPR(LSQ, 4); MGB_SEREG_OPR(LSQBackward, 5); MGB_SEREG_OPR(LayerNorm, 0); MGB_SEREG_OPR(LayerNormBackward, 0); +MGB_SEREG_OPR(GroupNorm, 0); +MGB_SEREG_OPR(GroupNormBackward, 0); MGB_SEREG_OPR(RNNCellForward, 6); MGB_SEREG_OPR(LSTMCellForward, 7); MGB_SEREG_OPR(RNNForward, 3); @@ -755,6 +969,14 @@ MGB_SEREG_OPR(LSTMForward, 4); MGB_SEREG_OPR(LSTMBackward, 9); MGB_SEREG_OPR(Softmax, 1); MGB_SEREG_OPR(SoftmaxBackward, 2); +MGB_SEREG_OPR_V2( + GroupNorm, 0, + (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_2, CURRENT_VERSION); +MGB_SEREG_OPR_V2( + GroupNormBackward, 0, + (mgb::serialization::OprLoadDumpImplV2::replace_opr), + VERSION_2, CURRENT_VERSION); } // namespace opr } // namespace mgb diff --git a/src/opr/impl/dnn/group_norm.cpp b/src/opr/impl/dnn/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..b70544136a9455f86d21a32de424da3fe0079daa --- /dev/null +++ b/src/opr/impl/dnn/group_norm.cpp @@ -0,0 +1,239 @@ +#include "megbrain/opr/dnn/group_norm.h" + +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; + +/* ==================== GroupNormForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormForward); + +GroupNormForward::GroupNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "group_norm", {data, weight, bias}} { + init_megdnn_opr(*this, param); + + add_input({data, weight, bias}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +GroupNormForward::GroupNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "group_norm", {data}} { + init_megdnn_opr(*this, param); + + add_input({data}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +SymbolVarArray GroupNormForward::make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, + const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), weight.node(), bias.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray GroupNormForward::make( + SymbolVar data, const Param& param, const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void GroupNormForward::get_output_var_shape( + const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { + size_t group = param().group; + out_shape[0] = inp_shape[0]; + size_t N = inp_shape[0].shape[0]; + TensorShape unnormalized_shape{N, group}; + out_shape[1] = unnormalized_shape; + out_shape[2] = unnormalized_shape; +} + +size_t GroupNormForward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return intl::MegDNNOprMethInvoker::get_workspace_in_bytes( + megdnn_opr(), this, input_shapes, output_shapes); +} + +void GroupNormForward::scn_do_execute() { + if (param().affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), {}, {}, + output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output().back())); + } +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(GroupNormForward) { + auto p = opr.param(); + SymbolVarArray grad; + VarNodeArray ret; + if (p.affine) { + mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); + grad = GroupNormBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), + opr.param()); + } else { + mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); + grad = GroupNormBackward::make( + out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); + } + + uint32_t nr_ret = p.affine ? 3 : 1; + for (uint32_t i = 0; i < nr_ret; ++i) { + ret.push_back(grad[i].node()); + } + return ret; +} +#endif + +/* ==================== GroupNormBackward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormBackward); + +GroupNormBackward::GroupNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "group_norm_backward", + {diff, data, weight, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, weight, mean, rstd}); +} + +GroupNormBackward::GroupNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, + const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "group_norm_backward", + {diff, data, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, mean, rstd}); + auto mark_empty_var = [&](VarNode* var) { + var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::VOLATILE_CONTENT); + }; + mark_empty_var(output(1)); + mark_empty_var(output(2)); +} + +SymbolVarArray GroupNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), weight.node(), mean.node(), + rstd.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray GroupNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), mean.node(), rstd.node(), + param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void GroupNormBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); + if (param().affine) { + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); + } else { + TensorShape empty; + empty.ndim = 0; + mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); + } + this->init_output_static_infer_desc_workspace( + intl::AutoAddWorkspaceNeedLimitGetter::val); +} + +void GroupNormBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(2)->dtype()); +} + +size_t GroupNormBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return intl::MegDNNOprMethInvoker:: + get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes); +} + +void GroupNormBackward::scn_do_execute() { + if (param().affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), + intl::get_megdnn_workspace_from_var(output(3))); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + {}, input(2)->dev_tensor().as_megdnn(), + input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + {}, {}, intl::get_megdnn_workspace_from_var(output(3))); + } +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/dnn/group_norm.h b/src/opr/include/megbrain/opr/dnn/group_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..32a7da09adb8f83b9bfdb1340a7c6795321f7b88 --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/group_norm.h @@ -0,0 +1,67 @@ +#pragma once + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" + +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + GroupNormForward, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC GroupNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC GroupNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void get_output_var_shape( + const TensorShapeArray& inp_shape, + TensorShapeArray& out_shape) const override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; +using GroupNorm = GroupNormForward; + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + GroupNormBackward, intl::MegDNNOprWrapperBwd) // { +public: + MGE_WIN_DECLSPEC_FUC GroupNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC GroupNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param = {}, const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/group_norm.cpp b/src/opr/test/dnn/group_norm.cpp new file mode 100644 index 0000000000000000000000000000000000000000..be832a7cf790cae644b30039dbd19b4a465351e2 --- /dev/null +++ b/src/opr/test/dnn/group_norm.cpp @@ -0,0 +1,90 @@ +#include "megbrain/opr/dnn/group_norm.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" +#include "megbrain/test/helper.h" +#include "megbrain/test/megdnn_helper.h" + +#include "megdnn/oprs.h" + +#include +#include +#include +#include + +using namespace mgb; + +namespace { +using Param = opr::GroupNormForward::Param; + +void run_forward(bool is_affine) { + using Checker = AutoOprChecker<3, 3>; + + Param param; + param.eps = 1e-5; + param.affine = is_affine; + param.group = 3; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto out = opr::GroupNormForward::make(inputs[0], inputs[1], inputs[2], param); + return {out[0], out[1], out[2]}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = + MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + auto inp_shape = inp[0]->shape(); + auto n_slices = inp_shape[0]; + + opr->param() = param; + + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize(inp_shape); + dest[1].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices, param.group}); + dest[2].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices, param.group}); + std::vector workspace(opr->get_workspace_in_bytes( + inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(), + dest[1].layout(), dest[2].layout())); + opr->exec( + inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), + dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), + {workspace.data(), workspace.size()}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator src_gen(0.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions option; + option.numdiff_max_err = 1e-4; + Checker checker{make_graph, fwd}; + + checker.set_input_generator(0, gen); + checker.set_input_generator(1, gen); + checker.set_input_generator(2, gen); + checker.set_input_allow_grad(0, false); + checker.set_input_allow_grad(1, false); + checker.set_input_allow_grad(2, false); + checker.set_output_allow_grad(0, false); + checker.set_output_allow_grad(1, false); + checker.set_output_allow_grad(2, false); + + checker.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) + .run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option) + .run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option); +} + +TEST(TestOprDNN, GroupNormForward) { + REQUIRE_GPU(1); + run_forward(true); +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index a7ac763bd9383faaf9f765f8207b36bdd20fe875..9a94aea7354e9e1874dd3b07745ad65e25a9ac27 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -123,6 +123,7 @@ union OperatorParam { param.LSTM = 89, param.Softmax = 90, param.Diag = 91, + param.GroupNorm = 92, } table Operator { diff --git a/src/serialization/impl/schema_v2.fbs b/src/serialization/impl/schema_v2.fbs index d931d79471217206ab80ce13eb14deb5db705023..87c2ef36af3fb32b38eab6ee5c4322bf35bfca66 100644 --- a/src/serialization/impl/schema_v2.fbs +++ b/src/serialization/impl/schema_v2.fbs @@ -140,6 +140,7 @@ union OperatorParam { param.LSTM = 89, param.Softmax = 90, param.Diag = 91, + param.GroupNorm = 92, } table Operator {