未验证 提交 44b81e63 编写于 作者: F furnace 提交者: GitHub

[Cherry-pick] Layer norm fp16 and Nvidia optimize (#29169 #29434 #29522 #29576) (#30110)

* Layer norm fp16 (#29169)

* add fp16 for layer_norm op

* revert layernorm api

* fix forward

* fix forward

* fix backward for layernorm with fp16

* fix unit test for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16

* 1. revert to PADDLE_ENFORCE_NOT_NULL, 2. change static_cast<float> to static_cast<U>

* fix with_mkldnn compile error for layernorm with fp16

* fix with_mkldnn compile error for layernorm with fp16
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>

* fix layer_norm accuracy (#29434)

* Layernorm opt (#29522)

* layernorm fw opt

* layernorm bw opt

* fix typo, test=develop

* remove const dim3 for windows CI compatibility

* merge develop
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>

* Fix compile problem when cuda_arch < 6000 (#29576)

* fix compile problem when cuda_arch < 6000

* refine code

* refine code
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
Co-authored-by: Nzlsh80826 <zlsh80826@gmail.com>
上级 cb71fea0
......@@ -14,6 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/layer_norm_op.h"
#include <memory>
#include <string>
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
......@@ -98,7 +99,26 @@ class LayerNormOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const {
const framework::ExecutionContext &ctx) const override {
auto input_data_type = OperatorWithKernel::IndicateVarDataType(ctx, "X");
// By default, the type of the scale, bias, mean,
// and var tensors should both be float. (For float or float16 input tensor)
// or double (For double input tensor).
auto ln_param_type = framework::proto::VarType::FP32;
if (input_data_type == framework::proto::VarType::FP64) {
ln_param_type = framework::proto::VarType::FP64;
}
if (ctx.HasInput("Scale")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Scale")->type(),
platform::errors::InvalidArgument(
"Scale input should be of float type"));
}
if (ctx.HasInput("Bias")) {
PADDLE_ENFORCE_EQ(ln_param_type, ctx.Input<Tensor>("Bias")->type(),
platform::errors::InvalidArgument(
"Bias input should be of float type"));
}
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
......@@ -110,9 +130,8 @@ class LayerNormOp : public framework::OperatorWithKernel {
}
#endif
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout,
library);
}
};
......@@ -224,7 +243,13 @@ class LayerNormGradOp : public framework::OperatorWithKernel {
}
PADDLE_ENFORCE_NOT_NULL(
t, platform::errors::NotFound("Y@GRAD of LayerNorm Op is not found."));
return framework::OpKernelType(t->type(), ctx.GetPlace());
framework::LibraryType library = framework::LibraryType::kPlain;
framework::DataLayout layout = framework::DataLayout::kAnyLayout;
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace(),
layout, library);
}
};
......
......@@ -15,12 +15,22 @@ limitations under the License. */
#include <cub/cub.cuh>
#include <memory>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
#include "paddle/fluid/operators/layer_norm_op.h"
#include "paddle/fluid/platform/cudnn_helper.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using DataLayout = framework::DataLayout;
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
inline static int GetDesiredBlockDim(int block_dim) {
const int kMaxBlockDim = 512;
return block_dim >= kMaxBlockDim
......@@ -97,59 +107,350 @@ struct PairForLayerNormAddFunctor {
}
};
template <typename T, int BlockDim>
__global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
T *y, T *mean, T *var, float epsilon,
template <typename T>
__inline__ __device__ T rsqrt(const T val) {
return static_cast<T>(1) / sqrt(val);
}
template <>
__inline__ __device__ float rsqrt(const float val) {
return rsqrtf(val);
}
template <>
__inline__ __device__ double rsqrt(const double val) {
return rsqrt(val);
}
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
template <>
__inline__ __device__ half rsqrt(const half val) {
return hrsqrt(val);
}
#endif
template <typename T, typename U, int BlockDim>
__global__ void LayerNormForward(const T *x, const U *scale, const U *bias,
T *y, U *mean, U *var, float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<double>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ U mean_share;
__shared__ U var_share;
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
// Step 1: Reduce to calculate mean and var
double mean_val = 0;
double var_val = 0;
U mean_val = 0;
U var_val = 0;
for (int i = beg_idx; i < end_idx; i += BlockDim) {
T tmp = x[i];
U tmp = static_cast<U>(x[i]);
mean_val += tmp;
var_val += (tmp * tmp);
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<double>(mean_val, var_val),
PairForLayerNormAddFunctor<double>());
.Reduce(PairForLayerNorm<U>(mean_val, var_val),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
auto tmp = pair.first_ / feature_size;
mean[blockIdx.x] = static_cast<T>(tmp);
var[blockIdx.x] = static_cast<T>(pair.second_ / feature_size - tmp * tmp);
mean[blockIdx.x] = mean_share = static_cast<U>(tmp);
var[blockIdx.x] = var_share =
static_cast<U>(pair.second_ / feature_size - tmp * tmp);
}
__syncthreads();
mean_val = mean[blockIdx.x];
var_val = static_cast<T>(real_sqrt(var[blockIdx.x] + epsilon));
mean_val = mean_share;
U invvar = rsqrt<U>(var_share + static_cast<U>(epsilon));
// Step 2: Calculate y
if (scale != nullptr) {
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = scale[j] * (x[i] - mean_val) / var_val + bias[j];
y[i] = static_cast<T>(
scale[j] * (static_cast<U>(x[i]) - mean_val) * invvar + bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = scale[j] * (x[i] - mean_val) / var_val;
y[i] = static_cast<T>(scale[j] * (static_cast<U>(x[i]) - mean_val) *
invvar);
}
}
} else { // scale == nullptr
if (bias != nullptr) {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = (x[i] - mean_val) / var_val + bias[j];
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar +
bias[j]);
}
} else {
for (int i = beg_idx, j = threadIdx.x; i < end_idx;
i += BlockDim, j += BlockDim) {
y[i] = (x[i] - mean_val) / var_val;
y[i] = static_cast<T>((static_cast<U>(x[i]) - mean_val) * invvar);
}
}
}
}
template <typename T, typename U, int VPT>
__inline__ __device__ void cuLoadAddStridedInputs(
const int i1_block, const int thr_load_row_off, const int thr_load_col_off,
const int i2_off, const int row_stride, U *warp_buf1, U *warp_buf2,
const T *input, const T *dout, const int i1_end, const int n2,
const U *__restrict__ mean, const U *__restrict__ var,
const float epsilon) {
const int i1 = i1_block + thr_load_row_off;
if (i1 >= i1_end) return;
U curr_mean = mean[i1];
U curr_invvar = rsqrt<U>(var[i1] + epsilon);
for (int k = 0; k < VPT; ++k) {
const int i2 = i2_off + k;
const int load_idx = i1 * n2 + i2;
const int write_idx = thr_load_row_off * row_stride + thr_load_col_off + k;
if (i2 < n2) {
U curr_input = static_cast<U>(input[load_idx]);
U curr_dout = static_cast<U>(dout[load_idx]);
warp_buf1[write_idx] += curr_dout;
warp_buf2[write_idx] +=
curr_dout * (curr_input - curr_mean) * curr_invvar;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY, int VPTX>
__global__ void LayerNormBackwardPartGradGammaBeta(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2, const U *__restrict__ mean, const U *__restrict__ var,
float epsilon, U *part_grad_gamma, U *part_grad_beta) {
// VPTX -> value per thread.x, BDIMX -> blockDim.x, BDIMY -> blockDim.y, BDIMX
// -> blockDim.x
// template for compile time optimizations
constexpr int row_stride = BDIMX + 1;
const int thr_load_col_off = (threadIdx.x * VPTX) & (BDIMX - 1);
const int thr_load_row_off =
(threadIdx.x * VPTX) / BDIMX + threadIdx.y * BDIMY;
const int i2_off = blockIdx.x * BDIMX + thr_load_col_off;
constexpr int shared_cap = (BDIMX * BDIMY > 2 * VPTX * BDIMY * row_stride)
? BDIMX * BDIMY
: 2 * VPTX * BDIMY * row_stride;
__shared__ U buf[shared_cap];
U *warp_buf1 = reinterpret_cast<U *>(buf);
U *warp_buf2 = warp_buf1 + VPTX * BDIMY * row_stride;
for (int idx = threadIdx.y * blockDim.x + threadIdx.x;
idx < 2 * VPTX * BDIMY * row_stride; idx += BDIMX * BDIMY) {
buf[idx] = U(0);
}
__syncthreads();
for (int i1_block = blockIdx.y * BDIMY * VPTX; i1_block < n1;
i1_block += VPTX * BDIMY * gridDim.y) {
cuLoadAddStridedInputs<T, U, VPTX>(
i1_block, thr_load_row_off, thr_load_col_off, i2_off, row_stride,
warp_buf1, warp_buf2, input, dout, n1, n2, mean, var, epsilon);
}
__syncthreads();
// inter-warp reductions
// sum within each warp
U acc1 = U(0);
U acc2 = U(0);
for (int k = 0; k < VPTX; ++k) {
int row1 = threadIdx.y + k * VPTX;
int idx1 = row1 * row_stride + threadIdx.x;
acc1 += warp_buf1[idx1];
acc2 += warp_buf2[idx1];
}
warp_buf1[threadIdx.y * row_stride + threadIdx.x] = acc1;
warp_buf2[threadIdx.y * row_stride + threadIdx.x] = acc2;
__syncthreads();
// sum all warps
for (int offset = VPTX >> 1; offset > 1; offset >>= 1) {
if (threadIdx.y < offset) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + offset;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
warp_buf1[idx1] += warp_buf1[idx2];
warp_buf2[idx1] += warp_buf2[idx2];
}
__syncthreads();
}
int i2 = blockIdx.x * blockDim.x + threadIdx.x;
if (threadIdx.y == 0 && i2 < n2) {
int row1 = threadIdx.y;
int row2 = threadIdx.y + 1;
int idx1 = row1 * row_stride + threadIdx.x;
int idx2 = row2 * row_stride + threadIdx.x;
part_grad_beta[blockIdx.y * n2 + i2] = warp_buf1[idx1] + warp_buf1[idx2];
part_grad_gamma[blockIdx.y * n2 + i2] = warp_buf2[idx1] + warp_buf2[idx2];
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardSumGradGammaBeta(
const U *part_grad_gamma, const U *part_grad_beta, const int part_size,
// const int n1, const int n2, T* grad_gamma, T* grad_beta) {
const int n1, const int n2, U *grad_gamma, U *grad_beta) {
// sum partial gradients for gamma and beta
__shared__ U buf[BDIMX * BDIMY];
int i2 = blockIdx.x * BDIMX + threadIdx.x;
if (i2 < n2) {
// each warp does sequential reductions until reduced part_size is num_warps
int num_warp_reductions = part_size / BDIMY;
U sum_gamma = U(0);
U sum_beta = U(0);
const U *part_grad_gamma_ptr =
part_grad_gamma + threadIdx.y * num_warp_reductions * n2 + i2;
const U *part_grad_beta_ptr =
part_grad_beta + threadIdx.y * num_warp_reductions * n2 + i2;
for (int warp_offset = 0; warp_offset < num_warp_reductions;
++warp_offset) {
sum_gamma += part_grad_gamma_ptr[warp_offset * n2];
sum_beta += part_grad_beta_ptr[warp_offset * n2];
}
// inter-warp reductions
constexpr int nbsize3 = BDIMX * BDIMY / 2;
for (int offset = BDIMY / 2; offset >= 1; offset /= 2) {
// top half write to shared memory
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int write_idx = (threadIdx.y - offset) * blockDim.x + threadIdx.x;
buf[write_idx] = sum_gamma;
buf[write_idx + nbsize3] = sum_beta;
}
__syncthreads();
// bottom half sums
if (threadIdx.y < offset) {
const int read_idx = threadIdx.y * BDIMX + threadIdx.x;
sum_gamma += buf[read_idx];
sum_beta += buf[read_idx + nbsize3];
}
__syncthreads();
}
// write out fully summed gradients
if (threadIdx.y == 0) {
grad_gamma[i2] = sum_gamma;
grad_beta[i2] = sum_beta;
}
}
}
template <typename T, typename U, int BDIMX, int BDIMY>
__global__ void LayerNormBackwardComputeGradInput(
const T *__restrict__ dout, const T *__restrict__ input, const int n1,
const int n2,
// const U* __restrict__ mean, const U* __restrict__ var, const float
// epsilon, const T* gamma,
const U *__restrict__ mean, const U *__restrict__ var, const float epsilon,
const U *gamma, T *grad_input) {
for (auto i1 = blockIdx.y; i1 < n1; i1 += gridDim.y) {
U sum_loss1 = U(0);
U sum_loss2 = U(0);
const U c_mean = mean[i1];
const U c_invvar = rsqrt<U>(var[i1] + epsilon);
const T *k_input = input + i1 * n2;
const T *k_dout = dout + i1 * n2;
constexpr int numx = BDIMX * BDIMY;
const int thrx = threadIdx.x + threadIdx.y * BDIMX;
if (gamma != NULL) {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss * gamma[l + k];
sum_loss2 += c_loss * gamma[l + k] * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss * gamma[l];
sum_loss2 += c_loss * gamma[l] * (c_h - c_mean) * c_invvar;
}
} else {
int l = 4 * thrx;
for (; l + 3 < n2; l += 4 * numx) {
for (int k = 0; k < 4; ++k) {
const U c_h = static_cast<U>(k_input[l + k]);
const U c_loss = static_cast<U>(k_dout[l + k]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
for (; l < n2; ++l) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
sum_loss1 += c_loss;
sum_loss2 += c_loss * (c_h - c_mean) * c_invvar;
}
}
// intra-warp reductions
for (int mask = BDIMX / 2; mask > 0; mask /= 2) {
sum_loss1 +=
__shfl_xor_sync(0xffffffff, sum_loss1, mask,
warpSize); // WARP_SHFL_XOR(sum_loss1, mask);
sum_loss2 +=
__shfl_xor_sync(0xffffffff, sum_loss2, mask,
warpSize); // WARP_SHFL_XOR(sum_loss2, mask);
}
// inter-warp reductions
if (BDIMY > 1) {
__shared__ U buf[BDIMX * BDIMY];
for (int offset = BDIMY / 2; offset > 0; offset /= 2) {
// upper half of warps write to shared
if (threadIdx.y >= offset && threadIdx.y < 2 * offset) {
const int wrt_i = (threadIdx.y - offset) * BDIMX + threadIdx.x;
buf[2 * wrt_i] = sum_loss1;
buf[2 * wrt_i + 1] = sum_loss2;
}
__syncthreads();
// lower half merges
if (threadIdx.y < offset) {
const int read_i = threadIdx.y * blockDim.x + threadIdx.x;
sum_loss1 += buf[2 * read_i];
sum_loss2 += buf[2 * read_i + 1];
}
__syncthreads();
}
if (threadIdx.y == 0) {
buf[2 * threadIdx.x] = sum_loss1;
buf[2 * threadIdx.x + 1] = sum_loss2;
}
__syncthreads();
if (threadIdx.y != 0) {
sum_loss1 = buf[2 * threadIdx.x];
sum_loss2 = buf[2 * threadIdx.x + 1];
}
}
// all threads now have the two sums over l
U fH = (U)n2;
U term1 = (U(1) / fH) * c_invvar;
T *k_grad_input = grad_input + i1 * n2;
if (gamma != NULL) {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss * gamma[l];
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
} else {
for (int l = thrx; l < n2; l += numx) {
const U c_h = static_cast<U>(k_input[l]);
const U c_loss = static_cast<U>(k_dout[l]);
U f_grad_input = fH * c_loss;
f_grad_input -= sum_loss1;
f_grad_input -= (c_h - c_mean) * c_invvar * sum_loss2;
f_grad_input *= term1;
k_grad_input[l] = static_cast<T>(f_grad_input);
}
}
}
......@@ -157,35 +458,37 @@ __global__ void LayerNormForward(const T *x, const T *scale, const T *bias,
// Make sure that d_scale != nullptr && d_bias != nullptr
// Since d_scale != nullptr, scale would not be nullptr
template <typename T, int BlockDim, bool HasDx>
template <typename T, typename U, int BlockDim, bool HasDx>
__global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
T *d_scale, T *d_bias, T *d_x,
const T *mean, const T *var,
const T *scale, float epsilon,
U *d_scale, U *d_bias, T *d_x,
const U *mean, const U *var,
const U *scale, float epsilon,
int batch_size, int feature_size,
int col_offset) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + (blockIdx.x + col_offset);
int end_idx = batch_size * feature_size + (blockIdx.x + col_offset);
int stride = BlockDim * feature_size;
T d_scale_partial = 0, d_bias_partial = 0;
U d_scale_partial = static_cast<U>(0), d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
d_scale_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
d_bias_partial += d_y[i];
auto var_val = real_sqrt(static_cast<U>(var[row_idx]) + epsilon);
d_scale_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) / var_val;
d_bias_partial += static_cast<U>(d_y[i]);
if (HasDx) {
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
}
}
auto pair = BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_scale_partial, d_bias_partial),
PairForLayerNormAddFunctor<T>());
.Reduce(PairForLayerNorm<U>(d_scale_partial, d_bias_partial),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
d_scale[blockIdx.x + col_offset] = pair.first_;
......@@ -196,32 +499,36 @@ __global__ void LayerNormBackwardGradientAll(const T *x, const T *d_y,
// Make sure that there is only one true expression: d_scale != nullptr
// or d_bias != nullptr
// Notice: scale may be nullptr
template <typename T, int BlockDim, bool HasDx, bool HasDScale>
template <typename T, typename U, int BlockDim, bool HasDx, bool HasDScale>
__global__ void LayerNormBackwardGradientScaleOrBias(
const T *x, const T *d_y, T *d_scale, T *d_bias, T *d_x, const T *mean,
const T *var, const T *scale, float epsilon, int batch_size,
const T *x, const T *d_y, U *d_scale, U *d_bias, T *d_x, const U *mean,
const U *var, const U *scale, float epsilon, int batch_size,
int feature_size, int col_offset) {
using BlockReduce = cub::BlockReduce<T, BlockDim>;
using BlockReduce = cub::BlockReduce<U, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
int beg_idx = threadIdx.x * feature_size + blockIdx.x + col_offset;
int end_idx = batch_size * feature_size + blockIdx.x + col_offset;
int stride = BlockDim * feature_size;
T d_scale_or_d_bias_partial = 0;
U d_scale_or_d_bias_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += stride) {
int row_idx = i / feature_size;
auto var_val = static_cast<T>(real_sqrt(var[row_idx] + epsilon));
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[row_idx]) + epsilon));
if (HasDScale) {
d_scale_or_d_bias_partial += d_y[i] * (x[i] - mean[row_idx]) / var_val;
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]) *
(static_cast<U>(x[i]) - mean[row_idx]) /
var_val;
} else { // d_bias != nullptr
d_scale_or_d_bias_partial += d_y[i];
d_scale_or_d_bias_partial += static_cast<U>(d_y[i]);
}
if (HasDx) {
if (scale != nullptr) {
d_x[i] = d_y[i] * scale[blockIdx.x + col_offset] / var_val;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) *
scale[blockIdx.x + col_offset] / var_val);
} else {
d_x[i] = d_y[i] / var_val;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
}
}
......@@ -238,121 +545,138 @@ __global__ void LayerNormBackwardGradientScaleOrBias(
}
}
template <typename T, int BlockDim>
template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardPostProcessToCalculateDX(const T *x, T *d_x,
const T *mean,
const T *var,
const U *mean,
const U *var,
float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2];
__shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x];
T block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
U block_mean = mean[blockIdx.x];
U block_var = var[blockIdx.x];
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial +=
static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
.Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
d_x_reduce_tmp[1] =
static_cast<float>(pair.second_) /
(feature_size * (static_cast<float>(block_var) + epsilon));
}
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
}
}
// Here, we only calculate d_x
template <typename T, int BlockDim>
template <typename T, typename U, int BlockDim>
__global__ void LayerNormBackwardGradientOnlyDX(const T *x, const T *d_y,
T *d_x, const T *mean,
const T *var, const T *scale,
T *d_x, const U *mean,
const U *var, const U *scale,
float epsilon,
int feature_size) {
using BlockReduce = cub::BlockReduce<PairForLayerNorm<T>, BlockDim>;
using BlockReduce = cub::BlockReduce<PairForLayerNorm<U>, BlockDim>;
__shared__ typename BlockReduce::TempStorage temp_storage;
__shared__ T d_x_reduce_tmp[2];
__shared__ U d_x_reduce_tmp[2];
int beg_idx = blockIdx.x * feature_size + threadIdx.x;
int end_idx = (blockIdx.x + 1) * feature_size;
T block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
T d_x_mean_partial = 0, d_x_var_partial = 0;
U block_mean = mean[blockIdx.x], block_var = var[blockIdx.x];
U d_x_mean_partial = static_cast<U>(0), d_x_var_partial = static_cast<U>(0);
for (int i = beg_idx; i < end_idx; i += BlockDim) {
auto var_val = static_cast<T>(real_sqrt(block_var + epsilon));
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(block_var) + epsilon));
if (scale != nullptr) {
int col_idx = i % feature_size;
d_x[i] = d_y[i] * scale[col_idx] / var_val;
d_x[i] =
static_cast<T>(static_cast<U>(d_y[i]) * scale[col_idx] / var_val);
} else {
d_x[i] = d_y[i] / var_val;
d_x[i] = static_cast<T>(static_cast<U>(d_y[i]) / var_val);
}
d_x_mean_partial += d_x[i];
d_x_var_partial += d_x[i] * (x[i] - block_mean);
d_x_mean_partial += static_cast<U>(d_x[i]);
d_x_var_partial +=
static_cast<U>(d_x[i]) * (static_cast<U>(x[i]) - block_mean);
}
auto pair =
BlockReduce(temp_storage)
.Reduce(PairForLayerNorm<T>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<T>());
.Reduce(PairForLayerNorm<U>(d_x_mean_partial, d_x_var_partial),
PairForLayerNormAddFunctor<U>());
if (threadIdx.x == 0) {
d_x_reduce_tmp[0] = pair.first_ / feature_size;
d_x_reduce_tmp[1] = pair.second_ / (feature_size * (block_var + epsilon));
d_x_reduce_tmp[0] = static_cast<float>(pair.first_) / feature_size;
d_x_reduce_tmp[1] =
static_cast<float>(pair.second_) /
(feature_size * (static_cast<float>(block_var) + epsilon));
}
__syncthreads();
d_x_mean_partial = d_x_reduce_tmp[0];
d_x_var_partial = d_x_reduce_tmp[1];
for (int i = beg_idx; i < end_idx; i += BlockDim) {
d_x[i] -= d_x_mean_partial;
d_x[i] -= (x[i] - block_mean) * d_x_var_partial;
d_x[i] -= static_cast<T>(d_x_mean_partial);
d_x[i] -=
static_cast<T>((static_cast<U>(x[i]) - block_mean) * d_x_var_partial);
}
}
template <typename T>
template <typename T, typename U>
__global__ void LayerNormBackwardWhenBatchSizeIsOne(
const T *x, const T *d_y, T *d_x, T *d_scale, T *d_bias, const T *mean,
const T *var, const T *scale, float epsilon, int feature_size) {
const T *x, const T *d_y, T *d_x, U *d_scale, U *d_bias, const U *mean,
const U *var, const U *scale, float epsilon, int feature_size) {
int idx = threadIdx.x + blockIdx.x * blockDim.x;
if (idx < feature_size) {
auto var_val = static_cast<T>(real_sqrt(var[idx] + epsilon));
auto var_val =
static_cast<U>(real_sqrt(static_cast<float>(var[idx]) + epsilon));
if (d_x != nullptr) {
if (d_scale == nullptr) {
d_x[idx] = d_y[idx] / var_val;
d_x[idx] = static_cast<T>(static_cast<U>(d_y[idx]) / var_val);
} else {
d_x[idx] = d_y[idx] * scale[idx] / var_val;
d_x[idx] =
static_cast<T>(static_cast<U>(d_y[idx]) * scale[idx] / var_val);
}
}
if (d_scale != nullptr) {
d_scale[idx] = d_y[idx] * (x[idx] - mean[idx]) / var_val;
d_scale[idx] = static_cast<U>(d_y[idx]) *
(static_cast<U>(x[idx]) - mean[idx]) / var_val;
}
if (d_bias != nullptr) d_bias[idx] = d_y[idx];
if (d_bias != nullptr) d_bias[idx] = static_cast<U>(d_y[idx]);
}
}
template <typename T>
static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
const T *mean, const T *var, T *d_x, T *d_scale,
T *d_bias, float epsilon, int batch_size,
int feature_size, cudaStream_t stream) {
template <typename T, typename U>
static void LayerNormBackward(const T *x, const T *d_y, const U *scale,
const U *mean, const U *var, T *d_x, U *d_scale,
U *d_bias, float epsilon, int batch_size,
int feature_size,
const framework::ExecutionContext &ctx) {
auto &dev_ctx = ctx.cuda_device_context();
auto stream = dev_ctx.stream();
const int kMaxBlockDim = 512;
const int kMaxBlockNum = 128;
int gradient_flag = ((d_x != nullptr ? 1 : 0) << 2) |
......@@ -362,14 +686,14 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
if (batch_size == 1) {
LayerNormBackwardWhenBatchSizeIsOne<
T><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim, 0,
stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale, epsilon,
feature_size);
T, U><<<(feature_size + kMaxBlockDim - 1) / kMaxBlockDim, kMaxBlockDim,
0, stream>>>(x, d_y, d_x, d_scale, d_bias, mean, var, scale,
epsilon, feature_size);
if (d_x != nullptr) {
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<1, kBlockDim, 0, stream>>>(
T, U, kBlockDim><<<1, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
}
......@@ -383,7 +707,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, false,
T, U, kBlockDim, false,
false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
......@@ -394,7 +718,8 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, false, true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false,
true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -404,7 +729,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll<
T, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
......@@ -413,7 +738,7 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardGradientOnlyDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_y, d_x, mean, var, scale, epsilon, feature_size));
}
break;
......@@ -422,14 +747,15 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, true, false><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true,
false><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break;
......@@ -438,33 +764,57 @@ static void LayerNormBackward(const T *x, const T *d_y, const T *scale,
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientScaleOrBias<
T, kBlockDim, true, true><<<block_num, kBlockDim, 0, stream>>>(
T, U, kBlockDim, true,
true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
T, U, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
break;
case 7: // d_x != nullptr, d_scale != nullptr, d_bias != nullptr
switch (block_dim) {
FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE(
feature_size, kMaxBlockNum,
LayerNormBackwardGradientAll<
T, kBlockDim, true><<<block_num, kBlockDim, 0, stream>>>(
x, d_y, d_scale, d_bias, d_x, mean, var, scale, epsilon,
batch_size, feature_size, col_offset));
}
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormBackwardPostProcessToCalculateDX<
T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x, d_x, mean, var, epsilon, feature_size));
}
{
constexpr int VPT = 4;
constexpr int BDIMX2 = 32;
constexpr int BDIMY2 = 4;
dim3 threads2(BDIMX2, BDIMY2, 1);
constexpr int part_size = BDIMY2 * VPT;
const dim3 blocks2((feature_size + BDIMX2 - 1) / BDIMX2, part_size, 1);
auto part_grad_gamma_ptr =
memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
auto part_grad_beta_ptr =
memory::Alloc(dev_ctx, part_size * feature_size * sizeof(U));
U *part_grad_gamma = reinterpret_cast<U *>(part_grad_gamma_ptr->ptr());
U *part_grad_beta = reinterpret_cast<U *>(part_grad_beta_ptr->ptr());
LayerNormBackwardPartGradGammaBeta<T, U, BDIMX2, BDIMY2,
VPT><<<blocks2, threads2, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, part_grad_gamma,
part_grad_beta); // compute part_grad_gamma, beta
constexpr int BDIMX3 = 32;
constexpr int BDIMY3 = 8;
dim3 threads3(BDIMX3, BDIMY3, 1);
const dim3 blocks3((feature_size + BDIMX2 - 1) / BDIMX2, 1, 1);
LayerNormBackwardSumGradGammaBeta<
T, U, BDIMX3, BDIMY3><<<blocks3, threads3, 0, stream>>>(
part_grad_gamma, part_grad_beta, part_size, batch_size, feature_size,
d_scale, d_bias);
constexpr int BDIMX1 = 32;
constexpr int BDIMY1 = 4;
dim3 threads1(BDIMX1, BDIMY1, 1);
const dim3 blocks1(1, batch_size, 1);
LayerNormBackwardComputeGradInput<
T, U, BDIMX1, BDIMY1><<<blocks1, threads1, 0, stream>>>(
d_y, x, batch_size, feature_size, mean, var, epsilon, scale, d_x);
break;
}
default:
break;
}
......@@ -483,7 +833,7 @@ void LayerNormDirectCUDAFunctor<T>::operator()(cudaStream_t stream,
int feature_size = static_cast<int>(matrix_dim[1]);
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
LayerNormForward<T, T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
input, scale, bias, output, mean, variance, eps, feature_size));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
......@@ -498,6 +848,7 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
auto *scale = ctx.Input<Tensor>("Scale");
auto *bias = ctx.Input<Tensor>("Bias");
......@@ -511,10 +862,10 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
const auto x_dims = x->dims();
auto *x_data = x->data<T>();
auto *y_data = y->mutable_data<T>(ctx.GetPlace());
auto *mean_data = mean->mutable_data<T>(ctx.GetPlace());
auto *var_data = var->mutable_data<T>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<T>());
auto *mean_data = mean->mutable_data<U>(ctx.GetPlace());
auto *var_data = var->mutable_data<U>(ctx.GetPlace());
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *bias_data = (bias == nullptr ? nullptr : bias->data<U>());
auto matrix_dim = framework::flatten_to_2d(x_dims, begin_norm_axis);
int batch_size = static_cast<int>(matrix_dim[0]);
......@@ -524,7 +875,8 @@ class LayerNormKernel<platform::CUDADeviceContext, T>
switch (GetDesiredBlockDim(feature_size)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
LayerNormForward<T, U,
kBlockDim><<<batch_size, kBlockDim, 0, stream>>>(
x_data, scale_data, bias_data, y_data, mean_data, var_data,
epsilon, feature_size));
default:
......@@ -540,6 +892,7 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
: public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
using U = LayerNormParamType<T>;
const float epsilon = ctx.Attr<float>("epsilon");
// d_x, d_scale, d_bias may be nullptr
auto *d_x = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -554,14 +907,15 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
auto *x_data = x->data<T>();
auto *d_y_data = d_y->data<T>();
auto *mean_data = mean->data<T>();
auto *var_data = var->data<T>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<T>());
auto *mean_data = mean->data<U>();
auto *var_data = var->data<U>();
auto *scale_data = (scale == nullptr ? nullptr : scale->data<U>());
auto *d_scale_data =
(d_scale == nullptr ? nullptr
: d_scale->mutable_data<T>(ctx.GetPlace()));
: d_scale->mutable_data<U>(ctx.GetPlace()));
auto *d_bias_data =
(d_bias == nullptr ? nullptr : d_bias->mutable_data<T>(ctx.GetPlace()));
(d_bias == nullptr ? nullptr : d_bias->mutable_data<U>(ctx.GetPlace()));
auto *d_x_data =
(d_x == nullptr ? nullptr : d_x->mutable_data<T>(ctx.GetPlace()));
......@@ -571,14 +925,14 @@ class LayerNormGradKernel<platform::CUDADeviceContext, T>
int batch_size = static_cast<int>(matrix_dim[0]);
int feature_size = static_cast<int>(matrix_dim[1]);
auto stream = ctx.cuda_device_context().stream();
LayerNormBackward<T>(x_data, d_y_data, scale_data, mean_data, var_data,
LayerNormBackward<T, U>(x_data, d_y_data, scale_data, mean_data, var_data,
d_x_data, d_scale_data, d_bias_data, epsilon,
batch_size, feature_size, stream);
batch_size, feature_size, ctx);
}
};
template class LayerNormDirectCUDAFunctor<float>;
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE_BASE
#undef FIXED_BLOCK_DIM_FIXED_BLOCK_NUM_CASE
#undef FIXED_BLOCK_DIM_CASE_BASE
......@@ -587,11 +941,15 @@ template class LayerNormDirectCUDAFunctor<float>;
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
layer_norm,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>);
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormKernel<paddle::platform::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(
layer_norm_grad,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>);
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::LayerNormGradKernel<paddle::platform::CUDADeviceContext,
plat::float16>);
......@@ -109,9 +109,11 @@ gray_list = {
'elementwise_mod',
'elementwise_floordiv',
'batch_norm',
'layer_norm',
'tanh',
'sigmoid',
'lookup_table',
'lookup_table_v2',
'top_k',
'pool2d',
'pool3d',
......@@ -123,6 +125,7 @@ gray_list = {
'flatten2',
'stack',
'unstack',
'uniform_random',
'uniform_random_batch_size_like',
'gaussian_random',
'gaussian_random_batch_size_like',
......@@ -192,7 +195,6 @@ unsupported_fp16_list = {
'sequence_concat',
'sequence_slice',
'data_norm',
'layer_norm',
'group_norm',
'spectral_norm',
'depthwise_conv2d_transpose',
......
......@@ -76,7 +76,7 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and op.type in [
'batch_norm', 'fused_bn_add_activation'
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
]:
if in_name not in {'X', 'Z'}:
continue
......@@ -110,7 +110,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
op._set_attr('in_dtype', dest_dtype)
if src_dtype == core.VarDesc.VarType.FP32 and dest_dtype == core.VarDesc.VarType.FP16:
for out_name in op.output_names:
if op.type in ['batch_norm', 'fused_bn_add_activation'
if op.type in [
'batch_norm', 'fused_bn_add_activation', 'layer_norm'
] and out_name != 'Y':
continue
for out_var_name in op.output(out_name):
......
......@@ -15,6 +15,7 @@
from __future__ import print_function
import unittest
import numpy as np
import paddle
from operator import mul
import paddle.fluid.core as core
......@@ -310,6 +311,8 @@ class TestLayerNormAPI(unittest.TestCase):
class TestDygraphLayerNormAPIError(unittest.TestCase):
def test_errors(self):
with program_guard(Program(), Program()):
paddle.enable_static()
layer_norm = fluid.LayerNorm([32, 32])
# the input of LayerNorm must be Variable.
x1 = np.random.random((3, 32, 32)).astype('float32')
......
......@@ -299,7 +299,8 @@ def layer_norm(x,
'begin_norm_axis', begin_norm_axis)
return dygraph_utils._append_activation_in_dygraph(pre_act, act=None)
check_variable_and_dtype(x, 'input', ['float32', 'float64'], 'LayerNorm')
check_variable_and_dtype(x, 'input', ['float16', 'float32', 'float64'],
'LayerNorm')
inputs = dict()
inputs['X'] = [x]
......@@ -311,11 +312,13 @@ def layer_norm(x,
# create output
helper = LayerHelper('layer_norm', **locals())
dtype = x.dtype
mean_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
dtype=dtype, stop_gradient=True)
variance_out = helper.create_variable_for_type_inference(
dtype=x.dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(x.dtype)
dtype=dtype, stop_gradient=True)
layer_norm_out = helper.create_variable_for_type_inference(dtype)
helper.append_op(
type="layer_norm",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册