diff --git a/paddle/fluid/operators/fused/CMakeLists.txt b/paddle/fluid/operators/fused/CMakeLists.txt index 541e5afdf9b71e4f087adcc4fe58cacdc54f4f61..3df2144aa3594427563b0754ce8cc2f567188734 100644 --- a/paddle/fluid/operators/fused/CMakeLists.txt +++ b/paddle/fluid/operators/fused/CMakeLists.txt @@ -71,4 +71,9 @@ if (WITH_GPU OR WITH_ROCM) op_library(fused_bn_add_activation_op) file(APPEND ${pybind_file} "USE_CUDA_ONLY_OP(fused_bn_add_activation);\n") endif() + # fused_dropout + # only support CUDA + if(NOT WITH_ROCM) + nv_test(test_fused_residual_dropout_bias SRCS fused_residual_dropout_bias_test.cu DEPS tensor op_registry dropout_op device_context generator memory) + endif() endif() diff --git a/paddle/fluid/operators/fused/fused_dropout_common.h b/paddle/fluid/operators/fused/fused_dropout_common.h new file mode 100644 index 0000000000000000000000000000000000000000..24f6f53c63630e3e5f635a6a4dec78c546759adb --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_common.h @@ -0,0 +1,99 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include + +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/amp/fp16_type_traits.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/fluid/platform/cuda_device_function.h" +#include "paddle/fluid/platform/device_context.h" +#include "paddle/fluid/platform/float16.h" +#include "paddle/fluid/platform/gpu_launch_config.h" + +namespace paddle { +namespace operators { + +#define CACHE_LINE 128 +#define MAX_CACHE_BYTES (CACHE_LINE / CHAR_BIT) + +/** + * get the threads for fused_residual_dropout_bias: + * 1D blocks: blockDim.x = cols + * 2D grids: gridDim.y = rows + */ +inline platform::GpuLaunchConfig Get1DBlocksAnd2DGrids( + const platform::CUDADeviceContext &ctx, const uint32_t rows, + const uint32_t cols, const int VecSize) { + const uint32_t tmp_cols = cols / VecSize; + int threads = std::max( + static_cast(32), + std::min(tmp_cols, static_cast(ctx.GetMaxThreadsPerBlock()))); + const auto blocks_x = + std::max(static_cast(1), (tmp_cols + threads - 1) / threads); + const auto blocks_y = std::max(static_cast(1), rows); + platform::GpuLaunchConfig config; + config.block_per_grid.x = blocks_x; + config.block_per_grid.y = blocks_y; + config.thread_per_block.x = threads; + return config; +} + +__forceinline__ __device__ void Rand1(curandStatePhilox4_32_10_t *state, + float *data) { + data[0] = curand_uniform(state); +} + +__forceinline__ __device__ void Rand2(curandStatePhilox4_32_10_t *state, + float *data) { + data[0] = curand_uniform(state); + data[1] = curand_uniform(state); +} + +__forceinline__ __device__ void Rand4(curandStatePhilox4_32_10_t *state, + float *data) { + float4 rand4 = curand_uniform4(state); + data[0] = rand4.x; + data[1] = rand4.y; + data[2] = rand4.w; + data[3] = rand4.z; +} + +__forceinline__ __device__ void Rand8(curandStatePhilox4_32_10_t *state, + float *data) { + Rand4(state, data); + Rand4(state, data + 4); +} + +__forceinline__ __device__ void RandVec(curandStatePhilox4_32_10_t *state, + float *data, const int VecSize) { + if (VecSize == 1) { + Rand1(state, data); + } else if (VecSize == 2) { + Rand2(state, data); + } else if (VecSize == 4) { + Rand4(state, data); + } else if (VecSize == 8) { + Rand8(state, data); + } else { + return; + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h new file mode 100644 index 0000000000000000000000000000000000000000..288b415aef31f9990629fc15efa85c49630f1088 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -0,0 +1,117 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/memory/memory.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; +namespace memory = paddle::memory; + +USE_OP(dropout); + +/** + * @brief call paddle dropout op + */ +template +void Dropout(const std::vector &x, const framework::DDim &x_dim, + std::vector *out, std::vector *mask, + const platform::CUDADeviceContext &ctx, uint64_t seed, + float dropout_prob, bool is_upscale_in_train, bool is_test) { + framework::Scope scope; + auto var_x = scope.Var("X"); + auto tensor_x = var_x->GetMutable(); + framework::TensorFromVector(x, ctx, tensor_x); + tensor_x->Resize(x_dim); + + auto var_out = scope.Var("Out"); + auto tensor_out = var_out->GetMutable(); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"fix_seed", 1}); + attrs.insert({"seed", static_cast(seed)}); + attrs.insert({"dropout_prob", dropout_prob}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } + + if (is_test) { + attrs.insert({"is_test", true}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout", {{"X", {"X"}}}, {{"Out", {"Out"}}, {"Mask", {"Mask"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + framework::TensorToVector(*tensor_out, ctx, out); + if (!is_test) { + framework::TensorToVector(*tensor_mask, ctx, mask); + } + ctx.Wait(); +} + +/** + * @brief call paddle dropout_grad op + */ +template +void DropoutGrad(std::vector *dx, const framework::DDim &x_dim, + const std::vector &dout, const std::vector &mask, + const platform::CUDADeviceContext &ctx, float dropout_prob, + bool is_upscale_in_train) { + framework::Scope scope; + const size_t n = x_dim[0] * x_dim[1]; + auto var_out = scope.Var("DOut"); + auto tensor_out = var_out->GetMutable(); + framework::TensorFromVector(dout, ctx, tensor_out); + tensor_out->Resize(x_dim); + + auto var_mask = scope.Var("Mask"); + auto tensor_mask = var_mask->GetMutable(); + framework::TensorFromVector(mask, ctx, tensor_mask); + tensor_mask->Resize(x_dim); + + auto var_dx = scope.Var("DX"); + auto tensor_dx = var_dx->GetMutable(); + + framework::AttributeMap attrs; + attrs.insert({"dropout_prob", dropout_prob}); + attrs.insert({"is_test", false}); + if (is_upscale_in_train) { + attrs.insert({"dropout_implementation", std::string("upscale_in_train")}); + } else { + attrs.insert({"dropout_implementation", std::string("downgrade_in_infer")}); + } + + auto op = framework::OpRegistry::CreateOp( + "dropout_grad", {{"Out@GRAD", {"DOut"}}, {"Mask", {"Mask"}}}, + {{"X@GRAD", {"DX"}}}, attrs); + op->Run(scope, ctx.GetPlace()); + + framework::TensorToVector(*tensor_dx, ctx, dx); + ctx.Wait(); +} diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias.h b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h new file mode 100644 index 0000000000000000000000000000000000000000..cd9dfd1c79ca8f454140522f23c7777bfcdf3239 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias.h @@ -0,0 +1,322 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/operators/fused/fused_dropout_common.h" +#include "paddle/fluid/operators/layer_norm_kernel.cu.h" + +namespace paddle { +namespace operators { + +/** + * @brief The fused function called by every thread + * VecSize can be 1, 2, 4 or 8 + */ +template +__forceinline__ __device__ void FusedResidualDropoutBiasOneThread( + const int row_id, const int col_id, const int cols, + curandStatePhilox4_32_10_t *state, const float dropout_prob, const T factor, + const T *src, const T *residual, const T *bias, T *dst, MaskType *mask, + const bool is_test, typename details::MPTypeTrait::Type *mean_val, + typename details::MPTypeTrait::Type *var_val) { + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskStoreT = platform::AlignedVector; + using U = typename details::MPTypeTrait::Type; + + LoadT src_vec; + LoadT residual_vec; + LoadT bias_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + bias_vec[ii] = static_cast(0); + } + // vectorize load data from global + platform::Load(&src[row_id * cols + col_id], &src_vec); + platform::Load(&residual[row_id * cols + col_id], &residual_vec); + + if (bias) { + platform::Load(&bias[col_id], &bias_vec); + } + + MaskStoreT mask_vec; + if (!is_test) { + float rand[VecSize]; + RandVec(state, rand, VecSize); +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(rand[ii] >= dropout_prob); + } + } else { +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + mask_vec[ii] = static_cast(1); + } + } + + StoreT dest_vec; + +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dest_vec[ii] = + (src_vec[ii] + bias_vec[ii]) * static_cast(mask_vec[ii]) * factor + + residual_vec[ii]; + if (ComputeLayerNorm) { + U tmp = static_cast(dest_vec[ii]); + *mean_val += tmp; + *var_val += (tmp * tmp); + } + } + + // store result to global + platform::Store(dest_vec, &dst[row_id * cols + col_id]); + if (!is_test) { + platform::Store(mask_vec, &mask[row_id * cols + col_id]); + } +} + +/** + * @brief dst = residual + dropout(src + bias); + * the src, residual, mask and dst shape is (rows, cols) + * the bias shape is (1, cols) + * is_test: only used in inference + * mask: can be null if is_test=true + */ +template +__global__ void FusedResidualDropoutBias( + const size_t rows, const size_t cols, uint64_t seed, + const float dropout_prob, const bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, MaskType *mask, T *dst, + uint64_t increment, const bool is_test) { + int col_id = blockDim.x * blockIdx.x + threadIdx.x; + int row_id = blockIdx.y; + int idx = row_id * cols + col_id; + curandStatePhilox4_32_10_t state; + curand_init(seed, idx, increment, &state); + + T factor = static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0f); + } + if (is_test) { + factor = static_cast(1.0f - dropout_prob); + if (is_upscale_in_train) { + factor = static_cast(1.0f); + } + } + for (int r = row_id; r < rows; r += blockDim.y * gridDim.y) { + for (int i = col_id * VecSize; i < cols; + i += blockDim.x * gridDim.x * VecSize) { + FusedResidualDropoutBiasOneThread( + r, i, cols, &state, dropout_prob, factor, src, residual, bias, dst, + mask, is_test, nullptr, nullptr); + } + } +} + +/** + * @brief dst = residual + dropout(src + bias); + */ +template +void LaunchResidualDropoutBias(const uint32_t rows, const uint32_t cols, + const int increment, uint64_t seed, + const float dropout_prob, const bool is_test, + bool is_upscale_in_train, const T *src, + const T *residual, const T *bias, + MaskType *mask_data, T *dst, + const platform::CUDADeviceContext &ctx) { + // dropout_prob == 1.0f + if (std::abs(dropout_prob - 1.0f) < 1e-5) { + if (residual == dst) return; + auto cuda_place = BOOST_GET_CONST(platform::CUDAPlace, ctx.GetPlace()); + memory::Copy(cuda_place, dst, cuda_place, residual, rows * cols * sizeof(T), + ctx.stream()); + if (!is_test) { + PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( + mask_data, 0, rows * cols * sizeof(MaskType), ctx.stream())); + } + return; + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + const int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + auto config = Get1DBlocksAnd2DGrids(ctx, rows, cols, real_vec_size); + if (cols % VecSize == 0) { + FusedResidualDropoutBias<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment, is_test); + } else { + FusedResidualDropoutBias< + T, uint8_t, + 1><<>>( + rows, cols, seed, dropout_prob, is_upscale_in_train, src, residual, + bias, mask_data, dst, increment, is_test); + } +} + +/* + * @brief calculate the grad of no bias + */ +template +__global__ void FusedResidualDropoutGrad(const T *dout, const MaskType *mask, + const T factor, const int64_t size, + T *dx) { + int64_t idx = blockDim.x * blockIdx.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + for (int i = idx * VecSize; i < size; i += blockDim.x * gridDim.x * VecSize) { + LoadT dout_vec; + MaskLoadT mask_vec; + platform::Load(&dout[i], &dout_vec); + platform::Load(&mask[i], &mask_vec); + + StoreT dx_vec; +#pragma unroll + for (int ii = 0; ii < VecSize; ii++) { + dx_vec[ii] = dout_vec[ii] * static_cast(mask_vec[ii]) * factor; + } + platform::Store(dx_vec, &dx[i]); + } +} + +/** + * blocks(128 * 8) + * 1. calculate the dx and reduce total rows to 128 rows + * 2. save 128*8 temporary sum in 8*128 shared memory + * 3. reduce the sum of 128 rows data by 8*VecSize warps + */ +template +__global__ void FusedResidualDropoutBiasGrad(const T *dout, + const MaskType *mask, + const T factor, const int64_t rows, + const int64_t cols, T *dx, + T *dbias) { + int64_t col_id = blockIdx.x * blockDim.x + threadIdx.x; + + using LoadT = platform::AlignedVector; + using StoreT = platform::AlignedVector; + using MaskLoadT = platform::AlignedVector; + + T tmp_sum[VecSize] = {static_cast(0)}; + // calculate the dx and temporary sum + if (col_id * VecSize < cols) { + for (int row_id = threadIdx.y; row_id < rows; row_id += blockDim.y) { + int index = row_id * cols + col_id * VecSize; + LoadT out_vec; + MaskLoadT mask_vec; + StoreT dx_vec; + platform::Load(&dout[index], &out_vec); + platform::Load(&mask[index], &mask_vec); + +#pragma unroll + for (int i = 0; i < VecSize; i++) { + dx_vec[i] = out_vec[i] * static_cast(mask_vec[i]) * factor; + tmp_sum[i] += out_vec[i]; + } + + platform::Store(dx_vec, &dx[index]); + } + } + + // save temporary sum to cache and do transpose + __shared__ T cache[BlockSizeX * VecSize][BlockSizeY]; + for (int i = 0; i < VecSize; i++) { + cache[threadIdx.x * VecSize + i][threadIdx.y] = tmp_sum[i]; + } + __syncthreads(); + + // reduce sum + T sum = static_cast(0); + int tid = threadIdx.y * blockDim.x + threadIdx.x; + int x = tid >> 5; // warp id + int y = tid & 31; // thread id on warp 0~31 + + // need BlockSizeX * VecSize warps + if (x < BlockSizeX * VecSize) { +// reduce 128 to 32 +#pragma unroll + for (int i = 0; i < (BlockSizeY >> 5); i++) { + sum += cache[x][y + i * 32]; + } + } + + // reduce 32 to 1 + sum = WarpReduceSum(sum); + + // save sum to dbias + int bias_id = blockIdx.x * blockDim.x * VecSize + x; + if (y == 0 && x < VecSize * BlockSizeX && bias_id < cols) { + dbias[bias_id] = sum; + } +} + +/** + * @brief to launch kernel FusedResidualDropoutBiasGradVec + */ +template +void LaunchResidualDropoutBiasGrad(const T *dout, const MaskType *mask, + const float dropout_prob, + const bool is_upscale_in_train, + const uint32_t rows, const uint32_t cols, + T *dx, T *dbias, + const platform::CUDADeviceContext &ctx) { + const T zero = static_cast(0.0f); + auto factor = dropout_prob == static_cast(1.0f) + ? zero + : static_cast(1.0f / (1.0f - dropout_prob)); + if (!is_upscale_in_train) { + factor = static_cast(1.0f); + } + + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + int real_vec_size = cols % VecSize == 0 ? VecSize : 1; + if (dbias != nullptr) { + auto threads = std::min(cols / real_vec_size, static_cast(8)); + auto blocks = + std::max((uint32_t)1, (cols / real_vec_size + threads - 1) / threads); + dim3 block_dim(threads, 128, 1); + dim3 grid_dim(blocks, 1, 1); + if (cols % VecSize == 0) { + FusedResidualDropoutBiasGrad< + T, MaskType, 8, 128, + VecSize><<>>( + dout, mask, factor, rows, cols, dx, dbias); + } else { + FusedResidualDropoutBiasGrad<<>>( + dout, mask, factor, rows, cols, dx, dbias); + } + } else { + const uint64_t n = rows * cols; + platform::GpuLaunchConfig config = + platform::GetGpuLaunchConfig1D(ctx, n / real_vec_size); + if (n % VecSize == 0) { + FusedResidualDropoutGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + dout, mask, factor, n, dx); + } else { + FusedResidualDropoutGrad<<< + config.block_per_grid, config.thread_per_block, 0, ctx.stream()>>>( + dout, mask, factor, n, dx); + } + } +} + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu new file mode 100644 index 0000000000000000000000000000000000000000..e687823bc8158bc4512c59485cacfd384de36ff8 --- /dev/null +++ b/paddle/fluid/operators/fused/fused_residual_dropout_bias_test.cu @@ -0,0 +1,328 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include + +#include +#include + +#include "paddle/fluid/operators/fused/fused_dropout_test.h" +#include "paddle/fluid/operators/fused/fused_residual_dropout_bias.h" + +namespace framework = paddle::framework; +namespace platform = paddle::platform; + +/** + * @brief the unittest of fusedresidualdropoutbias + * 1. random input data + * 2. add bias, call paddle dropout op, add residual, and get the base result + * 3. call FusedResidualDropoutBias function get fused result + * 4. compare ther base result and fused result + */ + +template +struct TestFusedResidualDropoutBias { + uint32_t rows; + uint32_t cols; + uint64_t seed; + float dropout_prob; + bool is_upscale_in_train; + bool is_test; // default false, Set to true for inference only + bool has_bias = true; + framework::Tensor src, residual, bias, out, mask; + framework::Tensor dsrc, dbias; + + std::vector src_vec, residual_vec, bias_vec; + std::vector correct_out, correct_dsrc, correct_dbias; + std::vector correct_mask; + + platform::CUDAPlace place; + platform::CUDADeviceContext *ctx; + + TestFusedResidualDropoutBias() { + rows = 32; + cols = 32; + seed = 0; + dropout_prob = 0.0; + is_upscale_in_train = false; + is_test = false; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto device_ctx = pool.Get(place); + ctx = reinterpret_cast(device_ctx); + } + + TestFusedResidualDropoutBias(int rows_, int cols_, uint64_t seed_ = 0, + float dropout_prob_ = 0.0, + bool is_upscale_in_train_ = false, + bool is_test_ = false) { + rows = rows_; + cols = cols_; + seed = seed_; + dropout_prob = dropout_prob_; + is_upscale_in_train = is_upscale_in_train_; + is_test = is_test_; + has_bias = true; + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto device_ctx = pool.Get(place); + ctx = reinterpret_cast(device_ctx); + } + + ~TestFusedResidualDropoutBias() {} + + void SetUp() { + const int n = rows * cols; + correct_out.resize(n); + correct_mask.resize(n); + correct_dsrc.resize(n); + correct_dbias.resize(cols); + + src_vec.resize(n); + residual_vec.resize(n); + bias_vec.resize(cols); + std::default_random_engine random(time(NULL)); + std::uniform_real_distribution dis(0.0, 1.0); + + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + src_vec[i * cols + j] = static_cast(dis(random)); + residual_vec[i * cols + j] = static_cast(dis(random)); + if (i == 0) { + bias_vec[j] = dis(random); + } + } + } + + framework::TensorFromVector(src_vec, *ctx, &src); + src.Resize({rows, cols}); + framework::TensorFromVector(residual_vec, *ctx, &residual); + residual.Resize({rows, cols}); + if (has_bias) { + framework::TensorFromVector(bias_vec, *ctx, &bias); + bias.Resize({cols}); + } + + { + out.Resize({rows, cols}); + out.mutable_data(place); + mask.Resize({rows, cols}); + mask.mutable_data(place); + dsrc.Resize({rows, cols}); + dsrc.mutable_data(place); + + if (has_bias) { + dbias.Resize({cols}); + dbias.mutable_data(place); + } + } + } + + void BaseForward() { + std::vector out1(rows * cols), out2(rows * cols); + if (has_bias) { + // add bias + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + out1[i * cols + j] = src_vec[i * cols + j] + bias_vec[j]; + } + } + // call dropout + Dropout(out1, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } else { + Dropout(src_vec, src.dims(), &out2, &correct_mask, *ctx, seed, + dropout_prob, is_upscale_in_train, is_test); + } + ctx->Wait(); + // add residual + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + correct_out[i * cols + j] = + residual_vec[i * cols + j] + out2[i * cols + j]; + } + } + } + + void BaseBackward() { + DropoutGrad(&correct_dsrc, src.dims(), correct_out, correct_mask, *ctx, + dropout_prob, is_upscale_in_train); + // calc dbias + memset(&correct_dbias[0], 0, cols * sizeof(T)); + for (int i = 0; i < rows; i++) { + for (int j = 0; j < cols; j++) { + correct_dbias[j] += correct_out[i * cols + j]; + } + } + } + + void FusedForward() { + const int VecSize = MAX_CACHE_BYTES / sizeof(T); + auto config = paddle::operators::Get1DBlocksAnd2DGrids( + *ctx, (uint64_t)rows, (uint64_t)cols, VecSize); + const int increment = ((cols - 1) / (config.thread_per_block.x * + config.block_per_grid.x * VecSize) + + 1) * + VecSize; + + T *bias_ptr = nullptr; + if (has_bias) { + bias_ptr = bias.data(); + } + paddle::operators::LaunchResidualDropoutBias( + rows, cols, increment, seed, dropout_prob, is_test, is_upscale_in_train, + src.data(), residual.data(), bias_ptr, mask.data(), + out.data(), *ctx); + ctx->Wait(); + } + + void FusedBackward() { + if (is_test) { + return; + } + + T *bias_ptr = nullptr; + if (has_bias) { + bias_ptr = dbias.data(); + } + paddle::operators::LaunchResidualDropoutBiasGrad( + out.data(), mask.data(), dropout_prob, is_upscale_in_train, + rows, cols, dsrc.data(), bias_ptr, *ctx); + } + + void Run() { + SetUp(); + BaseForward(); + FusedForward(); + BaseBackward(); + FusedBackward(); + } + + void CheckOut(const T diff) { + const int n = rows * cols; + std::vector _out(n); + std::vector _mask(n); + framework::TensorToVector(out, *ctx, &_out); + if (!is_test) { + framework::TensorToVector(mask, *ctx, &_mask); + } + ctx->Wait(); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_out[i] - correct_out[i]), diff); + if (!is_test) EXPECT_EQ(_mask[i], correct_mask[i]); + } + } + + void CheckGrad(const T diff) { + if (is_test) { + return; + } + + const int n = rows * cols; + + std::vector _dsrc(n); + framework::TensorToVector(dsrc, *ctx, &_dsrc); + + for (int i = 0; i < n; i++) { + EXPECT_LT(std::abs(_dsrc[i] - correct_dsrc[i]), diff); + } + + if (has_bias) { + std::vector _dbias(cols); + framework::TensorToVector(dbias, *ctx, &_dbias); + ctx->Wait(); + for (int i = 0; i < cols; i++) { + EXPECT_LT(std::abs(_dbias[i] - correct_dbias[i]), diff); + } + } + } +}; + +// test the shape and bias +template +static void BaseTest(const bool is_fp16 = false) { + const int rows = 16; + std::vector cols_list = {16, 17}; + bool has_bias[2] = {true, false}; + T default_diff = static_cast(1e-5); + if (is_fp16) { + default_diff = static_cast(1e-2); + } + for (int i = 0; i < cols_list.size(); i++) { + for (int j = 0; j < 2; j++) { + TestFusedResidualDropoutBias test(rows, cols_list[i]); + test.has_bias = has_bias[j]; + test.Run(); + test.CheckOut(default_diff); + if (!is_fp16) { + test.CheckGrad(default_diff); + } + } + } +} + +TEST(FusedDropout, GPUFusedResidualDropoutBias) { BaseTest(); } + +TEST(FusedDropout, GPUFusedResidualDropoutBiasDouble) { BaseTest(); } + +// test fp16, For inference, check_grad is not required. ref: testdropout_op.py +TEST(FusedDropout, GPUFusedResidualDropoutBiasFp16) { + BaseTest(true); +} + +TEST(FusedDropout, GPUFusedResidualDropoutBias2) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedResidualDropoutBias3) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 1.0, true, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedResidualDropoutBias4) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 0, 0.35, true, true); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +TEST(FusedDropout, GPUFusedResidualDropoutBias5) { + const int rows = 16; + const int cols = 16; + TestFusedResidualDropoutBias test(rows, cols, 125, 0.0, false, false); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-5)); +} + +// test large shape +TEST(FusedDropout, GPUFusedResidualDropoutBias6) { + const int rows = 256; + const int cols = 4096; + TestFusedResidualDropoutBias test(rows, cols); + test.Run(); + test.CheckOut(static_cast(1e-5)); + test.CheckGrad(static_cast(1e-3)); +}