diff --git a/paddle/phi/api/yaml/backward.yaml b/paddle/phi/api/yaml/backward.yaml index 6891cef3ffe707c687c8e9dd206094db30f18d95..0c05a4e806a97ddc2f43d83194a29c47388d8398 100644 --- a/paddle/phi/api/yaml/backward.yaml +++ b/paddle/phi/api/yaml/backward.yaml @@ -605,6 +605,16 @@ kernel : func : frame_grad +- backward_op : fused_dropout_add_grad + forward : fused_dropout_add (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed) -> Tensor(out), Tensor(seed_offset) + args : (Tensor seed_offset, Tensor out_grad, Scalar p, bool is_test, str mode, bool fix_seed) + output : Tensor(x_grad), Tensor(y_grad) + infer_meta : + func : GeneralBinaryGradInferMeta + param : [out_grad, out_grad] + kernel : + func : fused_dropout_add_grad + - backward_op : gather_nd_grad forward : gather_nd (Tensor x, Tensor index) -> Tensor(out) args : (Tensor x, Tensor index, Tensor out_grad) diff --git a/paddle/phi/api/yaml/ops.yaml b/paddle/phi/api/yaml/ops.yaml index b20c9e6769d4dcf891567fd571120adf4123de30..e953047639f67e80277782b06bafb52c7834c928 100644 --- a/paddle/phi/api/yaml/ops.yaml +++ b/paddle/phi/api/yaml/ops.yaml @@ -604,6 +604,16 @@ func : frame backward : frame_grad +- op : fused_dropout_add + args : (Tensor x, Tensor y, Scalar p, bool is_test, str mode, int seed, bool fix_seed) + output : Tensor(out), Tensor(seed_offset) + infer_meta : + func : FusedDropoutAddInferMeta + kernel : + func : fused_dropout_add + data_type : x + backward : fused_dropout_add_grad + - op : fused_linear_param_grad_add args : (Tensor x, Tensor dout, Tensor dweight, Tensor dbias, bool multi_precision = true) output : Tensor(dweight_out), Tensor(dbias_out) diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index a57bcc1d0d01d1ee1ea812f303251decfe5d7885..0a3e31054b0f5b23585389f062723c5d783ad80d 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -1287,6 +1287,22 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x, out->set_dtype(x.dtype()); } +void FusedDropoutAddInferMeta(const MetaTensor& x, + const MetaTensor& y, + const Scalar& p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + MetaTensor* out, + MetaTensor* seed_offset) { + out->share_meta(x); + if (seed_offset) { + seed_offset->set_dims({2}); + seed_offset->set_dtype(DataType::INT64); + } +} + // Used in FusedMatmulInferMeta static std::vector GetInputShape(phi::DDim dim, std::vector shape, diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 59e900aefb30d2f3254e1ed9e92c3211bdd77b0f..ed4da703ce5207caaa875103dad05b978e910972 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -222,6 +222,16 @@ void FillDiagonalTensorInferMeta(const MetaTensor& x, int dim2, MetaTensor* out); +void FusedDropoutAddInferMeta(const MetaTensor& x, + const MetaTensor& y, + const Scalar& p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + MetaTensor* out, + MetaTensor* seed_offset); + void FusedMatmulInferMeta(const MetaTensor& x, const MetaTensor& y, const MetaTensor& residual_data, diff --git a/paddle/phi/kernels/CMakeLists.txt b/paddle/phi/kernels/CMakeLists.txt index 7604a50448c30a50e9b7fd3a4e767b32f729394c..b963a4c506d79fae9c6a439f6317dc1bab1b45d5 100644 --- a/paddle/phi/kernels/CMakeLists.txt +++ b/paddle/phi/kernels/CMakeLists.txt @@ -81,11 +81,12 @@ set(COMMON_KERNEL_DEPS utf8proc gather_scatter_functor) +set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) + if(WITH_FLASHATTN) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} phi_dynload_flashattn) endif() -set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group) if(WITH_NCCL OR WITH_RCCL) set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} process_group_nccl nccl_comm_context) diff --git a/paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h b/paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..5efa0d60a2ce477218f08817cdc5c0512da0c894 --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h @@ -0,0 +1,33 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void FusedDropoutAddGradKernel(const Context& dev_ctx, + const DenseTensor& seed_offset, + const DenseTensor& out_grad, + const Scalar& p, + bool is_test, + const std::string& mode, + bool fix_seed, + DenseTensor* x_grad, + DenseTensor* y_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/fused_dropout_add_kernel.h b/paddle/phi/kernels/fusion/fused_dropout_add_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..cbd359a6c33f94e1ebce2c17decd7cd598783178 --- /dev/null +++ b/paddle/phi/kernels/fusion/fused_dropout_add_kernel.h @@ -0,0 +1,55 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" +#include "paddle/phi/kernels/funcs/distribution_helper.h" + +namespace phi { + +template +void FusedDropoutAddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* seed_offset); + +template +static inline std::vector GetRandomCudaProp(int numel, + const Context& dev_ctx) { + constexpr int kVecSize = funcs::uniform_distribution::kReturnsCount; + auto gpu_config = + backends::gpu::GetGpuLaunchConfig1D(dev_ctx, numel, kVecSize); + size_t grid_size = gpu_config.GetGridSize(); + size_t block_size = gpu_config.GetBlockSize(); + int64_t device_id = dev_ctx.GetPlace().GetDeviceId(); + const auto& prop = phi::backends::gpu::GetDeviceProperties(device_id); + size_t max_grid_size = + prop.maxThreadsPerMultiProcessor * prop.multiProcessorCount / block_size; + grid_size = std::min(grid_size, max_grid_size); + auto offset = + ((numel - 1) / (grid_size * block_size * kVecSize) + 1) * kVecSize; + size_t main_offset = + numel / (block_size * kVecSize) * (block_size * kVecSize); + return {grid_size, block_size, offset, main_offset}; +} + +} // namespace phi diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..e6fb4b97d66e153ea38d50e0f12dab95ecddda75 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_grad_kernel.cu @@ -0,0 +1,241 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/fusion/fused_dropout_add_grad_kernel.h" +#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/funcs/distribution_helper.h" +#include "paddle/phi/kernels/funcs/dropout_impl_util.h" +#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" + +#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" + +static constexpr int kNumCUDAThreads = 512; +static constexpr int kNumMaxinumNumBlocks = 4096; + +static inline int NumBlocks(const int N) { + return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads, + kNumMaxinumNumBlocks); +} + +namespace phi { + +template +__global__ void FuseScaleAddGrad(const T* grad, + T* x, + T* y, + const MT factor, + const int64_t limit, + bool upscale_in_train) { + CUDA_KERNEL_LOOP(i, limit) { + y[i] = grad[i]; + x[i] = upscale_in_train ? grad[i] + : static_cast(static_cast(grad[i]) * factor); + } +} + +template +__global__ void FuseScaleAddGradRateZero(const T* grad, + T* src, + T* res, + const int64_t limit) { + CUDA_KERNEL_LOOP(i, limit) { + res[i] = grad[i]; + src[i] = 0; + } +} + +template +struct NoMaskBwFunctor { + const float retain_prob_; + using MT = typename phi::kps::details::MPTypeTrait::Type; + MT factor_; + HOSTDEVICE inline NoMaskBwFunctor(const float retain_prob) + : retain_prob_(retain_prob) { + factor_ = static_cast(1.0f / retain_prob_); + } + + HOSTDEVICE inline NoMaskBwFunctor(const float retain_prob, const MT factor) + : retain_prob_(retain_prob), factor_(factor) {} + + HOSTDEVICE inline void operator()(OutT* dst, + const T1* src_val, + const T2* rand, + int num) const { + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; +#pragma unroll + for (int i = 0; i < kCount; i++) { + dst[i + kCount] = src_val[i]; + dst[i] = rand[i] < retain_prob_ + ? static_cast(static_cast(src_val[i]) * factor_) + : static_cast(0); + } + } +}; + +template +__global__ void VectorizedDropoutBackward(const size_t n, + uint64_t seed, + T* src, + T* res, + const T* dst, + uint64_t increment, + size_t main_offset, + Functor functor) { + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; + size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; +#ifdef PADDLE_WITH_HIP + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = hiprandStatePhilox4_32_10_t; +#else + curandStatePhilox4_32_10_t state; + curand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = curandStatePhilox4_32_10_t; +#endif + + float rands[kCount]; + T src_res[kCount * 2]; + T res_grad[kCount]; + + using Rand = phi::funcs::uniform_distribution; + using Cast = kps::IdentityFunctor; + + int deal_size = BLOCK_NUM_X * kCount; + size_t fix = idx * kCount; + for (; fix < main_offset; fix += stride) { + kps::ReadData(&src_res[0], dst, deal_size); + kps::ElementwiseRandom( + &rands[0], Rand(), &state); + // x_grad + kps::OperatorTernary( + &src_res[0], &src_res[0], &rands[0], functor, kCount); + kps::WriteData(src + fix, &src_res[0], deal_size); + // res + kps::ElementwiseUnary( + &res_grad[0], &src_res[kCount], Cast()); + kps::WriteData(res + fix, &res_grad[0], deal_size); + if (fix > idx * kCount + 1) { + __syncthreads(); + } + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&src_res[0], dst + fix, remainder); + kps::ElementwiseRandom( + &rands[0], Rand(), &state); + // x_grad + kps::OperatorTernary( + &src_res[0], &src_res[0], &rands[0], functor, kCount); + kps::WriteData(src + fix, &src_res[0], remainder); + + // res + kps::ElementwiseUnary( + &res_grad[0], &src_res[kCount], Cast()); + kps::WriteData(res + fix, &res_grad[0], remainder); + __syncthreads(); + } +} + +template +void FusedDropoutAddGradKernel(const Context& dev_ctx, + const DenseTensor& seed_offset, + const DenseTensor& out_grad, + const Scalar& p, + bool is_test, + const std::string& mode, + bool fix_seed, + DenseTensor* x_grad, + DenseTensor* y_grad) { + int64_t numel = out_grad.numel(); + auto stream = dev_ctx.stream(); + float dropout_rate = p.to(); + bool upscale_in_train = (mode == "upscale_in_train"); + + const auto* seed_offset_data = seed_offset.data(); + const uint64_t seed_data = static_cast(seed_offset_data[0]); + const uint64_t increment = static_cast(seed_offset_data[1]); + + auto* x_grad_data = dev_ctx.template Alloc(x_grad); + auto* y_grad_data = dev_ctx.template Alloc(y_grad); + + const auto* out_grad_data = out_grad.data(); + using MT = typename phi::kps::details::MPTypeTrait::Type; + int blocks = NumBlocks(numel); + int threads = kNumCUDAThreads; + + if (is_test) { + MT factor = static_cast(1.0f - dropout_rate); + FuseScaleAddGrad<<>>(out_grad_data, + x_grad_data, + y_grad_data, + factor, + numel, + upscale_in_train); + + } else { + if (upscale_in_train && dropout_rate == 1.0f) { + FuseScaleAddGradRateZero<<>>( + out_grad_data, x_grad_data, y_grad_data, numel); + return; + } + auto random_prop = GetRandomCudaProp(numel, dev_ctx); + size_t grid_size = random_prop[0]; + size_t block_size = random_prop[1]; + size_t offset = random_prop[2]; + size_t main_offset = random_prop[3]; + + auto functor = upscale_in_train + ? NoMaskBwFunctor(1.0f - dropout_rate) + : NoMaskBwFunctor(1.0f - dropout_rate, 1.0f); +#define PD_DROPOUT_KERNEL_NAME \ + VectorizedDropoutBackward> + PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed, + PD_DROPOUT_KERNEL_NAME, + grid_size, + block_size, + 0, + stream, + offset, + KERNEL_PARAMS.As(1), + KERNEL_PARAMS.As(5), + numel, + seed_data, // need save + x_grad_data, + y_grad_data, + out_grad_data, // grad + increment, // need save + main_offset, + functor); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fused_dropout_add_grad, + GPU, + ALL_LAYOUT, + phi::FusedDropoutAddGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) { + kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); // seed_offset +} diff --git a/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..e47d4fcbabb5ac41b782058af84399ee6283bc00 --- /dev/null +++ b/paddle/phi/kernels/fusion/gpu/fused_dropout_add_kernel.cu @@ -0,0 +1,218 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include "paddle/phi/kernels/fusion/fused_dropout_add_kernel.h" +#include "paddle/phi/core/kernel_registry.h" + +#include "paddle/phi/kernels/funcs/dropout_impl_util.h" +#include "paddle/phi/kernels/funcs/functors.h" +#include "paddle/phi/kernels/primitive/compute_primitives.h" + +#include "paddle/phi/kernels/funcs/dropout_impl.cu.h" + +namespace phi { + +template +struct NoMaskFwFunctor { + const float retain_prob_; + const bool is_upscale_in_train_; + using MT = typename phi::kps::details::MPTypeTrait::Type; + MT factor; + HOSTDEVICE inline NoMaskFwFunctor(const float retain_prob, + const bool is_upscale_in_train) + : retain_prob_(retain_prob), is_upscale_in_train_(is_upscale_in_train) { + factor = static_cast(1.0f / retain_prob_); + } + + HOSTDEVICE inline void operator()(OutT* dst, + const T1* src_val, + const T2* rand, + int num) const { + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; +#pragma unroll + for (int i = 0; i < kCount; i++) { + if (rand[i] < retain_prob_) { + dst[i] = is_upscale_in_train_ + ? static_cast(static_cast(src_val[i]) * factor) + : static_cast(src_val[i]); + dst[i] += src_val[i + kCount]; + } else { + dst[i] = src_val[i + kCount]; + } + } + } +}; + +template +struct ScaleAddFuctor { + using MT = typename phi::kps::details::MPTypeTrait::Type; + explicit ScaleAddFuctor(const MT factor, bool upscale_in_train) + : factor_(factor), upscale_in_train_(upscale_in_train) {} + + __device__ __forceinline__ T operator()(const T src, const T res) const { + return upscale_in_train_ + ? src + res + : static_cast(static_cast(src) * factor_) + res; + } + + private: + MT factor_; + bool upscale_in_train_; +}; + +template +__global__ void VectorizedDropoutForward(const size_t n, + uint64_t seed, + const T* src, + const T* res, + T* dst, + uint64_t increment, + size_t main_offset, + Functor functor) { + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; + size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; +#ifdef PADDLE_WITH_HIP + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = hiprandStatePhilox4_32_10_t; +#else + curandStatePhilox4_32_10_t state; + curand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = curandStatePhilox4_32_10_t; +#endif + T dst_res[kCount * 2]; + float rands[kCount]; + + using Rand = phi::funcs::uniform_distribution; + int deal_size = BLOCK_NUM_X * kCount; + size_t fix = idx * kCount; + + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_res[0], src + fix, deal_size); + kps::ReadData(&dst_res[kCount], res + fix, deal_size); + + kps::ElementwiseRandom( + &rands[0], Rand(), &state); + + // dst + kps::OperatorTernary( + &dst_res[0], &dst_res[0], &rands[0], functor, kCount); + + kps::WriteData(dst + fix, &dst_res[0], deal_size); + + if (fix > idx * kCount + 1) { + __syncthreads(); + } + } + + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_res[0], src + fix, remainder); + kps::ReadData(&dst_res[kCount], res + fix, remainder); + kps::ElementwiseRandom( + &rands[0], Rand(), &state); + // dst + kps::OperatorTernary( + &dst_res[0], &dst_res[0], &rands[0], functor, kCount); + kps::WriteData(dst + fix, &dst_res[0], remainder); + __syncthreads(); + } +} + +template +void FusedDropoutAddKernel(const Context& dev_ctx, + const DenseTensor& x, + const DenseTensor& y, + const Scalar& p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* seed_offset) { + auto* out_data = dev_ctx.template Alloc(out); + auto* seed_offset_data = dev_ctx.template HostAlloc(seed_offset); + int64_t numel = x.numel(); + auto stream = dev_ctx.stream(); + bool upscale_in_train = (mode == "upscale_in_train"); + + const auto* x_data = x.data(); + const auto* y_data = y.data(); + float dropout_rate = p.to(); + + if (!is_test) { + if (dropout_rate == 1.0f) { + phi::Copy(dev_ctx, y, dev_ctx.GetPlace(), false, out); + return; + } + uint64_t seed_data; + uint64_t increment; + auto random_prop = GetRandomCudaProp(numel, dev_ctx); + size_t grid_size = random_prop[0]; + size_t block_size = random_prop[1]; + size_t offset = random_prop[2]; + size_t main_offset = random_prop[3]; + funcs::GetSeedDataAndIncrement( + dev_ctx, nullptr, fix_seed, seed, offset, &seed_data, &increment); + seed_offset_data[0] = static_cast(seed_data); + seed_offset_data[1] = static_cast(increment); + + auto dst_functor = + NoMaskFwFunctor(1.0f - dropout_rate, upscale_in_train); + +#define PD_DROPOUT_KERNEL_NAME \ + VectorizedDropoutForward> + PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(!fix_seed, + PD_DROPOUT_KERNEL_NAME, + grid_size, + block_size, + 0, + stream, + offset, + KERNEL_PARAMS.As(1), + KERNEL_PARAMS.As(5), + numel, + seed_data, // need save + x_data, + y_data, + out_data, + increment, // need save + main_offset, + dst_functor); +#undef PD_DROPOUT_KERNEL_NAME + } else { + using MT = typename phi::kps::details::MPTypeTrait::Type; + MT factor = static_cast(1.0f - dropout_rate); + std::vector outs = {out}; + std::vector ins = {&x, &y}; + + phi::funcs::ElementwiseKernel( + dev_ctx, ins, &outs, ScaleAddFuctor(factor, upscale_in_train)); + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(fused_dropout_add, + GPU, + ALL_LAYOUT, + phi::FusedDropoutAddKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) { + kernel->OutputAt(1).SetDataType(phi::DataType::INT64); +} diff --git a/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py b/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py new file mode 100644 index 0000000000000000000000000000000000000000..55ceb432e75323076968d9a327e6b7ffbf2dc831 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_fused_dropout_add_op.py @@ -0,0 +1,189 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest + +import numpy as np + +import paddle +import paddle.fluid.core as core +from paddle import fluid +from paddle.incubate.nn.functional import fused_dropout_add +from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd + + +def paddle_dropout_add(x, y, p=0.5, training=True, mode="upscale_in_train"): + tmp = paddle.nn.functional.dropout(x, p, training=training, mode=mode) + return tmp + y + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), + "core is not compiled with CUDA ", +) +class TestFusedDropoutAdd(unittest.TestCase): + def setUp(self): + self.shape = (2, 10, 10, 2) + self.dtype = 'float64' + self.dropout_rate = 0.9 + self.training = True + self.mode = "upscale_in_train" + self.seed = 1027 + + def get_paddle_tensor(self): + tmp = paddle.randn(self.shape, self.dtype) + tmp.stop_gradient = False + return tmp + + def get_forward_backward(self, dropout_add, seed): + paddle.disable_static() + paddle.seed(seed) + count = 3 + data = [] + fw = [] + bw = [] + for _ in range(count): + data.append(self.get_paddle_tensor()) + + out = data[0] + for i in range(1, count): + out = dropout_add( + out, + data[i], + p=self.dropout_rate, + training=self.training, + mode=self.mode, + ) + fw.append(out) + + loss = paddle.mean(out) + loss.backward() + for i in range(count): + bw.append(data[i].grad) + return fw, bw + + def test_fused_dropout_add(self): + p_fw, p_bw = self.get_forward_backward( + paddle_dropout_add, seed=self.seed + ) + f_fw, f_bw = self.get_forward_backward( + fused_dropout_add, seed=self.seed + ) + for i in range(len(p_fw)): + np.testing.assert_allclose( + p_fw[i].numpy(), f_fw[i].numpy(), rtol=1e-05 + ) + np.testing.assert_allclose( + p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 + ) + + +def create_test_class(parent, dtype, mode, training, p, seed): + @unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA" + ) + class TestFusedDropoutAddCase(parent): + def setUp(self): + self.shape = (2, 10, 10, 2) + self.dtype = dtype + self.dropout_rate = p + self.training = training + self.mode = mode + self.seed = seed + + cls_name = "{0}_{1}_{2}_{3}_{4}_{5}".format( + parent.__name__, dtype, mode, str(training), str(p), str(seed) + ) + TestFusedDropoutAddCase.__name__ = cls_name + globals()[cls_name] = TestFusedDropoutAddCase + + +for dtype in ["float64", "float32", "float16"]: + for mode in ["upscale_in_train", "downscale_in_infer"]: + for p in [0.0, 0.5, 0.9, 1.0]: + for training in [True, False]: + for seed in [0, 1024]: + create_test_class( + TestFusedDropoutAdd, dtype, mode, training, p, seed + ) + + +@unittest.skipIf( + not core.is_compiled_with_cuda(), "core is not compiled with CUDA " +) +class TestFusedDropoutAddStatic(unittest.TestCase): + def setUp(self): + self.place = paddle.CUDAPlace(0) + self.shape = (2, 80, 8, 2) + self.dtype = 'float16' + + def test_static_op(self): + paddle.disable_static() + paddle.seed(312) + x_data = np.random.random(self.shape) + y_data = np.random.random(self.shape) + x = paddle.to_tensor( + x_data, place=self.place, dtype=self.dtype, stop_gradient=False + ) + y = paddle.to_tensor( + y_data, place=self.place, dtype=self.dtype, stop_gradient=False + ) + out = fused_dropout_add(x, y, p=0.5, training=True) + paddle.enable_static() + paddle.seed(312) + + with paddle.static.program_guard(paddle.static.Program()): + xs = paddle.static.data( + name="xs", shape=self.shape, dtype=self.dtype + ) + ys = paddle.static.data( + name="ys", shape=self.shape, dtype=self.dtype + ) + + outs = fused_dropout_add(xs, ys, p=0.5, training=True) + + exe = fluid.Executor(self.place) + out_s = exe.run( + feed={ + "xs": x_data.astype('float16'), + "ys": y_data.astype('float16'), + }, + fetch_list=[outs], + ) + np.testing.assert_allclose(out_s[0], out) + + def test_fused_dropout_add_layer(self): + x = paddle.randn(self.shape, self.dtype) + y = paddle.randn(self.shape, self.dtype) + fused_d_a = FusedDropoutAdd(p=0.5) + d = paddle.nn.Dropout(p=0.5) + print(d) + paddle.seed(2048) + fused_out = fused_d_a(x, y) + paddle.seed(2048) + out = d(x) + y + np.testing.assert_allclose(fused_out, out) + + def test_assert(self): + def check_raise(): + x = paddle.randn(self.shape, self.dtype) + y = paddle.randn(self.shape, self.dtype) + fused_d_a = FusedDropoutAdd(p=-1) + fused_out = fused_d_a(x, y) + + self.assertRaises(ValueError, check_raise) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/incubate/nn/__init__.py b/python/paddle/incubate/nn/__init__.py index fe6a2abd9bbed99205fd474056e5a5b6b8fbe906..3b6869f88c6283b304e61fd9b7e7c35eac9e622f 100644 --- a/python/paddle/incubate/nn/__init__.py +++ b/python/paddle/incubate/nn/__init__.py @@ -21,6 +21,7 @@ from .layer.fused_transformer import ( FusedBiasDropoutResidualLayerNorm, ) # noqa: F401 from .layer.fused_ec_moe import FusedEcMoe # noqa: F401 +from .layer.fused_dropout_add import FusedDropoutAdd # noqa: F401 __all__ = [ # noqa 'FusedMultiHeadAttention', @@ -30,4 +31,5 @@ __all__ = [ # noqa 'FusedLinear', 'FusedBiasDropoutResidualLayerNorm', 'FusedEcMoe', + 'FusedDropoutAdd', ] diff --git a/python/paddle/incubate/nn/functional/__init__.py b/python/paddle/incubate/nn/functional/__init__.py index a8f9cb70eca41775fcbce98fb9f79615176544f9..9d9f570ccc5b263a1cb8dda2bf9e1649a6479b38 100644 --- a/python/paddle/incubate/nn/functional/__init__.py +++ b/python/paddle/incubate/nn/functional/__init__.py @@ -18,6 +18,7 @@ from .fused_transformer import fused_multi_transformer from .fused_matmul_bias import fused_matmul_bias, fused_linear from .fused_transformer import fused_bias_dropout_residual_layer_norm from .fused_ec_moe import fused_ec_moe +from .fused_dropout_add import fused_dropout_add __all__ = [ 'fused_multi_head_attention', @@ -27,4 +28,5 @@ __all__ = [ 'fused_linear', 'fused_bias_dropout_residual_layer_norm', 'fused_ec_moe', + 'fused_dropout_add', ] diff --git a/python/paddle/incubate/nn/functional/fused_dropout_add.py b/python/paddle/incubate/nn/functional/fused_dropout_add.py new file mode 100644 index 0000000000000000000000000000000000000000..251b9e2e77ed62aefb5d3fa4d481270dc9a97217 --- /dev/null +++ b/python/paddle/incubate/nn/functional/fused_dropout_add.py @@ -0,0 +1,116 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + + +from paddle import _C_ops +from paddle.common_ops_import import default_main_program +from paddle.fluid import core +from paddle.fluid.framework import in_dygraph_mode +from paddle.framework import LayerHelper + + +def fused_dropout_add( + x, y, p=0.5, training=True, mode='upscale_in_train', name=None +): + r""" + Fused Dropout and Add. + + Args: + x (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + y (Tensor): The input tensor. The data type is bfloat16, float16, float32 or float64. + + p (float|int, optional): Probability of setting units to zero. Default: 0.5. + training (bool, optional): A flag indicating whether it is in train phrase or not. Default: True. + mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer']. + + 1. upscale_in_train (default), upscale the output at training time + + - train: :math:`out = x \times \frac{mask}{(1.0 - dropout\_prob)} + y` + - inference: :math:`out = x + y` + + 2. downscale_in_infer, downscale the output at inference + + - train: :math:`out = input \times mask + y` + - inference: :math:`out = input \times (1.0 - dropout\_prob) + y` + + name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`. + + Returns: + A Tensor representing the fused dropout and add, has same shape and data type as `x` . + + + Examples: + + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.functional import fused_dropout_add + + x = paddle.randn([4, 10], dtype='float16') + y = paddle.randn([4, 10], dtype='float16') + out = fused_dropout_add(x, y, p=0.5) + """ + if isinstance(p, (int, float)): + # fast return for p == 0 + if p == 0: + return x + y + elif p < 0 or p > 1: + raise ValueError("p argument should between 0 and 1") + if mode not in ('downscale_in_infer', 'upscale_in_train'): + raise ValueError( + "mode argument should be 'downscale_in_infer' or 'upscale_in_train'" + ) + seed = None + if in_dygraph_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed + out, seed_offset = _C_ops.fused_dropout_add( + x, + y, + p, + not training, + mode, + seed if seed is not None else 0, + seed is not None, + ) + return out + else: + helper = LayerHelper('fused_dropout_add', **locals()) + out = helper.create_variable_for_type_inference(dtype=x.dtype) + seed_offset = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.INT64, stop_gradient=True + ) + + def get_attrs(prog, dropout_prob, is_test, seed): + if (seed is None or seed == 0) and prog.random_seed != 0: + seed = prog.random_seed + attrs = { + 'p': dropout_prob, + 'is_test': is_test, + 'mode': mode, + 'seed': seed if seed is not None else 0, + 'fix_seed': seed is not None, + } + return attrs + + attrs = get_attrs(helper.main_program, p, not training, seed) + + helper.append_op( + type='fused_dropout_add', + inputs={'x': x, 'y': y}, + outputs={'out': [out], 'seed_offset': [seed_offset]}, + attrs=attrs, + ) + return out diff --git a/python/paddle/incubate/nn/layer/fused_dropout_add.py b/python/paddle/incubate/nn/layer/fused_dropout_add.py new file mode 100644 index 0000000000000000000000000000000000000000..373103442922e1cfb6b37ceed72ab2c2bc98dd2f --- /dev/null +++ b/python/paddle/incubate/nn/layer/fused_dropout_add.py @@ -0,0 +1,78 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from paddle.incubate.nn import functional as F +from paddle.nn import Layer + + +class FusedDropoutAdd(Layer): + r""" + Fused Dropout and Add. + + Parameters: + p (float|int, optional): Probability of setting units to zero. Default: 0.5 + mode(str, optional): ['upscale_in_train'(default) | 'downscale_in_infer'] + + 1. upscale_in_train (default), upscale the output at training time + + - train: :math:`out = x \times \frac{mask}{(1.0 - p)} + y` + - inference: :math:`out = x + y` + + 2. downscale_in_infer, downscale the output at inference + + - train: :math:`out = x \times mask + y` + - inference: :math:`out = x \times (1.0 - p) + y` + name (str, optional): Name for the operation, Default: None. For more information, please refer to :ref:`api_guide_Name`. + + Shape: + - x: N-D tensor. + - y: N-D tensor. + - output: N-D tensor, the same shape as x. + + + Examples: + .. code-block:: python + + # required: gpu + import paddle + from paddle.incubate.nn.layer.fused_dropout_add import FusedDropoutAdd + + x = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32") + y = paddle.to_tensor([[1,2,3], [4,5,6]], dtype="float32") + + m = FusedDropoutAdd(p=0.5) + + out = m(x, y) + """ + + def __init__(self, p=0.5, mode="upscale_in_train", name=None): + super().__init__() + self.p = p + self.mode = mode + self.name = name + + def forward(self, x, y): + out = F.fused_dropout_add( + x, + y, + p=self.p, + training=self.training, + mode=self.mode, + name=self.name, + ) + return out + + def extra_repr(self): + name_str = ', name={}'.format(self.name) if self.name else '' + return 'p={}, mode={}{}'.format(self.p, self.mode, name_str)