From a93741815b875b10c5de4ad30e0dad8a02eed87e Mon Sep 17 00:00:00 2001 From: Megvii Engine Team Date: Wed, 8 Dec 2021 16:35:18 +0800 Subject: [PATCH] feat(mgb/opr): add layernorm forward and backward kernel GitOrigin-RevId: 0cd484e753a4fbfb88cf81ddbde6ad80b844e69d --- dnn/include/megdnn/oprs/nn.h | 69 ++ dnn/scripts/opr_param_defs.py | 7 + dnn/src/common/handle_impl.h | 5 +- dnn/src/common/layer_norm.cpp | 180 +++++ dnn/src/common/opr_trait.h | 2 + dnn/src/cuda/handle_create.cpp | 1 + dnn/src/cuda/layer_norm/layer_norm_cuda.cu | 664 ++++++++++++++++++ dnn/src/cuda/layer_norm/layer_norm_cuda.cuh | 34 + dnn/src/cuda/layer_norm/opr_impl.cpp | 94 +++ dnn/src/cuda/layer_norm/opr_impl.h | 53 ++ dnn/src/naive/handle.cpp | 1 + dnn/src/naive/layer_norm/opr_impl.cpp | 170 +++++ dnn/src/naive/layer_norm/opr_impl.h | 51 ++ dnn/test/common/deduce_layout_proxy.h | 9 + dnn/test/cuda/layer_norm.cpp | 94 +++ imperative/python/megengine/functional/nn.py | 87 +-- .../test/unit/functional/test_functional.py | 55 -- .../python/test/unit/functional/test_loss.py | 2 +- imperative/src/impl/ops/specializations.cpp | 26 + src/core/include/megbrain/ir/ops.td | 2 + src/opr/impl/dnn/dnn.sereg.h | 44 ++ src/opr/impl/dnn/layer_norm.cpp | 248 +++++++ src/opr/include/megbrain/opr/dnn/layer_norm.h | 78 ++ src/opr/test/dnn/layer_norm.cpp | 108 +++ src/serialization/impl/schema.fbs | 1 + 25 files changed, 1960 insertions(+), 125 deletions(-) create mode 100644 dnn/src/common/layer_norm.cpp create mode 100644 dnn/src/cuda/layer_norm/layer_norm_cuda.cu create mode 100644 dnn/src/cuda/layer_norm/layer_norm_cuda.cuh create mode 100644 dnn/src/cuda/layer_norm/opr_impl.cpp create mode 100644 dnn/src/cuda/layer_norm/opr_impl.h create mode 100644 dnn/src/naive/layer_norm/opr_impl.cpp create mode 100644 dnn/src/naive/layer_norm/opr_impl.h create mode 100644 dnn/test/cuda/layer_norm.cpp create mode 100644 src/opr/impl/dnn/layer_norm.cpp create mode 100644 src/opr/include/megbrain/opr/dnn/layer_norm.h create mode 100644 src/opr/test/dnn/layer_norm.cpp diff --git a/dnn/include/megdnn/oprs/nn.h b/dnn/include/megdnn/oprs/nn.h index e5ea399ca..cfac433fb 100644 --- a/dnn/include/megdnn/oprs/nn.h +++ b/dnn/include/megdnn/oprs/nn.h @@ -1936,6 +1936,75 @@ protected: const TensorLayout& grad_s, size_t workspace_in_bytes); }; +class LayerNormBase : public OperatorBase { + DEF_OPR_IMPL_CTOR(LayerNormBase, OperatorBase); + DEF_OPR_PARAM(LayerNorm); + +protected: + void deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, + TensorLayout& rstd); + void check_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd); +}; + +class LayerNormForward : public LayerNormBase { + DEF_OPR_IMPL(LayerNormForward, LayerNormBase, 3, 3); + +public: + virtual void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, TensorLayout& dst, TensorLayout& mean, + TensorLayout& rstd); + virtual size_t get_workspace_in_bytes( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd) = 0; + +protected: + void check_exec( + const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& bias, const TensorLayout& dst, const TensorLayout& mean, + const TensorLayout& rstd, size_t workspace_in_bytes); +}; +using LayerNorm = LayerNormForward; + +class LayerNormBackward : public LayerNormBase { + DEF_OPR_IMPL(LayerNormBackward, LayerNormBase, 5, 3); + +public: + virtual void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) = 0; + void deduce_layout( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, TensorLayout& ddata, TensorLayout& dweight, + TensorLayout& dbias); + virtual size_t get_workspace_in_bytes( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias) = 0; + +protected: + void check_exec( + const TensorLayout& diff, const TensorLayout& data, + const TensorLayout& weight, const TensorLayout& mean, + const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias, + size_t workspace_in_bytes); +}; + } // namespace megdnn #include "megdnn/internal/opr_header_epilogue.h" diff --git a/dnn/scripts/opr_param_defs.py b/dnn/scripts/opr_param_defs.py index 76220c99e..08828e8e0 100755 --- a/dnn/scripts/opr_param_defs.py +++ b/dnn/scripts/opr_param_defs.py @@ -1212,3 +1212,10 @@ PADDING_MODES = [Doc('REPLICATE = 0', 'aaaaaa|abcdefgh|hhhhhhh'), member_alias=[(i, 'PADDING_{}'.format(i)) for i in PADDING_MODES] ) ) + +(pdef('LayerNorm') + .add_fields('bool', 'affine', 'true') + .add_fields('float32', 'eps', '1e-5f') + .add_fields('uint64', 'normalized_dim', '1') + .add_fields('uint64', 'normalized_size', '1') + ) diff --git a/dnn/src/common/handle_impl.h b/dnn/src/common/handle_impl.h index 7c3e01a1a..2b5206f96 100644 --- a/dnn/src/common/handle_impl.h +++ b/dnn/src/common/handle_impl.h @@ -209,7 +209,10 @@ private: cb(LSQBackward) \ cb(Fill) \ cb(PaddingForward) \ - cb(PaddingBackward) + cb(PaddingBackward) \ + cb(LayerNormForward) \ + cb(LayerNormBackward) + // clang-format on /*! diff --git a/dnn/src/common/layer_norm.cpp b/dnn/src/common/layer_norm.cpp new file mode 100644 index 000000000..44bb16e11 --- /dev/null +++ b/dnn/src/common/layer_norm.cpp @@ -0,0 +1,180 @@ +/** + * \file dnn/src/common/layer_norm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "megdnn/oprs.h" + +#include "src/common/utils.h" + +namespace megdnn { + +void LayerNormBase::deduce_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { + MEGDNN_MARK_USED_VAR(weight); + MEGDNN_MARK_USED_VAR(bias); + auto p = param(); + TensorShape unnormalized_shape; + unnormalized_shape.ndim = data.ndim - p.normalized_dim; + for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { + unnormalized_shape.shape[i] = data.shape[i]; + } + TensorLayout unnormalized_layout = + TensorLayout(unnormalized_shape, dtype::Float32()); + dst = data; + mean = unnormalized_layout; + rstd = unnormalized_layout; +} + +void LayerNormBase::check_layout_fwd( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd) { + megdnn_assert_contiguous(data); + megdnn_assert_contiguous(weight); + megdnn_assert_contiguous(bias); + megdnn_assert_contiguous(dst); + megdnn_assert_contiguous(mean); + megdnn_assert_contiguous(rstd); + auto errmsg = [&]() { + return megdnn_layout_msg(data) + ", " + megdnn_layout_msg(weight) + ", " + + megdnn_layout_msg(bias) + ", " + megdnn_layout_msg(dst) + ", " + + megdnn_layout_msg(mean) + ", " + megdnn_layout_msg(rstd); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { + if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && + lhs.format == rhs.format)) + return false; + for (size_t i = 0; i < lhs.ndim; ++i) { + if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { + return false; + } + } + return true; + }; + + megdnn_assert(equal_layout(data, dst), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(weight, bias), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); + + auto p = param(); + uint64_t normalized_dim = p.normalized_dim; + size_t unnormalized_dim = data.ndim - normalized_dim; + megdnn_assert( + normalized_dim < data.ndim, + "the dims of normalized shape should smaller than input dims"); + + for (size_t i = 0; i < unnormalized_dim; ++i) { + megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); + } + if (p.affine) { + for (size_t i = 0; i < normalized_dim; ++i) { + megdnn_assert( + data.shape[unnormalized_dim + i] == weight.shape[i], "%s", + errmsg().c_str()); + } + } +} + +void LayerNormForward::deduce_layout( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + TensorLayout& dst, TensorLayout& mean, TensorLayout& rstd) { + deduce_layout_fwd(data, weight, bias, dst, mean, rstd); +} + +void LayerNormForward::check_exec( + const TensorLayout& data, const TensorLayout& weight, const TensorLayout& bias, + const TensorLayout& dst, const TensorLayout& mean, const TensorLayout& rstd, + size_t workspace_in_bytes) { + check_layout_fwd(data, weight, bias, dst, mean, rstd); + auto required_workspace_in_bytes = + get_workspace_in_bytes(data, weight, bias, dst, mean, rstd); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); +} + +void LayerNormBackward::deduce_layout( + const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& mean, const TensorLayout& rstd, TensorLayout& ddata, + TensorLayout& dweight, TensorLayout& dbias) { + MEGDNN_MARK_USED_VAR(diff); + MEGDNN_MARK_USED_VAR(mean); + MEGDNN_MARK_USED_VAR(rstd); + ddata = data; + dweight = weight; + dbias = weight; +} + +void LayerNormBackward::check_exec( + const TensorLayout& diff, const TensorLayout& data, const TensorLayout& weight, + const TensorLayout& mean, const TensorLayout& rstd, const TensorLayout& ddata, + const TensorLayout& dweight, const TensorLayout& dbias, + size_t workspace_in_bytes) { + auto p = param(); + auto required_workspace_in_bytes = get_workspace_in_bytes( + diff, data, weight, mean, rstd, ddata, dweight, dbias); + megdnn_assert(workspace_in_bytes >= required_workspace_in_bytes); + + megdnn_assert_contiguous(diff); + megdnn_assert_contiguous(data); + megdnn_assert_contiguous(mean); + megdnn_assert_contiguous(rstd); + megdnn_assert_contiguous(ddata); + if (p.affine) { + megdnn_assert_contiguous(weight); + megdnn_assert_contiguous(dweight); + megdnn_assert_contiguous(dbias); + } + + auto errmsg = [&]() { + return megdnn_layout_msg(diff) + ", " + megdnn_layout_msg(data) + ", " + + megdnn_layout_msg(weight) + ", " + megdnn_layout_msg(mean) + ", " + + megdnn_layout_msg(rstd) + ", " + megdnn_layout_msg(ddata) + ", " + + megdnn_layout_msg(dweight) + ", " + megdnn_layout_msg(dbias); + }; + MEGDNN_MARK_USED_VAR(errmsg); + + auto equal_layout = [](const TensorLayout& lhs, const TensorLayout& rhs) -> bool { + if (!(lhs.ndim == rhs.ndim && lhs.dtype == rhs.dtype && + lhs.format == rhs.format)) + return false; + for (size_t i = 0; i < lhs.ndim; ++i) { + if (lhs.shape[i] != rhs.shape[i] || lhs.stride[i] != rhs.stride[i]) { + return false; + } + } + return true; + }; + + megdnn_assert(equal_layout(data, ddata), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(mean, rstd), "%s", errmsg().c_str()); + if (p.affine) { + megdnn_assert(equal_layout(weight, dweight), "%s", errmsg().c_str()); + megdnn_assert(equal_layout(weight, dbias), "%s", errmsg().c_str()); + } + + size_t normalized_dim = p.normalized_dim; + size_t unnormalized_dim = data.ndim - normalized_dim; + + for (size_t i = 0; i < unnormalized_dim; ++i) { + megdnn_assert(data.shape[i] == mean.shape[i], "%s", errmsg().c_str()); + } + if (p.affine) { + for (size_t i = 0; i < normalized_dim; ++i) { + megdnn_assert( + data.shape[unnormalized_dim + i] == weight.shape[i], "%s", + errmsg().c_str()); + } + } +} + +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/common/opr_trait.h b/dnn/src/common/opr_trait.h index 8999b736d..851b5d8ed 100644 --- a/dnn/src/common/opr_trait.h +++ b/dnn/src/common/opr_trait.h @@ -135,6 +135,8 @@ DEF(CheckNonFinite, 2, true, true); DEF(LSQForward, 5, true, true); DEF(LSQBackward, 7, true, false); DEF(Fill, 1, true, false); +DEF(LayerNormForward, 6, true, true); +DEF(LayerNormBackward, 8, true, true); } // namespace megdnn // vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/handle_create.cpp b/dnn/src/cuda/handle_create.cpp index 03858f5f6..6738740a5 100644 --- a/dnn/src/cuda/handle_create.cpp +++ b/dnn/src/cuda/handle_create.cpp @@ -45,6 +45,7 @@ #include "src/cuda/images2neibs/opr_impl.h" #include "src/cuda/indexing_multi_axis_vec/opr_impl.h" #include "src/cuda/indexing_one_hot/opr_impl.h" +#include "src/cuda/layer_norm/opr_impl.h" #include "src/cuda/linspace/opr_impl.h" #include "src/cuda/local/opr_impl.h" #include "src/cuda/local_share/opr_impl.h" diff --git a/dnn/src/cuda/layer_norm/layer_norm_cuda.cu b/dnn/src/cuda/layer_norm/layer_norm_cuda.cu new file mode 100644 index 000000000..2cca694a4 --- /dev/null +++ b/dnn/src/cuda/layer_norm/layer_norm_cuda.cu @@ -0,0 +1,664 @@ +/** + * \file dnn/src/cuda/layer_norm/layer_norm_cuda.cu + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include +#include +#include +#include "megdnn/arch.h" +#include "megdnn/dtype.h" +#include "src/cuda/cuda_shfl_compat.cuh" +#include "src/cuda/layer_norm/layer_norm_cuda.cuh" +#include "src/cuda/utils.cuh" + +namespace megdnn { +namespace cuda { +namespace layer_norm { + +constexpr int kCUDANumThreads = 256; +constexpr int vec_size = 4; + +// warp size may be used as array length, or used in host function, +// so we define WARP_SIZE rather than using warpSize +#define WARP_SIZE 32 + +#if defined(__clang__) +#define __ubsan_ignore_float_divide_by_zero__ \ + __attribute__((no_sanitize("float-divide-by-zero"))) +#else +#define __ubsan_ignore_float_divide_by_zero__ +#endif + +struct WelfordStat { + float mean; + float sigma2; + float count; + MEGDNN_HOST MEGDNN_DEVICE WelfordStat() : mean(0.f), sigma2(0.f), count(0.f) {} + MEGDNN_HOST MEGDNN_DEVICE WelfordStat(float mean, float sigma2, float count) + : mean(mean), sigma2(sigma2), count(count) {} +}; + +template +struct WelfordData { + T mean; + T sigma2; + combine_t count; + + MEGDNN_HOST MEGDNN_DEVICE WelfordData() : mean(0), sigma2(0), count(0) {} + + MEGDNN_HOST MEGDNN_DEVICE WelfordData(T mean, T sigma2, combine_t count) + : mean(mean), sigma2(sigma2), count(count) {} +}; + +template +struct WelfordOps { +public: + using WelfordData_T = WelfordData; + inline MEGDNN_DEVICE WelfordData_T reduce(WelfordData_T acc, T data) const { + T delta = data - acc.mean; + T new_mean = static_cast(acc.mean + delta / (acc.count + 1)); + T new_delta = static_cast(data - new_mean); + return { + new_mean, + acc.sigma2 + delta * new_delta, + combine_t(acc.count + 1), + }; + } + inline MEGDNN_DEVICE WelfordData_T + combine(WelfordData_T lhs, WelfordData_T rhs) const { + if (lhs.count != 0 && rhs.count != 0) { + T delta = rhs.mean - lhs.mean; + combine_t new_count = lhs.count + rhs.count; + T nb_over_n = rhs.count / new_count; + return {lhs.mean + delta * nb_over_n, + lhs.sigma2 + rhs.sigma2 + delta * delta * lhs.count * nb_over_n, + new_count}; + } else { + return (lhs.count != 0) ? lhs : rhs; + } + } + inline MEGDNN_DEVICE res_t + project(WelfordData_T acc) const __ubsan_ignore_float_divide_by_zero__ { + const auto mean = static_cast(acc.mean); + const combine_t divisor = static_cast(acc.count); + const auto var = acc.sigma2 / divisor; + res_t results(var, mean); + return results; + } + +#if defined(__CUDACC__) || defined(__HIPCC__) + inline MEGDNN_DEVICE WelfordData_T + warp_shfl_down(WelfordData_T acc, int offset) const { + return {__shfl_down(acc.mean, offset, warpSize), + __shfl_down(acc.sigma2, offset, warpSize), + __shfl_down(acc.count, offset, warpSize)}; + } +#endif + MEGDNN_HOST MEGDNN_DEVICE WelfordOps() {} +}; + +template +struct alignas(sizeof(T) * vec_size) aligned_vector { + T val[vec_size]; +}; + +template +using acc_type = T; + +template +MEGDNN_DEVICE WelfordStat +update_welford_stat_online(const U val, const WelfordStat& curr_sum) { + U delta = static_cast(val - curr_sum.mean); + U new_count = static_cast(curr_sum.count + 1.f); + U new_mean = static_cast(curr_sum.mean + delta * (1.f / new_count)); + return {new_mean, curr_sum.sigma2 + delta * (val - new_mean), new_count}; +} + +MEGDNN_DEVICE WelfordStat +combine_welford_stat(const WelfordStat lhs, const WelfordStat rhs) { + using U = decltype(lhs.count); + U delta = lhs.mean - rhs.mean; + U count = rhs.count + lhs.count; + U mean, sigma2; + if (count > decltype(lhs.count){0}) { + auto coef = 1.f / count; + auto nA = rhs.count * coef; + auto nB = lhs.count * coef; + mean = nA * rhs.mean + nB * lhs.mean; + sigma2 = rhs.sigma2 + lhs.sigma2 + delta * delta * rhs.count * nB; + } else { + mean = U(0); + sigma2 = U(0); + } + return {mean, sigma2, count}; +} + +template +MEGDNN_DEVICE WelfordStat +compute_stats(const T* __restrict__ X, const int slice_len, float* buf) { + using vec_t = aligned_vector; + using acc_t = acc_type; + const vec_t* X_vec = reinterpret_cast(X); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = slice_len / vec_size; + WelfordStat w_stat(0.f, 0.f, 0.f); + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + w_stat = update_welford_stat_online( + static_cast(data.val[ii]), w_stat); + } + } + // intra-warp reduction +#pragma unroll + for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { + WelfordStat w_tmp{ + __shfl_down(w_stat.mean, offset, warpSize), + __shfl_down(w_stat.sigma2, offset, warpSize), + __shfl_down(w_stat.count, offset, warpSize)}; + w_stat = combine_welford_stat(w_stat, w_tmp); + } + + // threadIdx.x == 0 has correct values for each warp + // inter-warp reductions + if (blockDim.y > 1) { + float* mean_sigma_buf = buf; + float* count_buf = buf + blockDim.y; + for (int offset = blockDim.y / 2; offset > 0; offset /= 2) { + // upper half of warps write to shared + if (threadIdx.x == 0 && threadIdx.y >= offset && threadIdx.y < 2 * offset) { + const int wrt_y = threadIdx.y - offset; + mean_sigma_buf[2 * wrt_y] = w_stat.mean; + mean_sigma_buf[2 * wrt_y + 1] = w_stat.sigma2; + count_buf[wrt_y] = w_stat.count; + } + __syncthreads(); + + // lower half merges + if (threadIdx.x == 0 && threadIdx.y < offset) { + WelfordStat w_tmp{ + mean_sigma_buf[2 * threadIdx.y], + mean_sigma_buf[2 * threadIdx.y + 1], count_buf[threadIdx.y]}; + w_stat = combine_welford_stat(w_stat, w_tmp); + } + __syncthreads(); + } + if (threadIdx.x == 0 && threadIdx.y == 0) { + mean_sigma_buf[0] = w_stat.mean; + mean_sigma_buf[1] = w_stat.sigma2 / float(slice_len); + } + __syncthreads(); + return WelfordStat{mean_sigma_buf[0], mean_sigma_buf[1], 0.f}; + + } else { + return WelfordStat{ + __shfl(w_stat.mean, 0, warpSize), + __shfl(w_stat.sigma2, 0, warpSize) / float(slice_len), 0.f}; + } +} + +template +__global__ void vectorized_layer_norm_forward_affine_kernel( + const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, + const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { + // if we made smem WelfordStat type, there would be bank conflicts, + // as one thread would have to write 3 consecutive floats + extern __shared__ float s_data[]; + + auto slice_id = blockIdx.x; + const T* slice = X + slice_id * slice_len; + WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(slice); + vec_t* Y_vec = reinterpret_cast(Y + slice_id * slice_len); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = slice_len / vec_size; + T_ACC rstd_val = static_cast(rsqrt(slice_w_stat.sigma2 + eps)); + + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; + // computation is performed in T_ACC, X is cast to T_ACC and result is + // implicitly cast to T + +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = static_cast(weight[i * vec_size + ii]) * + (rstd_val * (static_cast(data.val[ii]) - + slice_w_stat.mean)) + + static_cast(bias[i * vec_size + ii]); + } + Y_vec[i] = out; + } + if (thrx == 0) { + mean[slice_id] = slice_w_stat.mean; + rstd[slice_id] = rstd_val; + } +} + +template +__global__ void vectorized_layer_norm_forward_kernel( + const int slice_len, T_ACC eps, const T* __restrict__ X, const T* weight, + const T* bias, T_ACC* mean, T_ACC* rstd, T* Y) { + extern __shared__ float s_data[]; + + auto slice_id = blockIdx.x; + const T* slice = X + slice_id * slice_len; + WelfordStat slice_w_stat = compute_stats(slice, slice_len, s_data); + using vec_t = aligned_vector; + const vec_t* X_vec = reinterpret_cast(slice); + vec_t* Y_vec = reinterpret_cast(Y + slice_id * slice_len); + const int numx = blockDim.x * blockDim.y; + const int thrx = threadIdx.x + threadIdx.y * blockDim.x; + const int n_vec_to_read = slice_len / vec_size; + T_ACC rstd_val = static_cast(rsqrt(slice_w_stat.sigma2 + eps)); + + for (int i = thrx; i < n_vec_to_read; i += numx) { + vec_t data = X_vec[i]; + vec_t out; + +#pragma unroll + for (int ii = 0; ii < vec_size; ii++) { + out.val[ii] = + rstd_val * (static_cast(data.val[ii]) - slice_w_stat.mean); + } + Y_vec[i] = out; + } + if (thrx == 0) { + mean[slice_id] = slice_w_stat.mean; + rstd[slice_id] = rstd_val; + } +} + +template +void launch_vectorized_layer_norm_forward_kernel( + int64_t slice_len, int64_t slice_num, T_ACC eps, const T* X_data, + const T* weight_data, const T* bias_data, T* Y_data, T_ACC* mean_data, + T_ACC* rstd_data, cudaStream_t stream) { + const int num_threads = 128; + const dim3 threads(WARP_SIZE, num_threads / WARP_SIZE, 1); + const dim3 blocks(slice_num); + int nshared = threads.y > 1 ? threads.y * 3 / 2 * sizeof(T_ACC) : 0; + + if (weight_data == nullptr && bias_data == nullptr) { + vectorized_layer_norm_forward_kernel<<>>( + slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, + Y_data); + } else { + vectorized_layer_norm_forward_affine_kernel<<< + blocks, threads, nshared, stream>>>( + slice_len, eps, X_data, weight_data, bias_data, mean_data, rstd_data, + Y_data); + } + after_kernel_launch(); +} + +template +__inline__ MEGDNN_DEVICE T welford_warp_reduce(T val, const ReduceOp& op) { +#pragma unroll + for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { + val = op.combine(val, op.warp_shfl_down(val, offset)); + } + return val; +} + +template +__inline__ MEGDNN_DEVICE T +welford_block_reduce(T val, const ReduceOp& op, const T& identity_element, T* shared) { + const int lid = threadIdx.x % warpSize; + const int wid = threadIdx.x / warpSize; + val = welford_warp_reduce(val, op); + __syncthreads(); + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : identity_element; + if (wid == 0) { + val = welford_warp_reduce(val, op); + } + return val; +} + +template +__global__ void get_input_mean_and_rstd_kernel( + int64_t slice_len, T_ACC eps, const T* X, T_ACC* mean, T_ACC* rstd) { + using WelfordType = WelfordData; + using WelfordOp = WelfordOps>; + + __shared__ typename std::aligned_storage< + sizeof(WelfordType), alignof(WelfordType)>::type val_shared[WARP_SIZE]; + WelfordType* val_shared_ptr = reinterpret_cast(val_shared); + + const int64_t i = blockIdx.x; + WelfordOp welford_op; + WelfordType val( + static_cast(0), static_cast(0), static_cast(0)); + + for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { + const int64_t index = i * slice_len + j; + val = welford_op.reduce(val, static_cast(X[index])); + } + val = welford_block_reduce( + val, welford_op, + WelfordType( + static_cast(0), static_cast(0), + static_cast(0)), + val_shared_ptr); + + if (threadIdx.x == 0) { + T_ACC slice_mean; + T_ACC slice_sigma2; + thrust::tie(slice_sigma2, slice_mean) = welford_op.project(val); + mean[i] = slice_mean; + rstd[i] = rsqrt(slice_sigma2 + eps); + } +} + +template +__global__ void layer_norm_forward_kernel( + int64_t slice_len, const T* X, const T_ACC* mean, const T_ACC* rstd, + const T* weight, const T* bias, T* Y) { + const int64_t i = blockIdx.x; + for (int64_t j = threadIdx.x; j < slice_len; j += blockDim.x) { + const int64_t index = i * slice_len + j; + const T_ACC weight_v = + weight == nullptr ? T_ACC(1) : static_cast(weight[j]); + const T_ACC bias_v = bias == nullptr ? T_ACC(0) : static_cast(bias[j]); + Y[index] = (static_cast(X[index]) - static_cast(mean[i])) * + static_cast(rstd[i]) * weight_v + + bias_v; + } +} + +template +void forward( + T* X, T* weight, T* bias, int64_t slice_num, int64_t slice_len, T_ACC eps, T* Y, + T_ACC* mean, T_ACC* rstd, cudaStream_t stream) { + auto can_vectorize = [&](const T* ptr, int alignment) { + uint64_t addr = reinterpret_cast(ptr); + return addr % alignment == 0; + }; + constexpr int num_vec_elems = vec_size; + constexpr int alignment = num_vec_elems * sizeof(T); + if ((std::is_same::value || std::is_same::value || + std::is_same::value) && + slice_len <= static_cast(1ULL << std::numeric_limits::digits) && + slice_len % num_vec_elems == 0 && can_vectorize(X, alignment) && + can_vectorize(Y, alignment)) { + launch_vectorized_layer_norm_forward_kernel( + slice_len, slice_num, static_cast(eps), X, weight, bias, Y, mean, + rstd, stream); + after_kernel_launch(); + } else { + get_input_mean_and_rstd_kernel + <<>>(slice_len, eps, X, mean, rstd); + after_kernel_launch(); + layer_norm_forward_kernel<<>>( + slice_len, X, mean, rstd, weight, bias, Y); + after_kernel_launch(); + } +} + +template +__inline__ MEGDNN_DEVICE T warp_reduce_sum(T val) { +#pragma unroll + for (int offset = (warpSize >> 1); offset > 0; offset >>= 1) { + val += __shfl_down(val, offset, warpSize); + } + return val; +} + +template +__inline__ MEGDNN_DEVICE T block_reduce_sum(T val, T* shared) { + const int lid = threadIdx.x % warpSize; + const int wid = threadIdx.x / warpSize; + val = warp_reduce_sum(val); + __syncthreads(); + if (lid == 0) { + shared[wid] = val; + } + __syncthreads(); + val = (threadIdx.x < blockDim.x / warpSize) ? shared[lid] : T(0); + if (wid == 0) { + val = warp_reduce_sum(val); + } + return val; +} + +template +__inline__ MEGDNN_DEVICE void layer_norm_grad_input_kernel_impl( + const T* __restrict__ dY, const T* __restrict__ X, + const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, + const T* __restrict__ weight, T* dX, const int slice_len, T_ACC* buf) { + const auto slice_id = blockIdx.x; + const T_ACC mean_val = mean[slice_id]; + const T_ACC rstd_val = rstd[slice_id]; + T_ACC stats_x1{0}, stats_x2{0}; + constexpr int unroll = 4; + auto l = unroll * threadIdx.x; + const T* X_i = X + slice_id * slice_len; + const T* dY_i = dY + slice_id * slice_len; + T* dX_i = dX + slice_id * slice_len; + // vectorized reads don't improve perf, so use regular unrolling + + for (; l + unroll - 1 < slice_len; l += blockDim.x * unroll) { +#pragma unroll + for (int k = 0; k < unroll; k++) { + T_ACC weight_val = + (weight != nullptr) ? static_cast(weight[l + k]) : T_ACC(1); + const T_ACC c_h = static_cast(X_i[l + k]); + const T_ACC c_loss = static_cast(dY_i[l + k]); + stats_x1 += c_loss * weight_val; + stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; + } + } + for (; l < slice_len; l++) { + T_ACC weight_val = + (weight != nullptr) ? static_cast(weight[l]) : T_ACC(1); + const T_ACC c_h = static_cast(X_i[l]); + const T_ACC c_loss = static_cast(dY_i[l]); + stats_x1 += c_loss * weight_val; + stats_x2 += c_loss * weight_val * (c_h - mean_val) * rstd_val; + } + + stats_x1 = block_reduce_sum(stats_x1, buf); + stats_x2 = block_reduce_sum(stats_x2, buf); + if (threadIdx.x == 0) { + buf[0] = stats_x1; + buf[1] = stats_x2; + } + __syncthreads(); + stats_x1 = buf[0]; + stats_x2 = buf[1]; + T_ACC fH = slice_len; + T_ACC term1 = (T_ACC(1) / fH) * rstd_val; + + for (int l = threadIdx.x; l < slice_len; l += blockDim.x) { + const T_ACC x = X_i[l]; + const T_ACC dy = dY_i[l]; + T_ACC weight_val = + (weight != nullptr) ? static_cast(weight[l]) : T_ACC(1); + T_ACC f_grad_input = fH * weight_val * dy; + f_grad_input -= (x - mean_val) * rstd_val * stats_x2; + f_grad_input -= stats_x1; + f_grad_input *= term1; + dX_i[l] = f_grad_input; + } +} + +template +__global__ void layer_norm_grad_input_kernel( + const T* __restrict__ dY, const T* __restrict__ X, + const T_ACC* __restrict__ mean, const T_ACC* __restrict__ rstd, + const T* __restrict__ weight, T* dX, const int slice_len) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* buf = reinterpret_cast(&s_data1); + + layer_norm_grad_input_kernel_impl(dY, X, mean, rstd, weight, dX, slice_len, buf); +} + +template +__global__ void layer_norm_grad_weight_bias_simple_kernel( + int64_t slice_num, int64_t slice_len, const T* dY, const T* X, + const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + if (j < slice_len) { + T_ACC sum1 = 0; + T_ACC sum2 = 0; + for (int64_t i = 0; i < slice_num; ++i) { + const int64_t index = i * slice_len + j; + sum1 += dweight == nullptr ? T_ACC(0) + : static_cast(dY[index]) * + (static_cast(X[index]) - + static_cast(mean[i])) * + static_cast(rstd[i]); + sum2 += dbias == nullptr ? T_ACC(0) : static_cast(dY[index]); + } + if (dweight != nullptr) { + dweight[j] = sum1; + } + if (dbias != nullptr) { + dbias[j] = sum2; + } + } +} + +template +__global__ void layer_norm_grad_weight_bias_kernel( + int64_t slice_num, int64_t slice_len, const T* dY, const T* X, + const T_ACC* mean, const T_ACC* rstd, T* dweight, T* dbias) { + alignas(sizeof(double)) extern __shared__ char s_data1[]; + T_ACC* s_data_typed = reinterpret_cast(&s_data1); + const int64_t j = blockIdx.x * blockDim.x + threadIdx.x; + constexpr int unroll = 8; + T dYs[unroll]; + T Xs[unroll]; + T_ACC* means = s_data_typed; + T_ACC* rstds = s_data_typed + unroll * blockDim.y; + T_ACC dg_sum = 0; + T_ACC db_sum = 0; + if (j < slice_len) { + int bcounter; + for (bcounter = 0; bcounter < slice_num / (blockDim.y * unroll); bcounter++) { + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; +#pragma unroll + for (int ii = 0; ii < unroll; ii++) { + if (threadIdx.x == 0) { + means[ii * blockDim.y + threadIdx.y] = mean[offset + ii]; + rstds[ii * blockDim.y + threadIdx.y] = rstd[offset + ii]; + } + dYs[ii] = dY[(offset + ii) * slice_len + j]; + Xs[ii] = X[(offset + ii) * slice_len + j]; + } + __syncthreads(); +#pragma unroll + for (int ii = 0; ii < unroll; ii++) { + dg_sum += dYs[ii] * (Xs[ii] - means[ii * blockDim.y + threadIdx.y]) * + rstds[ii * blockDim.y + threadIdx.y]; + db_sum += dYs[ii]; + } + __syncthreads(); + } + int offset = (bcounter * blockDim.y + threadIdx.y) * unroll; + for (int ii = 0; ii < 8; ii++) { + T_ACC mean_val, rstd_val; // we don't use smem in the tail to avoid awkward + // synchronizations, perf penalty is negligible + if ((offset + ii) < slice_num) { + mean_val = mean[offset + ii]; + rstd_val = rstd[offset + ii]; + dYs[0] = dY[(offset + ii) * slice_len + j]; + Xs[0] = X[(offset + ii) * slice_len + j]; + dg_sum += dYs[0] * (Xs[0] - mean_val) * rstd_val; + db_sum += dYs[0]; + } + } + s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] = dg_sum; + s_data_typed[blockDim.x * blockDim.y + threadIdx.y * blockDim.x + threadIdx.x] = + db_sum; + __syncthreads(); + for (int offset = blockDim.y / 2; offset >= 1; offset /= 2) { + if (threadIdx.y < offset) { + s_data_typed[threadIdx.y * blockDim.x + threadIdx.x] += + s_data_typed[(threadIdx.y + offset) * blockDim.x + threadIdx.x]; + s_data_typed + [blockDim.x * blockDim.y + threadIdx.y * blockDim.x + + threadIdx.x] += s_data_typed + [blockDim.x * blockDim.y + + (threadIdx.y + offset) * blockDim.x + threadIdx.x]; + } + __syncthreads(); + } + if (threadIdx.y == 0) { + if (dweight) { + dweight[j] = s_data_typed[threadIdx.x]; + } + if (dbias) { + dbias[j] = s_data_typed[threadIdx.x + blockDim.x * blockDim.y]; + } + } + } +} + +template +void backward( + const T* dY_data, const T* X_data, const T_ACC* mean_data, + const T_ACC* rstd_data, const T* weight_data, int64_t slice_num, + int64_t slice_len, T* dX_data, T* dweight_data, T* dbias_data, + cudaStream_t stream) { + if (dX_data != nullptr) { + const int num_threads = 128; + const dim3 blocks(slice_num); + int nshared = (num_threads / WARP_SIZE) * sizeof(T_ACC); + layer_norm_grad_input_kernel<<>>( + dY_data, X_data, mean_data, rstd_data, weight_data, dX_data, slice_len); + after_kernel_launch(); + } + if (dweight_data || dbias_data) { + if (slice_num < 512) { + const int64_t B = (slice_len + kCUDANumThreads - 1) / kCUDANumThreads; + layer_norm_grad_weight_bias_simple_kernel + <<>>( + slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, + dweight_data, dbias_data); + after_kernel_launch(); + } else { + dim3 threads{16, 32}; + int blocks = (slice_len + threads.x - 1) / threads.x; + layer_norm_grad_weight_bias_kernel + <<>>( + slice_num, slice_len, dY_data, X_data, mean_data, rstd_data, + dweight_data, dbias_data); + after_kernel_launch(); + } + } +} + +#define INST(T, T_ACC) \ + template void forward( \ + T*, T*, T*, int64_t, int64_t, T_ACC, T*, T_ACC*, T_ACC*, cudaStream_t); \ + template void backward( \ + const T*, const T*, const T_ACC*, const T_ACC*, const T*, int64_t, \ + int64_t, T*, T*, T*, cudaStream_t); + +INST(dt_float32, dt_float32) +INST(dt_float16, dt_float32) +INST(dt_bfloat16, dt_float32) +#undef INST + +} // namespace layer_norm +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/layer_norm/layer_norm_cuda.cuh b/dnn/src/cuda/layer_norm/layer_norm_cuda.cuh new file mode 100644 index 000000000..8e14de34e --- /dev/null +++ b/dnn/src/cuda/layer_norm/layer_norm_cuda.cuh @@ -0,0 +1,34 @@ +/** + * \file dnn/src/cuda/layer_norm/layer_norm.cuh + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include + +namespace megdnn { +namespace cuda { +namespace layer_norm { + +template +void forward( + T* X, T* gamma, T* beta, int64_t M, int64_t N, T_ACC eps, T* Y, T_ACC* mean, + T_ACC* rstd, cudaStream_t stream); + +template +void backward( + const T* dY_data, const T* X_data, const T_ACC* mean_data, + const T_ACC* rstd_data, const T* gamma_data, int64_t M, int64_t N, T* dX_data, + T* dgamma_data, T* dbeta_data, cudaStream_t stream); + +} // namespace layer_norm +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/layer_norm/opr_impl.cpp b/dnn/src/cuda/layer_norm/opr_impl.cpp new file mode 100644 index 000000000..426de5273 --- /dev/null +++ b/dnn/src/cuda/layer_norm/opr_impl.cpp @@ -0,0 +1,94 @@ +/** + * \file dnn/src/cuda/layer_norm/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "src/cuda/layer_norm/opr_impl.h" +#include "src/cuda/layer_norm/layer_norm_cuda.cuh" +#include "src/cuda/utils.h" + +namespace megdnn { +namespace cuda { + +void LayerNormForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) { + check_exec( + data.layout, weight.layout, bias.layout, dst.layout, mean.layout, + rstd.layout, workspace.size); + + auto p = param(); + float eps = p.eps; + bool affine = p.affine; + uint64_t slice_length = p.normalized_size; + uint64_t slice_dim = p.normalized_dim; + uint64_t n_slices = 1; + for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { + n_slices = n_slices * data.layout.shape[i]; + } + + auto stream = cuda_stream(handle()); + using namespace ::megdnn::cuda::layer_norm; + +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + forward( \ + data.ptr(), affine ? weight.ptr() : nullptr, \ + affine ? bias.ptr() : nullptr, static_cast(n_slices), \ + static_cast(slice_length), static_cast(eps), \ + dst.ptr(), mean.ptr(), rstd.ptr(), stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void LayerNormBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) { + check_exec( + diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, + ddata.layout, dweight.layout, dbias.layout, workspace.size); + auto p = param(); + bool affine = p.affine; + uint64_t slice_length = p.normalized_size; + uint64_t slice_dim = p.normalized_dim; + uint64_t n_slices = 1; + for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { + n_slices = n_slices * data.layout.shape[i]; + } + + auto stream = cuda_stream(handle()); + using namespace ::megdnn::cuda::layer_norm; +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + using T = typename DTypeTrait::ctype; \ + using T_ACC = float; \ + backward( \ + diff.ptr(), data.ptr(), mean.ptr(), rstd.ptr(), \ + affine ? weight.ptr() : nullptr, n_slices, slice_length, \ + ddata.ptr(), affine ? dweight.ptr() : nullptr, \ + affine ? dbias.ptr() : nullptr, stream); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace cuda +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/cuda/layer_norm/opr_impl.h b/dnn/src/cuda/layer_norm/opr_impl.h new file mode 100644 index 000000000..8bca6a75d --- /dev/null +++ b/dnn/src/cuda/layer_norm/opr_impl.h @@ -0,0 +1,53 @@ +/** + * \file dnn/src/cuda/layer_norm/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +#include "src/cuda/cudnn_wrapper.h" + +namespace megdnn { +namespace cuda { + +class LayerNormForwardImpl final : public LayerNormForward { +public: + using LayerNormForward::LayerNormForward; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +class LayerNormBackwardImpl final : public LayerNormBackward { +public: + using LayerNormBackward::LayerNormBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +} // namespace cuda +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/handle.cpp b/dnn/src/naive/handle.cpp index e38bfead2..1ff6675f8 100644 --- a/dnn/src/naive/handle.cpp +++ b/dnn/src/naive/handle.cpp @@ -47,6 +47,7 @@ #include "src/naive/images2neibs/opr_impl.h" #include "src/naive/indexing_multi_axis_vec/opr_impl.h" #include "src/naive/indexing_one_hot/opr_impl.h" +#include "src/naive/layer_norm/opr_impl.h" #include "src/naive/linspace/opr_impl.h" #include "src/naive/local/opr_impl.h" #include "src/naive/local_share/opr_impl.h" diff --git a/dnn/src/naive/layer_norm/opr_impl.cpp b/dnn/src/naive/layer_norm/opr_impl.cpp new file mode 100644 index 000000000..cc9670358 --- /dev/null +++ b/dnn/src/naive/layer_norm/opr_impl.cpp @@ -0,0 +1,170 @@ +/** + * \file dnn/src/naive/layer_norm/opr_impl.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "src/naive/layer_norm/opr_impl.h" +#include +#include "src/common/utils.h" +#include "src/naive/handle.h" + +using namespace megdnn; +using namespace naive; + +namespace { + +using Param = megdnn::LayerNorm::Param; + +template +void forward( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + const Param& param) { + float eps = param.eps; + bool affine = param.affine; + uint64_t slice_length = param.normalized_size; + uint64_t slice_dim = param.normalized_dim; + uint64_t n_slices = 1; + for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { + n_slices = n_slices * data.layout.shape[i]; + } + + for (size_t i = 0; i < n_slices; i++) { + T_ACC slice_sum = static_cast(0.0f); + for (size_t j = 0; j < slice_length; j++) { + auto value = data.ptr()[i * slice_length + j]; + slice_sum += value; + } + T_ACC slice_mean = static_cast(slice_sum / slice_length); + + T_ACC slice_var = static_cast(0.0f); + for (size_t j = 0; j < slice_length; j++) { + slice_var += (data.ptr()[i * slice_length + j] - slice_mean) * + (data.ptr()[i * slice_length + j] - slice_mean); + } + slice_var = slice_var / slice_length; + + T_ACC slice_std = static_cast(sqrt(slice_var + eps)); + for (size_t j = 0; j < slice_length; j++) { + dst.ptr()[i * slice_length + j] = + (data.ptr()[i * slice_length + j] - slice_mean) / slice_std; + if (affine) { + dst.ptr()[i * slice_length + j] = + dst.ptr()[i * slice_length + j] * weight.ptr()[j] + + bias.ptr()[j]; + } + } + mean.ptr()[i] = static_cast(slice_mean); + rstd.ptr()[i] = static_cast(1.0 / slice_std); + } +} + +template +void backward( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, const Param& param) { + bool affine = param.affine; + uint64_t slice_length = param.normalized_size; + uint64_t slice_dim = param.normalized_dim; + uint64_t n_slices = 1; + for (size_t i = 0; i < data.layout.ndim - slice_dim; ++i) { + n_slices = n_slices * data.layout.shape[i]; + } + + if (affine) { + for (size_t i = 0; i < slice_length; ++i) { + dweight.ptr()[i] = 0; + dbias.ptr()[i] = 0; + } + + for (size_t i = 0; i < n_slices; ++i) { + for (size_t j = 0; j < slice_length; ++j) { + dweight.ptr()[j] += + (data.ptr()[i * slice_length + j] - mean.ptr()[i]) * + rstd.ptr()[i] * diff.ptr()[i * slice_length + j]; + + dbias.ptr()[j] += diff.ptr()[i * slice_length + j]; + } + } + } + + for (size_t i = 0; i < n_slices; ++i) { + T_ACC ds = static_cast(0.0f); + T_ACC db = static_cast(0.0f); + T_ACC a = static_cast(0.0f); + T_ACC b = static_cast(0.0f); + T_ACC c = static_cast(0.0f); + + for (size_t j = 0; j < slice_length; ++j) { + auto value = data.ptr()[i * slice_length + j]; + auto diff_v = diff.ptr()[i * slice_length + j]; + auto weight_v = affine ? weight.ptr()[j] : static_cast(1.0f); + db += diff_v * weight_v; + ds += diff_v * value * weight_v; + } + + a = rstd.ptr()[i]; + b = (db * mean.ptr()[i] - ds) * a * a * a / slice_length; + c = -b * mean.ptr()[i] - db * a / slice_length; + + for (uint64_t j = 0; j < slice_length; j++) { + auto weight_v = affine ? weight.ptr()[j] : static_cast(1.0f); + ddata.ptr()[i * slice_length + j] = + diff.ptr()[i * slice_length + j] * a * weight_v + + data.ptr()[i * slice_length + j] * b + c; + } + } +} + +} // namespace + +namespace megdnn { +namespace naive { + +void LayerNormForwardImpl::exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) { + check_exec( + data.layout, weight.layout, bias.layout, dst.layout, mean.layout, + rstd.layout, workspace.size); +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(forward::ctype>( \ + data, weight, bias, dst, mean, rstd, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +void LayerNormBackwardImpl::exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) { + check_exec( + diff.layout, data.layout, weight.layout, mean.layout, rstd.layout, + ddata.layout, dweight.layout, dbias.layout, workspace.size); +#define cb(DType) \ + if (data.layout.dtype == DType()) { \ + MEGDNN_DISPATCH_CPU_KERN_OPR(backward::ctype>( \ + diff, data, weight, mean, rstd, ddata, dweight, dbias, param())); \ + return; \ + } + MEGDNN_FOREACH_COMPUTING_DTYPE_FLOAT(cb) +#undef cb + megdnn_throw("bad dtype"); +} + +} // namespace naive +} // namespace megdnn +// vim: syntax=cpp.doxygen diff --git a/dnn/src/naive/layer_norm/opr_impl.h b/dnn/src/naive/layer_norm/opr_impl.h new file mode 100644 index 000000000..99d93e799 --- /dev/null +++ b/dnn/src/naive/layer_norm/opr_impl.h @@ -0,0 +1,51 @@ +/** + * \file dnn/src/naive/layer_norm/opr_impl.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#pragma once +#include "megdnn/oprs.h" + +namespace megdnn { +namespace naive { + +class LayerNormForwardImpl final : public LayerNormForward { +public: + using LayerNormForward::LayerNormForward; + void exec( + _megdnn_tensor_in data, _megdnn_tensor_in weight, _megdnn_tensor_in bias, + _megdnn_tensor_out dst, _megdnn_tensor_out mean, _megdnn_tensor_out rstd, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +class LayerNormBackwardImpl final : public LayerNormBackward { +public: + using LayerNormBackward::LayerNormBackward; + void exec( + _megdnn_tensor_in diff, _megdnn_tensor_in data, _megdnn_tensor_in weight, + _megdnn_tensor_in mean, _megdnn_tensor_in rstd, _megdnn_tensor_out ddata, + _megdnn_tensor_out dweight, _megdnn_tensor_out dbias, + _megdnn_workspace workspace) override; + size_t get_workspace_in_bytes( + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&, const TensorLayout&, + const TensorLayout&, const TensorLayout&) override { + return 0; + } +}; + +} // namespace naive +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/dnn/test/common/deduce_layout_proxy.h b/dnn/test/common/deduce_layout_proxy.h index 17afc1ddd..f1067aec3 100644 --- a/dnn/test/common/deduce_layout_proxy.h +++ b/dnn/test/common/deduce_layout_proxy.h @@ -57,6 +57,15 @@ struct DeduceLayoutProxy { } }; +template +struct DeduceLayoutProxy { + static void deduce_layout(Opr* opr, TensorLayoutArray& layouts) { + megdnn_assert(layouts.size() == 6); + opr->deduce_layout( + layouts[0], layouts[1], layouts[2], layouts[3], layouts[4], layouts[5]); + } +}; + template struct DeduceLayoutProxy { static void deduce_layout(Opr*, TensorLayoutArray&) {} diff --git a/dnn/test/cuda/layer_norm.cpp b/dnn/test/cuda/layer_norm.cpp new file mode 100644 index 000000000..b4d042043 --- /dev/null +++ b/dnn/test/cuda/layer_norm.cpp @@ -0,0 +1,94 @@ +/** + * \file dnn/test/cuda/layer_norm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ +#include "test/cuda/fixture.h" + +#include "test/common/checker.h" + +namespace megdnn { +namespace test { + +TEST_F(CUDA, LAYERNORM_FORWARD) { + using Param = LayerNormForward::Param; + Param param; + param.affine = true; + param.eps = 1e-6; + param.normalized_dim = 1; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-2); + + auto run = [&](DType d) { + for (size_t n_slices : {10, 30}) + for (size_t slice_len : {10, 30}) { + param.normalized_size = slice_len; + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, d) + .set_dtype(2, d) + .set_dtype(3, d) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, dtype::Float32()) + .execs({{n_slices, slice_len}, + {slice_len}, + {slice_len}, + {n_slices, slice_len}, + {n_slices}, + {n_slices}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + run(dtype::BFloat16()); +} + +TEST_F(CUDA, LAYERNORM_BACKWARD) { + using Param = LayerNormBackward::Param; + Param param; + param.affine = true; + param.eps = 1e-6; + param.normalized_dim = 1; + Checker checker(handle_cuda()); + checker.set_epsilon(1e-1); + + auto run = [&](DType d) { + for (size_t n_slices : {10, 30}) + for (size_t slice_len : {10, 30}) { + param.normalized_size = slice_len; + checker.set_param(param) + .set_dtype(0, d) + .set_dtype(1, d) + .set_dtype(2, d) + .set_dtype(3, dtype::Float32()) + .set_dtype(4, dtype::Float32()) + .set_dtype(5, d) + .set_dtype(6, d) + .set_dtype(7, d) + .execs({{n_slices, slice_len}, + {n_slices, slice_len}, + {slice_len}, + {n_slices}, + {n_slices}, + {n_slices, slice_len}, + {slice_len}, + {slice_len}}); + } + }; + + run(dtype::Float32()); + run(dtype::Float16()); + run(dtype::BFloat16()); +} + +} // namespace test +} // namespace megdnn + +// vim: syntax=cpp.doxygen diff --git a/imperative/python/megengine/functional/nn.py b/imperative/python/megengine/functional/nn.py index 9ababcebe..9dfd42984 100644 --- a/imperative/python/megengine/functional/nn.py +++ b/imperative/python/megengine/functional/nn.py @@ -1066,57 +1066,6 @@ def softmax(inp: Tensor, axis: Optional[int] = None) -> Tensor: return cached / down -@lru_cache(maxsize=None) -def _get_layerNorm(device, dtype, dim, gopt_level=2): - @subgraph("LayerNormAffine", dtype, device, 5, gopt_level=gopt_level) - def layerNormAffine(inputs, f, c): - inp, eps, _flatten_shape, weight, bias = inputs - inp_shape = f(GetVarShape(), inp) - - inp = f(Reshape(axis=dim), inp, _flatten_shape) - mean = f(Reduce(mode="mean", axis=-1), inp) - x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) - reduce_shape = f(GetVarShape(), x2s) - reduce_size = f( - "//", - f(Reduce(mode="product", axis=0), inp_shape), - f(Reduce(mode="product", axis=0), reduce_shape), - ) - reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) - var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) - inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) - oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) - affine_oup = f(Reshape(), oup, inp_shape) - affine_oup = f("fma3", affine_oup, weight, bias) - - # NOTE: return oup make backward faster but take more memory - return (affine_oup, oup, mean, x2s), (True, False, False, False) - - @subgraph("LayerNorm", dtype, device, 3, gopt_level=gopt_level) - def layerNorm(inputs, f, c): - inp, eps, _flatten_shape = inputs - inp_shape = f(GetVarShape(), inp) - - inp = f(Reshape(axis=dim), inp, _flatten_shape) - mean = f(Reduce(mode="mean", axis=-1), inp) - x2s = f(Reduce(mode="sum_sqr", axis=-1), inp) - reduce_shape = f(GetVarShape(), x2s) - reduce_size = f( - "//", - f(Reduce(mode="product", axis=0), inp_shape), - f(Reduce(mode="product", axis=0), reduce_shape), - ) - reduce_size_f = f(TypeCvt(dtype=dtype), reduce_size) - var = f("-", f("/", x2s, reduce_size_f), f("**", mean, c(2))) - inv_sqrt_var = f("**", f("+", var, eps), c(-0.5)) - oup = f("fma3", inp, inv_sqrt_var, f("*", f("-", mean), inv_sqrt_var)) - oup = f(Reshape(), oup, inp_shape) - - return (oup,), (True,) - - return (layerNorm, layerNormAffine) - - def layer_norm( inp: Tensor, normalized_shape: tuple, @@ -1133,32 +1082,34 @@ def layer_norm( normalized_shape: the shape that you want to be normalizated affine: whether to use weight and bias weight: must not be None when the affine is true - bias: must not be None when the bias is true + bias: must not be None when the affine is true eps: a value added to the denominator for numerical stability. Default: 1e-5 """ - if amp._enabled: inp, weight, bias = cast_tensors(inp, weight, bias, promote=True) - _device = inp.device - _dtype = inp.dtype - _dim = len(inp.shape) - len(normalized_shape) + if isinstance(normalized_shape, int): + normalized_shape = [normalized_shape] - _flatten_shape = concat( - ( - convert_single_value(inp.shape[:_dim], dtype="int32", device=inp.device), - convert_single_value(-1, dtype="int32", device=inp.device), - ) - ) - (layerNorm, layerNormAffine) = _get_layerNorm(_device, _dtype, _dim) + normalized_dim = len(normalized_shape) + assert normalized_dim > 0 - eps = convert_single_value(eps, dtype=inp.dtype, device=inp.device) + normalized_size = 1 + for i in range(normalized_dim): + normalized_size = normalized_size * normalized_shape[i] + + op = builtin.LayerNorm( + affine=affine, + eps=eps, + normalized_dim=normalized_dim, + normalized_size=normalized_size, + ) if affine: - outvar, *_ = apply(layerNormAffine(), inp, eps, _flatten_shape, weight, bias) + assert weight is not None and bias is not None + return apply(op, inp, weight, bias)[0] else: - outvar, *_ = apply(layerNorm(), inp, eps, _flatten_shape) - - return outvar + # assert weight is None and bias is None + return apply(op, inp)[0] def batch_norm( diff --git a/imperative/python/test/unit/functional/test_functional.py b/imperative/python/test/unit/functional/test_functional.py index 1a82b30db..39b3eeba3 100644 --- a/imperative/python/test/unit/functional/test_functional.py +++ b/imperative/python/test/unit/functional/test_functional.py @@ -865,61 +865,6 @@ def test_conv1d(): ) -def test_layer_norm(): - def _layer_norm(x, normalized_shape, affine, weight=None, bias=None, eps=1e-5): - __layer_norm = LayerNorm(normalized_shape=normalized_shape, affine=affine) - __layer_norm.weight = weight - __layer_norm.bias = bias - return __layer_norm(x) - - def _layer_norm_numpy( - x, normalized_shape, affine, weight=None, bias=None, eps=1e-5 - ): - x_shape = x.shape - dim_delta = len(x_shape) - len(normalized_shape) - non_flatten_shape = x_shape[:dim_delta] - x = x.reshape(*non_flatten_shape, -1) - - mean = x.mean(axis=-1, keepdims=True) - var = (x ** 2).mean(axis=-1, keepdims=True) - mean * mean - - x = (x - mean) / F.sqrt(var + eps) - x = x.reshape(x_shape) - if affine: - x = weight * x + bias - - return x - - normalized_shape = (28, 28) - inp_feat = Tensor(np.random.randn(32, 64, 28, 28), dtype="float32") - weight = Tensor(np.random.randn(28, 28), dtype="float32") - bias = Tensor(np.random.randn(28, 28), dtype="float32") - - inp_feat = inp_feat + 1 - weight = weight + 1 - bias = bias - - affine = False - - outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) - targetvar = _layer_norm_numpy(inp_feat, normalized_shape, affine, weight, bias) - - assert abs(outvar - targetvar).mean() < 1e-7 - - # no random, affine True - normalized_shape = (28, 28) - inp_feat = Tensor(np.ones((32, 64, 28, 28)), dtype="float32") - weight = Tensor(np.ones((28, 28)), dtype="float32") - bias = Tensor(np.zeros((28, 28)), dtype="float32") - - affine = True - - outvar = F.nn.layer_norm(inp_feat, normalized_shape, affine, weight, bias) - targetvar = _layer_norm(inp_feat, normalized_shape, affine, weight, bias) - assert abs((outvar - targetvar).mean()) < 1e-7 - assert abs(outvar.mean()) < 1e-7 - - def test_batchnorm2d_autocast(): """check amp's result is equal to manually converted result""" amp.enabled = True diff --git a/imperative/python/test/unit/functional/test_loss.py b/imperative/python/test/unit/functional/test_loss.py index d46f40b67..abf4b2fec 100644 --- a/imperative/python/test/unit/functional/test_loss.py +++ b/imperative/python/test/unit/functional/test_loss.py @@ -43,7 +43,7 @@ def test_cross_entropy(): x = softmax(x) l_ref = ref(x, y) l = F.nn.cross_entropy(tensor(x, "float32"), tensor(y, "int32"), with_logits=False) - np.testing.assert_allclose(l.numpy(), l_ref) + np.testing.assert_allclose(l.numpy(), l_ref, 1e-6, 1e-6) def test_cross_entropy_reduction(): diff --git a/imperative/src/impl/ops/specializations.cpp b/imperative/src/impl/ops/specializations.cpp index d153b429a..5d3562a2b 100644 --- a/imperative/src/impl/ops/specializations.cpp +++ b/imperative/src/impl/ops/specializations.cpp @@ -20,6 +20,7 @@ #include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" +#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" @@ -636,4 +637,29 @@ auto apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { } OP_TRAIT_REG(LRN, LRN).apply_on_var_node(apply_on_var_node).fallback(); } // namespace lrn + +namespace layer_norm { + +cg::OperatorNodeBase* apply_on_var_node(const OpDef& def, const VarNodeArray& inputs) { + auto&& op = static_cast(def); + size_t nr_inp = inputs.size(); + auto p = op.param(); + mgb_assert((nr_inp == 3 && p.affine) || (nr_inp == 1 && !p.affine)); + OperatorNodeConfig config{op.make_name()}; + if (nr_inp == 3) { + return opr::LayerNorm::make( + inputs[0], inputs[1], inputs[2], op.param(), config)[0] + .node() + ->owner_opr(); + } else { + return opr::LayerNorm::make(inputs[0], op.param(), config)[0] + .node() + ->owner_opr(); + } +} + +OP_TRAIT_REG(LayerNorm, LayerNorm).apply_on_var_node(apply_on_var_node).fallback(); + +} // namespace layer_norm + } // namespace mgb::imperative diff --git a/src/core/include/megbrain/ir/ops.td b/src/core/include/megbrain/ir/ops.td index 233c99f32..a300ccf8b 100644 --- a/src/core/include/megbrain/ir/ops.td +++ b/src/core/include/megbrain/ir/ops.td @@ -431,4 +431,6 @@ def Padding: MgbHashableOp<"Padding", [PaddingParam]>; def LRN: MgbHashableOp<"LRN", [LRNParam]>; +def LayerNorm: MgbHashableOp<"LayerNorm", [LayerNormParam]>; + #endif // MGB_OPS diff --git a/src/opr/impl/dnn/dnn.sereg.h b/src/opr/impl/dnn/dnn.sereg.h index 4455bceb7..ceb0c9723 100644 --- a/src/opr/impl/dnn/dnn.sereg.h +++ b/src/opr/impl/dnn/dnn.sereg.h @@ -16,6 +16,7 @@ #include "megbrain/opr/dnn/correlation.h" #include "megbrain/opr/dnn/fake_quant.h" #include "megbrain/opr/dnn/images2neibs.h" +#include "megbrain/opr/dnn/layer_norm.h" #include "megbrain/opr/dnn/local.h" #include "megbrain/opr/dnn/lrn.h" #include "megbrain/opr/dnn/lsq.h" @@ -420,6 +421,47 @@ struct OprMaker { } }; +template <> +struct OprMaker { + using Param = opr::LayerNorm::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 3) { + return opr::LayerNorm::make(i[0], i[1], i[2], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 1); + return opr::LayerNorm::make(i[0], param, config)[0].node()->owner_opr(); + } + } +}; + +// OprMaker in MGB_SEREG_OPR only support unique output opr +template <> +struct OprMaker { + using Param = opr::LayerNormBackward::Param; + static cg::OperatorNodeBase* make( + const Param& param, const cg::VarNodeArray& i, ComputingGraph& graph, + const OperatorNodeConfig& config) { + MGB_MARK_USED_VAR(graph); + if (i.size() == 5) { + return opr::LayerNormBackward::make( + i[0], i[1], i[2], i[3], i[4], param, config)[0] + .node() + ->owner_opr(); + } else { + mgb_assert(i.size() == 4); + return opr::LayerNormBackward::make( + i[0], i[1], i[2], i[3], param, config)[0] + .node() + ->owner_opr(); + } + } +}; + template struct MakeLocalShareCaller2 { template @@ -641,6 +683,8 @@ MGB_SEREG_OPR(TQT, 2); MGB_SEREG_OPR(TQTBackward, 3); MGB_SEREG_OPR(LSQ, 4); MGB_SEREG_OPR(LSQBackward, 5); +MGB_SEREG_OPR(LayerNorm, 0); +MGB_SEREG_OPR(LayerNormBackward, 0); } // namespace opr } // namespace mgb diff --git a/src/opr/impl/dnn/layer_norm.cpp b/src/opr/impl/dnn/layer_norm.cpp new file mode 100644 index 000000000..3506111a7 --- /dev/null +++ b/src/opr/impl/dnn/layer_norm.cpp @@ -0,0 +1,248 @@ +/** + * \file src/opr/impl/dnn/layer_norm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/layer_norm.h" + +#include "megbrain/graph/grad_impl.h" +#include "megbrain/opr/internal/out_shape_by_sym_var.h" +#include "megbrain/opr/utility.h" + +#include "../internal/megdnn_opr_wrapper.inl" + +using namespace mgb; +using namespace opr; + +/* ==================== LayerNormForward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormForward); + +LayerNormForward::LayerNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "layer_norm", {data, weight, bias}} { + init_megdnn_opr(*this, param); + + add_input({data, weight, bias}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +LayerNormForward::LayerNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config) + : Super{data->owner_graph(), config, "layer_norm", {data}} { + init_megdnn_opr(*this, param); + + add_input({data}); + output(0)->dtype(data->dtype()); + output(1)->dtype(dtype::Float32()); + output(2)->dtype(dtype::Float32()); +} + +SymbolVarArray LayerNormForward::make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param, + const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), weight.node(), bias.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray LayerNormForward::make( + SymbolVar data, const Param& param, const OperatorNodeConfig& config) { + auto outs = data.node() + ->owner_graph() + ->insert_opr(std::make_unique( + data.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void LayerNormForward::get_output_var_shape( + const TensorShapeArray& inp_shape, TensorShapeArray& out_shape) const { + uint64_t normalized_dim = param().normalized_dim; + out_shape[0] = inp_shape[0]; + TensorShape unnormalized_shape; + unnormalized_shape.ndim = inp_shape[0].ndim - normalized_dim; + for (size_t i = 0; i < unnormalized_shape.ndim; ++i) { + unnormalized_shape.shape[i] = inp_shape[0].shape[i]; + } + out_shape[1] = unnormalized_shape; + out_shape[2] = unnormalized_shape; +} + +size_t LayerNormForward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return 0; +} + +void LayerNormForward::scn_do_execute() { + if (param().affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), {}); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), {}, {}, + output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), {}); + } +} + +#if MGB_ENABLE_GRAD +MGB_IMPL_OPR_GRAD(LayerNormForward) { + auto p = opr.param(); + SymbolVarArray grad; + VarNodeArray ret; + if (p.affine) { + mgb_assert(wrt_idx < 3, "wrt_idx %zu is out of range", wrt_idx); + grad = LayerNormBackward::make( + out_grad[0], opr.input(0), opr.input(1), opr.output(1), opr.output(2), + opr.param()); + } else { + mgb_assert(wrt_idx < 1, "wrt_idx %zu is out of range", wrt_idx); + grad = LayerNormBackward::make( + out_grad[0], opr.input(0), opr.output(1), opr.output(2), opr.param()); + } + + uint32_t nr_ret = p.affine ? 3 : 1; + for (uint32_t i = 0; i < nr_ret; ++i) { + ret.push_back(grad[i].node()); + } + return ret; +} +#endif + +/* ==================== LayerNormBackward ==================== */ +MGB_DYN_TYPE_OBJ_FINAL_IMPL(LayerNormBackward); + +LayerNormBackward::LayerNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "layer_norm_backward", + {diff, data, weight, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, weight, mean, rstd}); +} + +LayerNormBackward::LayerNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, const Param& param, + const OperatorNodeConfig& config) + : Super({diff->owner_graph(), + config, + "layer_norm_backward", + {diff, data, mean, rstd}}, + 0, true) { + init_megdnn_opr(*this, param); + add_input({diff, data, mean, rstd}); + auto mark_empty_var = [&](VarNode* var) { + var->add_flag(VarNode::Flag::ALLOW_EMPTY_SHAPE) + .add_flag(VarNode::Flag::VOLATILE_CONTENT); + }; + mark_empty_var(output(1)); + mark_empty_var(output(2)); +} + +SymbolVarArray LayerNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), weight.node(), mean.node(), + rstd.node(), param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +SymbolVarArray LayerNormBackward::make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param, const OperatorNodeConfig& config) { + auto outs = diff.node() + ->owner_graph() + ->insert_opr(std::make_unique( + diff.node(), data.node(), mean.node(), rstd.node(), + param, config)) + ->output(); + SymbolVarArray ret; + for (auto&& out : outs) { + ret.emplace_back(out); + } + return ret; +} + +void LayerNormBackward::init_output_static_infer_desc() { + using namespace cg::static_infer; + auto&& mgr = owner_graph()->static_infer_manager(); + mgr.register_shape_infer(output(0), ShapeInferDesc::make_identity(input(1))); + if (param().affine) { + mgr.register_shape_infer(output(1), ShapeInferDesc::make_identity(input(2))); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_identity(input(2))); + } else { + TensorShape empty; + empty.ndim = 0; + mgr.register_shape_infer(output(1), ShapeInferDesc::make_const(empty)); + mgr.register_shape_infer(output(2), ShapeInferDesc::make_const(empty)); + } + this->init_output_static_infer_desc_workspace(false); +} + +void LayerNormBackward::init_output_dtype() { + output(0)->dtype(input(1)->dtype()); + output(1)->dtype(input(2)->dtype()); + output(2)->dtype(input(2)->dtype()); +} + +size_t LayerNormBackward::get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const { + return 0; +} + +void LayerNormBackward::scn_do_execute() { + if (param().affine) { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + input(2)->dev_tensor().as_megdnn(), input(3)->dev_tensor().as_megdnn(), + input(4)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + output(1)->dev_tensor().as_megdnn(), + output(2)->dev_tensor().as_megdnn(), {}); + } else { + megdnn_opr()->exec( + input(0)->dev_tensor().as_megdnn(), input(1)->dev_tensor().as_megdnn(), + {}, input(2)->dev_tensor().as_megdnn(), + input(3)->dev_tensor().as_megdnn(), output(0)->dev_tensor().as_megdnn(), + {}, {}, {}); + } +} + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/include/megbrain/opr/dnn/layer_norm.h b/src/opr/include/megbrain/opr/dnn/layer_norm.h new file mode 100644 index 000000000..29712de00 --- /dev/null +++ b/src/opr/include/megbrain/opr/dnn/layer_norm.h @@ -0,0 +1,78 @@ +/** + * \file src/opr/include/megbrain/opr/dnn/layer_norm.h + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + */ + +#pragma once + +#include "megbrain/opr/internal/megdnn_opr_wrapper.h" +#include "megdnn/oprs.h" + +namespace mgb { +namespace opr { + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + LayerNormForward, intl::MegDNNOprWrapperFwd) // { +public: + MGE_WIN_DECLSPEC_FUC LayerNormForward( + VarNode* data, VarNode* weight, VarNode* bias, const Param& param, + const OperatorNodeConfig& config); + MGE_WIN_DECLSPEC_FUC LayerNormForward( + VarNode* data, const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, SymbolVar weight, SymbolVar bias, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar data, const Param& param = {}, + const OperatorNodeConfig& config = {}); + +private: + void get_output_var_shape( + const TensorShapeArray& inp_shape, + TensorShapeArray& out_shape) const override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; +using LayerNorm = LayerNormForward; + +MGB_DEFINE_OPR_CLASS_WITH_EXPORT( + LayerNormBackward, intl::MegDNNOprWrapperBwd) // { +public: + MGE_WIN_DECLSPEC_FUC LayerNormBackward( + VarNode* diff, VarNode* data, VarNode* weight, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC LayerNormBackward( + VarNode* diff, VarNode* data, VarNode* mean, VarNode* rstd, + const Param& param, const OperatorNodeConfig& config); + + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar weight, SymbolVar mean, + SymbolVar rstd, const Param& param = {}, + const OperatorNodeConfig& config = {}); + MGE_WIN_DECLSPEC_FUC static SymbolVarArray make( + SymbolVar diff, SymbolVar data, SymbolVar mean, SymbolVar rstd, + const Param& param = {}, const OperatorNodeConfig& config = {}); + +private: + void init_output_static_infer_desc() override; + void init_output_dtype() override; + size_t get_workspace_size_bytes( + const TensorShapeArray& input_shapes, + const TensorShapeArray& output_shapes) const override; + void scn_do_execute() override; +}; + +} // namespace opr +} // namespace mgb + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/opr/test/dnn/layer_norm.cpp b/src/opr/test/dnn/layer_norm.cpp new file mode 100644 index 000000000..15db672c8 --- /dev/null +++ b/src/opr/test/dnn/layer_norm.cpp @@ -0,0 +1,108 @@ +/** + * \file src/opr/test/dnn/layer_norm.cpp + * MegEngine is Licensed under the Apache License, Version 2.0 (the "License") + * + * Copyright (c) 2014-2021 Megvii Inc. All rights reserved. + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or + * implied. + */ + +#include "megbrain/opr/dnn/layer_norm.h" +#include "megbrain/comp_node_env.h" +#include "megbrain/test/autocheck.h" +#include "megbrain/test/helper.h" +#include "megbrain/test/megdnn_helper.h" + +#include "megdnn/oprs.h" + +#include +#include +#include +#include + +using namespace mgb; + +namespace { +using Param = opr::LayerNormForward::Param; + +void run_forward(bool is_affine, size_t normalized_size) { + using Checker = AutoOprChecker<3, 3>; + + Param param; + param.eps = 1e-5; + param.affine = is_affine; + param.normalized_dim = 1; + param.normalized_size = normalized_size; + + auto make_graph = [&](const Checker::SymInpArray& inputs) -> Checker::SymOutArray { + auto out = opr::LayerNormForward::make(inputs[0], inputs[1], inputs[2], param); + return {out[0], out[1], out[2]}; + }; + + auto fwd = [&](Checker::NumOutArray& dest, Checker::NumInpArray inp) { + auto opr = + MegDNNHandle::get(CompNodeEnv::from_comp_node(CompNode::default_cpu())) + ->create_operator(); + auto inp_shape = inp[0]->shape(); + auto n_slices = inp_shape[0]; + auto slice_len = inp_shape[1]; + + opr->param() = param; + + dest[0].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices, slice_len}); + dest[1].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices}); + dest[2].dtype(dtype::Float32()) + .comp_node(inp[0]->comp_node()) + .resize({n_slices}); + opr->exec( + inp[0]->as_megdnn(), inp[1]->as_megdnn(), inp[2]->as_megdnn(), + dest[0].as_megdnn(), dest[1].as_megdnn(), dest[2].as_megdnn(), {}); + }; + + auto gen = [&](HostTensorND& src) { + HostTensorGenerator src_gen(0.f); + src = *src_gen(src.shape(), src.comp_node()); + }; + + Checker::RunOptions option; + option.numdiff_max_err = 1e-4; + Checker checker{make_graph, fwd}; + + checker.set_input_generator(0, gen); + checker.set_input_generator(1, gen); + checker.set_input_generator(2, gen); + checker.set_input_allow_grad(0, false); + checker.set_input_allow_grad(1, false); + checker.set_input_allow_grad(2, false); + checker.set_output_allow_grad(0, false); + checker.set_output_allow_grad(1, false); + checker.set_output_allow_grad(2, false); + + checker.run({TensorShape{normalized_size, normalized_size}, + TensorShape{normalized_size}, TensorShape{normalized_size}}, + option) + .run({TensorShape{normalized_size, normalized_size}, + TensorShape{normalized_size}, TensorShape{normalized_size}}, + option) + .run({TensorShape{normalized_size, normalized_size}, + TensorShape{normalized_size}, TensorShape{normalized_size}}, + option); +} + +TEST(TestOprDNN, LayerNormForwardAffine) { + REQUIRE_GPU(1); + run_forward(true, 1); + run_forward(true, 16); + run_forward(true, 17); +} + +} // anonymous namespace + +// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}} diff --git a/src/serialization/impl/schema.fbs b/src/serialization/impl/schema.fbs index f91477e63..868948913 100644 --- a/src/serialization/impl/schema.fbs +++ b/src/serialization/impl/schema.fbs @@ -116,6 +116,7 @@ union OperatorParam { param.Padding = 82, param.ShuffleRNG = 83, param.CheckNonFinite = 84, + param.LayerNorm = 85, } table Operator { -- GitLab