diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 0a12735acf2a05c5d901fda659a3664500483ae1..e3dcff949f43c3438efdd7a2349168a6867339ad 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -74,7 +74,8 @@ if (WITH_GPU OR WITH_ROCM) # fused_dropout # only support CUDA if(NOT WITH_ROCM) - nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) - nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) + nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) + nv_test(test_fused_dropout_act_bias SRCS fused_dropout_act_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) + nv_test(test_fused_layernorm_residual_dropout_bias SRCS fused_layernorm_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op layer_norm_op device_context generator memory) endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout_act_bias.h b/paddle/fluid/operators/fused/fused_dropout_act_bias.h index 7d815bb8c39933ed9f9efe073bde0dabeac8185f..994601a2f0608b4fc04966c7549c421f395f3ec7 100755 --- a/paddle/fluid/operators/fused/fused_dropout_act_bias.h +++ b/paddle/fluid/operators/fused/fused_dropout_act_bias.h @@ -17,8 +17,7 @@ limitations under the License. */ #define _USE_MATH_DEFINES #endif -#include "paddle/fluid/operators/fused/fused_dropout_common.h" -#include "paddle/fluid/operators/math/functors.h" +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" namespace paddle { namespace operators { @@ -75,66 +74,15 @@ __global__ void FusedDropoutActBias( curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); - T factor = static_cast(1.0f / (1.0f - dropout_prob)); - if (!is_upscale_in_train) { - factor = static_cast(1.0); - } - if (is_test) { - factor = static_cast(1.0f - dropout_prob); - if (is_upscale_in_train) { - factor = static_cast(1.0f); - } - } - - using LoadT = platform::AlignedVector; - using StoreT = platform::AlignedVector; - using MaskLoadT = platform::AlignedVector; - using MaskStoreT = platform::AlignedVector; + const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - LoadT src_vec; - LoadT bias_vec; - // vectorize load data from global - platform::Load(&src[r * cols + i], &src_vec); - - if (bias) { - platform::Load(&bias[i], &bias_vec); - } else { -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - bias_vec[ii] = static_cast(0); - } - } - - MaskStoreT mask_vec; - if (!is_test) { - float rand[VecSize]; - RandVec(&state, rand); -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); - } - } else { -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - mask_vec[ii] = static_cast(1); - } - } - - StoreT dest_vec; -#pragma unroll - for (int ii = 0; ii < VecSize; ii++) { - const T tmp = src_vec[ii] + bias_vec[ii]; - const T act_out = act(tmp); - dest_vec[ii] = act_out * static_cast(mask_vec[ii]) * factor; - } - // store result to global - platform::Store(dest_vec, &dst[r * cols + i]); - if (!is_test) { - platform::Store(mask_vec, &mask[r * cols + i]); - } + FusedResidualDropoutBiasOneThread( + r, i, cols, &state, dropout_prob, factor, src, nullptr, bias, dst, + mask, is_test, nullptr, nullptr, act); } } } @@ -197,10 +145,8 @@ __global__ void FusedDropoutActGrad(Functor act_grad, const T *dout, StoreT dx_vec; #pragma unroll for (int ii = 0; ii < VecSize; ii++) { - T args[2]; - args[0] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; - args[1] = src_vec[ii]; - dx_vec[ii] = args[0] * act_grad.UseOut(args[1]); + T tmp = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + dx_vec[ii] = tmp * act_grad.UseOut(src_vec[ii]); } platform::Store(dx_vec, &dx[i]); } @@ -243,10 +189,8 @@ __global__ void FusedDropoutActBiasGrad(Functor act_grad, const T *dout, #pragma unroll for (int i = 0; i < VecSize; i++) { T val; - T args[2]; - args[0] = dout_vec[i] * static_cast(mask_vec[i]) * factor; - args[1] = src_vec[i] + bias_vec[i]; - val = args[0] * act_grad.UseOut(args[1]); + T tmp = dout_vec[i] * static_cast(mask_vec[i]) * factor; + val = tmp * act_grad.UseOut(src_vec[i] + bias_vec[i]); dx_vec[i] = val; tmp_sum[i] += val; } diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h index 02c3a2d6f1a12ff1ba671efbb2525069068b7687..3fb58eab077bca9af95a26aac15f54cfba48cc99 100644 --- a/paddle/fluid/operators/fused/fused_dropout_common.h +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -21,6 +21,7 @@ limitations under the License. */ #include "paddle/fluid/memory/memory.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/layer_norm_kernel.cu.h" +#include "paddle/fluid/operators/math/functors.h" #include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/cuda_device_function.h" #include "paddle/fluid/platform/device_context.h" @@ -133,5 +134,17 @@ inline __device__ void CalculateDBias(const T *tmp_sum, T *dbias, } } +template +inline __device__ T GetFactor(const float dropout_prob, + const bool is_upscale_in_train, + const bool is_test) { + T factor = is_upscale_in_train ? static_cast(1.0f / (1.0f - dropout_prob)) + : static_cast(1.0f); + if (is_test) { + factor = is_upscale_in_train ? static_cast(1.0f) + : static_cast(1.0f - dropout_prob); + } + return factor; +} } // namespace operators } // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index eae2f5457b07f8085e3d013c19db9d9b6b5e9ced..a0d1cd43404eb9e43bc775ff79e7613e5e1317f0 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -23,6 +23,7 @@ limitations under the License. */ #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" #include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/string/printf.h" @@ -31,6 +32,12 @@ namespace platform = paddle::platform; namespace memory = paddle::memory; USE_OP(dropout); +USE_OP(layer_norm); + +template +using CudnnDataType = platform::CudnnDataType; +template +using LayerNormParamType = typename CudnnDataType::BatchNormParamType; /** * @brief call paddle dropout op @@ -116,6 +123,60 @@ void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, ctx.Wait(); } +/** + * @brief call paddle layer_norm op + */ +template +void LayerNorm(const std::vector> &scale, + const std::vector> &bias, + const std::vector &x, + std::vector> *means, + std::vector> *vars, std::vector *y, + const float epsilon, const int rows, const int cols, + const platform::CUDADeviceContext &ctx) { + framework::Scope scope; + auto place = ctx.GetPlace(); + if (scale.size() > 0) { + auto var_scale = scope.Var("Scale"); + auto tensor_scale = var_scale->GetMutable(); + framework::TensorFromVector(scale, ctx, tensor_scale); + tensor_scale->Resize({cols}); + } + + if (bias.size() > 0) { + auto var_bias = scope.Var("Bias"); + auto tensor_bias = var_bias->GetMutable(); + framework::TensorFromVector(bias, ctx, tensor_bias); + tensor_bias->Resize({cols}); + } + + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + framework::TensorFromVector(x, ctx, tensor_x); + tensor_x->Resize({rows, cols}); + + auto var_y = scope.Var("Y"); + auto tensor_y = var_y->GetMutable(); + + auto var_mean = scope.Var("Mean"); + auto tensor_mean = var_mean->GetMutable(); + + auto var_variance = scope.Var("Variance"); + auto tensor_variance = var_variance->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"epsilon", epsilon}); + + auto op = framework::OpRegistry::CreateOp( + "layer_norm", {{"X", {"X"}}, {"Scale", {"Scale"}}, {"Bias", {"Bias"}}}, + {{"Y", {"Y"}}, {"Mean", {"Mean"}}, {"Variance", {"Variance"}}}, attrs); + op->Run(scope, place); + framework::TensorToVector(*tensor_y, ctx, y); + framework::TensorToVector(*tensor_mean, ctx, means); + framework::TensorToVector(*tensor_variance, ctx, vars); + ctx.Wait(); +} + template inline void ReduceSum(const std::vector &dout, std::vector *dbias, const int rows, const int cols) { diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..f257d3efa433e6a817de713e18d60ced9da5acbd --- /dev/null +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h @@ -0,0 +1,209 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" + +namespace paddle { +namespace operators { + +template +using CudnnDataType = platform::CudnnDataType; +template +using LayerNormParamType = typename CudnnDataType::BatchNormParamType; + +/** + * @brief fused add_bias, dropout, add residual and leyer_norm into one + * operators. Currently only support forward + */ + +template +__device__ void CalcLayernormY(const LayerNormParamType *scale, + const LayerNormParamType *bias, const T *x, + T *y, const int row_id, const int col_id, + const int cols, + const LayerNormParamType mean_val, + const LayerNormParamType invvar) { + using U = LayerNormParamType; + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using LoadU = platform::AlignedVector; + for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { + LoadU scale_vec; + LoadU bias_vec; + LoadT x_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + scale_vec[ii] = static_cast(1); + bias_vec[ii] = static_cast(0); + } + // vectorize load data from global + platform::Load(&x[row_id * cols + i], &x_vec); + + if (scale != nullptr) { + platform::Load(&scale[i], &scale_vec); + } + if (bias != nullptr) { + platform::Load(&bias[i], &bias_vec); + } + + StoreT y_vec; + for (int ii = 0; ii < VecSize; ii++) { + y_vec[ii] = static_cast( + scale_vec[ii] * (static_cast(x_vec[ii]) - mean_val) * invvar + + bias_vec[ii]); + } + platform::Store(y_vec, &y[row_id * cols + i]); + } +} + +/** + * @brief layernorm(residual + dropout(src + bias)); + * @param + * rows: batch_size * seq_len + * cols: feature_size or hidden_size + * src: [rows, cols], inputs + * bias: [cols], linear bias, can be null + * residual:[rows, cols] + * mask: [rows, cols], dropout result + * dst: [rows, cols], residual + dropout(src+bias) + * layernorm_dst: [rows, cols], layernorm result + * layernorm_bias: [cols], layernorm bias, can be null + * scale: [cols]: layernorm scale, can be null + * means: [rows]: layernorm means + * vars: [rows]: layernorm vars + */ +template +__global__ void FusedLayernormResidualDropoutBias( + const size_t rows, const size_t cols, uint64_t seed, + const float dropout_prob, const bool is_upscale_in_train, + const bool is_test, const uint64_t increment, const float epsilon, + const T *src, const T *residual, const T *bias, + const LayerNormParamType *scale, + const LayerNormParamType *layernorm_bias, MaskType *mask, T *dst, + T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var) { + int col_id = threadIdx.x; + int row_id = blockIdx.x; + int idx = row_id * cols + col_id; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + using U = LayerNormParamType; + + __shared__ U mean_share; + __shared__ U var_share; + __shared__ U shared_mean[32]; + __shared__ U shared_var[32]; + + math::ReluFunctor relu; + U mean_val = 0; + U var_val = 0; + for (int i = col_id * VecSize; i < cols; i += blockDim.x * VecSize) { + FusedResidualDropoutBiasOneThread>( + row_id, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, + mask, is_test, &mean_val, &var_val, relu); + } + + mean_val = BlockReduceSum(mean_val, shared_mean); + var_val = BlockReduceSum(var_val, shared_var); + if (threadIdx.x == 0) { + auto scale = static_cast(1.) / static_cast(cols); + auto tmp = mean_val * scale; + mean[row_id] = mean_share = static_cast(tmp); + var_share = static_cast(var_val * scale - mean_share * mean_share); + var_share = var_share > U(0) ? var_share : U(0); + var[row_id] = var_share; + } + __syncthreads(); + + mean_val = mean_share; + U invvar = rsqrt_(var_share + static_cast(epsilon)); + + // calculate layernorm_dst + CalcLayernormY(scale, layernorm_bias, dst, layernorm_dst, row_id, + col_id, cols, mean_val, invvar); +} + +/** + * @brief layernorm(residual + dropout(src + bias)); + * @param + * rows: batch_size * seq_len + * cols: feature_size or hidden_size + * src: [rows, cols], inputs + * bias: [cols], linear bias, can be null + * residual:[rows, cols] + * mask: [rows, cols], dropout result, can be null if is_test = true + * dst: [rows, cols], residual + dropout(src+bias) + * layernorm_dst: [rows, cols], layernorm result + * layernorm_bias: [cols], layernorm bias, can be null + * scale: [cols]: layernorm scale, can be null + * means: [rows]: layernorm means + * vars: [rows]: layernorm vars + */ +template +void LaunchLayernormResidualDropoutBias( + const uint32_t rows, const uint32_t cols, const int increment, + uint64_t seed, const float dropout_prob, const float epsilon, + const bool is_upscale_in_train, const bool is_test, const T *src, + const T *residual, const T *bias, const LayerNormParamType *scale, + const LayerNormParamType *layernorm_bias, MaskType *mask_data, T *dst, + T *layernorm_dst, LayerNormParamType *mean, LayerNormParamType *var, + const platform::CUDADeviceContext &ctx) { + using U = LayerNormParamType; + // dropout_prob == 1.0f + if (std::abs(dropout_prob - 1.0f) < 1e-5) { + auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), + ctx.stream()); + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + + // call layernorm forward + switch (GetDesiredBlockDim(cols)) { + FIXED_BLOCK_DIM_CASE( + LayerNormForward<<>>( + dst, scale, layernorm_bias, layernorm_dst, mean, var, epsilon, + cols)); + default: + PADDLE_THROW(platform::errors::InvalidArgument( + "Product from begin_norm_axis to end must be larger than 1")); + break; + } + return; + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + if (cols % VecSize != 0) { + int blockDim = GetDesiredBlockDim(cols); + FusedLayernormResidualDropoutBias<<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, + epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, + layernorm_dst, mean, var); + } else { + int blockDim = GetDesiredBlockDim(cols / VecSize); + FusedLayernormResidualDropoutBias< + T, uint8_t, VecSize><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, is_test, increment, + epsilon, src, residual, bias, scale, layernorm_bias, mask_data, dst, + layernorm_dst, mean, var); + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..50e3555b4bcd629c85aca27fd3f2999b6694ecb2 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias_test.cu @@ -0,0 +1,332 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "paddle/fluid/operators/fused/fused_dropout_test.h" +#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h" + +/** + * @brief The unit test of fused_layernorm_residual_dropout_bias + */ + +template +struct TestFusedLayernormResidualDropoutBias { + uint32_t rows; + uint32_t cols; + uint64_t seed; + float dropout_prob, epsilon; + bool is_upscale_in_train; + bool is_test; // default false, Set to true for inference only + bool has_bias = true; + bool has_scale = true; + bool has_layernorm_bias = true; + framework::Tensor src, residual, bias, out, mask, scale, layernorm_bias, + layernorm_out, means, vars; + framework::Tensor dsrc, dbias; + + std::vector src_vec, residual_vec, bias_vec; + std::vector> means_vec, vars_vec, scale_vec, + layernorm_bias_vec; + std::vector correct_out, correct_dsrc, correct_dbias, + correct_layernorm_out; + std::vector> correct_means, correct_vars; + std::vector correct_mask; + + platform::CUDAPlace place; + platform::CUDADeviceContext *ctx; + + TestFusedLayernormResidualDropoutBias() { + rows = 32; + cols = 32; + seed = 0; + dropout_prob = 0.0; + is_upscale_in_train = false; + is_test = false; + has_bias = true; + has_scale = true; + has_layernorm_bias = true; + epsilon = 0.00001f; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + TestFusedLayernormResidualDropoutBias(int _rows, int _cols, + uint64_t _seed = 0, + float _dropout_prob = 0.0, + float _epsilon = 0.00001f, + bool _is_upscale_in_train = false, + bool _is_test = false) { + rows = _rows; + cols = _cols; + seed = _seed; + dropout_prob = _dropout_prob; + epsilon = _epsilon; + is_upscale_in_train = _is_upscale_in_train; + is_test = _is_test; + has_bias = true; + has_scale = true; + has_layernorm_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto devicectx = pool.Get(place); + ctx = reinterpret_cast(devicectx); + } + + ~TestFusedLayernormResidualDropoutBias() {} + + void SetUp() { + using U = LayerNormParamType; + const int n = rows * cols; + correct_out.resize(n); + correct_mask.resize(n); + correct_dsrc.resize(n); + correct_dbias.resize(cols); + correct_means.resize(rows); + correct_vars.resize(rows); + correct_layernorm_out.resize(n); + + src_vec.resize(n); + residual_vec.resize(n); + if (has_bias) { + bias_vec.resize(cols); + } + if (has_scale) { + scale_vec.resize(cols); + } + if (has_layernorm_bias) { + layernorm_bias_vec.resize(cols); + } + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + src_vec[i * cols + j] = static_cast(dis(random)); + residual_vec[i * cols + j] = static_cast(dis(random)); + if (i == 0) { + if (has_bias) { + bias_vec[j] = static_cast(dis(random)); + } + if (has_scale) { + scale_vec[j] = static_cast(dis(random)); + } + if (has_layernorm_bias) { + layernorm_bias_vec[j] = static_cast(dis(random)); + } + } + } + } + + framework::TensorFromVector(src_vec, *ctx, &src); + src.Resize({rows, cols}); + framework::TensorFromVector(residual_vec, *ctx, &residual); + residual.Resize({rows, cols}); + if (has_bias) { + framework::TensorFromVector(bias_vec, *ctx, &bias); + bias.Resize({cols}); + } + if (has_scale) { + framework::TensorFromVector(scale_vec, *ctx, &scale); + scale.Resize({cols}); + } + if (has_layernorm_bias) { + framework::TensorFromVector(layernorm_bias_vec, *ctx, &layernorm_bias); + layernorm_bias.Resize({cols}); + } + + { + out.Resize({rows, cols}); + out.mutable_data(place); + mask.Resize({rows, cols}); + mask.mutable_data(place); + means.Resize({rows}); + means.mutable_data(place); + vars.Resize({rows}); + vars.mutable_data(place); + layernorm_out.Resize({rows, cols}); + layernorm_out.mutable_data(place); + dsrc.Resize({rows, cols}); + dsrc.mutable_data(place); + + if (has_bias) { + dbias.Resize({cols}); + dbias.mutable_data(place); + } + } + } + + void BaseForward() { + using U = LayerNormParamType; + std::vector out1(rows * cols), out2(rows * cols); + if (has_bias) { + // add bias + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; + } + } + // call dropout + Dropout(out1, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } else { + Dropout(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } + // add residual + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + correct_out[i * cols + j] = + residual_vec[i * cols + j] + out2[i * cols + j]; + } + } + + LayerNorm(scale_vec, layernorm_bias_vec, correct_out, &correct_means, + &correct_vars, &correct_layernorm_out, epsilon, rows, cols, + *ctx); + ctx->Wait(); + } + + void FusedForward() { + using U = LayerNormParamType; + int VecSize = MAX_CACHE_BYTES / sizeof(T); + if (cols % 4 != 0) { + VecSize = 1; + } + int threads = paddle::operators::GetDesiredBlockDim(cols / VecSize); + const int increment = ((cols - 1) / (threads * VecSize) + 1) * VecSize; + + T *bias_ptr = nullptr; + U *scale_ptr = nullptr; + U *layernorm_bias_ptr = nullptr; + if (has_bias) { + bias_ptr = bias.data(); + } + if (has_scale) { + scale_ptr = scale.data(); + } + if (has_layernorm_bias) { + layernorm_bias_ptr = layernorm_bias.data(); + } + + paddle::operators::LaunchLayernormResidualDropoutBias( + rows, cols, increment, seed, dropout_prob, epsilon, is_upscale_in_train, + is_test, src.data(), residual.data(), bias_ptr, scale_ptr, + layernorm_bias_ptr, mask.data(), out.data(), + layernorm_out.data(), means.data(), vars.data(), *ctx); + ctx->Wait(); + } + + void Run() { + SetUp(); + BaseForward(); + FusedForward(); + } + + void CheckOut(const T diff) { + using U = LayerNormParamType; + const int n = rows * cols; + std::vector _out(n), _layernorm_out(n); + std::vector _means(rows), _vars(cols); + std::vector _mask(n); + framework::TensorToVector(out, *ctx, &_out); + framework::TensorToVector(layernorm_out, *ctx, &_layernorm_out); + framework::TensorToVector(means, *ctx, &_means); + framework::TensorToVector(vars, *ctx, &_vars); + if (!is_test) { + framework::TensorToVector(mask, *ctx, &_mask); + } + ctx->Wait(); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); + EXPECT_LT(std::abs(_layernorm_out[i] - correct_layernorm_out[i]), diff); + if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + } + for (int i = 0; i < rows; i++) { + EXPECT_LT(std::abs(_means[i] - correct_means[i]), static_cast(diff)); + EXPECT_LT(std::abs(_vars[i] - correct_vars[i]), static_cast(diff)); + } + } +}; + +template +static void BaseTest(const bool is_fp16 = false) { + const int rows = 16; + T default_diff = !is_fp16 ? static_cast(1e-4) : static_cast(1e-2); + for (auto cols : {16, 17}) { + for (auto has_bias : {true, false}) { + for (auto has_scale : {true, false}) { + for (auto has_layernorm_bias : {true, false}) { + TestFusedLayernormResidualDropoutBias test(rows, cols); + test.has_bias = has_bias; + test.has_scale = has_scale; + test.has_layernorm_bias = has_layernorm_bias; + test.Run(); + test.CheckOut(default_diff); + } + } + } + } +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBias) { BaseTest(); } + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasDouble) { + BaseTest(); +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasFp16) { + BaseTest(true); +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasIsUpscaleInTrain) { + const int rows = 16; + const int cols = 16; + for (auto is_upscale_in_train : {true, false}) { + TestFusedLayernormResidualDropoutBias test( + rows, cols, 0, 1.0, 0.00001f, is_upscale_in_train, false); + test.Run(); + test.CheckOut(static_cast(1e-4)); + } +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasIsTest) { + const int rows = 16; + const int cols = 16; + TestFusedLayernormResidualDropoutBias test(rows, cols, 0, 0.35, + 0.00001f, true, true); + test.Run(); + test.CheckOut(static_cast(1e-4)); +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutBiasSeed) { + const int rows = 16; + const int cols = 16; + TestFusedLayernormResidualDropoutBias test(rows, cols, 125, 0.0, + 0.00001f, false, false); + test.Run(); + test.CheckOut(static_cast(1e-4)); +} + +TEST(FusedDropout, GPUFusedLayernormResidualDropoutLargeShape) { + const int rows = 512; + const int cols = 512; + TestFusedLayernormResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-4)); +} diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h index 0230244c981555d2a206c306eb7eff68b295310a..d984ad1a27768ddad248863fcc17d9e088c42da1 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -23,14 +23,15 @@ namespace operators { * @brief The fused function called by every thread * VecSize can be 1, 2, 4 or 8 */ -template +template __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( const int row_id, const int col_id, const int cols, curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, const T *__restrict__ src, const T *__restrict__ residual, const T *__restrict__ bias, T *dst, MaskType *mask, const bool is_test, typename details::MPTypeTrait::Type *mean_val, - typename details::MPTypeTrait::Type *var_val) { + typename details::MPTypeTrait::Type *var_val, Functor act_func) { using LoadT = platform::AlignedVector; using StoreT = platform::AlignedVector; using MaskStoreT = platform::AlignedVector; @@ -42,10 +43,14 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( #pragma unroll for (int ii = 0; ii < VecSize; ii++) { bias_vec[ii] = static_cast(0); + residual_vec[ii] = static_cast(0); } // vectorize load data from global platform::Load(&src[row_id * cols + col_id], &src_vec); - platform::Load(&residual[row_id * cols + col_id], &residual_vec); + if (residual) { + platform::Load(&residual[row_id * cols + col_id], + &residual_vec); + } if (bias) { platform::Load(&bias[col_id], &bias_vec); @@ -70,9 +75,12 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( #pragma unroll for (int ii = 0; ii < VecSize; ii++) { + T tmp = src_vec[ii] + bias_vec[ii]; + if (Activation) { + tmp = act_func(tmp); + } dest_vec[ii] = - (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) * factor + - residual_vec[ii]; + tmp * static_cast(mask_vec[ii]) * factor + residual_vec[ii]; if (ComputeLayerNorm) { U tmp = static_cast(dest_vec[ii]); *mean_val += tmp; @@ -106,19 +114,15 @@ __global__ void FusedResidualDropoutBias( int idx = row_id * cols + col_id; curandStatePhilox4_32_10_t state; curand_init(seed, idx, increment, &state); - - T factor = is_upscale_in_train ? static_cast(1.0f / (1.0f - dropout_prob)) - : static_cast(1.0f); - if (is_test) { - factor = is_upscale_in_train ? static_cast(1.0f) - : static_cast(1.0f - dropout_prob); - } + const T factor = GetFactor(dropout_prob, is_upscale_in_train, is_test); + math::ReluFunctor relu; for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { for (int i = col_id * VecSize; i < cols; i += blockDim.x * gridDim.x * VecSize) { - FusedResidualDropoutBiasOneThread( + FusedResidualDropoutBiasOneThread>( r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, - mask, is_test, nullptr, nullptr); + mask, is_test, nullptr, nullptr, relu); } } } diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu index d44df536bdd10c4d2c977ac679c54beec3671419..1a12e6b565f02035b3fb9673636c2344823f288e 100644 --- a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -165,6 +165,7 @@ struct TestFusedResidualDropoutBias { auto config = paddle::operators::Get1DBlocksAnd2DGrids( *ctx, static_cast(rows), static_cast(cols), VecSize); + const int increment = ((cols - 1) / (config.thread_per_block.x * config.block_per_grid.x * VecSize) + 1) *