未验证 提交 44b81e63 编写于 作者: F furnace 提交者: GitHub

[Cherry-pick] Layer norm fp16 and Nvidia optimize (#29169 #29434 #29522 #29576) (#30110)

* Layer norm fp16 (#29169)

* add fp16 for layer_norm op

* revert layernorm api

* fix forward

* fix forward

* fix backward for layernorm with fp16

* fix unit test for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16

* 1. revert to PADDLE_ENFORCE_NOT_NULL, 2. change static_cast<float> to static_cast<U>

* fix with_mkldnn compile error for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>

* fix layer_norm accuracy (#29434)

* Layernorm opt (#29522)

* layernorm fw opt

* layernorm bw opt

* fix typo, test=develop

* remove const dim3 for windows CI compatibility

* merge develop
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>

* Fix compile problem when cuda_arch < 6000 (#29576)

* fix compile problem when cuda_arch < 6000

* refine code

* refine code
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>
上级 cb71fea0
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/operators/layer_norm_op.h"
#include <memory> #include <memory>
#include <string>
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const { const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain; framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout; framework::DataLayout layout = framework::DataLayout::kAnyLayout;
...@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel { ...@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel {
} }
#endif #endif
return framework::OpKernelType( return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(), library);
layout, library);
} }
}; };
...@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel { ...@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
} }
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found.")); t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found."));
return framework::OpKernelType(t->type(), ctx.GetPlace());
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
} }
}; };
......
...@@ -15,12 +15,22 @@ limitations under the License. */ ...@@ -15,12 +15,22 @@ limitations under the License. */
#include <cub/cub.cuh> #include <cub/cub.cuh>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/operators/layer_norm_op.h" #include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
inline static int GetDesiredBlockDim(int block_dim) { inline static int GetDesiredBlockDim(int block_dim) {
const int kMaxBlockDim = 512; const int kMaxBlockDim = 512;
return block_dim >= kMaxBlockDim return block_dim >= kMaxBlockDim
...@@ -97,59 +107,350 @@ struct PairForLayerNormAddFunctor { ...@@ -97,59 +107,350 @@ struct PairForLayerNormAddFunctor {
} }
}; };
template <typename T, int BlockDim> template <typename T>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias, __inline__ __device__ T rsqrt(const T val) {
T *y, T *mean, T *var, float epsilon, return static_cast<T>(1) / sqrt(val);
}
template <>
__inline__ __device__ float rsqrt(const float val) {
return rsqrtf(val);
}
template <>
__inline__ __device__ double rsqrt(const double val) {
return rsqrt(val);
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__inline__ __device__ half rsqrt(const half val) {
return hrsqrt(val);
}
#endif
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int feature_size) { int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share;
__shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var // Step 1: Reduce to calculate mean and var
double mean_val = 0; U mean_val = 0;
double var_val = 0; U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
T tmp = x[i]; U tmp = static_cast<U>(x[i]);
mean_val += tmp; mean_val += tmp;
var_val += (tmp * tmp); var_val += (tmp * tmp);
} }
auto pair = BlockReduce(temp_storage) auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<double>(mean_val, var_val), .Reduce(PairForLayerNorm<U>(mean_val, var_val),
PairForLayerNormAddFunctor<double>()); PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size; auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = static_cast<T>(tmp); mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = static_cast<T>(pair.second_ / feature_size - tmp * tmp); var[blockIdx.x] = var_share =
static_cast<U>(pair.second_ / feature_size - tmp * tmp);
} }
__syncthreads(); __syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon)); mean_val = mean_share;
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y // Step 2: Calculate y
if (scale != nullptr) { if (scale != nullptr) {
if (bias != nullptr) { if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = scale[j] * (x[i] - mean_val) / var_val + bias[j]; y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
} }
} else { } else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = scale[j] * (x[i] - mean_val) / var_val; y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
} }
} }
} else { // scale == nullptr } else { // scale == nullptr
if (bias != nullptr) { if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = (x[i] - mean_val) / var_val + bias[j]; y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
} }
} else { } else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx; for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) { i += BlockDim, j += BlockDim) {
y[i] = (x[i] - mean_val) / var_val; y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
}
template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ var,
const float epsilon) {
const int i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return;
U curr_mean = mean[i1];
U curr_invvar = rsqrt<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2;
const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ var,
float epsilon, U *part_grad_gamma, U *part_grad_beta) {
// VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
// -> blockDim.x
// template for compile time optimizations
constexpr int row_stride = BDIMX + 1;
const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
const int thr_load_row_off =
(threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;
constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
? BDIMX * BDIMY
: 2 * VPTX * BDIMY * row_stride;
__shared__ U buf[shared_cap];
U *warp_buf1 = reinterpret_cast<U *>(buf);
U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;
for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
buf[idx] = U(0);
}
__syncthreads();
for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
i1_block += VPTX * BDIMY * gridDim.y) {
cuLoadAddStridedInputs<T, U, VPTX>(
i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < VPTX; ++k) {
int row1 = threadIdx.y + k * VPTX;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) {
// sum partial gradients for gamma and beta
__shared__ U buf[BDIMX * BDIMY];
int i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / BDIMY;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
constexpr int nbsize3 = BDIMX * BDIMY / 2;
for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2,
// const U* __restrict__ mean, const U* __restrict__ var, const float
// epsilon, const T* gamma,
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = rsqrt<U>(var[i1] + epsilon);
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
constexpr int numx = BDIMX * BDIMY;
const int thrx = threadIdx.x + threadIdx.y * BDIMX;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
sum_loss1 +=
__shfl_xor_sync(0xffffffff, sum_loss1, mask,
warpSize); // WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 +=
__shfl_xor_sync(0xffffffff, sum_loss2, mask,
warpSize); // WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (BDIMY > 1) {
__shared__ U buf[BDIMX * BDIMY];
for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
} }
} }
} }
...@@ -157,35 +458,37 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias, ...@@ -157,35 +458,37 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
// Make sure that d_scale != nullptr && d_bias != nullptr // Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr // Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx> template <typename T, typename U, int BlockDim, bool HasDx>
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
T *d_scale, T *d_bias, T *d_x, U *d_scale, U *d_bias, T *d_x,
const T *mean, const T *var, const U *mean, const U *var,
const T *scale, float epsilon, const U *scale, float epsilon,
int batch_size, int feature_size, int batch_size, int feature_size,
int col_offset) { int col_offset) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset); int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int end_idx = batch_size * feature_size + (blockIdx.x + col_offset); int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int stride = BlockDim * feature_size; int stride = BlockDim * feature_size;
T d_scale_partial = 0, d_bias_partial = 0; U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) { for (int i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon)); auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; d_scale_partial += static_cast<U>(d_y[i]) *
d_bias_partial += d_y[i]; (static_cast<U>(x[i]) - mean[row_idx]) / var_val;
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) { if (HasDx) {
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
} }
} }
auto pair = BlockReduce(temp_storage) auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial), .Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
PairForLayerNormAddFunctor<T>()); PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = pair.first_; d_scale[blockIdx.x + col_offset] = pair.first_;
...@@ -196,32 +499,36 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y, ...@@ -196,32 +499,36 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
// Make sure that there is only one true expression: d_scale != nullptr // Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr // or d_bias != nullptr
// Notice: scale may be nullptr // Notice: scale may be nullptr
template <typename T, int BlockDim, bool HasDx, bool HasDScale> template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
__global__ void LayerNormBackwardGradientScaleOrBias( __global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean, const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const T *var, const T *scale, float epsilon, int batch_size, const U *var, const U *scale, float epsilon, int batch_size,
int feature_size, int col_offset) { int feature_size, int col_offset) {
using BlockReduce = cub::BlockReduce<T, BlockDim>; using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset; int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
int end_idx = batch_size * feature_size + blockIdx.x + col_offset; int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
int stride = BlockDim * feature_size; int stride = BlockDim * feature_size;
T d_scale_or_d_bias_partial = 0; U d_scale_or_d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) { for (int i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size; int row_idx = i / feature_size;
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon)); auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
if (HasDScale) { if (HasDScale) {
d_scale_or_d_bias_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val; d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) /
var_val;
} else { // d_bias != nullptr } else { // d_bias != nullptr
d_scale_or_d_bias_partial += d_y[i]; d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
} }
if (HasDx) { if (HasDx) {
if (scale != nullptr) { if (scale != nullptr) {
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val; d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
} else { } else {
d_x[i] = d_y[i] / var_val; d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
} }
} }
} }
...@@ -238,121 +545,138 @@ __global__ void LayerNormBackwardGradientScaleOrBias( ...@@ -238,121 +545,138 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
} }
} }
template <typename T, int BlockDim> template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x, __global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
const T *mean, const U *mean,
const T *var, const U *var,
float epsilon, float epsilon,
int feature_size) { int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2]; __shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x]; U block_mean = mean[blockIdx.x];
T block_var = var[blockIdx.x]; U block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0; U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x_mean_partial += d_x[i]; d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial += d_x[i] * (x[i] - block_mean); d_x_var_partial +=
static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
} }
auto pair = auto pair =
BlockReduce(temp_storage) BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial), .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>()); PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size; d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); d_x_reduce_tmp[1] =
static_cast<float>(pair.second_) /
(feature_size * (static_cast<float>(block_var) + epsilon));
} }
__syncthreads(); __syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0]; d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1]; d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial; d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -= (x[i] - block_mean) * d_x_var_partial; d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
} }
} }
// Here, we only calculate d_x // Here, we only calculate d_x
template <typename T, int BlockDim> template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y, __global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const T *mean, T *d_x, const U *mean,
const T *var, const T *scale, const U *var, const U *scale,
float epsilon, float epsilon,
int feature_size) { int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>; using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage; __shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2]; __shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x; int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size; int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x]; U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0; U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = static_cast<T>(real_sqrt(block_var + epsilon)); auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) { if (scale != nullptr) {
int col_idx = i % feature_size; int col_idx = i % feature_size;
d_x[i] = d_y[i] * scale[col_idx] / var_val; d_x[i] =
static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
} else { } else {
d_x[i] = d_y[i] / var_val; d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
} }
d_x_mean_partial += d_x[i]; d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial += d_x[i] * (x[i] - block_mean); d_x_var_partial +=
static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
} }
auto pair = auto pair =
BlockReduce(temp_storage) BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial), .Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>()); PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) { if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size; d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon)); d_x_reduce_tmp[1] =
static_cast<float>(pair.second_) /
(feature_size * (static_cast<float>(block_var) + epsilon));
} }
__syncthreads(); __syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0]; d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1]; d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) { for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial; d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -= (x[i] - block_mean) * d_x_var_partial; d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
} }
} }
template <typename T> template <typename T, typename U>
__global__ void LayerNormBackwardWhenBatchSizeIsOne( __global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, T *d_scale, T *d_bias, const T *mean, const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
const T *var, const T *scale, float epsilon, int feature_size) { const U *var, const U *scale, float epsilon, int feature_size) {
int idx = threadIdx.x + blockIdx.x * blockDim.x; int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < feature_size) { if (idx < feature_size) {
auto var_val = static_cast<T>(real_sqrt(var[idx] + epsilon)); auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
if (d_x != nullptr) { if (d_x != nullptr) {
if (d_scale == nullptr) { if (d_scale == nullptr) {
d_x[idx] = d_y[idx] / var_val; d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
} else { } else {
d_x[idx] = d_y[idx] * scale[idx] / var_val; d_x[idx] =
static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
} }
} }
if (d_scale != nullptr) { if (d_scale != nullptr) {
d_scale[idx] = d_y[idx] * (x[idx] - mean[idx]) / var_val; d_scale[idx] = static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[idx]) / var_val;
} }
if (d_bias != nullptr) d_bias[idx] = d_y[idx]; if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
} }
} }
template <typename T> template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const T *scale, static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const T *mean, const T *var, T *d_x, T *d_scale, const U *mean, const U *var, T *d_x, U *d_scale,
T *d_bias, float epsilon, int batch_size, U *d_bias, float epsilon, int batch_size,
int feature_size, cudaStream_t stream) { int feature_size,
const framework::ExecutionContext &ctx) {
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
const int kMaxBlockDim = 512; const int kMaxBlockDim = 512;
const int kMaxBlockNum = 128; const int kMaxBlockNum = 128;
int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) | int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
...@@ -362,14 +686,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -362,14 +686,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
if (batch_size == 1) { if (batch_size == 1) {
LayerNormBackwardWhenBatchSizeIsOne< LayerNormBackwardWhenBatchSizeIsOne<
T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0, T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon, 0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
feature_size); epsilon, feature_size);
if (d_x != nullptr) { if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX< FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<1, kBlockDim, 0, stream>>>( T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size)); x, d_x, mean, var, epsilon, feature_size));
} }
} }
...@@ -383,7 +707,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -383,7 +707,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, false, T, U, kBlockDim, false,
false><<<block_num, kBlockDim, 0, stream>>>( false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
...@@ -394,7 +718,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -394,7 +718,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, false, true><<<block_num, kBlockDim, 0, stream>>>( T, U, kBlockDim, false,
true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -404,7 +729,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -404,7 +729,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll< LayerNormBackwardGradientAll<
T, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>( T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
...@@ -413,7 +738,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -413,7 +738,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX< LayerNormBackwardGradientOnlyDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size)); x, d_y, d_x, mean, var, scale, epsilon, feature_size));
} }
break; break;
...@@ -422,14 +747,15 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -422,14 +747,15 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, true, false><<<block_num, kBlockDim, 0, stream>>>( T, U, kBlockDim, true,
false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX< LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size)); x, d_x, mean, var, epsilon, feature_size));
} }
break; break;
...@@ -438,33 +764,57 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale, ...@@ -438,33 +764,57 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum, feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias< LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, true, true><<<block_num, kBlockDim, 0, stream>>>( T, U, kBlockDim, true,
true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset)); batch_size, feature_size, col_offset));
} }
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX< LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size)); x, d_x, mean, var, epsilon, feature_size));
} }
break; break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch (block_dim) { {
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE( constexpr int VPT = 4;
feature_size, kMaxBlockNum, constexpr int BDIMX2 = 32;
LayerNormBackwardGradientAll< constexpr int BDIMY2 = 4;
T, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>( dim3 threads2(BDIMX2, BDIMY2, 1);
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon, constexpr int part_size = BDIMY2 * VPT;
batch_size, feature_size, col_offset)); const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);
}
switch (GetDesiredBlockDim(feature_size)) { auto part_grad_gamma_ptr =
FIXED_BLOCK_DIM_CASE( memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
LayerNormBackwardPostProcessToCalculateDX< auto part_grad_beta_ptr =
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
x, d_x, mean, var, epsilon, feature_size)); U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
} U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
VPT><<<blocks2, threads2, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
part_grad_beta); // compute part_grad_gamma, beta
constexpr int BDIMX3 = 32;
constexpr int BDIMY3 = 8;
dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias);
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break; break;
}
default: default:
break; break;
} }
...@@ -483,7 +833,7 @@ void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream, ...@@ -483,7 +833,7 @@ void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
int feature_size = static_cast<int>(matrix_dim[1]); int feature_size = static_cast<int>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size)); input, scale, bias, output, mean, variance, eps, feature_size));
default: default:
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
...@@ -498,6 +848,7 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -498,6 +848,7 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale"); auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias"); auto *bias = ctx.Input<Tensor>("Bias");
...@@ -511,10 +862,10 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -511,10 +862,10 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
const auto x_dims = x->dims(); const auto x_dims = x->dims();
auto *x_data = x->data<T>(); auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace()); auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<T>(ctx.GetPlace()); auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<T>(ctx.GetPlace()); auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>()); auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<T>()); auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis); auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]); int batch_size = static_cast<int>(matrix_dim[0]);
...@@ -524,7 +875,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T> ...@@ -524,7 +875,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
switch (GetDesiredBlockDim(feature_size)) { switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE( FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>( LayerNormForward<T, U,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x_data, scale_data, bias_data, y_data, mean_data, var_data, x_data, scale_data, bias_data, y_data, mean_data, var_data,
epsilon, feature_size)); epsilon, feature_size));
default: default:
...@@ -540,6 +892,7 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -540,6 +892,7 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> { : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
// d_x, d_scale, d_bias may be nullptr // d_x, d_scale, d_bias may be nullptr
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X")); auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -554,14 +907,15 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -554,14 +907,15 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
auto *x_data = x->data<T>(); auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>(); auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<T>(); auto *mean_data = mean->data<U>();
auto *var_data = var->data<T>(); auto *var_data = var->data<U>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *d_scale_data = auto *d_scale_data =
(d_scale == nullptr ? nullptr (d_scale == nullptr ? nullptr
: d_scale->mutable_data<T>(ctx.GetPlace())); : d_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_bias_data = auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace())); (d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_x_data = auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace())); (d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
...@@ -571,14 +925,14 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T> ...@@ -571,14 +925,14 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
int batch_size = static_cast<int>(matrix_dim[0]); int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]); int feature_size = static_cast<int>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream(); LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
LayerNormBackward<T>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon, d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size, stream); batch_size, feature_size, ctx);
} }
}; };
template class LayerNormDirectCUDAFunctor<float>; template class LayerNormDirectCUDAFunctor<float>;
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE #undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
#undef FIXED_BLOCK_DIM_CASE_BASE #undef FIXED_BLOCK_DIM_CASE_BASE
...@@ -587,11 +941,15 @@ template class LayerNormDirectCUDAFunctor<float>; ...@@ -587,11 +941,15 @@ template class LayerNormDirectCUDAFunctor<float>;
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
layer_norm, layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>, ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>); ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
layer_norm_grad, layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>, ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>); ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
...@@ -109,9 +109,11 @@ gray_list = { ...@@ -109,9 +109,11 @@ gray_list = {
'elementwise_mod', 'elementwise_mod',
'elementwise_floordiv', 'elementwise_floordiv',
'batch_norm', 'batch_norm',
'layer_norm',
'tanh', 'tanh',
'sigmoid', 'sigmoid',
'lookup_table', 'lookup_table',
'lookup_table_v2',
'top_k', 'top_k',
'pool2d', 'pool2d',
'pool3d', 'pool3d',
...@@ -123,6 +125,7 @@ gray_list = { ...@@ -123,6 +125,7 @@ gray_list = {
'flatten2', 'flatten2',
'stack', 'stack',
'unstack', 'unstack',
'uniform_random',
'uniform_random_batch_size_like', 'uniform_random_batch_size_like',
'gaussian_random', 'gaussian_random',
'gaussian_random_batch_size_like', 'gaussian_random_batch_size_like',
...@@ -192,7 +195,6 @@ unsupported_fp16_list = { ...@@ -192,7 +195,6 @@ unsupported_fp16_list = {
'sequence_concat', 'sequence_concat',
'sequence_slice', 'sequence_slice',
'data_norm', 'data_norm',
'layer_norm',
'group_norm', 'group_norm',
'spectral_norm', 'spectral_norm',
'depthwise_conv2d_transpose', 'depthwise_conv2d_transpose',
......
...@@ -76,7 +76,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -76,7 +76,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names: for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [ if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation' 'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]: ]:
if in_name not in {'X', 'Z'}: if in_name not in {'X', 'Z'}:
continue continue
...@@ -110,7 +110,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype): ...@@ -110,7 +110,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype) op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16: if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names: for out_name in op.output_names:
if op.type in ['batch_norm', 'fused_bn_add_activation' if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y': ] and out_name != 'Y':
continue continue
for out_var_name in op.output(out_name): for out_var_name in op.output(out_name):
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
from __future__ import print_function from __future__ import print_function
import unittest import unittest
import numpy as np import numpy as np
import paddle
from operator import mul from operator import mul
import paddle.fluid.core as core import paddle.fluid.core as core
...@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase): ...@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase):
class TestDygraphLayerNormAPIError(unittest.TestCase): class TestDygraphLayerNormAPIError(unittest.TestCase):
def test_errors(self): def test_errors(self):
with program_guard(Program(), Program()): with program_guard(Program(), Program()):
paddle.enable_static()
layer_norm = fluid.LayerNorm([32, 32]) layer_norm = fluid.LayerNorm([32, 32])
# the input of LayerNorm must be Variable. # the input of LayerNorm must be Variable.
x1 = np.random.random((3, 32, 32)).astype('float32') x1 = np.random.random((3, 32, 32)).astype('float32')
......
...@@ -299,7 +299,8 @@ def layer_norm(x, ...@@ -299,7 +299,8 @@ def layer_norm(x,
'begin_norm_axis', begin_norm_axis) 'begin_norm_axis', begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None) return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm') check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'LayerNorm')
inputs = dict() inputs = dict()
inputs['X'] = [x] inputs['X'] = [x]
...@@ -311,11 +312,13 @@ def layer_norm(x, ...@@ -311,11 +312,13 @@ def layer_norm(x,
# create output # create output
helper = LayerHelper('layer_norm', **locals()) helper = LayerHelper('layer_norm', **locals())
dtype = x.dtype
mean_out = helper.create_variable_for_type_inference( mean_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference( variance_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True) dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype) layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op( helper.append_op(
type="layer_norm", type="layer_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册