未验证 提交 7975dfcf 编写于 作者: Z zhangkaihuo 提交者: GitHub

add a fusion op: fused_layernorm_residual_dropout_bias (#35151)

Fused elementwise_add, dropout, elementwise_add and layer_norm into one operator, only support Forward. 
No Python API changed.
上级 fb4d5689
......@@ -74,7 +74,8 @@ if (WITH_GPU OR WITH_ROCM)
# fused_dropout
# only support CUDA
if(NOT WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory)
endif()
endif()
......@@ -17,8 +17,7 @@ limitations under the License. */
#define _USE_MATH_DEFINES
#endif
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
namespace paddle {
namespace operators {
......@@ -75,66 +74,15 @@ __global__ void FusedDropoutActBias(
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0);
}
if (is_test) {
factor = static_cast<T>(1.0f - dropout_prob);
if (is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
}
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
LoadT src_vec;
LoadT bias_vec;
// vectorize load data from global
platform::Load<T, VecSize>(&src[r * cols + i], &src_vec);
if (bias) {
platform::Load<T, VecSize>(&bias[i], &bias_vec);
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
}
}
MaskStoreT mask_vec;
if (!is_test) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
}
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(1);
}
}
StoreT dest_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
const T tmp = src_vec[ii] + bias_vec[ii];
const T act_out = act(tmp);
dest_vec[ii] = act_out * static_cast<T>(mask_vec[ii]) * factor;
}
// store result to global
platform::Store<T, VecSize>(dest_vec, &dst[r * cols + i]);
if (!is_test) {
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
}
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, true,
Functor>(
r, i, cols, &state, dropout_prob, factor, src, nullptr, bias, dst,
mask, is_test, nullptr, nullptr, act);
}
}
}
......@@ -197,10 +145,8 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T args[2];
args[0] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
args[1] = src_vec[ii];
dx_vec[ii] = args[0] * act_grad.UseOut(args[1]);
T tmp = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]);
}
platform::Store<T, VecSize>(dx_vec, &dx[i]);
}
......@@ -243,10 +189,8 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T args[2];
args[0] = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
args[1] = src_vec[i] + bias_vec[i];
val = args[0] * act_grad.UseOut(args[1]);
T tmp = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]);
dx_vec[i] = val;
tmp_sum[i] += val;
}
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/functors.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -133,5 +134,17 @@ inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias,
}
}
template <typename T>
inline __device__ T GetFactor(const float dropout_prob,
const bool is_upscale_in_train,
const bool is_test) {
T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
: static_cast<T>(1.0f);
if (is_test) {
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
return factor;
}
} // namespace operators
} // namespace paddle
......@@ -23,6 +23,7 @@ limitations under the License. */
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
......@@ -31,6 +32,12 @@ namespace platform = paddle::platform;
namespace memory = paddle::memory;
USE_OP(dropout);
USE_OP(layer_norm);
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
/**
* @brief call paddle dropout op
......@@ -116,6 +123,60 @@ void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
ctx.Wait();
}
/**
* @brief call paddle layer_norm op
*/
template <typename T>
void LayerNorm(const std::vector<LayerNormParamType<T>> &scale,
const std::vector<LayerNormParamType<T>> &bias,
const std::vector<T> &x,
std::vector<LayerNormParamType<T>> *means,
std::vector<LayerNormParamType<T>> *vars, std::vector<T> *y,
const float epsilon, const int rows, const int cols,
const platform::CUDADeviceContext &ctx) {
framework::Scope scope;
auto place = ctx.GetPlace();
if (scale.size() > 0) {
auto var_scale = scope.Var("Scale");
auto tensor_scale = var_scale->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(scale, ctx, tensor_scale);
tensor_scale->Resize({cols});
}
if (bias.size() > 0) {
auto var_bias = scope.Var("Bias");
auto tensor_bias = var_bias->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(bias, ctx, tensor_bias);
tensor_bias->Resize({cols});
}
auto var_x = scope.Var("X");
auto tensor_x = var_x->GetMutable<framework::LoDTensor>();
framework::TensorFromVector(x, ctx, tensor_x);
tensor_x->Resize({rows, cols});
auto var_y = scope.Var("Y");
auto tensor_y = var_y->GetMutable<framework::LoDTensor>();
auto var_mean = scope.Var("Mean");
auto tensor_mean = var_mean->GetMutable<framework::LoDTensor>();
auto var_variance = scope.Var("Variance");
auto tensor_variance = var_variance->GetMutable<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs.insert({"epsilon", epsilon});
auto op = framework::OpRegistry::CreateOp(
"layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}},
{{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs);
op->Run(scope, place);
framework::TensorToVector(*tensor_y, ctx, y);
framework::TensorToVector(*tensor_mean, ctx, means);
framework::TensorToVector(*tensor_variance, ctx, vars);
ctx.Wait();
}
template <typename T>
inline void ReduceSum(const std::vector<T> &dout, std::vector<T> *dbias,
const int rows, const int cols) {
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h"
namespace paddle {
namespace operators {
template <typename T>
using CudnnDataType = platform::CudnnDataType<T>;
template <typename T>
using LayerNormParamType = typename CudnnDataType<T>::BatchNormParamType;
/**
* @brief fused add_bias, dropout, add residual and leyer_norm into one
* operators. Currently only support forward
*/
template <typename T, int VecSize>
__device__ void CalcLayernormY(const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *bias, const T *x,
T *y, const int row_id, const int col_id,
const int cols,
const LayerNormParamType<T> mean_val,
const LayerNormParamType<T> invvar) {
using U = LayerNormParamType<T>;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using LoadU = platform::AlignedVector<U, VecSize>;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
LoadU scale_vec;
LoadU bias_vec;
LoadT x_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
scale_vec[ii] = static_cast<U>(1);
bias_vec[ii] = static_cast<U>(0);
}
// vectorize load data from global
platform::Load<T, VecSize>(&x[row_id * cols + i], &x_vec);
if (scale != nullptr) {
platform::Load<U, VecSize>(&scale[i], &scale_vec);
}
if (bias != nullptr) {
platform::Load<U, VecSize>(&bias[i], &bias_vec);
}
StoreT y_vec;
for (int ii = 0; ii < VecSize; ii++) {
y_vec[ii] = static_cast<T>(
scale_vec[ii] * (static_cast<U>(x_vec[ii]) - mean_val) * invvar +
bias_vec[ii]);
}
platform::Store<T, VecSize>(y_vec, &y[row_id * cols + i]);
}
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T, typename MaskType, int VecSize>
__global__ void FusedLayernormResidualDropoutBias(
const size_t rows, const size_t cols, uint64_t seed,
const float dropout_prob, const bool is_upscale_in_train,
const bool is_test, const uint64_t increment, const float epsilon,
const T *src, const T *residual, const T *bias,
const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *layernorm_bias, MaskType *mask, T *dst,
T *layernorm_dst, LayerNormParamType<T> *mean, LayerNormParamType<T> *var) {
int col_id = threadIdx.x;
int row_id = blockIdx.x;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
using U = LayerNormParamType<T>;
__shared__ U mean_share;
__shared__ U var_share;
__shared__ U shared_mean[32];
__shared__ U shared_var[32];
math::ReluFunctor<T> relu;
U mean_val = 0;
U var_val = 0;
for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, true, false,
math::ReluFunctor<T>>(
row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, &mean_val, &var_val, relu);
}
mean_val = BlockReduceSum<U>(mean_val, shared_mean);
var_val = BlockReduceSum<U>(var_val, shared_var);
if (threadIdx.x == 0) {
auto scale = static_cast<float>(1.) / static_cast<float>(cols);
auto tmp = mean_val * scale;
mean[row_id] = mean_share = static_cast<U>(tmp);
var_share = static_cast<U>(var_val * scale - mean_share * mean_share);
var_share = var_share > U(0) ? var_share : U(0);
var[row_id] = var_share;
}
__syncthreads();
mean_val = mean_share;
U invvar = rsqrt_<U>(var_share + static_cast<U>(epsilon));
// calculate layernorm_dst
CalcLayernormY<T, VecSize>(scale, layernorm_bias, dst, layernorm_dst, row_id,
col_id, cols, mean_val, invvar);
}
/**
* @brief layernorm(residual + dropout(src + bias));
* @param
* rows: batch_size * seq_len
* cols: feature_size or hidden_size
* src: [rows, cols], inputs
* bias: [cols], linear bias, can be null
* residual:[rows, cols]
* mask: [rows, cols], dropout result, can be null if is_test = true
* dst: [rows, cols], residual + dropout(src+bias)
* layernorm_dst: [rows, cols], layernorm result
* layernorm_bias: [cols], layernorm bias, can be null
* scale: [cols]: layernorm scale, can be null
* means: [rows]: layernorm means
* vars: [rows]: layernorm vars
*/
template <typename T, typename MaskType>
void LaunchLayernormResidualDropoutBias(
const uint32_t rows, const uint32_t cols, const int increment,
uint64_t seed, const float dropout_prob, const float epsilon,
const bool is_upscale_in_train, const bool is_test, const T *src,
const T *residual, const T *bias, const LayerNormParamType<T> *scale,
const LayerNormParamType<T> *layernorm_bias, MaskType *mask_data, T *dst,
T *layernorm_dst, LayerNormParamType<T> *mean, LayerNormParamType<T> *var,
const platform::CUDADeviceContext &ctx) {
using U = LayerNormParamType<T>;
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace());
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T),
ctx.stream());
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
// call layernorm forward
switch (GetDesiredBlockDim(cols)) {
FIXED_BLOCK_DIM_CASE(
LayerNormForward<T, U,
kBlockDim><<<rows, kBlockDim, 0, ctx.stream()>>>(
dst, scale, layernorm_bias, layernorm_dst, mean, var, epsilon,
cols));
default:
PADDLE_THROW(platform::errors::InvalidArgument(
"Product from begin_norm_axis to end must be larger than 1"));
break;
}
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % VecSize != 0) {
int blockDim = GetDesiredBlockDim(cols);
FusedLayernormResidualDropoutBias<T, uint8_t,
1><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
} else {
int blockDim = GetDesiredBlockDim(cols / VecSize);
FusedLayernormResidualDropoutBias<
T, uint8_t, VecSize><<<rows, blockDim, 0, ctx.stream()>>>(
rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment,
epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst,
layernorm_dst, mean, var);
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <random>
#include <vector>
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
/**
* @brief The unit test of fused_layernorm_residual_dropout_bias
*/
template <typename T>
struct TestFusedLayernormResidualDropoutBias {
uint32_t rows;
uint32_t cols;
uint64_t seed;
float dropout_prob, epsilon;
bool is_upscale_in_train;
bool is_test; // default false, Set to true for inference only
bool has_bias = true;
bool has_scale = true;
bool has_layernorm_bias = true;
framework::Tensor src, residual, bias, out, mask, scale, layernorm_bias,
layernorm_out, means, vars;
framework::Tensor dsrc, dbias;
std::vector<T> src_vec, residual_vec, bias_vec;
std::vector<LayerNormParamType<T>> means_vec, vars_vec, scale_vec,
layernorm_bias_vec;
std::vector<T> correct_out, correct_dsrc, correct_dbias,
correct_layernorm_out;
std::vector<LayerNormParamType<T>> correct_means, correct_vars;
std::vector<uint8_t> correct_mask;
platform::CUDAPlace place;
platform::CUDADeviceContext *ctx;
TestFusedLayernormResidualDropoutBias() {
rows = 32;
cols = 32;
seed = 0;
dropout_prob = 0.0;
is_upscale_in_train = false;
is_test = false;
has_bias = true;
has_scale = true;
has_layernorm_bias = true;
epsilon = 0.00001f;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto devicectx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}
TestFusedLayernormResidualDropoutBias(int _rows, int _cols,
uint64_t _seed = 0,
float _dropout_prob = 0.0,
float _epsilon = 0.00001f,
bool _is_upscale_in_train = false,
bool _is_test = false) {
rows = _rows;
cols = _cols;
seed = _seed;
dropout_prob = _dropout_prob;
epsilon = _epsilon;
is_upscale_in_train = _is_upscale_in_train;
is_test = _is_test;
has_bias = true;
has_scale = true;
has_layernorm_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto devicectx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}
~TestFusedLayernormResidualDropoutBias() {}
void SetUp() {
using U = LayerNormParamType<T>;
const int n = rows * cols;
correct_out.resize(n);
correct_mask.resize(n);
correct_dsrc.resize(n);
correct_dbias.resize(cols);
correct_means.resize(rows);
correct_vars.resize(rows);
correct_layernorm_out.resize(n);
src_vec.resize(n);
residual_vec.resize(n);
if (has_bias) {
bias_vec.resize(cols);
}
if (has_scale) {
scale_vec.resize(cols);
}
if (has_layernorm_bias) {
layernorm_bias_vec.resize(cols);
}
std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
src_vec[i * cols + j] = static_cast<T>(dis(random));
residual_vec[i * cols + j] = static_cast<T>(dis(random));
if (i == 0) {
if (has_bias) {
bias_vec[j] = static_cast<T>(dis(random));
}
if (has_scale) {
scale_vec[j] = static_cast<U>(dis(random));
}
if (has_layernorm_bias) {
layernorm_bias_vec[j] = static_cast<U>(dis(random));
}
}
}
}
framework::TensorFromVector<T>(src_vec, *ctx, &src);
src.Resize({rows, cols});
framework::TensorFromVector<T>(residual_vec, *ctx, &residual);
residual.Resize({rows, cols});
if (has_bias) {
framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
bias.Resize({cols});
}
if (has_scale) {
framework::TensorFromVector<U>(scale_vec, *ctx, &scale);
scale.Resize({cols});
}
if (has_layernorm_bias) {
framework::TensorFromVector<U>(layernorm_bias_vec, *ctx, &layernorm_bias);
layernorm_bias.Resize({cols});
}
{
out.Resize({rows, cols});
out.mutable_data<T>(place);
mask.Resize({rows, cols});
mask.mutable_data<uint8_t>(place);
means.Resize({rows});
means.mutable_data<U>(place);
vars.Resize({rows});
vars.mutable_data<U>(place);
layernorm_out.Resize({rows, cols});
layernorm_out.mutable_data<T>(place);
dsrc.Resize({rows, cols});
dsrc.mutable_data<T>(place);
if (has_bias) {
dbias.Resize({cols});
dbias.mutable_data<T>(place);
}
}
}
void BaseForward() {
using U = LayerNormParamType<T>;
std::vector<T> out1(rows * cols), out2(rows * cols);
if (has_bias) {
// add bias
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j];
}
}
// call dropout
Dropout<T>(out1, src.dims(), &out2, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test);
} else {
Dropout<T>(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test);
}
// add residual
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
correct_out[i * cols + j] =
residual_vec[i * cols + j] + out2[i * cols + j];
}
}
LayerNorm<T>(scale_vec, layernorm_bias_vec, correct_out, &correct_means,
&correct_vars, &correct_layernorm_out, epsilon, rows, cols,
*ctx);
ctx->Wait();
}
void FusedForward() {
using U = LayerNormParamType<T>;
int VecSize = MAX_CACHE_BYTES / sizeof(T);
if (cols % 4 != 0) {
VecSize = 1;
}
int threads = paddle::operators::GetDesiredBlockDim(cols / VecSize);
const int increment = ((cols - 1) / (threads * VecSize) + 1) * VecSize;
T *bias_ptr = nullptr;
U *scale_ptr = nullptr;
U *layernorm_bias_ptr = nullptr;
if (has_bias) {
bias_ptr = bias.data<T>();
}
if (has_scale) {
scale_ptr = scale.data<U>();
}
if (has_layernorm_bias) {
layernorm_bias_ptr = layernorm_bias.data<U>();
}
paddle::operators::LaunchLayernormResidualDropoutBias<T, uint8_t>(
rows, cols, increment, seed, dropout_prob, epsilon, is_upscale_in_train,
is_test, src.data<T>(), residual.data<T>(), bias_ptr, scale_ptr,
layernorm_bias_ptr, mask.data<uint8_t>(), out.data<T>(),
layernorm_out.data<T>(), means.data<U>(), vars.data<U>(), *ctx);
ctx->Wait();
}
void Run() {
SetUp();
BaseForward();
FusedForward();
}
void CheckOut(const T diff) {
using U = LayerNormParamType<T>;
const int n = rows * cols;
std::vector<T> _out(n), _layernorm_out(n);
std::vector<U> _means(rows), _vars(cols);
std::vector<uint8_t> _mask(n);
framework::TensorToVector(out, *ctx, &_out);
framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out);
framework::TensorToVector(means, *ctx, &_means);
framework::TensorToVector(vars, *ctx, &_vars);
if (!is_test) {
framework::TensorToVector(mask, *ctx, &_mask);
}
ctx->Wait();
for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff);
EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]);
}
for (int i = 0; i < rows; i++) {
EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast<U>(diff));
EXPECT_LT(std::abs(_vars[i] - correct_vars[i]), static_cast<U>(diff));
}
}
};
template <typename T>
static void BaseTest(const bool is_fp16 = false) {
const int rows = 16;
T default_diff = !is_fp16 ? static_cast<T>(1e-4) : static_cast<T>(1e-2);
for (auto cols : {16, 17}) {
for (auto has_bias : {true, false}) {
for (auto has_scale : {true, false}) {
for (auto has_layernorm_bias : {true, false}) {
TestFusedLayernormResidualDropoutBias<T> test(rows, cols);
test.has_bias = has_bias;
test.has_scale = has_scale;
test.has_layernorm_bias = has_layernorm_bias;
test.Run();
test.CheckOut(default_diff);
}
}
}
}
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBias) { BaseTest<float>(); }
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasDouble) {
BaseTest<double>();
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasFp16) {
BaseTest<platform::float16>(true);
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasIsUpscaleInTrain) {
const int rows = 16;
const int cols = 16;
for (auto is_upscale_in_train : {true, false}) {
TestFusedLayernormResidualDropoutBias<float> test(
rows, cols, 0, 1.0, 0.00001f, is_upscale_in_train, false);
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasIsTest) {
const int rows = 16;
const int cols = 16;
TestFusedLayernormResidualDropoutBias<float> test(rows, cols, 0, 0.35,
0.00001f, true, true);
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasSeed) {
const int rows = 16;
const int cols = 16;
TestFusedLayernormResidualDropoutBias<float> test(rows, cols, 125, 0.0,
0.00001f, false, false);
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}
TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) {
const int rows = 512;
const int cols = 512;
TestFusedLayernormResidualDropoutBias<float> test(rows, cols);
test.Run();
test.CheckOut(static_cast<float>(1e-4));
}
......@@ -23,14 +23,15 @@ namespace operators {
* @brief The fused function called by every thread
* VecSize can be 1, 2, 4 or 8
*/
template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm>
template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm,
bool Activation, typename Functor>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id, const int col_id, const int cols,
curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
const T *__restrict__ src, const T *__restrict__ residual,
const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val) {
typename details::MPTypeTrait<T>::Type *var_val, Functor act_func) {
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
......@@ -42,10 +43,14 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
residual_vec[ii] = static_cast<T>(0);
}
// vectorize load data from global
platform::Load<T, VecSize>(&src[row_id * cols + col_id], &src_vec);
platform::Load<T, VecSize>(&residual[row_id * cols + col_id], &residual_vec);
if (residual) {
platform::Load<T, VecSize>(&residual[row_id * cols + col_id],
&residual_vec);
}
if (bias) {
platform::Load<T, VecSize>(&bias[col_id], &bias_vec);
......@@ -70,9 +75,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T tmp = src_vec[ii] + bias_vec[ii];
if (Activation) {
tmp = act_func(tmp);
}
dest_vec[ii] =
(src_vec[ii] + bias_vec[ii]) * static_cast<T>(mask_vec[ii]) * factor +
residual_vec[ii];
tmp * static_cast<T>(mask_vec[ii]) * factor + residual_vec[ii];
if (ComputeLayerNorm) {
U tmp = static_cast<U>(dest_vec[ii]);
*mean_val += tmp;
......@@ -106,19 +114,15 @@ __global__ void FusedResidualDropoutBias(
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
: static_cast<T>(1.0f);
if (is_test) {
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
const T factor = GetFactor<T>(dropout_prob, is_upscale_in_train, is_test);
math::ReluFunctor<T> relu;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false>(
FusedResidualDropoutBiasOneThread<T, MaskType, VecSize, false, false,
math::ReluFunctor<T>>(
r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst,
mask, is_test, nullptr, nullptr);
mask, is_test, nullptr, nullptr, relu);
}
}
}
......
......@@ -165,6 +165,7 @@ struct TestFusedResidualDropoutBias {
auto config = paddle::operators::Get1DBlocksAnd2DGrids(
*ctx, static_cast<uint64_t>(rows), static_cast<uint64_t>(cols),
VecSize);
const int increment = ((cols - 1) / (config.thread_per_block.x *
config.block_per_grid.x * VecSize) +
1) *
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册