未验证 提交 cee70434 编写于 作者: Z zhangkaihuo 提交者: GitHub

add a fusion op: fused_dropout_act_bias (#35129)

上级 bab39eb2
......@@ -75,5 +75,6 @@ if (WITH_GPU OR WITH_ROCM)
# only support CUDA
if(NOT WITH_ROCM)
nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory)
endif()
endif()
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifndef _USE_MATH_DEFINES
#define _USE_MATH_DEFINES
#endif
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/math/functors.h"
namespace paddle {
namespace operators {
/**
*@brief the gelu functor
*/
template <typename T>
struct GeluFunctor {
inline __host__ __device__ T operator()(const T x) const {
using U = LayerNormParamType<T>;
const U casted_x = static_cast<U>(x);
const U temp = erf(casted_x * static_cast<U>(M_SQRT1_2));
const U out = (casted_x * static_cast<U>(0.5) * (static_cast<U>(1) + temp));
return static_cast<T>(out);
}
};
/**
*@brief the gelu grad functor
*/
template <typename T>
struct GeluGradFunctor {
inline __host__ __device__ T UseOut(const T x) const {
using U = LayerNormParamType<T>;
auto casted_x = static_cast<U>(x);
auto first =
static_cast<U>(0.5) *
(static_cast<U>(1) + erf(casted_x * static_cast<U>(M_SQRT1_2)));
auto second = static_cast<U>(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x *
exp(-static_cast<U>(0.5) * casted_x * casted_x);
return static_cast<T>((first + second));
}
};
/**
* @brief dst = dropout(activation(src + bias));
* the src, mask and dst shape is (rows, cols)
* the bias shape is (1, cols)
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActBias(
Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols,
const int increment, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const T *__restrict__ src, const T *__restrict__ bias, T *dst,
MaskType *mask) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0);
}
if (is_test) {
factor = static_cast<T>(1.0f - dropout_prob);
if (is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
}
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
using MaskStoreT = platform::AlignedVector<MaskType, VecSize>;
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
i += blockDim.x * gridDim.x * VecSize) {
LoadT src_vec;
LoadT bias_vec;
// vectorize load data from global
platform::Load<T, VecSize>(&src[r * cols + i], &src_vec);
if (bias) {
platform::Load<T, VecSize>(&bias[i], &bias_vec);
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
bias_vec[ii] = static_cast<T>(0);
}
}
MaskStoreT mask_vec;
if (!is_test) {
float rand[VecSize];
RandVec<VecSize>(&state, rand);
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
}
} else {
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(1);
}
}
StoreT dest_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
const T tmp = src_vec[ii] + bias_vec[ii];
const T act_out = act(tmp);
dest_vec[ii] = act_out * static_cast<T>(mask_vec[ii]) * factor;
}
// store result to global
platform::Store<T, VecSize>(dest_vec, &dst[r * cols + i]);
if (!is_test) {
platform::Store<MaskType, VecSize>(mask_vec, &mask[r * cols + i]);
}
}
}
}
/**
* @brief dst = dropout(activation(src + bias));
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBias(Functor act_functor, const uint64_t seed,
const uint32_t rows, const uint32_t cols,
const int increment, const float dropout_prob,
const bool is_upscale_in_train, const bool is_test,
const T *src, const T *bias, T *dst,
MaskType *mask_data,
const platform::CUDADeviceContext &ctx) {
// dropout_prob == 1.0f
if (std::abs(dropout_prob - 1.0f) < 1e-5) {
SetZero<T>(ctx, dst, rows * cols);
SetZero<MaskType>(ctx, mask_data, rows * cols);
return;
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
const int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size);
if (cols % VecSize == 0) {
FusedDropoutActBias<T, MaskType, VecSize, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, seed, rows, cols, increment, dropout_prob,
is_upscale_in_train, is_test, src, bias, dst, mask_data);
} else {
FusedDropoutActBias<T, MaskType, 1, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, seed, rows, cols, increment, dropout_prob,
is_upscale_in_train, is_test, src, bias, dst, mask_data);
}
}
/*
* @brief calculate the grad of no bias
*/
template <typename T, typename MaskType, int VecSize, typename Functor>
__global__ void FusedDropoutActGrad(Functor act_grad, const T *dout,
const MaskType *mask, const T *src,
const T factor, const int64_t size, T *dx) {
int64_t idx = blockDim.x * blockIdx.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) {
LoadT dout_vec;
LoadT src_vec;
MaskLoadT mask_vec;
platform::Load<T, VecSize>(&dout[i], &dout_vec);
platform::Load<MaskType, VecSize>(&mask[i], &mask_vec);
platform::Load<T, VecSize>(&src[i], &src_vec);
StoreT dx_vec;
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
T args[2];
args[0] = dout_vec[ii] * static_cast<T>(mask_vec[ii]) * factor;
args[1] = src_vec[ii];
dx_vec[ii] = args[0] * act_grad.UseOut(args[1]);
}
platform::Store<T, VecSize>(dx_vec, &dx[i]);
}
}
/**
* blocks(128 * 8)
* 1. calculate the dx and reduce total rows to 128 rows
* 2. save 128*8 temporary sum in 8*128 shared memory
* 3. reduce the sum of 128 cols data by 8*VecSize warps
*/
template <typename T, typename MaskType, int BlockSizeX, int BlockSizeY,
int VecSize, typename Functor>
__global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout,
const MaskType *mask, const T *src,
const T *bias, const T factor,
const int64_t rows, const int64_t cols,
T *dx, T *dbias) {
int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x;
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
using MaskLoadT = platform::AlignedVector<MaskType, VecSize>;
T tmp_sum[VecSize] = {static_cast<T>(0)};
// calculate the dx and temporary sum
if (col_id * VecSize < cols) {
for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) {
int index = row_id * cols + col_id * VecSize;
LoadT dout_vec;
LoadT src_vec;
LoadT bias_vec;
MaskLoadT mask_vec;
platform::Load<T, VecSize>(&dout[index], &dout_vec);
platform::Load<T, VecSize>(&src[index], &src_vec);
platform::Load<MaskType, VecSize>(&mask[index], &mask_vec);
platform::Load<T, VecSize>(&bias[col_id * VecSize], &bias_vec);
StoreT dx_vec;
#pragma unroll
for (int i = 0; i < VecSize; i++) {
T val;
T args[2];
args[0] = dout_vec[i] * static_cast<T>(mask_vec[i]) * factor;
args[1] = src_vec[i] + bias_vec[i];
val = args[0] * act_grad.UseOut(args[1]);
dx_vec[i] = val;
tmp_sum[i] += val;
}
platform::Store<T, VecSize>(dx_vec, &dx[index]);
}
}
CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}
/**
* @brief to launch kernel FusedResidualDropoutBiasGradVec
*/
template <typename T, typename MaskType, typename Functor>
void LaunchDropoutActBiasGrad(Functor act_functor, const T *dout,
const MaskType *mask, const T *src, const T *bias,
const float dropout_prob,
const bool is_upscale_in_train,
const uint32_t rows, const uint32_t cols, T *dx,
T *dbias,
const platform::CUDADeviceContext &ctx) {
const T zero = static_cast<T>(0.0);
auto factor = dropout_prob == static_cast<float>(1.0f)
? zero
: static_cast<T>(1.0 / (1.0 - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
const auto threads = 8;
const auto blocks =
std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
FusedDropoutActBiasGrad<
T, MaskType, 8, 128, VecSize,
Functor><<<grid_dim, block_dim, 0, ctx.stream()>>>(
act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias);
} else {
FusedDropoutActBiasGrad<
T, MaskType, 8, 128, 1,
Functor><<<grid_dim, block_dim, 0, ctx.stream()>>>(
act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias);
}
} else {
const uint64_t n = rows * cols;
platform::GpuLaunchConfig config =
platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size);
if (n % VecSize == 0) {
FusedDropoutActGrad<T, MaskType, VecSize, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
} else {
FusedDropoutActGrad<T, MaskType, 1, Functor><<<
config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>(
act_functor, dout, mask, src, factor, n, dx);
}
}
}
} // namespace operators
} // namespace paddle
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <time.h>
#include <random>
#include <vector>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/fluid/operators/fused/fused_dropout_test.h"
#include "paddle/fluid/operators/math/functors.h"
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace details = paddle::operators::details;
namespace math = paddle::operators::math;
/**
* @brief the unittest of fused_dropout_act_bias
* 1. random input data
* 2. add bias, call activation, call paddle dropout, and get the base result
* 3. call FusedDropoutActBias function get fused result
* 4. compare ther base result and fused result
*/
template <typename T, typename Functor, typename GradFunctor>
struct TestFusedDropoutActBias {
uint32_t rows;
uint32_t cols;
uint64_t seed;
float dropout_prob;
bool is_upscale_in_train;
bool is_test; // default false, Set to true for inference only
bool has_bias = true;
framework::Tensor src, bias, out, mask;
framework::Tensor dsrc, dbias;
std::vector<T> src_vec, bias_vec, out_vec, mask_vec;
std::vector<T> correct_out, correct_dsrc, correct_dbias;
std::vector<uint8_t> correct_mask;
platform::CUDAPlace place;
platform::CUDADeviceContext *ctx;
TestFusedDropoutActBias() {
rows = 32;
cols = 32;
seed = 0;
dropout_prob = 0.0;
is_upscale_in_train = false;
is_test = false;
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto devicectx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}
TestFusedDropoutActBias(int rows_, int cols_, uint64_t seed_ = 0,
float dropout_prob_ = 0.0,
bool is_upscale_in_train_ = false,
bool is_test_ = false) {
rows = rows_;
cols = cols_;
seed = seed_;
dropout_prob = dropout_prob_;
is_upscale_in_train = is_upscale_in_train_;
is_test = is_test_;
has_bias = true;
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto devicectx = pool.Get(place);
ctx = reinterpret_cast<platform::CUDADeviceContext *>(devicectx);
}
~TestFusedDropoutActBias() {}
void SetUp() {
const int n = rows * cols;
correct_out.resize(n);
correct_mask.resize(n);
correct_dsrc.resize(n);
correct_dbias.resize(cols);
src_vec.resize(n);
bias_vec.resize(cols);
std::default_random_engine random(time(NULL));
std::uniform_real_distribution<float> dis(0.0, 1.0);
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
src_vec[i * cols + j] = static_cast<T>(dis(random));
if (i == 0) bias_vec[j] = dis(random);
}
}
framework::TensorFromVector<T>(src_vec, *ctx, &src);
src.Resize({rows, cols});
if (has_bias) {
framework::TensorFromVector<T>(bias_vec, *ctx, &bias);
bias.Resize({cols});
}
{
out.mutable_data<T>({rows, cols}, place);
mask.mutable_data<uint8_t>({rows, cols}, place);
dsrc.mutable_data<T>({rows, cols}, place);
if (has_bias) {
dbias.mutable_data<T>({cols}, place);
}
}
}
void BaseForward() {
std::vector<T> out1(rows * cols);
Functor act;
if (has_bias) {
// add bias and call activation
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
const T tmp = src_vec[i * cols + j] + bias_vec[j];
out1[i * cols + j] = act(tmp);
}
}
// call dropout
Dropout<T>(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test);
} else {
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
const T tmp = src_vec[i * cols + j];
out1[i * cols + j] = act(tmp);
}
}
Dropout<T>(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed,
dropout_prob, is_upscale_in_train, is_test);
}
ctx->Wait();
}
void BaseBackward() {
std::vector<T> _out(rows * cols);
// call dropout_grad
DropoutGrad<T>(&_out, src.dims(), correct_out, correct_mask, *ctx,
dropout_prob, is_upscale_in_train);
// calculate dbias
memset(&correct_dbias[0], 0, cols * sizeof(T));
GradFunctor act_grad;
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
T args[2];
args[0] = _out[i * cols + j];
if (has_bias) {
args[1] = src_vec[i * cols + j] + bias_vec[j];
} else {
args[1] = src_vec[i * cols + j];
}
T val = args[0] * act_grad.UseOut(args[1]);
correct_dsrc[i * cols + j] = val;
}
}
if (has_bias) {
// reduce_sum: keep the same calculate order as the GPU
ReduceSum<T>(correct_dsrc, &correct_dbias, rows, cols);
}
}
void FusedForward() {
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
auto config = paddle::operators::Get1DBlocksAnd2DGrids(
*ctx, static_cast<uint64_t>(rows), static_cast<uint64_t>(cols),
VecSize);
const int increment = ((cols - 1) / (config.thread_per_block.x *
config.block_per_grid.x * VecSize) +
1) *
VecSize;
T *bias_ptr = nullptr;
if (has_bias) {
bias_ptr = bias.data<T>();
}
Functor act;
paddle::operators::LaunchDropoutActBias<T, uint8_t, Functor>(
act, seed, rows, cols, increment, dropout_prob, is_upscale_in_train,
is_test, src.data<T>(), bias_ptr, out.data<T>(), mask.data<uint8_t>(),
*ctx);
ctx->Wait();
}
void FusedBackward() {
if (is_test) return;
T *bias_ptr = nullptr;
T *dbias_ptr = nullptr;
if (has_bias) {
dbias_ptr = dbias.data<T>();
bias_ptr = bias.data<T>();
}
GradFunctor act_grad;
paddle::operators::LaunchDropoutActBiasGrad<T, uint8_t, GradFunctor>(
act_grad, out.data<T>(), mask.data<uint8_t>(), src.data<T>(), bias_ptr,
dropout_prob, is_upscale_in_train, rows, cols, dsrc.data<T>(),
dbias_ptr, *ctx);
}
void Run() {
SetUp();
BaseForward();
FusedForward();
BaseBackward();
FusedBackward();
}
void CheckOut(const T diff) {
const int n = rows * cols;
std::vector<T> _out(n);
std::vector<uint8_t> _mask(n);
framework::TensorToVector(out, *ctx, &_out);
if (!is_test) {
framework::TensorToVector<uint8_t>(mask, *ctx, &_mask);
}
ctx->Wait();
for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff);
if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]);
}
}
void CheckGrad(const T diff) {
if (is_test) return;
const int n = rows * cols;
std::vector<T> _dsrc(n);
framework::TensorToVector(dsrc, *ctx, &_dsrc);
for (int i = 0; i < n; i++) {
EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff);
}
if (has_bias) {
std::vector<T> _dbias(cols);
framework::TensorToVector(dbias, *ctx, &_dbias);
ctx->Wait();
for (int i = 0; i < cols; i++) {
EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff);
}
}
}
};
// test the shape , bias, activation
template <typename T, typename Functor, typename GradFunctor>
static void BaseTest(const bool is_fp16 = false) {
const int rows = 16;
std::vector<int> cols_list = {16, 17};
bool has_bias[2] = {true, false};
T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1);
for (auto cols : {16, 17}) {
for (auto has_bias : {true, false}) {
TestFusedDropoutActBias<T, Functor, GradFunctor> test(rows, cols);
test.has_bias = has_bias;
test.Run();
test.CheckOut(default_diff);
test.CheckGrad(default_diff);
}
}
}
TEST(FusedDropout, GPUFusedDorpoutActBias) {
BaseTest<float, math::ReluFunctor<float>, math::ReluGradFunctor<float>>();
BaseTest<float, paddle::operators::GeluFunctor<float>,
paddle::operators::GeluGradFunctor<float>>();
}
TEST(FusedDropout, GPUFusedDropoutActBiasDouble) {
BaseTest<double, math::ReluFunctor<double>, math::ReluGradFunctor<double>>();
BaseTest<double, paddle::operators::GeluFunctor<double>,
paddle::operators::GeluGradFunctor<double>>();
}
// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py
TEST(FusedDropout, GPUFusedDropoutActBiasFp16) {
using fp16 = platform::float16;
BaseTest<fp16, math::ReluFunctor<fp16>, math::ReluGradFunctor<fp16>>(true);
}
TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) {
const int rows = 16;
const int cols = 16;
for (auto is_upscale_in_train : {true, false}) {
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
test(rows, cols, 0, 1.0, is_upscale_in_train, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3));
}
}
TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) {
const int rows = 16;
const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
test(rows, cols, 0, 0.35, true, true);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3));
}
TEST(FusedDropout, GPUFusedDropoutActBiasSeed) {
const int rows = 16;
const int cols = 16;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
test(rows, cols, 125, 0.0, false, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3));
}
TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) {
const int rows = 256;
const int cols = 4096;
TestFusedDropoutActBias<float, math::ReluFunctor<float>,
math::ReluGradFunctor<float>>
test(rows, cols);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-3));
}
......@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/platform/aligned_vector.h"
#include "paddle/fluid/platform/cuda_device_function.h"
#include "paddle/fluid/platform/device_context.h"
......@@ -39,8 +40,8 @@ namespace operators {
*/
inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
const platform::CUDADeviceContext &ctx, const uint32_t rows,
const uint32_t cols, const int VecSize) {
const uint32_t tmp_cols = cols / VecSize;
const uint32_t cols, const int vec_size) {
const uint32_t tmp_cols = cols / vec_size;
int threads = std::max(
static_cast<uint32_t>(32),
std::min(tmp_cols, static_cast<uint32_t>(ctx.GetMaxThreadsPerBlock())));
......@@ -54,19 +55,26 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids(
return config;
}
__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state,
float *data) {
template <int VecSize>
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
float *data);
template <>
__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
}
__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state,
float *data) {
template <>
__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state,
float *data) {
data[0] = curand_uniform(state);
data[1] = curand_uniform(state);
}
__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state,
float *data) {
template <>
__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state,
float *data) {
float4 rand4 = curand_uniform4(state);
data[0] = rand4.x;
data[1] = rand4.y;
......@@ -74,24 +82,54 @@ __forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state,
data[3] = rand4.z;
}
__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state,
float *data) {
Rand4(state, data);
Rand4(state, data + 4);
template <>
__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state,
float *data) {
RandVec<4>(state, data);
RandVec<4>(state, data + 4);
}
__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state,
float *data, const int VecSize) {
if (VecSize == 1) {
Rand1(state, data);
} else if (VecSize == 2) {
Rand2(state, data);
} else if (VecSize == 4) {
Rand4(state, data);
} else if (VecSize == 8) {
Rand8(state, data);
} else {
return;
template <typename T>
inline void SetZero(const platform::CUDADeviceContext &ctx, T *ptr,
const size_t size) {
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream()));
}
/**
* reduce the sum of 128 cols data by 8*VecSize warps
*/
template <typename T, int VecSize, int BlockSizeX, int BlockSizeY>
inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias,
const int cols) {
// save temporary sum to cache and do transpose
__shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
for (int i = 0; i < VecSize; i++) {
cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
}
__syncthreads();
// reduce sum
T sum = static_cast<T>(0);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 5; // warp id
int y = tid & 31; // thread id on warp 0~31
// need BlockSizeX * VecSize warps
if (x < BlockSizeX * VecSize) {
// reduce 128 to 32
#pragma unroll
for (int i = 0; i < (BlockSizeY >> 5); i++) {
sum += cache[x][y + i * 32];
}
}
// reduce 32 to 1
sum = WarpReduceSum(sum);
// save sum to dbias
int bias_id = blockIdx.x * blockDim.x * VecSize + x;
if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) {
dbias[bias_id] = sum;
}
}
......
......@@ -115,3 +115,22 @@ void DropoutGrad(std::vector<T> *dx, const framework::DDim &x_dim,
framework::TensorToVector(*tensor_dx, ctx, dx);
ctx.Wait();
}
template <typename T>
inline void ReduceSum(const std::vector<T> &dout, std::vector<T> *dbias,
const int rows, const int cols) {
for (int j = 0; j < cols; j++) {
std::vector<T> tmp_dbias(rows);
for (int i = 0; i < rows; i++) {
tmp_dbias[i] = dout[i * cols + j];
}
int tmp_rows = rows / 2;
while (tmp_rows) {
for (int i = 0; i < tmp_rows; i++) {
tmp_dbias[i] += tmp_dbias[i + tmp_rows];
}
tmp_rows /= 2;
}
(*dbias)[j] = tmp_dbias[0];
}
}
......@@ -15,7 +15,6 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
namespace paddle {
namespace operators {
......@@ -28,8 +27,9 @@ template <typename T, typename MaskType, int VecSize, bool ComputeLayerNorm>
__forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
const int row_id, const int col_id, const int cols,
curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor,
const T *src, const T *residual, const T *bias, T *dst, MaskType *mask,
const bool is_test, typename details::MPTypeTrait<T>::Type *mean_val,
const T *__restrict__ src, const T *__restrict__ residual,
const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val) {
using LoadT = platform::AlignedVector<T, VecSize>;
using StoreT = platform::AlignedVector<T, VecSize>;
......@@ -54,7 +54,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
MaskStoreT mask_vec;
if (!is_test) {
float rand[VecSize];
RandVec(state, rand, VecSize);
RandVec<VecSize>(state, rand);
#pragma unroll
for (int ii = 0; ii < VecSize; ii++) {
mask_vec[ii] = static_cast<MaskType>(rand[ii] >= dropout_prob);
......@@ -97,24 +97,21 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
template <typename T, typename MaskType, int VecSize>
__global__ void FusedResidualDropoutBias(
const size_t rows, const size_t cols, uint64_t seed,
const float dropout_prob, const bool is_upscale_in_train, const T *src,
const T *residual, const T *bias, MaskType *mask, T *dst,
uint64_t increment, const bool is_test) {
const float dropout_prob, const bool is_upscale_in_train,
const T *__restrict__ src, const T *__restrict__ residual,
const T *__restrict__ bias, MaskType *mask, T *dst, uint64_t increment,
const bool is_test) {
int col_id = blockDim.x * blockIdx.x + threadIdx.x;
int row_id = blockIdx.y;
int idx = row_id * cols + col_id;
curandStatePhilox4_32_10_t state;
curand_init(seed, idx, increment, &state);
T factor = static_cast<T>(1.0f / (1.0f - dropout_prob));
if (!is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
T factor = is_upscale_in_train ? static_cast<T>(1.0f / (1.0f - dropout_prob))
: static_cast<T>(1.0f);
if (is_test) {
factor = static_cast<T>(1.0f - dropout_prob);
if (is_upscale_in_train) {
factor = static_cast<T>(1.0f);
}
factor = is_upscale_in_train ? static_cast<T>(1.0f)
: static_cast<T>(1.0f - dropout_prob);
}
for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) {
for (int i = col_id * VecSize; i < cols;
......@@ -144,8 +141,7 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols,
memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T),
ctx.stream());
if (!is_test) {
PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync(
mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream()));
SetZero<MaskType>(ctx, mask_data, rows * cols);
}
return;
}
......@@ -234,36 +230,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout,
}
}
// save temporary sum to cache and do transpose
__shared__ T cache[BlockSizeX * VecSize][BlockSizeY];
for (int i = 0; i < VecSize; i++) {
cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i];
}
__syncthreads();
// reduce sum
T sum = static_cast<T>(0);
int tid = threadIdx.y * blockDim.x + threadIdx.x;
int x = tid >> 5; // warp id
int y = tid & 31; // thread id on warp 0~31
// need BlockSizeX * VecSize warps
if (x < BlockSizeX * VecSize) {
// reduce 128 to 32
#pragma unroll
for (int i = 0; i < (BlockSizeY >> 5); i++) {
sum += cache[x][y + i * 32];
}
}
// reduce 32 to 1
sum = WarpReduceSum(sum);
// save sum to dbias
int bias_id = blockIdx.x * blockDim.x * VecSize + x;
if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) {
dbias[bias_id] = sum;
}
CalculateDBias<T, VecSize, BlockSizeX, BlockSizeY>(tmp_sum, dbias, cols);
}
/**
......@@ -287,9 +254,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask,
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
int real_vec_size = cols % VecSize == 0 ? VecSize : 1;
if (dbias != nullptr) {
auto threads = std::min(cols / real_vec_size, static_cast<uint32_t>(8));
auto blocks =
std::max((uint32_t)1, (cols / real_vec_size + threads - 1) / threads);
const auto threads = 8;
auto blocks = std::max(static_cast<uint32_t>(1),
(cols / real_vec_size + threads - 1) / threads);
dim3 block_dim(threads, 128, 1);
dim3 grid_dim(blocks, 1, 1);
if (cols % VecSize == 0) {
......
......@@ -114,16 +114,12 @@ struct TestFusedResidualDropoutBias {
}
{
out.Resize({rows, cols});
out.mutable_data<T>(place);
mask.Resize({rows, cols});
mask.mutable_data<uint8_t>(place);
dsrc.Resize({rows, cols});
dsrc.mutable_data<T>(place);
out.mutable_data<T>({rows, cols}, place);
mask.mutable_data<uint8_t>({rows, cols}, place);
dsrc.mutable_data<T>({rows, cols}, place);
if (has_bias) {
dbias.Resize({cols});
dbias.mutable_data<T>(place);
dbias.mutable_data<T>({cols}, place);
}
}
}
......@@ -159,17 +155,16 @@ struct TestFusedResidualDropoutBias {
dropout_prob, is_upscale_in_train);
// calc dbias
memset(&correct_dbias[0], 0, cols * sizeof(T));
for (int i = 0; i < rows; i++) {
for (int j = 0; j < cols; j++) {
correct_dbias[j] += correct_out[i * cols + j];
}
if (has_bias) {
ReduceSum<T>(correct_out, &correct_dbias, rows, cols);
}
}
void FusedForward() {
const int VecSize = MAX_CACHE_BYTES / sizeof(T);
auto config = paddle::operators::Get1DBlocksAnd2DGrids(
*ctx, (uint64_t)rows, (uint64_t)cols, VecSize);
*ctx, static_cast<uint64_t>(rows), static_cast<uint64_t>(cols),
VecSize);
const int increment = ((cols - 1) / (config.thread_per_block.x *
config.block_per_grid.x * VecSize) +
1) *
......@@ -253,21 +248,14 @@ struct TestFusedResidualDropoutBias {
template <typename T>
static void BaseTest(const bool is_fp16 = false) {
const int rows = 16;
std::vector<int> cols_list = {16, 17};
bool has_bias[2] = {true, false};
T default_diff = static_cast<T>(1e-5);
if (is_fp16) {
default_diff = static_cast<T>(1e-2);
}
for (int i = 0; i < cols_list.size(); i++) {
for (int j = 0; j < 2; j++) {
TestFusedResidualDropoutBias<T> test(rows, cols_list[i]);
test.has_bias = has_bias[j];
T default_diff = !is_fp16 ? static_cast<T>(1e-5) : static_cast<T>(1e-1);
for (auto cols : {16, 17}) {
for (auto has_bias : {true, false}) {
TestFusedResidualDropoutBias<T> test(rows, cols);
test.has_bias = has_bias;
test.Run();
test.CheckOut(default_diff);
if (!is_fp16) {
test.CheckGrad(default_diff);
}
test.CheckGrad(default_diff);
}
}
}
......@@ -276,30 +264,23 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest<float>(); }
TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest<double>(); }
// test fp16, For inference, check_grad is not required. ref: testdropout_op.py
TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) {
BaseTest<platform::float16>(true);
}
TEST(FusedDropout, GPUFusedResidualDropoutBias2) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) {
const int rows = 16;
const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 1.0, false, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
}
TEST(FusedDropout, GPUFusedResidualDropoutBias3) {
const int rows = 16;
const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 1.0, true, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
for (auto is_upscale_in_train : {true, false}) {
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 1.0,
is_upscale_in_train, false);
test.Run();
test.CheckOut(static_cast<float>(1e-5));
test.CheckGrad(static_cast<float>(1e-5));
}
}
TEST(FusedDropout, GPUFusedResidualDropoutBias4) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) {
const int rows = 16;
const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 0, 0.35, true, true);
......@@ -308,7 +289,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias4) {
test.CheckGrad(static_cast<float>(1e-5));
}
TEST(FusedDropout, GPUFusedResidualDropoutBias5) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) {
const int rows = 16;
const int cols = 16;
TestFusedResidualDropoutBias<float> test(rows, cols, 125, 0.0, false, false);
......@@ -317,8 +298,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) {
test.CheckGrad(static_cast<float>(1e-5));
}
// test large shape
TEST(FusedDropout, GPUFusedResidualDropoutBias6) {
TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) {
const int rows = 256;
const int cols = 4096;
TestFusedResidualDropoutBias<float> test(rows, cols);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册