提交 dbd94839 编写于 作者: M Megvii Engine Team

feat(dnn,src,imperative): add groupnorm op

GitOrigin-RevId: de3c3d10e5b2c4986542b045340f0e4185661ccb
上级 069e4e07
......@@ -1869,6 +1869,13 @@ table LayerNorm {
normalized_size:ulong = 1;
}
table GroupNorm {
affine:bool = true;
eps:float = 1e-5;
group:uint = 1;
format:ConvolutionFormat = NCHW;
}
table Dropout {
drop_prob:float = 0;
seed:ulong = 0;
......
......@@ -140,6 +140,7 @@ union OperatorParam {
param.LSTM = 89,
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
}
table Operator {
......
......@@ -2430,6 +2430,76 @@ protected:
const TensorLayout& dhx, const TensorLayout& dcx, const TensorLayout& dw,
size_t workspace_in_bytes);
};
class GroupNormBase : public OperatorBase {
DEF_OPR_IMPL_CTOR(GroupNormBase, OperatorBase);
DEF_OPR_PARAM(GroupNorm);
protected:
void deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
void check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd);
};
class GroupNormForward : public GroupNormBase {
DEF_OPR_IMPL(GroupNormForward, GroupNormBase, 3, 3);
public:
virtual void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) = 0;
MGE_WIN_DECLSPEC_FUC void deduce_layout(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean,
TensorLayout& rstd);
virtual size_t get_workspace_in_bytes(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd) = 0;
protected:
void check_exec(
const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean,
const TensorLayout& rstd, size_t workspace_in_bytes);
};
using GroupNorm = GroupNormForward;
class GroupNormBackward : public GroupNormBase {
DEF_OPR_IMPL(GroupNormBackward, GroupNormBase, 5, 3);
public:
virtual void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) = 0;
void deduce_layout(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight,
TensorLayout& dbias);
virtual size_t get_workspace_in_bytes(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias) = 0;
protected:
void check_exec(
const TensorLayout& diff, const TensorLayout& data,
const TensorLayout& weight, const TensorLayout& mean,
const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes);
};
} // namespace megdnn
#include "megdnn/internal/opr_header_epilogue.h"
......
......@@ -1247,6 +1247,13 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'),
.add_fields('uint64', 'normalized_size', '1')
)
(pdef('GroupNorm')
.add_fields('bool', 'affine', 'true')
.add_fields('float32', 'eps', '1e-5f')
.add_fields('uint32', 'group', '1')
.add_enum_alias('Format', 'Convolution')
)
(pdef('Dropout')
.add_fields('float32', 'drop_prob', '0')
.add_fields('uint64', 'seed', '0')
......
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
using Param = GroupNormBase::Param;
void GroupNormBase::deduce_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
MEGDNN_MARK_USED_VAR(weight);
MEGDNN_MARK_USED_VAR(bias);
size_t N = data.shape[0];
size_t group = param().group;
TensorLayout unnormalized_layout({N, group}, dtype::Float32());
dst = data;
mean = unnormalized_layout;
rstd = unnormalized_layout;
}
void GroupNormBase::check_layout_fwd(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) {
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(bias);
megdnn_assert_contiguous(dst);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
auto errmsg = [&]() {
return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " +
megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " +
megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd);
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(data.eq_layout(dst), "%s", errmsg().c_str());
megdnn_assert(weight.eq_layout(bias), "%s", errmsg().c_str());
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
auto p = param();
size_t C = data.shape[1];
size_t group = p.group;
megdnn_assert(
group > 0, "Expected num groups to be greater than 0, got %zu", group);
megdnn_assert(
C % group == 0,
"Expected number of channels in input to be divisible by num_groups, but "
"got Channel of shape %zu and num_groups= %zu",
C, group);
}
void GroupNormForward::deduce_layout(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) {
deduce_layout_fwd(data, weight, bias, dst, mean, rstd);
}
void GroupNormForward::check_exec(
const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias,
const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd,
size_t workspace_in_bytes) {
check_layout_fwd(data, weight, bias, dst, mean, rstd);
auto required_workspace_in_bytes =
get_workspace_in_bytes(data, weight, bias, dst, mean, rstd);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
}
void GroupNormBackward::deduce_layout(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata,
TensorLayout& dweight, TensorLayout& dbias) {
MEGDNN_MARK_USED_VAR(diff);
MEGDNN_MARK_USED_VAR(mean);
MEGDNN_MARK_USED_VAR(rstd);
ddata = data;
dweight = weight;
dbias = weight;
}
void GroupNormBackward::check_exec(
const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight,
const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata,
const TensorLayout& dweight, const TensorLayout& dbias,
size_t workspace_in_bytes) {
auto p = param();
auto required_workspace_in_bytes = get_workspace_in_bytes(
diff, data, weight, mean, rstd, ddata, dweight, dbias);
megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes);
megdnn_assert_contiguous(diff);
megdnn_assert_contiguous(data);
megdnn_assert_contiguous(mean);
megdnn_assert_contiguous(rstd);
megdnn_assert_contiguous(ddata);
if (p.affine) {
megdnn_assert_contiguous(weight);
megdnn_assert_contiguous(dweight);
megdnn_assert_contiguous(dbias);
}
auto errmsg = [&]() {
return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " +
megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " +
megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " +
megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias);
};
MEGDNN_MARK_USED_VAR(errmsg);
megdnn_assert(data.eq_layout(ddata), "%s", errmsg().c_str());
megdnn_assert(mean.eq_layout(rstd), "%s", errmsg().c_str());
if (p.affine) {
megdnn_assert(weight.eq_layout(dweight), "%s", errmsg().c_str());
megdnn_assert(weight.eq_layout(dbias), "%s", errmsg().c_str());
}
}
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -216,7 +216,9 @@ private:
cb(NormForward) \
cb(RegionRestrictedConvolutionForward) \
cb(RegionRestrictedConvolutionBackwardData) \
cb(RegionRestrictedConvolutionBackwardFilter)
cb(RegionRestrictedConvolutionBackwardFilter) \
cb(GroupNormForward) \
cb(GroupNormBackward)
// clang-format on
/*!
......
......@@ -142,6 +142,8 @@ DEF(SoftmaxBackward, 3, true, false);
DEF(RegionRestrictedConvolutionForward, 5, true, true);
DEF(RegionRestrictedConvolutionBackwardData, 5, true, false);
DEF(RegionRestrictedConvolutionBackwardFilter, 5, true, false);
DEF(GroupNormForward, 6, true, true);
DEF(GroupNormBackward, 8, true, true);
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include <stdio.h>
#include <thrust/pair.h>
#include <thrust/tuple.h>
#include <cfloat>
#include "megdnn/arch.h"
#include "megdnn/basic_types.h"
#include "megdnn/dtype.h"
#include "src/cuda/cuda_shfl_compat.cuh"
#include "src/cuda/group_norm/group_norm_cuda.cuh"
#include "src/cuda/utils.cuh"
namespace megdnn {
namespace cuda {
namespace group_norm {
// warp size may be used as array length, or used in host function,
// so we define WARP_SIZE rather than using warpSize
#define WARP_SIZE 32
template <size_t kStart, size_t kEnd, bool kStop>
struct Compare {
template <typename T>
__host__ __device__ inline static bool Run(const T* d1, const T* d2) {
return d1[kStart] == d2[kStart] &&
Compare<kStart + 1, kEnd, kStart + 1 == kEnd>::Run(d1, d2);
}
};
template <size_t kStart, size_t kEnd>
struct Compare<kStart, kEnd, true> {
template <typename T>
__host__ __device__ inline constexpr static bool Run(const T* d1, const T* d2) {
return true;
}
};
template <size_t N>
using UnrollCompare = Compare<0, N, N == 0>;
template <typename T, size_t kStart, size_t kEnd, bool kStop>
struct UnrollVarArgsAssignImpl {
template <typename... Args>
__host__ __device__ inline static void Run(T* d, T val, Args... args) {
static_assert(sizeof...(args) + 1 == kEnd - kStart, "Wrong argument");
d[kStart] = val;
UnrollVarArgsAssignImpl<T, kStart + 1, kEnd, kStart + 1 == kEnd>::Run(
d, args...);
}
};
template <typename T, size_t kStart, size_t kEnd>
struct UnrollVarArgsAssignImpl<T, kStart, kEnd, true> {
__host__ __device__ inline static void Run(T* d) {}
};
template <typename T>
struct UnrollVarArgsAssign {
template <typename... Args>
__host__ __device__ inline static void Run(T* d, Args... args) {
UnrollVarArgsAssignImpl<T, 0, sizeof...(Args), sizeof...(Args) == 0>::Run(
d, args...);
}
};
template <typename T, size_t N>
class Array {
public:
static constexpr size_t kSize = N;
__host__ __device__ inline Array() {}
template <typename... Args>
__host__ __device__ inline explicit Array(const T& val, Args... args) {
static_assert(N == sizeof...(Args) + 1, "Invalid argument");
UnrollVarArgsAssign<T>::Run(data_, val, args...);
}
__host__ __device__ inline T& operator[](size_t i) { return *(data_ + i); }
__host__ __device__ inline const T& operator[](size_t i) const {
return *(data_ + i);
}
private:
template <typename U>
__host__ __device__ static inline U* advance(U* ptr, size_t i) {
return ptr + i;
}
T data_[N];
};
// ================================ group_norm forward ===========================
// implementation of groupnorm_forward from
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_kernel.cu#L115
template <typename T>
__forceinline__ __device__ T
CudaShuffleDownSync(T val, int delta, int width = warpSize) {
return __shfl_down(val, static_cast<unsigned>(delta), width);
}
template <>
__forceinline__ __device__ dt_float16
CudaShuffleDownSync(dt_float16 val, int delta, int width) {
return dt_float16(__shfl_down(val, static_cast<unsigned>(delta), width));
}
template <>
__forceinline__ __device__ dt_bfloat16
CudaShuffleDownSync(dt_bfloat16 val, int delta, int width) {
return dt_bfloat16(__shfl_down(val, static_cast<unsigned>(delta), width));
}
template <typename T, int VecSize>
struct alignas(sizeof(T) * VecSize) VectorType {
T val[VecSize];
};
template <typename T>
struct AddFunctor {
inline T initial() { return static_cast<T>(0.0f); }
__device__ __forceinline__ T operator()(const T a, const T b) const {
return b + a;
}
};
template <typename T, typename ReduceOp>
__device__ __forceinline__ T WarpReduce(T val, ReduceOp reducer) {
for (int stride = WARP_SIZE / 2; stride > 0; stride >>= 1) {
T temp = CudaShuffleDownSync<T>(val, stride);
val = reducer(val, temp);
}
return val;
}
template <typename T, typename ReduceOp>
__device__ __forceinline__ T BlockXReduce(T val, ReduceOp reducer) {
__syncthreads();
__shared__ T shared[64];
int block_dim_x = blockDim.x;
if (blockDim.x > WARP_SIZE) {
block_dim_x = blockDim.x / WARP_SIZE;
int lane = threadIdx.x % WARP_SIZE;
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int wid = tid / WARP_SIZE;
int bid = threadIdx.y;
val = WarpReduce<T, ReduceOp>(val, reducer);
if (lane == 0) {
shared[wid] = val;
}
__syncthreads();
val = shared[bid * block_dim_x + lane];
}
for (int stride = 1; stride < block_dim_x; stride <<= 1) {
T temp = CudaShuffleDownSync(val, stride);
val = reducer(val, temp);
}
if (threadIdx.x == 0) {
shared[threadIdx.y] = val;
}
__syncthreads();
return shared[threadIdx.y];
}
template <typename T>
__device__ __forceinline__ void ReduceMeanAndVar(
T* mean, T* var, T x_mean, T x_var, int size) {
const int nc = blockIdx.x;
x_mean = BlockXReduce<T, AddFunctor<T>>(x_mean, AddFunctor<T>());
x_var = BlockXReduce<T, AddFunctor<T>>(x_var, AddFunctor<T>());
__syncthreads();
if (threadIdx.x == 0) {
mean[nc] = static_cast<T>(x_mean / size);
var[nc] = static_cast<T>(x_var / size);
}
}
template <typename T, typename T_ACC, int VecSize, int Num>
__device__ __forceinline__ void ThreadReduce(
Array<const T*, Num> arrs, int size, const int offset, T_ACC* out_mean,
T_ACC* out_var) {
const T* x = arrs[0];
const T* y;
if (Num == 2) {
y = arrs[1];
}
using VecT = VectorType<T, VecSize>;
int tid = threadIdx.x;
if (offset > 0) {
x -= offset;
if (Num == 2) {
y -= offset;
}
size += offset;
if (tid >= offset) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
size -= blockDim.x;
x += blockDim.x;
if (Num == 2) {
y += blockDim.x;
}
}
int remain = size % (VecSize * blockDim.x);
T ins_x[VecSize];
T ins_y[VecSize];
VecT* ins_vec_x = reinterpret_cast<VecT*>(&ins_x);
VecT* ins_vec_y = reinterpret_cast<VecT*>(&ins_y);
// vector part
for (; VecSize * tid < (size - remain); tid += blockDim.x) {
*ins_vec_x = reinterpret_cast<const VecT*>(x)[tid];
if (Num == 2) {
*ins_vec_y = reinterpret_cast<const VecT*>(y)[tid];
}
#pragma unroll
for (int i = 0; i < VecSize; ++i) {
if (Num == 1) {
*out_mean += ins_x[i];
*out_var += ins_x[i] * ins_x[i];
} else if (Num == 2) {
*out_mean += ins_y[i];
*out_var += ins_y[i] * ins_x[i];
}
}
}
// scalar part
tid = size - remain + threadIdx.x;
for (; tid < size; tid += blockDim.x) {
if (Num == 1) {
*out_mean += x[tid];
*out_var += x[tid] * x[tid];
} else if (Num == 2) {
*out_mean += y[tid];
*out_var += y[tid] * x[tid];
}
}
}
template <typename T, typename T_ACC>
__global__ void ScalarGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) {
int i = blockIdx.x;
T_ACC x_mean = static_cast<T_ACC>(0);
T_ACC x_var = static_cast<T_ACC>(0);
for (int j = threadIdx.x; j < size; j += blockDim.x) {
T val;
val = x[i * size + j];
x_mean += val;
x_var += val * val;
}
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size);
}
template <typename T, typename T_ACC, int VecSize>
__global__ void VectorizedGetMeanAndVar(const T* x, T_ACC* mean, T_ACC* var, int size) {
int i = blockIdx.x;
T_ACC x_mean = static_cast<T_ACC>(0);
T_ACC x_var = static_cast<T_ACC>(0);
x += i * size;
const int input_offset = ((uint64_t)x) % 16 / sizeof(T);
Array<const T*, 1> ins;
ins[0] = x;
ThreadReduce<T, T_ACC, VecSize, 1>(ins, size, input_offset, &x_mean, &x_var);
ReduceMeanAndVar<T_ACC>(mean, var, x_mean, x_var, size);
}
template <typename T, typename T_ACC>
__global__ void GroupNormForward(
const T* x, const T_ACC* mean, const T_ACC* var, const T* scale, const T* bias,
int N, int C, int W, int imsize, int groups, int group_size, float epsilon,
T* y, T_ACC* real_var) {
int gid = blockIdx.y;
int cid = blockIdx.x;
int bid = blockIdx.z;
int ccid = gid * group_size + cid;
if (ccid >= C)
return;
auto ng = bid * groups + gid;
T_ACC x_mean = mean[ng];
T_ACC x_var = var[ng];
x_var = x_var - x_mean * x_mean;
T_ACC var_inv = rsqrt(x_var + epsilon);
if (cid == 0 && threadIdx.x == 0) {
real_var[ng] = x_var;
}
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
T val;
int index = (bid * C + ccid) * imsize + imid;
val = x[index];
val = (val - x_mean) * var_inv;
if (scale != nullptr) {
val *= scale[ccid];
}
if (bias != nullptr) {
val += bias[ccid];
}
y[index] = val;
}
}
template <typename T, typename T_ACC>
void forward(
T* src, T* weight, T* bias, T* dst, T_ACC* mean, T_ACC* rstd, T_ACC* temp_rstd,
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream) {
auto group_size = C / group;
int block_size = std::min(1024, imsize);
dim3 grid(group_size, group, N);
dim3 threads(block_size, 1, 1);
int size = group_size * imsize;
constexpr int vec_size = sizeof(float4) / sizeof(T);
int max_block_size = std::min(size / vec_size, 1024);
int block_size_temp = 1;
while (block_size_temp < max_block_size) {
block_size_temp *= 2;
}
block_size_temp = std::max(block_size_temp, WARP_SIZE);
dim3 grids(N * group);
dim3 blocks(block_size_temp);
if (size < vec_size * block_size_temp) {
ScalarGetMeanAndVar<T, T_ACC>
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size);
after_kernel_launch();
} else {
VectorizedGetMeanAndVar<T, T_ACC, vec_size>
<<<grids, blocks, 0, stream>>>(src, mean, temp_rstd, size);
after_kernel_launch();
}
GroupNormForward<T, T_ACC><<<grid, threads, 0, stream>>>(
src, mean, temp_rstd, weight, bias, N, C, W, imsize, group, group_size, eps,
dst, rstd);
after_kernel_launch();
}
// ================================ group_norm backward ===========================
// implementation of groupnorm_backward from
// https://github.com/PaddlePaddle/Paddle/blob/develop/paddle/phi/kernels/gpu/group_norm_grad_kernel.cu#L253
template <typename T, typename T_ACC>
__global__ void GetDsDbCUDAKernel(int imsize, const T* x, const T* dy, T* ds, T* db) {
const int nc = blockIdx.x;
T ds_sum = static_cast<T>(0);
T db_sum = static_cast<T>(0);
for (int i = threadIdx.x; i < imsize; i += blockDim.x) {
const int index = nc * imsize + i;
ds_sum += dy[index] * x[index];
db_sum += dy[index];
}
ReduceMeanAndVar<T>(db, ds, db_sum, ds_sum, 1);
}
template <typename T, typename T_ACC>
__global__ void GetBiasGradientCUDAKernel(
int N, int C, int group, T_ACC epsilon, const T_ACC* mean, const T_ACC* var,
const T* ds, const T* db, T* d_scale, T* d_bias) {
const int c = blockIdx.x * blockDim.x + threadIdx.x;
if (c < C) {
const int G = group;
const int D = C / G;
T sum1 = static_cast<T>(0);
T sum2 = static_cast<T>(0);
for (int n = 0; n < N; ++n) {
const int nc = n * C + c;
const int ng = n * G + c / D;
sum1 += (d_scale == nullptr)
? T(0)
: ((ds[nc] - db[nc] * static_cast<T>(mean[ng])) *
static_cast<T>(rsqrt((float)(var[ng] + epsilon))));
sum2 += (d_bias == nullptr) ? T(0) : db[nc];
}
if (d_scale != nullptr) {
d_scale[c] = sum1;
}
if (d_bias != nullptr) {
d_bias[c] = sum2;
}
}
}
template <typename T>
__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) {
#pragma unroll
for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) {
val += __shfl_down(val, offset, warpSize);
}
return val;
}
template <typename T>
__inline__ MEGDNN_DEVICE T BlockReduceSum(T val, T* shared) {
const int lid = threadIdx.x % warpSize;
const int wid = threadIdx.x / warpSize;
val = warp_reduce_sum(val);
__syncthreads();
if (lid == 0) {
shared[wid] = val;
}
__syncthreads();
val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0);
if (wid == 0) {
val = warp_reduce_sum(val);
}
return val;
}
template <typename T, typename T_ACC, int BlockDim>
__global__ void GetBackwardParamsCUDAKernel(
int imsize, int groups, int group_size, T_ACC epsilon, const T_ACC* mean,
const T_ACC* var, const T* scale, const T* ds, const T* db, T* p1, T* p2,
T* p3) {
const int n = blockIdx.x;
const int g = blockIdx.y;
const int ng = n * groups + g;
T sum1 = static_cast<T>(0);
T sum2 = static_cast<T>(0);
T var_inv = static_cast<T>(rsqrt(var[ng] + epsilon));
for (int64_t i = threadIdx.x; i < group_size; i += blockDim.x) {
const int64_t index = ng * group_size + i;
const int64_t c = g * group_size + i;
const T scale_v = scale == nullptr ? T(1) : static_cast<T>(scale[c]);
sum1 += ds[index] * scale_v;
sum2 += db[index] * scale_v;
const T scale_c = scale == nullptr ? T(0) : static_cast<T>(scale[c]);
p1[index] = scale_c * var_inv;
}
__shared__ T ds_shared[WARP_SIZE];
__shared__ T db_shared[WARP_SIZE];
sum1 = BlockReduceSum<T>(sum1, ds_shared);
sum2 = BlockReduceSum<T>(sum2, db_shared);
if (threadIdx.x == 0) {
const T s = T(1) / static_cast<T>(group_size * imsize);
const T x = (sum2 * static_cast<T>(mean[ng]) - sum1) * static_cast<T>(var_inv) *
static_cast<T>(var_inv) * static_cast<T>(var_inv) * s;
p2[ng] = x;
p3[ng] = -x * static_cast<T>(mean[ng]) - sum2 * static_cast<T>(var_inv) * s;
}
}
template <typename T, typename T_ACC>
__global__ void GetXGradientCUDAKernel(
int imsize, int C, int group_size, int groups, T* p1, T* p2, T* p3, const T* x,
const T* dy, T* dx) {
int cid = blockIdx.x;
int gid = blockIdx.y;
int bid = blockIdx.z;
int ccid = bid * C + gid * group_size + cid;
int ng = bid * groups + gid;
int nc = gid * group_size + cid;
for (int imid = threadIdx.x; imid < imsize; imid += blockDim.x) {
int index = (bid * C + nc) * imsize + imid;
dx[index] = p1[ccid] * dy[index] + p2[ng] * x[index] + p3[ng];
}
}
template <typename T, typename T_ACC>
void backward(
const T* dY_data, const T* X_data, const T_ACC* mean_data,
const T_ACC* rstd_data, const T* weight_data, T* dX_data, T* dweight_data,
T* dbias_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db,
T* p1, T* p2, T* p3, cudaStream_t stream) {
auto group_size = C / group;
int block_size = std::min(1024, imsize);
const int block_dims = 1024;
dim3 grid(group_size, group, N);
dim3 threads(block_size, 1, 1);
const int max_num_threads = 1024;
int max_block_size = std::min(imsize, max_num_threads);
int block_size_temp = 1;
while (block_size_temp < max_block_size) {
block_size_temp *= 2;
}
block_size_temp = std::max(block_size_temp, WARP_SIZE);
dim3 blocks(block_size_temp);
GetDsDbCUDAKernel<T, T_ACC>
<<<N * C, blocks, 0, stream>>>(imsize, X_data, dY_data, ds, db);
after_kernel_launch();
bool flag = weight_data != nullptr ? true : false;
if (flag) {
const int block = 256;
GetBiasGradientCUDAKernel<T, T_ACC>
<<<(C + block - 1) / block, block, 0, stream>>>(
N, C, group, eps, mean_data, rstd_data, ds, db, dweight_data,
dbias_data);
after_kernel_launch();
}
GetBackwardParamsCUDAKernel<T, T_ACC, block_dims>
<<<dim3(N, group), block_dims, 0, stream>>>(
imsize, group, group_size, eps, mean_data, rstd_data, weight_data,
ds, db, p1, p2, p3);
after_kernel_launch();
GetXGradientCUDAKernel<T, T_ACC><<<grid, threads, 0, stream>>>(
imsize, C, group_size, group, p1, p2, p3, X_data, dY_data, dX_data);
after_kernel_launch();
}
#define INST(T, T_ACC) \
template void forward<T, T_ACC>( \
T*, T*, T*, T*, T_ACC*, T_ACC*, T_ACC*, T_ACC, int, int, int, int, int, \
cudaStream_t); \
template void backward<T, T_ACC>( \
const T*, const T*, const T_ACC*, const T_ACC*, const T*, T*, T*, T*, \
T_ACC, int, int, int, int, T*, T*, T*, T*, T*, cudaStream_t);
INST(dt_float32, dt_float32)
INST(dt_float16, dt_float32)
INST(dt_bfloat16, dt_float32)
#undef INST
} // namespace group_norm
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include <cuda_runtime_api.h>
namespace megdnn {
namespace cuda {
namespace group_norm {
template <typename T, typename T_ACC>
void forward(
T* X, T* gamma, T* beta, T* Y, T_ACC* mean, T_ACC* rstd, T_ACC* tesmp_rstd,
T_ACC eps, int group, int N, int C, int W, int imsize, cudaStream_t stream);
template <typename T, typename T_ACC>
void backward(
const T* dY_data, const T* X_data, const T_ACC* mean_data,
const T_ACC* rstd_data, const T* gamma_data, T* dX_data, T* dgamma_data,
T* dbeta_data, T_ACC eps, int group, int N, int C, int imsize, T* ds, T* db,
T* p1, T* p2, T* p3, cudaStream_t stream);
} // namespace group_norm
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include "src/cuda/group_norm/opr_impl.h"
#include "src/cuda/group_norm/group_norm_cuda.cuh"
#include "src/cuda/utils.h"
namespace megdnn {
namespace cuda {
size_t GroupNormForwardImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout& rstd) {
size_t N = rstd.shape[0];
size_t G = rstd.shape[1];
return get_workspace_bundle(N, G, rstd.dtype.size()).total_size_in_bytes();
}
WorkspaceBundle GroupNormForwardImpl::get_workspace_bundle(
size_t N, size_t G, size_t dtype_size, void* raw_ptr) {
return {raw_ptr, {N * G * dtype_size}, handle()->alignment_requirement()};
}
void GroupNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);
auto p = param();
using Format = param::GroupNorm::Format;
float eps = p.eps;
int group = p.group;
bool affine = p.affine;
auto layout = data.layout;
auto format = p.format;
size_t N, C, H, W, imsize;
if (data.layout.ndim == 4 && format == Format::NCHW) {
N = layout.shape[0];
C = layout.shape[1];
H = layout.shape[2];
W = layout.shape[3];
imsize = H * W;
} else {
megdnn_throw(ssprintf("Unspport groupnorm input"));
}
auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::group_norm;
auto wbundle =
get_workspace_bundle(N, group, rstd.layout.dtype.size(), workspace.raw_ptr);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
T_ACC* temp_rstd = wbundle.get_workspace(0).ptr<T_ACC>(); \
forward<T, T_ACC>( \
data.ptr<T>(), affine ? weight.ptr<T>() : nullptr, \
affine ? bias.ptr<T>() : nullptr, dst.ptr<T>(), mean.ptr<T_ACC>(), \
rstd.ptr<T_ACC>(), temp_rstd, static_cast<T_ACC>(eps), \
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \
static_cast<int>(W), static_cast<int>(imsize), stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
size_t GroupNormBackwardImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout& data, const TensorLayout&,
const TensorLayout& mean, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) {
size_t N = data.shape[0];
size_t C = data.shape[1];
size_t G = mean.shape[1];
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes();
}
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle(
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) {
return {raw_ptr,
{N * C * dtype_size, N * C * dtype_size, N * C * dtype_size,
N * G * dtype_size, N * G * dtype_size},
handle()->alignment_requirement()};
}
void GroupNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
auto p = param();
using Format = param::GroupNorm::Format;
bool affine = p.affine;
float eps = p.eps;
int group = p.group;
auto layout = data.layout;
auto format = p.format;
size_t N, C, H, W, imsize;
if (layout.ndim == 4 && format == Format::NCHW) {
N = layout.shape[0];
C = layout.shape[1];
H = layout.shape[2];
W = layout.shape[3];
imsize = H * W;
} else {
megdnn_throw(ssprintf("Unspport groupnorm input"));
}
auto stream = cuda_stream(handle());
using namespace ::megdnn::cuda::group_norm;
auto wbundle = get_workspace_bundle(
N, C, group, data.layout.dtype.size(), workspace.raw_ptr);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
using T = typename DTypeTrait<DType>::ctype; \
using T_ACC = float; \
T* ds = wbundle.get_workspace(0).ptr<T>(); \
T* db = wbundle.get_workspace(1).ptr<T>(); \
T* p1 = wbundle.get_workspace(2).ptr<T>(); \
T* p2 = wbundle.get_workspace(3).ptr<T>(); \
T* p3 = wbundle.get_workspace(4).ptr<T>(); \
backward<T, T_ACC>( \
diff.ptr<T>(), data.ptr<T>(), mean.ptr<T_ACC>(), rstd.ptr<T_ACC>(), \
affine ? weight.ptr<T>() : nullptr, ddata.ptr<T>(), \
affine ? dweight.ptr<T>() : nullptr, \
affine ? dbias.ptr<T>() : nullptr, static_cast<T_ACC>(eps), \
static_cast<int>(group), static_cast<int>(N), static_cast<int>(C), \
static_cast<int>(imsize), ds, db, p1, p2, p3, stream); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
#include "src/cuda/cudnn_wrapper.h"
namespace megdnn {
namespace cuda {
class GroupNormForwardImpl final : public GroupNormForward {
public:
using GroupNormForward::GroupNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&,
const TensorLayout& rstd) override;
private:
WorkspaceBundle get_workspace_bundle(
size_t N, size_t G, size_t dtype_size, void* raw_ptr = nullptr);
};
class GroupNormBackwardImpl final : public GroupNormBackward {
public:
using GroupNormBackward::GroupNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout& data, const TensorLayout&,
const TensorLayout& mean, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override;
private:
WorkspaceBundle get_workspace_bundle(
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr);
};
} // namespace cuda
} // namespace megdnn
// vim: syntax=cpp.doxygen
......@@ -32,6 +32,7 @@
#include "src/cuda/flip/opr_impl.h"
#include "src/cuda/gaussian_blur/opr_impl.h"
#include "src/cuda/group_local/opr_impl.h"
#include "src/cuda/group_norm/opr_impl.h"
#include "src/cuda/images2neibs/opr_impl.h"
#include "src/cuda/indexing_multi_axis_vec/opr_impl.h"
#include "src/cuda/indexing_one_hot/opr_impl.h"
......@@ -163,6 +164,8 @@ MEGDNN_SPECIALIZE_CREATE_OPERATOR(BNBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardData);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupLocalBackwardFilter);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormForward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(GroupNormBackward);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Flip);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(Rotate);
MEGDNN_SPECIALIZE_CREATE_OPERATOR(ROICopy);
......
#include "src/naive/group_norm/opr_impl.h"
#include <algorithm>
#include "src/common/utils.h"
#include "src/naive/handle.h"
using namespace megdnn;
using namespace naive;
namespace {
using Param = megdnn::GroupNorm::Param;
template <typename T, typename T_ACC = float>
void forward(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
const Param& param) {
float eps = param.eps;
bool affine = param.affine;
size_t N = data.layout.shape[0];
size_t C = data.layout.shape[1];
size_t HxW = data.layout.shape[2] * data.layout.shape[3];
const int64_t G = param.group;
size_t D = C / G;
size_t inner_size = D * HxW;
for (size_t i = 0; i < N * G; i++) {
T_ACC slice_sum = static_cast<T>(0.0f);
for (size_t j = 0; j < inner_size; j++) {
auto value = data.ptr<T>()[i * inner_size + j];
slice_sum += value;
}
T_ACC slice_mean = static_cast<T>(slice_sum / inner_size);
T_ACC slice_var = static_cast<T>(0.0f);
for (size_t j = 0; j < inner_size; j++) {
slice_var += (data.ptr<T>()[i * inner_size + j] - slice_mean) *
(data.ptr<T>()[i * inner_size + j] - slice_mean);
}
slice_var = slice_var / inner_size;
T_ACC slice_std = static_cast<T>(1.0f) / static_cast<T>(sqrt(slice_var + eps));
if (affine) {
const int64_t g = i % G;
for (size_t j = 0; j < D; j++) {
const int64_t c = g * D + j;
T_ACC s = slice_std * weight.ptr<T>()[c];
T_ACC b = -s * slice_mean + bias.ptr<T>()[c];
for (size_t k = 0; k < HxW; k++) {
dst.ptr<T>()[(i * D + j) * HxW + k] =
s * data.ptr<T>()[(i * D + j) * HxW + k] + b;
}
}
} else {
for (size_t j = 0; j < inner_size; j++) {
dst.ptr<T>()[i * inner_size + j] =
(data.ptr<T>()[i * inner_size + j] - slice_mean) / slice_std;
}
}
mean.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_mean);
rstd.ptr<T_ACC>()[i] = static_cast<T_ACC>(slice_var);
}
}
template <typename T, typename T_ACC = float>
void backward(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param,
WorkspaceBundle wbundle) {
bool affine = param.affine;
size_t N = data.layout.shape[0];
size_t C = data.layout.shape[1];
size_t G = param.group;
float eps = param.eps;
size_t HxW = data.layout.shape[2] * data.layout.shape[3];
T* ds = wbundle.get_workspace(0).ptr<T>();
T* db = wbundle.get_workspace(1).ptr<T>();
T* slice_std = wbundle.get_workspace(2).ptr<T>();
for (size_t i = 0; i < N * G; i++) {
slice_std[i] =
static_cast<T>(1.0f) / static_cast<T>(sqrt(rstd.ptr<T_ACC>()[i] + eps));
}
for (size_t i = 0; i < N * C; i++) {
T ds_data = static_cast<T>(0.0f);
T db_data = static_cast<T>(0.0f);
for (size_t j = 0; j < HxW; j++) {
db_data += diff.ptr<T>()[i * HxW + j];
ds_data += data.ptr<T>()[i * HxW + j] * diff.ptr<T>()[i * HxW + j];
}
ds[i] = ds_data;
db[i] = db_data;
}
size_t D = C / G;
const T s = T(1) / static_cast<T>(D * HxW);
for (size_t i = 0; i < N * G; i++) {
const int64_t g = i % G;
T ds_v = static_cast<T>(0.0f);
T db_v = static_cast<T>(0.0f);
for (size_t j = 0; j < D; j += 1) {
auto weight_v = affine ? weight.ptr<T>()[g * D + j] : static_cast<T>(1.0f);
ds_v += ds[i * D + j] * weight_v;
db_v += db[i * D + j] * weight_v;
}
auto c2 = (db_v * mean.ptr<T_ACC>()[i] - ds_v) * slice_std[i] * slice_std[i] *
slice_std[i] * s;
auto c3 = -c2 * mean.ptr<T_ACC>()[i] - db_v * slice_std[i] * s;
for (size_t j = 0; j < D; j++) {
const int64_t c = g * D + j;
auto weight_v = affine ? weight.ptr<T>()[c] : static_cast<T>(1.0f);
auto c1 = slice_std[i] * weight_v;
for (size_t k = 0; k < HxW; k++) {
ddata.ptr<T>()[(i * D + j) * HxW + k] =
c1 * diff.ptr<T>()[(i * D + j) * HxW + k] +
c2 * data.ptr<T>()[(i * D + j) * HxW + k] + c3;
}
}
}
if (affine) {
for (size_t i = 0; i < C; ++i) {
dweight.ptr<T>()[i] = 0;
dbias.ptr<T>()[i] = 0;
}
for (size_t i = 0; i < N * G; i++) {
auto g = i % G;
for (size_t j = 0; j < D; j++) {
auto c = g * D + j;
dweight.ptr<T>()[c] +=
(ds[i * D + j] - db[i * D + j] * mean.ptr<T_ACC>()[i]) *
slice_std[i];
}
}
for (size_t i = 0; i < N; i++) {
for (size_t j = 0; j < C; j++) {
dbias.ptr<T>()[j] += db[i * C + j];
}
}
}
}
} // namespace
namespace megdnn {
namespace naive {
size_t GroupNormBackwardImpl::get_workspace_in_bytes(
const TensorLayout&, const TensorLayout& data, const TensorLayout&,
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&,
const TensorLayout&, const TensorLayout&) {
size_t N = data.shape[0];
size_t C = data.shape[1];
size_t G = rstd.shape[1];
return get_workspace_bundle(N, C, G, data.dtype.size()).total_size_in_bytes();
}
WorkspaceBundle GroupNormBackwardImpl::get_workspace_bundle(
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr) {
return {raw_ptr,
{N * C * dtype_size, N * C * dtype_size, N * G * dtype_size},
handle()->alignment_requirement()};
}
void GroupNormForwardImpl::exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) {
check_exec(
data.layout, weight.layout, bias.layout, dst.layout, mean.layout,
rstd.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
MEGDNN_DISPATCH_CPU_KERN_OPR(forward<typename DTypeTrait<DType>::ctype>( \
data, weight, bias, dst, mean, rstd, param())); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
void GroupNormBackwardImpl::exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) {
check_exec(
diff.layout, data.layout, weight.layout, mean.layout, rstd.layout,
ddata.layout, dweight.layout, dbias.layout, workspace.size);
#define cb(DType) \
if (data.layout.dtype == DType()) { \
auto wbundle = get_workspace_bundle( \
data.layout.shape[0], data.layout.shape[1], rstd.layout.shape[1], \
sizeof(DTypeTrait<DType>::ctype), workspace.raw_ptr); \
MEGDNN_DISPATCH_CPU_KERN_OPR(backward<typename DTypeTrait<DType>::ctype>( \
diff, data, weight, mean, rstd, ddata, dweight, dbias, param(), \
wbundle)); \
return; \
}
MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb)
#undef cb
megdnn_throw("bad dtype");
}
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
#pragma once
#include "megdnn/oprs.h"
#include "src/common/utils.h"
namespace megdnn {
namespace naive {
class GroupNormForwardImpl final : public GroupNormForward {
public:
using GroupNormForward::GroupNormForward;
void exec(
_megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias,
_megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout&, const TensorLayout&,
const TensorLayout&, const TensorLayout&, const TensorLayout&) override {
return 0;
}
};
class GroupNormBackwardImpl final : public GroupNormBackward {
public:
using GroupNormBackward::GroupNormBackward;
void exec(
_megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight,
_megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata,
_megdnn_tensor_out dweight, _megdnn_tensor_out dbias,
_megdnn_workspace workspace) override;
size_t get_workspace_in_bytes(
const TensorLayout&, const TensorLayout& data, const TensorLayout&,
const TensorLayout&, const TensorLayout& rstd, const TensorLayout&,
const TensorLayout&, const TensorLayout&) override;
private:
WorkspaceBundle get_workspace_bundle(
size_t N, size_t C, size_t G, size_t dtype_size, void* raw_ptr = nullptr);
};
} // namespace naive
} // namespace megdnn
// vim: syntax=cpp.doxygen
\ No newline at end of file
......@@ -34,6 +34,7 @@
#include "src/naive/flip/opr_impl.h"
#include "src/naive/gaussian_blur/opr_impl.h"
#include "src/naive/group_local/opr_impl.h"
#include "src/naive/group_norm/opr_impl.h"
#include "src/naive/images2neibs/opr_impl.h"
#include "src/naive/indexing_multi_axis_vec/opr_impl.h"
#include "src/naive/indexing_one_hot/opr_impl.h"
......
#include "test/cuda/fixture.h"
#include "test/common/checker.h"
namespace megdnn {
namespace test {
TEST_F(CUDA, GROUPNORM_FORWARD) {
using Param = GroupNormForward::Param;
Param param;
param.affine = true;
param.eps = 1e-6;
Checker<GroupNormForward> checker(handle_cuda());
checker.set_epsilon(1e-2);
auto run = [&](DType d) {
for (size_t group : {1, 3})
for (size_t C : {6, 9}) {
param.group = group;
checker.set_param(param)
.set_dtype(0, d)
.set_dtype(1, d)
.set_dtype(2, d)
.set_dtype(3, d)
.set_dtype(4, dtype::Float32())
.set_dtype(5, dtype::Float32())
.execs({{2, C, 2, 1},
{C},
{C},
{2, C, 2, 1},
{2, group},
{2, group}});
}
};
run(dtype::Float32());
run(dtype::Float16());
run(dtype::BFloat16());
}
} // namespace test
} // namespace megdnn
// vim: syntax=cpp.doxygen
#include "megdnn/dtype.h"
#include "megdnn/oprs.h"
#include "test/common/checker.h"
#include "test/naive/fixture.h"
namespace megdnn {
namespace test {
TEST_F(NAIVE, GROUPNORM_FORWARD) {
Checker<GroupNorm> checker(handle(), true);
GroupNorm::Param param;
param.affine = true;
param.group = 3;
checker.set_param(param).exect(
Testcase{
TensorValue(
{2, 3, 2, 1}, dtype::Float32(),
{3.3179, 0.109, -0.5855, 0.2566, -1.2897, 1.2683, -2.0587,
0.0711, -0.1169, 0.2509, -0.2393, 0.0876}), // input
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx
{},
{},
{}},
Testcase{
{},
{},
{},
TensorValue(
{2, 3, 2, 1}, dtype::Float32(),
{1., -1., -1., 1., -1., 1., -1., 1., -0.9999, 0.9999,
-0.9998, 0.9998}), // output
TensorValue(
{2, 3}, dtype::Float32(),
{1.7135, -0.1645, -0.0107, -0.9938, 0.067,
-0.0758}), // mean
TensorValue(
{2, 3}, dtype::Float32(),
{2.5742, 0.1772, 1.6358, 1.1340, 0.0338, 0.0267}), // var
});
checker.set_param(param).exect(
Testcase{
TensorValue(
{1, 3, 1, 2}, dtype::Float32(),
{-2.4348, -1.7948, 0.5223, 0.0932, -0.2955,
-0.0492}), // input
TensorValue({3}, dtype::Float32(), {1., 1., 1.}), // hx
TensorValue({3}, dtype::Float32(), {0., 0., 0.}), // cx
{},
{},
{}},
Testcase{
{},
{},
{},
TensorValue(
{1, 3, 1, 2}, dtype::Float32(),
{-0.9999, 0.9999, 0.9999, -0.9999, -0.9997,
0.9997}), // output
TensorValue(
{1, 3}, dtype::Float32(),
{-2.1148, 0.3077, -0.1724}), // mean
TensorValue(
{1, 3}, dtype::Float32(), {0.1023, 0.0460, 0.0151}), // var
});
}
} // namespace test
} // namespace megdnn
\ No newline at end of file
......@@ -60,6 +60,7 @@ __all__ = [
"dropout",
"embedding",
"gelu",
"group_norm",
"hsigmoid",
"hswish",
"indexing_one_hot",
......@@ -1202,6 +1203,33 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor:
return output
def group_norm(
inp: Tensor,
num_groups: int,
affine: bool,
weight: Optional[Tensor] = None,
bias: Optional[Tensor] = None,
eps: float = 1e-5,
):
r"""Applies Group Normalization over a mini-batch of inputs as described in
the paper `Group Normalization <https://arxiv.org/abs/1803.08494>`__
Args:
inp: input tensor.
num_groups: number of groups to separate the channels into
affine: whether to use weight and bias
weight: must not be None when the affine is true
bias: must not be None when the affine is true
eps: a value added to the denominator for numerical stability. Default: 1e-5
"""
op = builtin.GroupNorm(affine=affine, eps=eps, group=num_groups,)
if affine:
assert weight is not None and bias is not None
return apply(op, inp, weight, bias)[0]
else:
return apply(op, inp)[0]
def layer_norm(
inp: Tensor,
normalized_shape: tuple,
......
......@@ -34,21 +34,9 @@ class GroupNorm(Module):
zeros_(self.bias)
def forward(self, x):
N, C, H, W = x.shape
format = x.format
assert C == self.num_channels
x = x.reshape(N, self.num_groups, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x * x).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(1, -1, 1, 1)
# FIXME(czh): remove this after making it a builtin op.
if format == "nhwc":
x = mge.amp.convert_tensor_format(x, inplace=False)
x = F.nn.group_norm(
x, self.num_groups, self.affine, self.weight, self.bias, self.eps
)
return x
def _module_info_string(self) -> str:
......
......@@ -8,12 +8,14 @@ import pytest
import megengine as mge
import megengine.functional as F
from megengine import Parameter, Tensor, tensor
from megengine.device import get_device_count
from megengine.module import (
BatchNorm1d,
BatchNorm2d,
Conv1d,
Conv2d,
Dropout,
GroupNorm,
Linear,
MaxPool2d,
Module,
......@@ -698,3 +700,67 @@ def test_module_compatible():
assert (
old_attributes == current_attributes
), "Add or delete attributes in Module class may break compatibility of pickle serialization"
def test_grou_norm():
class OriginGroupNormFunc(Module):
def __init__(self, num_groups, num_channels, eps=1e-5, affine=True, **kwargs):
super().__init__(**kwargs)
assert num_channels % num_groups == 0
self.num_groups = num_groups
self.num_channels = num_channels
self.eps = eps
self.affine = affine
if self.affine:
self.weight = Parameter(np.ones(num_channels, dtype=np.float32))
self.bias = Parameter(np.zeros(num_channels, dtype=np.float32))
else:
self.weight = None
self.bias = None
def forward(self, x):
N, C, H, W = x.shape
x = x.reshape(N, self.num_groups, -1)
mean = x.mean(axis=2, keepdims=True)
var = (x * x).mean(axis=2, keepdims=True) - mean * mean
x = (x - mean) / F.sqrt(var + self.eps)
x = x.reshape(N, C, H, W)
if self.affine:
x = self.weight.reshape(1, -1, 1, 1) * x + self.bias.reshape(
1, -1, 1, 1
)
return x
inp = np.random.randn(2, 256, 10, 16).astype("float32")
mge_inp = Tensor(inp)
mge_m = GroupNorm(32, 256)
ori_inp = Tensor(inp)
ori_m = OriginGroupNormFunc(32, 256)
targets = np.array(2)
mge_gm = mge.autodiff.GradManager().attach(mge_m.parameters())
ori_gm = mge.autodiff.GradManager().attach(ori_m.parameters())
for i in range(2):
with mge_gm:
mge_output = mge_m(mge_inp)
loss = F.loss.square_loss(
mge_output.sum(), mge.tensor(targets, dtype=np.float32)
)
mge_gm.backward(loss)
with ori_gm:
ori_output = ori_m(ori_inp)
loss = F.loss.square_loss(
ori_output.sum(), mge.tensor(targets, dtype=np.float32)
)
ori_gm.backward(loss)
np.testing.assert_allclose(mge_output.numpy(), ori_output.numpy(), atol=1e-05)
np.testing.assert_allclose(
mge_m.weight.grad.numpy(), ori_m.weight.grad.numpy(), rtol=1e-03
)
np.testing.assert_allclose(
mge_m.bias.grad.numpy(), ori_m.bias.grad.numpy(), rtol=1e-03
)
#include "megbrain/opr/dnn/group_norm.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace mgb::imperative {
namespace group_norm {
cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) {
auto&& op = static_cast<const GroupNorm&>(def);
size_t nr_inp = inputs.size();
auto p = op.param();
mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine));
OperatorNodeConfig config{op.make_name()};
if (nr_inp == 3) {
return opr::GroupNorm::make(
inputs[0], inputs[1], inputs[2], op.param(), config)[0]
.node()
->owner_opr();
} else {
return opr::GroupNorm::make(inputs[0], op.param(), config)[0]
.node()
->owner_opr();
}
}
std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
const OpDef& def, const SmallVector<LogicalTensorDesc>& inputs) {
auto&& group_norm = def.cast_final_safe<GroupNorm>();
size_t nr_inp = inputs.size();
auto affine = group_norm.affine;
mgb_assert(
(nr_inp == 3 && affine) || (nr_inp == 1 && !affine),
"num of inputs of pooling should be 1 or 3 but you give %zu",
inputs.size());
auto&& inp = inputs[0];
auto& inp_cn = inp.comp_node;
if (inp.layout.ndim == 0) {
return {{{TensorLayout{inp.layout.dtype}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}},
{TensorLayout{dtype::Float32()}, inp_cn, {}}},
false};
}
DnnOprHelper<megdnn::GroupNorm> dnn_opr(group_norm.param());
auto&& [oup_layout, mean_layout, rstd_layout] =
dnn_opr.deduce_layouts<3>(inp.layout, TensorLayout{}, TensorLayout{});
return {{{oup_layout, inp_cn, {}},
{mean_layout, inp_cn, {}},
{rstd_layout, inp_cn, {}}},
true};
}
SmallVector<TensorPtr> apply_on_physical_tensor(
const OpDef& def, const SmallVector<TensorPtr>& inputs,
SmallVector<LogicalTensorDesc>& output_descs, const bool& validated) {
auto&& op_def = def.cast_final_safe<GroupNorm>();
size_t nr_inp = inputs.size();
auto p = op_def.param();
mgb_assert(
(nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine),
"num of inputs of groupnorm should be 1 or 3 but you give %zu",
inputs.size());
auto cn = inputs[0]->comp_node();
DnnOprCaller<megdnn::GroupNorm> caller(cn, op_def.param());
auto&& [oup_layout, mean_layout, rstd_layout] = caller.deduce_layouts<3>(
inputs[0]->layout(), TensorLayout{}, TensorLayout{});
auto out = Tensor::make(oup_layout, cn);
auto mean = Tensor::make(mean_layout, cn);
auto rstd = Tensor::make(rstd_layout, cn);
if (p.affine) {
caller.exec_with_ws(inputs[0], inputs[1], inputs[2], out, mean, rstd);
} else {
megdnn::TensorND empty_dnn;
caller.exec_with_ws(inputs[0], empty_dnn, empty_dnn, out, mean, rstd);
}
return {out, mean, rstd};
}
OP_TRAIT_REG(GroupNorm, GroupNorm)
.apply_on_var_node(apply_on_var_node)
.infer_output_attrs_fallible(infer_output_attrs_fallible)
.apply_on_physical_tensor(apply_on_physical_tensor)
.fallback();
} // namespace group_norm
} // namespace mgb::imperative
905bdf78e5413b06873be64b4ba55db9 ../../dnn/scripts/opr_param_defs.py
da03ffe2a15411f902cd88920d3d47ec ../../src/core/include/megbrain/ir/ops.td
5756619f37e4dc130e1b049d7706d4eb generated/opdef.h.inl
98d1291eed73970ee087f898b6241358 generated/opdef.cpp.inl
b1a9c7569392942294c2168d40939eb5 generated/opdef.py.inl
3d88d5358d15a39219957f5257e32f5b generated/opdef.cpy.inl
e38b68be4e2aaf3de2f22e3dddbeaac4 ../../dnn/scripts/opr_param_defs.py
cf864561de125ab559c0035158656682 ../../src/core/include/megbrain/ir/ops.td
9248d42a9b3e770693306992156f6015 generated/opdef.h.inl
5c7e7ac49d1338d70ac84ba309e6732b generated/opdef.cpp.inl
30b669eec36876a65717e0c68dd76c83 generated/opdef.py.inl
d10455217f5f01e3d2668e5689068920 generated/opdef.cpy.inl
71e1462bf4d882e2615c3c632cb671cc generated/enum_macro.h
......@@ -3775,6 +3775,110 @@ OP_TRAIT_REG(GroupLocal, GroupLocal)
.props(GroupLocal_props_impl)
.make_name(GroupLocal_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNorm);
namespace {
size_t GroupNorm_hash_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<GroupNorm>();
static_cast<void>(op_);
size_t val = mgb::hash(op_.dyn_typeinfo());
val = mgb::hash_pair_combine(val, mgb::hash(op_.affine));
val = mgb::hash_pair_combine(val, mgb::hash(op_.eps));
val = mgb::hash_pair_combine(val, mgb::hash(op_.group));
val = mgb::hash_pair_combine(val, mgb::enumhash()(op_.format));
return val;
}
bool GroupNorm_is_same_st_impl(const OpDef& lhs_, const OpDef& rhs_) {
auto &&a_ = lhs_.cast_final_safe<GroupNorm>(),
&&b_ = rhs_.cast_final_safe<GroupNorm>();
static_cast<void>(a_);
static_cast<void>(b_);
if (a_.affine != b_.affine) return false;
if (a_.eps != b_.eps) return false;
if (a_.group != b_.group) return false;
if (a_.format != b_.format) return false;
return true;
}
std::vector<std::pair<const char*, std::string>> GroupNorm_props_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<GroupNorm>();
static_cast<void>(op_);
std::vector<std::pair<const char*, std::string>> props_;
props_.emplace_back("affine", std::to_string(op_.affine));
props_.emplace_back("eps", std::to_string(op_.eps));
props_.emplace_back("group", std::to_string(op_.group));
switch (op_.format){
case GroupNorm::Format::NCHW:
props_.emplace_back("format", "NCHW");
break;
case GroupNorm::Format::NHWC:
props_.emplace_back("format", "NHWC");
break;
case GroupNorm::Format::NHWCD4:
props_.emplace_back("format", "NHWCD4");
break;
case GroupNorm::Format::NCHW4:
props_.emplace_back("format", "NCHW4");
break;
case GroupNorm::Format::NCHW8:
props_.emplace_back("format", "NCHW8");
break;
case GroupNorm::Format::NCHW32:
props_.emplace_back("format", "NCHW32");
break;
case GroupNorm::Format::NCHW88:
props_.emplace_back("format", "NCHW88");
break;
case GroupNorm::Format::NCHW44:
props_.emplace_back("format", "NCHW44");
break;
case GroupNorm::Format::NCHW44_DOT:
props_.emplace_back("format", "NCHW44_DOT");
break;
case GroupNorm::Format::NCHW4_NCHW32:
props_.emplace_back("format", "NCHW4_NCHW32");
break;
case GroupNorm::Format::NCHW32_NCHW4:
props_.emplace_back("format", "NCHW32_NCHW4");
break;
case GroupNorm::Format::NCHW4_NCHW:
props_.emplace_back("format", "NCHW4_NCHW");
break;
case GroupNorm::Format::NHWC_NCHW:
props_.emplace_back("format", "NHWC_NCHW");
break;
case GroupNorm::Format::NHWC_NCHW4_IC_SMALL:
props_.emplace_back("format", "NHWC_NCHW4_IC_SMALL");
break;
case GroupNorm::Format::NCHW_NCHW4_IC_SMALL:
props_.emplace_back("format", "NCHW_NCHW4_IC_SMALL");
break;
case GroupNorm::Format::CHWN4:
props_.emplace_back("format", "CHWN4");
break;
case GroupNorm::Format::NCHW64:
props_.emplace_back("format", "NCHW64");
break;
case GroupNorm::Format::NCHW4_NHWC:
props_.emplace_back("format", "NCHW4_NHWC");
break;
default:
props_.emplace_back("format", "INVALID");
break;
}
return props_;
}
std::string GroupNorm_make_name_impl(const OpDef& def_) {
auto&& op_ = def_.cast_final_safe<GroupNorm>();
static_cast<void>(op_);
return "GroupNorm";
}
} // anonymous namespace
OP_TRAIT_REG(GroupNorm, GroupNorm)
.hash(GroupNorm_hash_impl)
.is_same_st(GroupNorm_is_same_st_impl)
.props(GroupNorm_props_impl)
.make_name(GroupNorm_make_name_impl);
MGB_DYN_TYPE_OBJ_FINAL_IMPL(Identity);
namespace {
......
......@@ -10075,6 +10075,158 @@ void _init_py_GroupLocal(py::module m) {
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupLocal::typeinfo(), &py_type).second);
}
void _init_py_GroupNorm_Format(PyTypeObject& py_type) {
auto& e_type = EnumWrapper<GroupNorm::Format>::type;
Py_INCREF(e_type);
mgb_assert(PyDict_SetItemString(
py_type.tp_dict, "Format", reinterpret_cast<PyObject*>(e_type)) >= 0);
}
PyOpDefBegin(GroupNorm) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
static PyObject* getstate(PyObject* self, PyObject*) {
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst();
static_cast<void>(opdef);
std::unordered_map<std::string, py::object> state {
{"affine", serialization<decltype(opdef.affine)>::dump(opdef.affine)},
{"eps", serialization<decltype(opdef.eps)>::dump(opdef.eps)},
{"group", serialization<decltype(opdef.group)>::dump(opdef.group)},
{"format", serialization<decltype(opdef.format)>::dump(opdef.format)}
};
return py::cast(state).release().ptr();
}
static PyObject* setstate(PyObject* self, PyObject* args) {
PyObject* dict = PyTuple_GetItem(args, 0);
if (!dict) return NULL;
auto state = py::cast<std::unordered_map<std::string, py::object>>(dict);
auto& opdef = reinterpret_cast<PyOp(GroupNorm)*>(self)->inst();
static_cast<void>(opdef);
{
auto&& iter = state.find("affine");
if (iter != state.end()) {
opdef.affine = serialization<decltype(opdef.affine)>::load(iter->second);
}
}
{
auto&& iter = state.find("eps");
if (iter != state.end()) {
opdef.eps = serialization<decltype(opdef.eps)>::load(iter->second);
}
}
{
auto&& iter = state.find("group");
if (iter != state.end()) {
opdef.group = serialization<decltype(opdef.group)>::load(iter->second);
}
}
{
auto&& iter = state.find("format");
if (iter != state.end()) {
opdef.format = serialization<decltype(opdef.format)>::load(iter->second);
}
}
Py_RETURN_NONE;
}
static int py_init(PyObject *self, PyObject *args, PyObject *kwds);
// };
PyOpDefEnd(GroupNorm)
int PyOp(GroupNorm)::py_init(PyObject *self, PyObject *args, PyObject *kwds) {
static const char* kwlist[] = {"affine", "eps", "group", "format", "scope", NULL};
PyObject *affine = NULL, *eps = NULL, *group = NULL, *format = NULL, *scope = NULL;
if (!PyArg_ParseTupleAndKeywords(args, kwds, "|OOOOO", const_cast<char**>(kwlist), &affine, &eps, &group, &format, &scope))
return -1;
if (affine) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().affine =
py::cast<decltype(GroupNorm::affine)>(py::handle(affine));
} CATCH_ALL(-1)
}
if (eps) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().eps =
py::cast<decltype(GroupNorm::eps)>(py::handle(eps));
} CATCH_ALL(-1)
}
if (group) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().group =
py::cast<decltype(GroupNorm::group)>(py::handle(group));
} CATCH_ALL(-1)
}
if (format) {
try {
// TODO: remove this guard which is used for pybind11 implicit conversion
py::detail::loader_life_support guard{};
reinterpret_cast<PyOp(GroupNorm)*>(self)->inst().format =
py::cast<decltype(GroupNorm::format)>(py::handle(format));
} CATCH_ALL(-1)
}
if (scope) {
try {
reinterpret_cast<PyOp(OpDef)*>(self)->op
->set_scope(py::cast<std::string>(py::handle(scope)));
} CATCH_ALL(-1)
}
return 0;
}
PyGetSetDef PyOp(GroupNorm)::py_getsetters[] = {
{const_cast<char*>("affine"), py_get_generic(GroupNorm, affine), py_set_generic(GroupNorm, affine), const_cast<char*>("affine"), NULL},
{const_cast<char*>("eps"), py_get_generic(GroupNorm, eps), py_set_generic(GroupNorm, eps), const_cast<char*>("eps"), NULL},
{const_cast<char*>("group"), py_get_generic(GroupNorm, group), py_set_generic(GroupNorm, group), const_cast<char*>("group"), NULL},
{const_cast<char*>("format"), py_get_generic(GroupNorm, format), py_set_generic(GroupNorm, format), const_cast<char*>("format"), NULL},
{NULL} /* Sentinel */
};
PyMethodDef PyOp(GroupNorm)::tp_methods[] = {
{const_cast<char*>("__getstate__"), PyOp(GroupNorm)::getstate, METH_NOARGS, "GroupNorm getstate"},
{const_cast<char*>("__setstate__"), PyOp(GroupNorm)::setstate, METH_VARARGS, "GroupNorm setstate"},
{NULL} /* Sentinel */
};
void _init_py_GroupNorm(py::module m) {
using py_op = PyOp(GroupNorm);
auto& py_type = PyOpType(GroupNorm);
py_type = {PyVarObject_HEAD_INIT(NULL, 0)};
py_type.tp_name = "megengine.core._imperative_rt.ops.GroupNorm";
py_type.tp_basicsize = sizeof(PyOp(GroupNorm));
py_type.tp_flags = Py_TPFLAGS_DEFAULT | Py_TPFLAGS_BASETYPE;
py_type.tp_doc = "GroupNorm";
py_type.tp_base = &PyOpType(OpDef);
py_type.tp_dealloc = py_dealloc_generic<py_op>;
py_type.tp_new = py_new_generic<py_op>;
py_type.tp_init = py_op::py_init;
py_type.tp_methods = py_op::tp_methods;
py_type.tp_getset = py_op::py_getsetters;
mgb_assert(PyType_Ready(&py_type) >= 0);
_init_py_GroupNorm_Format(py_type);
PyType_Modified(&py_type);
m.add_object("GroupNorm", reinterpret_cast<PyObject*>(&py_type));
mgb_assert(PyOp(OpDef)::ctype2pytype.emplace(GroupNorm::typeinfo(), &py_type).second);
}
PyOpDefBegin(Identity) // {
static PyGetSetDef py_getsetters[];
static PyMethodDef tp_methods[];
......@@ -19237,6 +19389,7 @@ void _init_py_WarpPerspectiveBackwardMat(py::module m) {
_init_py_GaussianRNG(m); \
_init_py_GetVarShape(m); \
_init_py_GroupLocal(m); \
_init_py_GroupNorm(m); \
_init_py_Identity(m); \
_init_py_Images2Neibs(m); \
_init_py_IncrMeshIndexing(m); \
......
......@@ -988,6 +988,23 @@ public:
}
};
class GroupNorm : public OpDefImplBase<GroupNorm> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
public:
using Format = ::megdnn::param::GroupNorm::Format;
bool affine = true;
float eps = 1e-5f;
uint32_t group = 1;
Format format = ::megdnn::param::GroupNorm::Format::NCHW;
GroupNorm() = default;
GroupNorm(bool affine_, float eps_, uint32_t group_, Format format_, std::string scope_ = {}): affine(affine_), eps(eps_), group(group_), format(format_) { set_scope(scope_); }
GroupNorm(::megdnn::param::GroupNorm packed_param_0): affine(packed_param_0.affine), eps(packed_param_0.eps), group(packed_param_0.group), format(packed_param_0.format) {}
::megdnn::param::GroupNorm param() const {
return {affine, eps, group, format};
}
};
class Identity : public OpDefImplBase<Identity> {
MGB_DYN_TYPE_OBJ_FINAL_DECL;
......
......@@ -1193,6 +1193,17 @@ GroupLocalInst
.def_readwrite("format", &GroupLocal::format)
.def_readwrite("compute_mode", &GroupLocal::compute_mode);
py::class_<GroupNorm, std::shared_ptr<GroupNorm>, OpDef> GroupNormInst(m, "GroupNorm");
GroupNormInst.attr("Format") = AdaptivePoolingInst.attr("Format");
GroupNormInst
.def(py::init<bool, float, uint32_t, ::megdnn::param::GroupNorm::Format, std::string>(), py::arg("affine") = true, py::arg("eps") = 1e-5f, py::arg("group") = 1, py::arg("format") = ::megdnn::param::GroupNorm::Format::NCHW, py::arg("scope") = {})
.def_readwrite("affine", &GroupNorm::affine)
.def_readwrite("eps", &GroupNorm::eps)
.def_readwrite("group", &GroupNorm::group)
.def_readwrite("format", &GroupNorm::format);
py::class_<Identity, std::shared_ptr<Identity>, OpDef> IdentityInst(m, "Identity");
IdentityInst
......
......@@ -490,6 +490,8 @@ def LRN: MgbHashableOp<"LRN", [LRNParam]>;
def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>;
def GroupNorm: MgbHashableOp<"GroupNorm", [GroupNormParam]>;
def LAMBUpdate: MgbHashableOp<"LAMBUpdate", [LAMBUpdateParam]>;
def RNNCell: MgbHashableOp<"RNNCell", [RNNCellParam]>;
......
#include "megbrain/opr/basic_arith.h"
#include "megbrain/opr/dnn/adaptive_pooling.h"
#include "megbrain/opr/dnn/batch_norm.h"
#include "megbrain/opr/dnn/convolution.h"
#include "megbrain/opr/dnn/correlation.h"
#include "megbrain/opr/dnn/fake_quant.h"
#include "megbrain/opr/dnn/group_norm.h"
#include "megbrain/opr/dnn/images2neibs.h"
#include "megbrain/opr/dnn/layer_norm.h"
#include "megbrain/opr/dnn/local.h"
......@@ -15,6 +17,9 @@
#include "megbrain/opr/dnn/sliding_window_transpose.h"
#include "megbrain/opr/dnn/softmax.h"
#include "megbrain/opr/dnn/tqt.h"
#include "megbrain/opr/io.h"
#include "megbrain/opr/tensor_manip.h"
#include "megbrain/serialization/oss_opr_load_dump.h"
#include "megbrain/serialization/sereg.h"
#include "megdnn/opr_param_defs.h"
#include "megdnn/oprs/nn.h"
......@@ -524,6 +529,213 @@ struct OprMaker<opr::LayerNormBackward, 0> {
}
};
template <>
struct OprMaker<opr::GroupNorm, 0> {
using Param = opr::GroupNorm::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 3) {
return opr::GroupNorm::make(i[0], i[1], i[2], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 1);
return opr::GroupNorm::make(i[0], param, config)[0].node()->owner_opr();
}
}
};
template <>
struct OprLoadDumpImplV2<opr::GroupNorm, 0> {
using Opr = opr::GroupNorm;
using Param = opr::GroupNorm::Param;
using ElemwiseParam = opr::Elemwise::Param;
using ReduceParam = opr::Reduce::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto graph = inputs[0]->owner_graph();
auto comp_node = inputs[0]->comp_node();
// std::unique_ptr<StaticInferManager> m_static_infer_manager;
auto opr_param = opr->cast_final_safe<opr::GroupNorm>().param();
float eps = opr_param.eps;
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5));
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps));
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node});
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node});
auto origin_shape = opr::GetVarShape::make(inputs[0]).node();
TensorShape input_shape =
inputs[0]->owner_graph()->static_infer_manager().infer_shape(inputs[0]);
size_t N = input_shape[0];
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3];
int group = opr_param.group;
int size = inner_size / group;
HostTensorND hv = HostTensorND(inputs[0]->comp_node(), {3}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = N;
ptr[1] = group;
ptr[2] = size;
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node});
auto inp = opr::Reshape::make(inputs[0], target_shape);
auto mean = opr::Reduce::make(inp, {ReduceParam::Mode::MEAN, 2});
auto elemwise1 = opr::Elemwise::make({inp, inp}, {ElemwiseParam::Mode::MUL});
auto temp_var = opr::Reduce::make(elemwise1, {ReduceParam::Mode::MEAN, 2});
auto elemwise2 = opr::Elemwise::make({mean, mean}, {ElemwiseParam::Mode::MUL});
auto var =
opr::Elemwise::make({temp_var, elemwise2}, {ElemwiseParam::Mode::SUB});
auto add_var = opr::Elemwise::make({var, eps_node}, {ElemwiseParam::Mode::ADD});
auto sqrt =
opr::Elemwise::make({add_var, half_node}, {ElemwiseParam::Mode::POW});
auto div = opr::Elemwise::make({inp, mean}, {ElemwiseParam::Mode::SUB});
auto temp_inp =
opr::Elemwise::make({div, sqrt}, {ElemwiseParam::Mode::TRUE_DIV});
auto res = opr::Reshape::make(temp_inp, origin_shape);
if (inputs.size() == 3) {
auto mul_temp =
opr::Elemwise::make({res, inputs[1]}, {ElemwiseParam::Mode::MUL});
auto res = opr::Elemwise::make(
{mul_temp, inputs[2]}, {ElemwiseParam::Mode::ADD});
return res.node()->owner_opr();
} else {
return res.node()->owner_opr();
}
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
// auto& fbs_ctx = CAST_TO_FBS_V2_CTX(ctx);
return OprMaker<opr::GroupNorm, 0>::make(
ctx.read_param<Param>(), inputs, ctx.graph(), config);
}
};
// OprMaker in MGB_SEREG_OPR only support unique output opr
template <>
struct OprMaker<opr::GroupNormBackward, 0> {
using Param = opr::GroupNormBackward::Param;
static cg::OperatorNodeBase* make(
const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph,
const OperatorNodeConfig& config) {
MGB_MARK_USED_VAR(graph);
if (i.size() == 5) {
return opr::GroupNormBackward::make(
i[0], i[1], i[2], i[3], i[4], param, config)[0]
.node()
->owner_opr();
} else {
mgb_assert(i.size() == 4);
return opr::GroupNormBackward::make(
i[0], i[1], i[2], i[3], param, config)[0]
.node()
->owner_opr();
}
}
};
template <>
struct OprLoadDumpImplV2<opr::GroupNormBackward, 0> {
using Opr = opr::GroupNormBackward;
using Param = opr::GroupNormBackward::Param;
using ElemwiseParam = opr::Elemwise::Param;
using ReduceParam = opr::Reduce::Param;
static void dump(OprDumpContext& ctx, const cg::OperatorNodeBase& opr) {
ctx.write_param<Param>(opr.cast_final_safe<Opr>().param());
}
static cg::OperatorNodeBase* replace_opr(
cg::OperatorNodeBase* opr, const VarNodeArray& inputs) {
auto rstd = inputs[4];
auto graph = inputs[1]->owner_graph();
auto comp_node = inputs[1]->comp_node();
auto opr_param = opr->cast_final_safe<opr::GroupNormBackward>().param();
float eps = opr_param.eps;
auto half = DTypeScalar(static_cast<megdnn::dt_float32>(0.5));
auto param_eps = DTypeScalar(static_cast<megdnn::dt_float32>(eps));
auto half_node = opr::ImmutableTensor::make(*graph, half, {comp_node});
auto eps_node = opr::ImmutableTensor::make(*graph, param_eps, {comp_node});
auto const_node =
opr::ImmutableTensor::make(*graph, DTypeScalar(1), {comp_node});
TensorShape input_shape =
inputs[1]->owner_graph()->static_infer_manager().infer_shape(inputs[0]);
auto origin_shape = opr::GetVarShape::make(inputs[1]).node();
size_t N = input_shape[0];
size_t C = input_shape[1];
size_t inner_size = input_shape[1] * input_shape[2] * input_shape[3];
int group = opr_param.group;
int size = inner_size / group;
HostTensorND hv = HostTensorND(inputs[1]->comp_node(), {3}, dtype::Int32());
auto* ptr = hv.ptr<dt_int32>();
ptr[0] = N;
ptr[1] = group;
ptr[2] = size;
auto target_shape = opr::ImmutableTensor::make(*graph, hv, {comp_node});
auto inp = opr::Reshape::make(inputs[1], target_shape);
auto temp_rstd =
opr::Elemwise::make({rstd, eps_node}, {ElemwiseParam::Mode::ADD});
auto sqrt =
opr::Elemwise::make({temp_rstd, half_node}, {ElemwiseParam::Mode::POW});
auto slice_std = opr::Elemwise::make(
{const_node, sqrt}, {ElemwiseParam::Mode::TRUE_DIV});
auto sub_mean =
opr::Elemwise::make({inp, inputs[3]}, {ElemwiseParam::Mode::SUB});
auto x_hat =
opr::Elemwise::make({sub_mean, slice_std}, {ElemwiseParam::Mode::MUL});
x_hat = opr::Reshape::make(x_hat, origin_shape);
auto size_node =
opr::ImmutableTensor::make(*graph, DTypeScalar(size), {comp_node});
auto temp1 = opr::Elemwise::make(
{slice_std, size_node}, {ElemwiseParam::Mode::TRUE_DIV});
auto dx_hat =
opr::Elemwise::make({inputs[0], inputs[2]}, {ElemwiseParam::Mode::MUL});
HostTensorND tshape = HostTensorND(inputs[1]->comp_node(), {5}, dtype::Int32());
auto* ptr2 = tshape.ptr<dt_int32>();
ptr2[0] = N;
ptr2[1] = group;
ptr2[2] = C / group;
ptr2[3] = input_shape[2];
ptr2[4] = input_shape[3];
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node});
x_hat = opr::Reshape::make(x_hat, target_shape);
dx_hat = opr::Reshape::make(dx_hat, target_shape);
auto temp2 =
opr::Elemwise::make({size_node, dx_hat}, {ElemwiseParam::Mode::MUL});
ptr2[2] = 1;
ptr2[3] = 1;
ptr2[4] = 1;
target_shape = opr::ImmutableTensor::make(*graph, tshape, {comp_node});
auto temp3 = opr::Reduce::make(dx_hat, {ReduceParam::Mode::SUM}, target_shape);
auto sum_dx_hat =
opr::Reduce::make(temp2, {ReduceParam::Mode::SUM}, target_shape);
auto temp4 =
opr::Elemwise::make({x_hat, sum_dx_hat}, {ElemwiseParam::Mode::MUL});
auto temp5 = opr::Elemwise::make({temp2, temp3}, {ElemwiseParam::Mode::SUB});
auto temp6 = opr::Elemwise::make({temp5, temp4}, {ElemwiseParam::Mode::SUB});
auto dx_temp = opr::Elemwise::make({temp1, temp6}, {ElemwiseParam::Mode::MUL});
auto dx = opr::Reshape::make(dx_temp, origin_shape);
return dx.node()->owner_opr();
}
static cg::OperatorNodeBase* load(
OprLoadContext& ctx, const cg::VarNodeArray& inputs,
const OperatorNodeConfig& config) {
return OprMaker<opr::GroupNormBackward, 0>::make(
ctx.read_param<Param>(), inputs, ctx.graph(), config);
}
};
template <class MegDNNConv = megdnn::LocalShare>
struct MakeLocalShareCaller2 {
template <typename Opr>
......@@ -747,6 +959,8 @@ MGB_SEREG_OPR(LSQ, 4);
MGB_SEREG_OPR(LSQBackward, 5);
MGB_SEREG_OPR(LayerNorm, 0);
MGB_SEREG_OPR(LayerNormBackward, 0);
MGB_SEREG_OPR(GroupNorm, 0);
MGB_SEREG_OPR(GroupNormBackward, 0);
MGB_SEREG_OPR(RNNCellForward, 6);
MGB_SEREG_OPR(LSTMCellForward, 7);
MGB_SEREG_OPR(RNNForward, 3);
......@@ -755,6 +969,14 @@ MGB_SEREG_OPR(LSTMForward, 4);
MGB_SEREG_OPR(LSTMBackward, 9);
MGB_SEREG_OPR(Softmax, 1);
MGB_SEREG_OPR(SoftmaxBackward, 2);
MGB_SEREG_OPR_V2(
GroupNorm, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNorm, 0>::replace_opr),
VERSION_2, CURRENT_VERSION);
MGB_SEREG_OPR_V2(
GroupNormBackward, 0,
(mgb::serialization::OprLoadDumpImplV2<opr::GroupNormBackward, 0>::replace_opr),
VERSION_2, CURRENT_VERSION);
} // namespace opr
} // namespace mgb
......
#include "megbrain/opr/dnn/group_norm.h"
#include "megbrain/graph/grad_impl.h"
#include "megbrain/opr/internal/out_shape_by_sym_var.h"
#include "megbrain/opr/utility.h"
#include "../internal/megdnn_opr_wrapper.inl"
using namespace mgb;
using namespace opr;
/* ==================== GroupNormForward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormForward);
GroupNormForward::GroupNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "group_norm", {data, weight, bias}} {
init_megdnn_opr(*this, param);
add_input({data, weight, bias});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
GroupNormForward::GroupNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config)
: Super{data->owner_graph(), config, "group_norm", {data}} {
init_megdnn_opr(*this, param);
add_input({data});
output(0)->dtype(data->dtype());
output(1)->dtype(dtype::Float32());
output(2)->dtype(dtype::Float32());
}
SymbolVarArray GroupNormForward::make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param,
const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<GroupNormForward>(
data.node(), weight.node(), bias.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray GroupNormForward::make(
SymbolVar data, const Param& param, const OperatorNodeConfig& config) {
auto outs = data.node()
->owner_graph()
->insert_opr(std::make_unique<GroupNormForward>(
data.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void GroupNormForward::get_output_var_shape(
const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const {
size_t group = param().group;
out_shape[0] = inp_shape[0];
size_t N = inp_shape[0].shape[0];
TensorShape unnormalized_shape{N, group};
out_shape[1] = unnormalized_shape;
out_shape[2] = unnormalized_shape;
}
size_t GroupNormForward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return intl::MegDNNOprMethInvoker<megdnn::GroupNormForward>::get_workspace_in_bytes(
megdnn_opr(), this, input_shapes, output_shapes);
}
void GroupNormForward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), {}, {},
output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output().back()));
}
}
#if MGB_ENABLE_GRAD
MGB_IMPL_OPR_GRAD(GroupNormForward) {
auto p = opr.param();
SymbolVarArray grad;
VarNodeArray ret;
if (p.affine) {
mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx);
grad = GroupNormBackward::make(
out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2),
opr.param());
} else {
mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx);
grad = GroupNormBackward::make(
out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param());
}
uint32_t nr_ret = p.affine ? 3 : 1;
for (uint32_t i = 0; i < nr_ret; ++i) {
ret.push_back(grad[i].node());
}
return ret;
}
#endif
/* ==================== GroupNormBackward ==================== */
MGB_DYN_TYPE_OBJ_FINAL_IMPL(GroupNormBackward);
GroupNormBackward::GroupNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"group_norm_backward",
{diff, data, weight, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, weight, mean, rstd});
}
GroupNormBackward::GroupNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param,
const OperatorNodeConfig& config)
: Super({diff->owner_graph(),
config,
"group_norm_backward",
{diff, data, mean, rstd}},
0, true) {
init_megdnn_opr(*this, param);
add_input({diff, data, mean, rstd});
auto mark_empty_var = [&](VarNode* var) {
var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE)
.add_flag(VarNode::Flag::VOLATILE_CONTENT);
};
mark_empty_var(output(1));
mark_empty_var(output(2));
}
SymbolVarArray GroupNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<GroupNormBackward>(
diff.node(), data.node(), weight.node(), mean.node(),
rstd.node(), param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
SymbolVarArray GroupNormBackward::make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param, const OperatorNodeConfig& config) {
auto outs = diff.node()
->owner_graph()
->insert_opr(std::make_unique<GroupNormBackward>(
diff.node(), data.node(), mean.node(), rstd.node(),
param, config))
->output();
SymbolVarArray ret;
for (auto&& out : outs) {
ret.emplace_back(out);
}
return ret;
}
void GroupNormBackward::init_output_static_infer_desc() {
using namespace cg::static_infer;
auto&& mgr = owner_graph()->static_infer_manager();
mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1)));
if (param().affine) {
mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2)));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2)));
} else {
TensorShape empty;
empty.ndim = 0;
mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty));
mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty));
}
this->init_output_static_infer_desc_workspace(
intl::AutoAddWorkspaceNeedLimitGetter<megdnn::GroupNormBackward>::val);
}
void GroupNormBackward::init_output_dtype() {
output(0)->dtype(input(1)->dtype());
output(1)->dtype(input(2)->dtype());
output(2)->dtype(input(2)->dtype());
}
size_t GroupNormBackward::get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const {
return intl::MegDNNOprMethInvoker<megdnn::GroupNormBackward>::
get_workspace_in_bytes(megdnn_opr(), this, input_shapes, output_shapes);
}
void GroupNormBackward::scn_do_execute() {
if (param().affine) {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(),
input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
output(1)->dev_tensor().as_megdnn(),
output(2)->dev_tensor().as_megdnn(),
intl::get_megdnn_workspace_from_var(output(3)));
} else {
megdnn_opr()->exec(
input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(),
{}, input(2)->dev_tensor().as_megdnn(),
input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(),
{}, {}, intl::get_megdnn_workspace_from_var(output(3)));
}
}
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#pragma once
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "megdnn/oprs.h"
namespace mgb {
namespace opr {
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
GroupNormForward, intl::MegDNNOprWrapperFwd<megdnn::GroupNormForward>) // {
public:
MGE_WIN_DECLSPEC_FUC GroupNormForward(
VarNode* data, VarNode* weight, VarNode* bias, const Param& param,
const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC GroupNormForward(
VarNode* data, const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar data, const Param& param = {},
const OperatorNodeConfig& config = {});
private:
void get_output_var_shape(
const TensorShapeArray& inp_shape,
TensorShapeArray& out_shape) const override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
using GroupNorm = GroupNormForward;
MGB_DEFINE_OPR_CLASS_WITH_EXPORT(
GroupNormBackward, intl::MegDNNOprWrapperBwd<megdnn::GroupNormBackward>) // {
public:
MGE_WIN_DECLSPEC_FUC GroupNormBackward(
VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC GroupNormBackward(
VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd,
const Param& param, const OperatorNodeConfig& config);
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean,
SymbolVar rstd, const Param& param = {},
const OperatorNodeConfig& config = {});
MGE_WIN_DECLSPEC_FUC static SymbolVarArray make(
SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd,
const Param& param = {}, const OperatorNodeConfig& config = {});
private:
void init_output_static_infer_desc() override;
void init_output_dtype() override;
size_t get_workspace_size_bytes(
const TensorShapeArray& input_shapes,
const TensorShapeArray& output_shapes) const override;
void scn_do_execute() override;
};
} // namespace opr
} // namespace mgb
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
#include "megbrain/opr/dnn/group_norm.h"
#include "megbrain/comp_node_env.h"
#include "megbrain/test/autocheck.h"
#include "megbrain/test/helper.h"
#include "megbrain/test/megdnn_helper.h"
#include "megdnn/oprs.h"
#include <cmath>
#include <iomanip>
#include <random>
#include <sstream>
using namespace mgb;
namespace {
using Param = opr::GroupNormForward::Param;
void run_forward(bool is_affine) {
using Checker = AutoOprChecker<3, 3>;
Param param;
param.eps = 1e-5;
param.affine = is_affine;
param.group = 3;
auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray {
auto out = opr::GroupNormForward::make(inputs[0], inputs[1], inputs[2], param);
return {out[0], out[1], out[2]};
};
auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) {
auto opr =
MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu()))
->create_operator<megdnn::GroupNormForward>();
auto inp_shape = inp[0]->shape();
auto n_slices = inp_shape[0];
opr->param() = param;
dest[0].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize(inp_shape);
dest[1].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, param.group});
dest[2].dtype(dtype::Float32())
.comp_node(inp[0]->comp_node())
.resize({n_slices, param.group});
std::vector<dt_byte> workspace(opr->get_workspace_in_bytes(
inp[0]->layout(), inp[1]->layout(), inp[2]->layout(), dest[0].layout(),
dest[1].layout(), dest[2].layout()));
opr->exec(
inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(),
dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(),
{workspace.data(), workspace.size()});
};
auto gen = [&](HostTensorND& src) {
HostTensorGenerator<dtype::Float32, RandomDistribution::GAUSSIAN> src_gen(0.f);
src = *src_gen(src.shape(), src.comp_node());
};
Checker::RunOptions option;
option.numdiff_max_err = 1e-4;
Checker checker{make_graph, fwd};
checker.set_input_generator(0, gen);
checker.set_input_generator(1, gen);
checker.set_input_generator(2, gen);
checker.set_input_allow_grad(0, false);
checker.set_input_allow_grad(1, false);
checker.set_input_allow_grad(2, false);
checker.set_output_allow_grad(0, false);
checker.set_output_allow_grad(1, false);
checker.set_output_allow_grad(2, false);
checker.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option)
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option)
.run({TensorShape{2, 6, 2, 1}, TensorShape{6}, TensorShape{6}}, option);
}
TEST(TestOprDNN, GroupNormForward) {
REQUIRE_GPU(1);
run_forward(true);
}
} // anonymous namespace
// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}
......@@ -123,6 +123,7 @@ union OperatorParam {
param.LSTM = 89,
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
}
table Operator {
......
......@@ -140,6 +140,7 @@ union OperatorParam {
param.LSTM = 89,
param.Softmax = 90,
param.Diag = 91,
param.GroupNorm = 92,
}
table Operator {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册