diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 3df2144aa3594427563b0754ce8cc2f567188734..0a12735acf2a05c5d901fda659a3664500483ae1 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -75,5 +75,6 @@ if (WITH_GPU OR WITH_ROCM) # only support CUDA if(NOT WITH_ROCM) nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) + nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h new file mode 100755 index 0000000000000000000000000000000000000000..7d815bb8c39933ed9f9efe073bde0dabeac8185f --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -0,0 +1,317 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once +#ifndef _USE_MATH_DEFINES +#define _USE_MATH_DEFINES +#endif + +#include "paddle/fluid/operators/fused/fused_dropout_common.h" +#include "paddle/fluid/operators/math/functors.h" + +namespace paddle { +namespace operators { + +/** + *@brief the gelu functor + */ +template +struct GeluFunctor { + inline __host__ __device__ T operator()(const T x) const { + using U = LayerNormParamType; + const U casted_x = static_cast(x); + const U temp = erf(casted_x * static_cast(M_SQRT1_2)); + const U out = (casted_x * static_cast(0.5) * (static_cast(1) + temp)); + return static_cast(out); + } +}; + +/** + *@brief the gelu grad functor + */ +template +struct GeluGradFunctor { + inline __host__ __device__ T UseOut(const T x) const { + using U = LayerNormParamType; + auto casted_x = static_cast(x); + + auto first = + static_cast(0.5) * + (static_cast(1) + erf(casted_x * static_cast(M_SQRT1_2))); + + auto second = static_cast(0.5 * M_2_SQRTPI * M_SQRT1_2) * casted_x * + exp(-static_cast(0.5) * casted_x * casted_x); + return static_cast((first + second)); + } +}; + +/** + * @brief dst = dropout(activation(src + bias)); + * the src, mask and dst shape is (rows, cols) + * the bias shape is (1, cols) + */ +template +__global__ void FusedDropoutActBias( + Functor act, const uint64_t seed, const uint64_t rows, const uint64_t cols, + const int increment, const float dropout_prob, + const bool is_upscale_in_train, const bool is_test, + const T *__restrict__ src, const T *__restrict__ bias, T *dst, + MaskType *mask) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0); + } + if (is_test) { + factor = static_cast(1.0f - dropout_prob); + if (is_upscale_in_train) { + factor = static_cast(1.0f); + } + } + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + using MaskStoreT = platform::AlignedVector; + + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < cols; + i += blockDim.x * gridDim.x * VecSize) { + LoadT src_vec; + LoadT bias_vec; + // vectorize load data from global + platform::Load(&src[r * cols + i], &src_vec); + + if (bias) { + platform::Load(&bias[i], &bias_vec); + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + } + + MaskStoreT mask_vec; + if (!is_test) { + float rand[VecSize]; + RandVec(&state, rand); +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); + } + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(1); + } + } + + StoreT dest_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + const T tmp = src_vec[ii] + bias_vec[ii]; + const T act_out = act(tmp); + dest_vec[ii] = act_out * static_cast(mask_vec[ii]) * factor; + } + // store result to global + platform::Store(dest_vec, &dst[r * cols + i]); + if (!is_test) { + platform::Store(mask_vec, &mask[r * cols + i]); + } + } + } +} + +/** + * @brief dst = dropout(activation(src + bias)); + */ +template +void LaunchDropoutActBias(Functor act_functor, const uint64_t seed, + const uint32_t rows, const uint32_t cols, + const int increment, const float dropout_prob, + const bool is_upscale_in_train, const bool is_test, + const T *src, const T *bias, T *dst, + MaskType *mask_data, + const platform::CUDADeviceContext &ctx) { + // dropout_prob == 1.0f + if (std::abs(dropout_prob - 1.0f) < 1e-5) { + SetZero(ctx, dst, rows * cols); + SetZero(ctx, mask_data, rows * cols); + return; + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + const auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); + if (cols % VecSize == 0) { + FusedDropoutActBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, seed, rows, cols, increment, dropout_prob, + is_upscale_in_train, is_test, src, bias, dst, mask_data); + } else { + FusedDropoutActBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, seed, rows, cols, increment, dropout_prob, + is_upscale_in_train, is_test, src, bias, dst, mask_data); + } +} + +/* + * @brief calculate the grad of no bias + */ +template +__global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, + const MaskType *mask, const T *src, + const T factor, const int64_t size, T *dx) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + LoadT dout_vec; + LoadT src_vec; + MaskLoadT mask_vec; + + platform::Load(&dout[i], &dout_vec); + platform::Load(&mask[i], &mask_vec); + platform::Load(&src[i], &src_vec); + + StoreT dx_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + T args[2]; + args[0] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + args[1] = src_vec[ii]; + dx_vec[ii] = args[0] * act_grad.UseOut(args[1]); + } + platform::Store(dx_vec, &dx[i]); + } +} + +/** + * blocks(128 * 8) + * 1. calculate the dx and reduce total rows to 128 rows + * 2. save 128*8 temporary sum in 8*128 shared memory + * 3. reduce the sum of 128 cols data by 8*VecSize warps + */ +template +__global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, + const MaskType *mask, const T *src, + const T *bias, const T factor, + const int64_t rows, const int64_t cols, + T *dx, T *dbias) { + int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + T tmp_sum[VecSize] = {static_cast(0)}; + // calculate the dx and temporary sum + if (col_id * VecSize < cols) { + for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { + int index = row_id * cols + col_id * VecSize; + LoadT dout_vec; + LoadT src_vec; + LoadT bias_vec; + MaskLoadT mask_vec; + + platform::Load(&dout[index], &dout_vec); + platform::Load(&src[index], &src_vec); + platform::Load(&mask[index], &mask_vec); + platform::Load(&bias[col_id * VecSize], &bias_vec); + + StoreT dx_vec; +#pragma unroll + for (int i = 0; i < VecSize; i++) { + T val; + T args[2]; + args[0] = dout_vec[i] * static_cast(mask_vec[i]) * factor; + args[1] = src_vec[i] + bias_vec[i]; + val = args[0] * act_grad.UseOut(args[1]); + dx_vec[i] = val; + tmp_sum[i] += val; + } + platform::Store(dx_vec, &dx[index]); + } + } + + CalculateDBias(tmp_sum, dbias, cols); +} + +/** + * @brief to launch kernel FusedResidualDropoutBiasGradVec + */ +template +void LaunchDropoutActBiasGrad(Functor act_functor, const T *dout, + const MaskType *mask, const T *src, const T *bias, + const float dropout_prob, + const bool is_upscale_in_train, + const uint32_t rows, const uint32_t cols, T *dx, + T *dbias, + const platform::CUDADeviceContext &ctx) { + const T zero = static_cast(0.0); + auto factor = dropout_prob == static_cast(1.0f) + ? zero + : static_cast(1.0 / (1.0 - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0f); + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + + if (dbias != nullptr) { + const auto threads = 8; + const auto blocks = + std::max(static_cast(1), + (cols / real_vec_size + threads - 1) / threads); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + if (cols % VecSize == 0) { + FusedDropoutActBiasGrad< + T, MaskType, 8, 128, VecSize, + Functor><<>>( + act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); + } else { + FusedDropoutActBiasGrad< + T, MaskType, 8, 128, 1, + Functor><<>>( + act_functor, dout, mask, src, bias, factor, rows, cols, dx, dbias); + } + } else { + const uint64_t n = rows * cols; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); + if (n % VecSize == 0) { + FusedDropoutActGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, dout, mask, src, factor, n, dx); + } else { + FusedDropoutActGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + act_functor, dout, mask, src, factor, n, dx); + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu new file mode 100755 index 0000000000000000000000000000000000000000..0adbf0be4e28aa1d95b92a273f2a78851ca196ed --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias_test.cu @@ -0,0 +1,346 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" +#include "paddle/fluid/operators/fused/fused_dropout_test.h" +#include "paddle/fluid/operators/math/functors.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace details = paddle::operators::details; +namespace math = paddle::operators::math; + +/** + * @brief the unittest of fused_dropout_act_bias + * 1. random input data + * 2. add bias, call activation, call paddle dropout, and get the base result + * 3. call FusedDropoutActBias function get fused result + * 4. compare ther base result and fused result + */ + +template +struct TestFusedDropoutActBias { + uint32_t rows; + uint32_t cols; + uint64_t seed; + float dropout_prob; + bool is_upscale_in_train; + bool is_test; // default false, Set to true for inference only + bool has_bias = true; + framework::Tensor src, bias, out, mask; + framework::Tensor dsrc, dbias; + + std::vector src_vec, bias_vec, out_vec, mask_vec; + std::vector correct_out, correct_dsrc, correct_dbias; + std::vector correct_mask; + + platform::CUDAPlace place; + platform::CUDADeviceContext *ctx; + + TestFusedDropoutActBias() { + rows = 32; + cols = 32; + seed = 0; + dropout_prob = 0.0; + is_upscale_in_train = false; + is_test = false; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + TestFusedDropoutActBias(int rows_, int cols_, uint64_t seed_ = 0, + float dropout_prob_ = 0.0, + bool is_upscale_in_train_ = false, + bool is_test_ = false) { + rows = rows_; + cols = cols_; + seed = seed_; + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + ~TestFusedDropoutActBias() {} + + void SetUp() { + const int n = rows * cols; + correct_out.resize(n); + correct_mask.resize(n); + correct_dsrc.resize(n); + correct_dbias.resize(cols); + + src_vec.resize(n); + bias_vec.resize(cols); + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + src_vec[i * cols + j] = static_cast(dis(random)); + if (i == 0) bias_vec[j] = dis(random); + } + } + + framework::TensorFromVector(src_vec, *ctx, &src); + src.Resize({rows, cols}); + if (has_bias) { + framework::TensorFromVector(bias_vec, *ctx, &bias); + bias.Resize({cols}); + } + + { + out.mutable_data({rows, cols}, place); + mask.mutable_data({rows, cols}, place); + dsrc.mutable_data({rows, cols}, place); + + if (has_bias) { + dbias.mutable_data({cols}, place); + } + } + } + + void BaseForward() { + std::vector out1(rows * cols); + Functor act; + if (has_bias) { + // add bias and call activation + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + const T tmp = src_vec[i * cols + j] + bias_vec[j]; + out1[i * cols + j] = act(tmp); + } + } + // call dropout + Dropout(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } else { + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + const T tmp = src_vec[i * cols + j]; + out1[i * cols + j] = act(tmp); + } + } + + Dropout(out1, src.dims(), &correct_out, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } + ctx->Wait(); + } + + void BaseBackward() { + std::vector _out(rows * cols); + // call dropout_grad + DropoutGrad(&_out, src.dims(), correct_out, correct_mask, *ctx, + dropout_prob, is_upscale_in_train); + + // calculate dbias + memset(&correct_dbias[0], 0, cols * sizeof(T)); + GradFunctor act_grad; + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + T args[2]; + args[0] = _out[i * cols + j]; + if (has_bias) { + args[1] = src_vec[i * cols + j] + bias_vec[j]; + } else { + args[1] = src_vec[i * cols + j]; + } + T val = args[0] * act_grad.UseOut(args[1]); + correct_dsrc[i * cols + j] = val; + } + } + + if (has_bias) { + // reduce_sum: keep the same calculate order as the GPU + ReduceSum(correct_dsrc, &correct_dbias, rows, cols); + } + } + + void FusedForward() { + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + auto config = paddle::operators::Get1DBlocksAnd2DGrids( + *ctx, static_cast(rows), static_cast(cols), + VecSize); + const int increment = ((cols - 1) / (config.thread_per_block.x * + config.block_per_grid.x * VecSize) + + 1) * + VecSize; + + T *bias_ptr = nullptr; + if (has_bias) { + bias_ptr = bias.data(); + } + Functor act; + paddle::operators::LaunchDropoutActBias( + act, seed, rows, cols, increment, dropout_prob, is_upscale_in_train, + is_test, src.data(), bias_ptr, out.data(), mask.data(), + *ctx); + ctx->Wait(); + } + + void FusedBackward() { + if (is_test) return; + + T *bias_ptr = nullptr; + T *dbias_ptr = nullptr; + if (has_bias) { + dbias_ptr = dbias.data(); + bias_ptr = bias.data(); + } + GradFunctor act_grad; + paddle::operators::LaunchDropoutActBiasGrad( + act_grad, out.data(), mask.data(), src.data(), bias_ptr, + dropout_prob, is_upscale_in_train, rows, cols, dsrc.data(), + dbias_ptr, *ctx); + } + + void Run() { + SetUp(); + BaseForward(); + FusedForward(); + BaseBackward(); + FusedBackward(); + } + + void CheckOut(const T diff) { + const int n = rows * cols; + std::vector _out(n); + std::vector _mask(n); + framework::TensorToVector(out, *ctx, &_out); + if (!is_test) { + framework::TensorToVector(mask, *ctx, &_mask); + } + ctx->Wait(); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); + if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + } + } + + void CheckGrad(const T diff) { + if (is_test) return; + + const int n = rows * cols; + + std::vector _dsrc(n); + framework::TensorToVector(dsrc, *ctx, &_dsrc); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff); + } + + if (has_bias) { + std::vector _dbias(cols); + framework::TensorToVector(dbias, *ctx, &_dbias); + ctx->Wait(); + for (int i = 0; i < cols; i++) { + EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff); + } + } + } +}; + +// test the shape , bias, activation +template +static void BaseTest(const bool is_fp16 = false) { + const int rows = 16; + std::vector cols_list = {16, 17}; + bool has_bias[2] = {true, false}; + T default_diff = !is_fp16 ? static_cast(1e-5) : static_cast(1e-1); + for (auto cols : {16, 17}) { + for (auto has_bias : {true, false}) { + TestFusedDropoutActBias test(rows, cols); + test.has_bias = has_bias; + test.Run(); + test.CheckOut(default_diff); + test.CheckGrad(default_diff); + } + } +} + +TEST(FusedDropout, GPUFusedDorpoutActBias) { + BaseTest, math::ReluGradFunctor>(); + BaseTest, + paddle::operators::GeluGradFunctor>(); +} +TEST(FusedDropout, GPUFusedDropoutActBiasDouble) { + BaseTest, math::ReluGradFunctor>(); + BaseTest, + paddle::operators::GeluGradFunctor>(); +} + +// test fp16, For inference, check_grad is not required. ref: test_dropout_op.py +TEST(FusedDropout, GPUFusedDropoutActBiasFp16) { + using fp16 = platform::float16; + BaseTest, math::ReluGradFunctor>(true); +} + +TEST(FusedDropout, GPUFusedDropoutActBiasIsUpscaleInTrain) { + const int rows = 16; + const int cols = 16; + for (auto is_upscale_in_train : {true, false}) { + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 0, 1.0, is_upscale_in_train, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); + } +} + +TEST(FusedDropout, GPUFusedDropoutActBiasIsTest) { + const int rows = 16; + const int cols = 16; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 0, 0.35, true, true); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +} + +TEST(FusedDropout, GPUFusedDropoutActBiasSeed) { + const int rows = 16; + const int cols = 16; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols, 125, 0.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +} + +TEST(FusedDropout, GPUFusedDropoutActBiasLargeShape) { + const int rows = 256; + const int cols = 4096; + TestFusedDropoutActBias, + math::ReluGradFunctor> + test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +} diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 24f6f53c63630e3e5f635a6a4dec78c546759adb..02c3a2d6f1a12ff1ba671efbb2525069068b7687 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -20,6 +20,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" @@ -39,8 +40,8 @@ namespace operators { */ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( const platform::CUDADeviceContext &ctx, const uint32_t rows, - const uint32_t cols, const int VecSize) { - const uint32_t tmp_cols = cols / VecSize; + const uint32_t cols, const int vec_size) { + const uint32_t tmp_cols = cols / vec_size; int threads = std::max( static_cast(32), std::min(tmp_cols, static_cast(ctx.GetMaxThreadsPerBlock()))); @@ -54,19 +55,26 @@ inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( return config; } -__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state, - float *data) { +template +__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, + float *data); + +template <> +__forceinline__ __device__ void RandVec<1>(curandStatePhilox4_32_10_t *state, + float *data) { data[0] = curand_uniform(state); } -__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state, - float *data) { +template <> +__forceinline__ __device__ void RandVec<2>(curandStatePhilox4_32_10_t *state, + float *data) { data[0] = curand_uniform(state); data[1] = curand_uniform(state); } -__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, - float *data) { +template <> +__forceinline__ __device__ void RandVec<4>(curandStatePhilox4_32_10_t *state, + float *data) { float4 rand4 = curand_uniform4(state); data[0] = rand4.x; data[1] = rand4.y; @@ -74,24 +82,54 @@ __forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, data[3] = rand4.z; } -__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state, - float *data) { - Rand4(state, data); - Rand4(state, data + 4); +template <> +__forceinline__ __device__ void RandVec<8>(curandStatePhilox4_32_10_t *state, + float *data) { + RandVec<4>(state, data); + RandVec<4>(state, data + 4); } -__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, - float *data, const int VecSize) { - if (VecSize == 1) { - Rand1(state, data); - } else if (VecSize == 2) { - Rand2(state, data); - } else if (VecSize == 4) { - Rand4(state, data); - } else if (VecSize == 8) { - Rand8(state, data); - } else { - return; +template +inline void SetZero(const platform::CUDADeviceContext &ctx, T *ptr, + const size_t size) { + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemsetAsync(ptr, 0, size * sizeof(T), ctx.stream())); +} + +/** + * reduce the sum of 128 cols data by 8*VecSize warps + */ +template +inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias, + const int cols) { + // save temporary sum to cache and do transpose + __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; + for (int i = 0; i < VecSize; i++) { + cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + } + __syncthreads(); + // reduce sum + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 5; // warp id + int y = tid & 31; // thread id on warp 0~31 + + // need BlockSizeX * VecSize warps + if (x < BlockSizeX * VecSize) { +// reduce 128 to 32 +#pragma unroll + for (int i = 0; i < (BlockSizeY >> 5); i++) { + sum += cache[x][y + i * 32]; + } + } + + // reduce 32 to 1 + sum = WarpReduceSum(sum); + + // save sum to dbias + int bias_id = blockIdx.x * blockDim.x * VecSize + x; + if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { + dbias[bias_id] = sum; } } diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index 288b415aef31f9990629fc15efa85c49630f1088..eae2f5457b07f8085e3d013c19db9d9b6b5e9ced 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -115,3 +115,22 @@ void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, framework::TensorToVector(*tensor_dx, ctx, dx); ctx.Wait(); } + +template +inline void ReduceSum(const std::vector &dout, std::vector *dbias, + const int rows, const int cols) { + for (int j = 0; j < cols; j++) { + std::vector tmp_dbias(rows); + for (int i = 0; i < rows; i++) { + tmp_dbias[i] = dout[i * cols + j]; + } + int tmp_rows = rows / 2; + while (tmp_rows) { + for (int i = 0; i < tmp_rows; i++) { + tmp_dbias[i] += tmp_dbias[i + tmp_rows]; + } + tmp_rows /= 2; + } + (*dbias)[j] = tmp_dbias[0]; + } +} diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index cd9dfd1c79ca8f454140522f23c7777bfcdf3239..0230244c981555d2a206c306eb7eff68b295310a 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -15,7 +15,6 @@ limitations under the License. */ #pragma once #include "paddle/fluid/operators/fused/fused_dropout_common.h" -#include "paddle/fluid/operators/layer_norm_kernel.cu.h" namespace paddle { namespace operators { @@ -28,8 +27,9 @@ template __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, - const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, - const bool is_test, typename details::MPTypeTrait::Type *mean_val, + const T *__restrict__ src, const T *__restrict__ residual, + const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test, + typename details::MPTypeTrait::Type *mean_val, typename details::MPTypeTrait::Type *var_val) { using LoadT = platform::AlignedVector; using StoreT = platform::AlignedVector; @@ -54,7 +54,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( MaskStoreT mask_vec; if (!is_test) { float rand[VecSize]; - RandVec(state, rand, VecSize); + RandVec(state, rand); #pragma unroll for (int ii = 0; ii < VecSize; ii++) { mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); @@ -97,24 +97,21 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( template __global__ void FusedResidualDropoutBias( const size_t rows, const size_t cols, uint64_t seed, - const float dropout_prob, const bool is_upscale_in_train, const T *src, - const T *residual, const T *bias, MaskType *mask, T *dst, - uint64_t increment, const bool is_test) { + const float dropout_prob, const bool is_upscale_in_train, + const T *__restrict__ src, const T *__restrict__ residual, + const T *__restrict__ bias, MaskType *mask, T *dst, uint64_t increment, + const bool is_test) { int col_id = blockDim.x * blockIdx.x + threadIdx.x; int row_id = blockIdx.y; int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - if (!is_upscale_in_train) { - factor = static_cast(1.0f); - } + T factor = is_upscale_in_train ? static_cast(1.0f / (1.0f - dropout_prob)) + : static_cast(1.0f); if (is_test) { - factor = static_cast(1.0f - dropout_prob); - if (is_upscale_in_train) { - factor = static_cast(1.0f); - } + factor = is_upscale_in_train ? static_cast(1.0f) + : static_cast(1.0f - dropout_prob); } for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; @@ -144,8 +141,7 @@ void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), ctx.stream()); if (!is_test) { - PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( - mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + SetZero(ctx, mask_data, rows * cols); } return; } @@ -234,36 +230,7 @@ __global__ void FusedResidualDropoutBiasGrad(const T *dout, } } - // save temporary sum to cache and do transpose - __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; - for (int i = 0; i < VecSize; i++) { - cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; - } - __syncthreads(); - - // reduce sum - T sum = static_cast(0); - int tid = threadIdx.y * blockDim.x + threadIdx.x; - int x = tid >> 5; // warp id - int y = tid & 31; // thread id on warp 0~31 - - // need BlockSizeX * VecSize warps - if (x < BlockSizeX * VecSize) { -// reduce 128 to 32 -#pragma unroll - for (int i = 0; i < (BlockSizeY >> 5); i++) { - sum += cache[x][y + i * 32]; - } - } - - // reduce 32 to 1 - sum = WarpReduceSum(sum); - - // save sum to dbias - int bias_id = blockIdx.x * blockDim.x * VecSize + x; - if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { - dbias[bias_id] = sum; - } + CalculateDBias(tmp_sum, dbias, cols); } /** @@ -287,9 +254,9 @@ void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, const int VecSize = MAX_CACHE_BYTES / sizeof(T); int real_vec_size = cols % VecSize == 0 ? VecSize : 1; if (dbias != nullptr) { - auto threads = std::min(cols / real_vec_size, static_cast(8)); - auto blocks = - std::max((uint32_t)1, (cols / real_vec_size + threads - 1) / threads); + const auto threads = 8; + auto blocks = std::max(static_cast(1), + (cols / real_vec_size + threads - 1) / threads); dim3 block_dim(threads, 128, 1); dim3 grid_dim(blocks, 1, 1); if (cols % VecSize == 0) { diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index e687823bc8158bc4512c59485cacfd384de36ff8..d44df536bdd10c4d2c977ac679c54beec3671419 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -114,16 +114,12 @@ struct TestFusedResidualDropoutBias { } { - out.Resize({rows, cols}); - out.mutable_data(place); - mask.Resize({rows, cols}); - mask.mutable_data(place); - dsrc.Resize({rows, cols}); - dsrc.mutable_data(place); + out.mutable_data({rows, cols}, place); + mask.mutable_data({rows, cols}, place); + dsrc.mutable_data({rows, cols}, place); if (has_bias) { - dbias.Resize({cols}); - dbias.mutable_data(place); + dbias.mutable_data({cols}, place); } } } @@ -159,17 +155,16 @@ struct TestFusedResidualDropoutBias { dropout_prob, is_upscale_in_train); // calc dbias memset(&correct_dbias[0], 0, cols * sizeof(T)); - for (int i = 0; i < rows; i++) { - for (int j = 0; j < cols; j++) { - correct_dbias[j] += correct_out[i * cols + j]; - } + if (has_bias) { + ReduceSum(correct_out, &correct_dbias, rows, cols); } } void FusedForward() { const int VecSize = MAX_CACHE_BYTES / sizeof(T); auto config = paddle::operators::Get1DBlocksAnd2DGrids( - *ctx, (uint64_t)rows, (uint64_t)cols, VecSize); + *ctx, static_cast(rows), static_cast(cols), + VecSize); const int increment = ((cols - 1) / (config.thread_per_block.x * config.block_per_grid.x * VecSize) + 1) * @@ -253,21 +248,14 @@ struct TestFusedResidualDropoutBias { template static void BaseTest(const bool is_fp16 = false) { const int rows = 16; - std::vector cols_list = {16, 17}; - bool has_bias[2] = {true, false}; - T default_diff = static_cast(1e-5); - if (is_fp16) { - default_diff = static_cast(1e-2); - } - for (int i = 0; i < cols_list.size(); i++) { - for (int j = 0; j < 2; j++) { - TestFusedResidualDropoutBias test(rows, cols_list[i]); - test.has_bias = has_bias[j]; + T default_diff = !is_fp16 ? static_cast(1e-5) : static_cast(1e-1); + for (auto cols : {16, 17}) { + for (auto has_bias : {true, false}) { + TestFusedResidualDropoutBias test(rows, cols); + test.has_bias = has_bias; test.Run(); test.CheckOut(default_diff); - if (!is_fp16) { - test.CheckGrad(default_diff); - } + test.CheckGrad(default_diff); } } } @@ -276,30 +264,23 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest(); } TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest(); } -// test fp16, For inference, check_grad is not required. ref: testdropout_op.py TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { BaseTest(true); } -TEST(FusedDropout, GPUFusedResidualDropoutBias2) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasIsUpscaleInTrain) { const int rows = 16; const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); -} - -TEST(FusedDropout, GPUFusedResidualDropoutBias3) { - const int rows = 16; - const int cols = 16; - TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); - test.Run(); - test.CheckOut(static_cast(1e-5)); - test.CheckGrad(static_cast(1e-5)); + for (auto is_upscale_in_train : {true, false}) { + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, + is_upscale_in_train, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); + } } -TEST(FusedDropout, GPUFusedResidualDropoutBias4) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasIsTest) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); @@ -308,7 +289,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias4) { test.CheckGrad(static_cast(1e-5)); } -TEST(FusedDropout, GPUFusedResidualDropoutBias5) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasSeed) { const int rows = 16; const int cols = 16; TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); @@ -317,8 +298,7 @@ TEST(FusedDropout, GPUFusedResidualDropoutBias5) { test.CheckGrad(static_cast(1e-5)); } -// test large shape -TEST(FusedDropout, GPUFusedResidualDropoutBias6) { +TEST(FusedDropout, GPUFusedResidualDropoutBiasLargeShape) { const int rows = 256; const int cols = 4096; TestFusedResidualDropoutBias test(rows, cols);