diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 6db3efa3cdd60b3b75a4e0bdf147fd5d88bc969e..5eede02567b43f736e120fc6838026b429522d58 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -34,10 +34,9 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/dropout_impl_util.h" -#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/aligned_vector.h" #include "paddle/fluid/platform/cuda_graph_with_memory_pool.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" +#include "paddle/phi/kernels/funcs/broadcast_function.h" #include "paddle/phi/kernels/funcs/distribution_helper.h" #include "paddle/phi/kernels/funcs/functors.h" @@ -142,15 +141,154 @@ __global__ void VectorizedRandomGenerator(const size_t n, uint64_t seed, } } +template +struct MaskFunctor { + const float retain_prob_; + using MT = typename details::MPTypeTrait::Type; + MT factor; + HOSTDEVICE inline MaskFunctor(const float retain_prob) + : retain_prob_(retain_prob) { + factor = static_cast(1.0f / retain_prob_); + } + + HOSTDEVICE inline void operator()(OutT* dst, const T2* rand, int num) const { + static constexpr int kCount = + phi::funcs::uniform_distribution::kReturnsCount; +// 0 ~ kCount -1 is dist , kCount ~ 2 * kCount - 1 is mask +#pragma unroll + for (int i = 0; i < kCount; i++) { + if (rand[i] < retain_prob_) { + dst[i] = static_cast(1); + } else { + dst[i] = static_cast(0); + } + } + } +}; + +template +struct DstFunctor { + using MT = typename details::MPTypeTrait::Type; + MT factor; + HOSTDEVICE inline DstFunctor(const float retain_prob, + const bool is_upscale_in_train, + const int64_t num) + : retain_prob_(retain_prob), + is_upscale_in_train_(is_upscale_in_train), + num_(num) { + factor = static_cast(1.0f / retain_prob_); + } + + HOSTDEVICE inline T operator()(const T src_val, const MaskType mask) const { + for (int i = 0; i < num_; i++) { + if (mask == static_cast(1)) { + return is_upscale_in_train_ + ? static_cast(static_cast(src_val) * factor) + : static_cast(src_val); + } else { + return static_cast(0); + } + } + } + + private: + const float retain_prob_; + const bool is_upscale_in_train_; + const int64_t num_; +}; + +template +__global__ void VectorizedGeneratorMask(const size_t n, uint64_t seed, + const float dropout_prob, const T* src, + MaskType* mask, uint64_t increment, + size_t main_offset) { + constexpr int kCount = phi::funcs::uniform_distribution::kReturnsCount; + size_t idx = static_cast(BLOCK_ID_X * BLOCK_NUM_X); + size_t stride = BLOCK_NUM_X * GRID_NUM_X * kCount; +#ifdef PADDLE_WITH_HIP + hiprandStatePhilox4_32_10_t state; + hiprand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = hiprandStatePhilox4_32_10_t; +#else + curandStatePhilox4_32_10_t state; + curand_init(seed, idx + THREAD_ID_X, increment, &state); + using SType = curandStatePhilox4_32_10_t; +#endif + T dst_mask[kCount]; // 0 ~ kCount -1 : dst;kCount ~ 2 * kCount - 1: mask + float rands[kCount]; + MaskType mask_result[kCount]; + using Rand = phi::funcs::uniform_distribution; + using Cast = kps::IdentityFunctor; + int deal_size = BLOCK_NUM_X * kCount; + + size_t fix = idx * kCount; + + auto mask_functor = MaskFunctor(1.0f - dropout_prob); + for (; fix < main_offset; fix += stride) { + kps::ReadData(&dst_mask[0], src + fix, deal_size); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorBinary>( + &dst_mask[0], &rands[0], mask_functor, kCount); + + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[0], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + deal_size); + if (fix > idx * kCount + 1) { + __syncthreads(); + } + } + int remainder = n - fix; + if (remainder > 0) { + kps::ReadData(&dst_mask[0], src + fix, remainder); + kps::ElementwiseRandom(&rands[0], Rand(), + &state); + // dst + kps::OperatorBinary>( + &dst_mask[0], &rands[0], mask_functor, kCount); + // mask + kps::ElementwiseUnary( + &mask_result[0], &dst_mask[0], Cast()); + kps::WriteData(mask + fix, &mask_result[0], + remainder); + __syncthreads(); + } +} + +inline void CalcBroadcastedMask(const phi::GPUContext& dev_ctx, + const framework::Tensor& mask, + framework::Tensor* broadcasted_mask) { + // The broadcast of mask can be combined to the following ElementwiseKernel + // when the BroadcastKernel supports different input types. + broadcasted_mask->mutable_data(dev_ctx.GetPlace()); + + std::vector ins = {&mask}; + std::vector outs = {broadcasted_mask}; + phi::funcs::BroadcastKernel( + dev_ctx, ins, &outs, -1, kps::IdentityFunctor()); +} + +template +void ScaleByDropoutFactor(const phi::GPUContext& dev_ctx, + const framework::Tensor& x, framework::Tensor* y, + MT factor) { + std::vector ins = {&x}; + std::vector outs = {y}; + auto functor = phi::funcs::ScaleFunctor(factor); + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, functor); +} + template void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, - const std::string dropout_implementation, float dropout_prob, bool upscale_in_train, bool is_fix_seed, int seed_val, const framework::Tensor& x, const framework::Tensor* seed, - framework::Tensor* mask, framework::Tensor* y) { - auto& place = *dev_ctx.eigen_device(); + framework::Tensor* mask, framework::Tensor* y, + bool is_dropout_nd = false) { int64_t x_numel = x.numel(); auto stream = dev_ctx.stream(); auto* x_data = x.data(); @@ -198,33 +336,38 @@ void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, size_t main_offset = size / (block_size * kVecSize) * (block_size * kVecSize); + if (is_dropout_nd) { + VectorizedGeneratorMask<<>>( + size, seed_data, dropout_prob, x_data, mask_data, increment, + main_offset); + + framework::Tensor broadcasted_mask; + broadcasted_mask.Resize(x.dims()); + CalcBroadcastedMask(dev_ctx, *mask, &broadcasted_mask); + + auto dst_functor = DstFunctor(1.0f - dropout_prob, + upscale_in_train, x_numel); + std::vector ins = {&x, &broadcasted_mask}; + std::vector outs = {y}; + phi::funcs::ElementwiseKernel(dev_ctx, ins, &outs, dst_functor); + } else { #define PD_DROPOUT_KERNEL_NAME VectorizedRandomGenerator - PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL( - !is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, stream, - offset, KERNEL_PARAMS.As(1), KERNEL_PARAMS.As(7), - size, seed_data, dropout_prob, x_data, mask_data, y_data, - upscale_in_train, increment, main_offset); + PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL( + !is_fix_seed, PD_DROPOUT_KERNEL_NAME, grid_size, block_size, 0, + stream, offset, KERNEL_PARAMS.As(1), + KERNEL_PARAMS.As(7), size, seed_data, dropout_prob, x_data, + mask_data, y_data, upscale_in_train, increment, main_offset); #undef PD_DROPOUT_KERNEL_NAME + } } else { if (upscale_in_train) { -// todo: can y share with data with x directly? -#ifdef PADDLE_WITH_HIP - PADDLE_ENFORCE_GPU_SUCCESS( - hipMemcpyAsync(y_data, x_data, sizeof(T) * x_numel, - hipMemcpyDeviceToDevice, stream)); -#else - PADDLE_ENFORCE_GPU_SUCCESS( - cudaMemcpyAsync(y_data, x_data, sizeof(T) * x_numel, - cudaMemcpyDeviceToDevice, stream)); -#endif + // y = x + framework::TensorCopy(x, dev_ctx.GetPlace(), dev_ctx, y); } else { using MT = typename details::MPTypeTrait::Type; MT factor = static_cast(1.0f - dropout_prob); - std::vector ins = {&x}; - std::vector outs = {y}; - auto functor = phi::funcs::ScaleFunctor(factor); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + // y = factor * x + ScaleByDropoutFactor(dev_ctx, x, y, factor); } } } @@ -246,45 +389,44 @@ struct CudaDropoutGradFunctor { }; template -void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, - const std::string dropout_implementation, - float dropout_prob, +void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, + float dropout_prob, bool upscale_in_train, const framework::Tensor& grad_y, - const framework::Tensor& mask, int64_t size, + const framework::Tensor& mask, framework::Tensor* grad_x, - bool is_test = false) { + bool is_dropout_nd = false) { using MT = typename details::MPTypeTrait::Type; + auto stream = dev_ctx.stream(); - MT factor; if (is_test) { - if (dropout_implementation == "upscale_in_train") { - factor = static_cast(1.0f); - } else { - factor = static_cast(1.0f - dropout_prob); - } - std::vector ins = {&grad_y}; - std::vector outs = {grad_x}; - auto functor = phi::funcs::ScaleFunctor(factor); - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx, ins, - &outs, functor); + MT factor = static_cast(upscale_in_train ? 1.0f : 1.0f - dropout_prob); + // y = factor * x + ScaleByDropoutFactor(dev_ctx, grad_y, grad_x, factor); } else { - std::vector ins = {&grad_y, &mask}; + framework::Tensor broadcasted_mask; + if (is_dropout_nd) { + broadcasted_mask.Resize(grad_y.dims()); + CalcBroadcastedMask(dev_ctx, mask, &broadcasted_mask); + } + + std::vector ins = { + &grad_y, is_dropout_nd ? &broadcasted_mask : &mask}; std::vector outs = {grad_x}; - if (dropout_implementation == "upscale_in_train") { + if (upscale_in_train) { if (dropout_prob == 1.0f) { #ifdef PADDLE_WITH_HIP - hipMemset(grad_x->data(), 0, size * sizeof(T)); + hipMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #else - cudaMemset(grad_x->data(), 0, size * sizeof(T)); + cudaMemset(grad_x->data(), 0, grad_x->numel() * sizeof(T)); #endif } else { - factor = static_cast(1.0f / (1.0f - dropout_prob)); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( + MT factor = static_cast(1.0f / (1.0f - dropout_prob)); + phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } } else { - factor = static_cast(1.0f); - paddle::operators::LaunchSameDimsElementwiseCudaKernel( + MT factor = static_cast(1.0f); + phi::funcs::ElementwiseKernel( dev_ctx, ins, &outs, CudaDropoutGradFunctor(factor)); } } diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 9426efa4942088473d5487a7bc5c08930baef002..3f65a6bfda97f3820d4426e4be00939d83d99e63 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -161,15 +161,49 @@ class DropoutGradOpMaker : public framework::SingleGradOpMaker { } }; +class DropoutNdOpMaker : public DropoutOpMaker { + public: + void Make() override { + DropoutOpMaker::Make(); + AddAttr>("axis", + "(std::vector). List of integers," + " indicating the dimensions to be dropout_nd.") + .SetDefault({}); + } +}; + +template +class DropoutNdGradOpMaker : public framework::SingleGradOpMaker { + public: + using framework::SingleGradOpMaker::SingleGradOpMaker; + + protected: + void Apply(GradOpPtr op) const override { + op->SetType("dropout_nd_grad"); + op->SetInput(framework::GradVarName("Out"), this->OutputGrad("Out")); + op->SetInput("Mask", this->Output("Mask")); + op->SetOutput(framework::GradVarName("X"), this->InputGrad("X")); + op->SetAttrMap(this->Attrs()); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; + DECLARE_INFER_SHAPE_FUNCTOR(dropout, DropoutInferShapeFunctor, PD_INFER_META(phi::DropoutInferMeta)); - REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ops::DropoutGradOpMaker, ops::DropoutGradOpMaker, DropoutInferShapeFunctor); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); + +DECLARE_INFER_SHAPE_FUNCTOR(dropout_nd, DropoutNdInferShapeFunctor, + PD_INFER_META(phi::DropoutNdInferMeta)); +REGISTER_OPERATOR(dropout_nd, ops::DropoutOp, ops::DropoutNdOpMaker, + ops::DropoutNdGradOpMaker, + ops::DropoutNdGradOpMaker, + DropoutNdInferShapeFunctor); +REGISTER_OPERATOR(dropout_nd_grad, ops::DropoutOpGrad); diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index ce95b0a320c66382f4ad441ae57337c90a757210..ef00a0203c7b83efb996e802a73cf37c1f45680d 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -131,8 +131,7 @@ class FMHARef { auto functor = phi::funcs::ScaleFunctor(alpha); std::vector ins = {&q_tensor}; std::vector outs = {&q_tensor}; - paddle::operators::LaunchSameDimsElementwiseCudaKernel(dev_ctx_, ins, - &outs, functor); + phi::funcs::ElementwiseKernel(dev_ctx_, ins, &outs, functor); } // q*k^t, batched_gemm @@ -186,13 +185,11 @@ class FMHARef { if (dropout_param_.dropout_prob_) { DropoutFwGPUKernelDriver( static_cast(dev_ctx_), - dropout_param_.is_test_, - static_cast( - dropout_param_.dropout_implementation_), - dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, - dropout_param_.is_fix_seed_, dropout_param_.seed_val_, + dropout_param_.is_test_, dropout_param_.dropout_prob_, + dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_, + dropout_param_.seed_val_, static_cast(*softmax_out_tensor), dropout_param_.seed_, - dropout_mask_out_tensor, dropout_out_tensor); + dropout_mask_out_tensor, dropout_out_tensor, false); blas.BatchedGEMM(transA, transB, gemm_m, gemm_n, gemm_k, alpha, dropout_out_data, v_ptr, beta, qktv_out_data, gemm_batch_size, stride_a, stride_b); @@ -288,13 +285,10 @@ class FMHARef { // dropout bw if (dropout_param_.dropout_prob_) { DropoutGradGPUKernelDriver( - static_cast(dev_ctx_), - static_cast( - dropout_param_.dropout_implementation_), - dropout_param_.dropout_prob_, + static_cast(dev_ctx_), false, + dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, static_cast(*dropout_out_grad_tensor), - dropout_mask_out_tensor, softmax_out_grad_tensor->numel(), - softmax_out_grad_tensor); + dropout_mask_out_tensor, softmax_out_grad_tensor, false); } if (src_mask_tensor != nullptr) { diff --git a/paddle/fluid/platform/aligned_vector.h b/paddle/fluid/platform/aligned_vector.h deleted file mode 100644 index b42ae15405e7ff0cb8d24fa7bf14aac442c16163..0000000000000000000000000000000000000000 --- a/paddle/fluid/platform/aligned_vector.h +++ /dev/null @@ -1,77 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.1 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.1 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once - -#include "paddle/phi/core/hostdevice.h" - -namespace paddle { -namespace platform { - -// Aligned vector generates vectorized load/store on CUDA. -template -struct alignas(sizeof(T) * Size) AlignedVector { - T val[Size]; - - HOSTDEVICE inline const T& operator[](int i) const { return val[i]; } - HOSTDEVICE inline T& operator[](int i) { return val[i]; } -}; - -template -HOSTDEVICE inline void Load(const T* addr, AlignedVector* vec) { - const AlignedVector* addr_vec = - reinterpret_cast*>(addr); - *vec = *addr_vec; -} - -template -HOSTDEVICE inline void Store(const AlignedVector& vec, T* addr) { - AlignedVector* addr_vec = - reinterpret_cast*>(addr); - *addr_vec = vec; -} - -/* - * Only the address of input data is the multiplier of 1,2,4, vectorized load - * with corresponding multiplier-value is possible. Moreover, the maximum length - * of vectorized load is 128 bits once. Hence, valid length of vectorized load - * shall be determined under both former constraints. - */ -template -int GetVectorizedSize(const T* pointer) { - constexpr int max_load_bits = 128; - int valid_vec_size = max_load_bits / CHAR_BIT / sizeof(T); - uint64_t address = reinterpret_cast(pointer); - constexpr int vec8 = std::alignment_of>::value; // NOLINT - constexpr int vec4 = std::alignment_of>::value; // NOLINT - constexpr int vec2 = std::alignment_of>::value; // NOLINT - if (address % vec8 == 0) { - /* - * Currently, decide to deal with no more than 4 data once while adopting - * vectorization load/store, if performance test shows that dealing with - * 8 data once in vectorization load/store does get optimized, return code - * below can be changed into " return std::min(8, valid_vec_size); " . - */ - return std::min(4, valid_vec_size); - } else if (address % vec4 == 0) { - return std::min(4, valid_vec_size); - } else if (address % vec2 == 0) { - return std::min(2, valid_vec_size); - } else { - return 1; - } -} - -} // namespace platform -} // namespace paddle diff --git a/paddle/phi/infermeta/binary.cc b/paddle/phi/infermeta/binary.cc index f10fc54795ddb05d6ee9c5095e4adaa113b3428b..add27da56b59a835a5a97de9780b4604e4141cc7 100644 --- a/paddle/phi/infermeta/binary.cc +++ b/paddle/phi/infermeta/binary.cc @@ -886,6 +886,58 @@ void DropoutInferMeta(const MetaTensor& x, } } +void DropoutNdInferMeta(const MetaTensor& x, + const MetaTensor& seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + const std::vector& axis, + MetaTensor* out, + MetaTensor* mask) { + auto x_dims = x.dims(); + + PADDLE_ENFORCE_LE( + axis.size(), + x_dims.size(), + phi::errors::InvalidArgument( + "The length of axis is expected to be less than or equal to the " + "dimension size of x. But recieved the length of axis is %d, the " + "dimension size of x is %d, x's shape is {%s}.", + axis.size(), + x_dims.size(), + x_dims)); + for (size_t i = 0; i < axis.size(); ++i) { + PADDLE_ENFORCE_EQ( + axis[i] >= 0 && axis[i] <= x_dims.size() - 1, + true, + phi::errors::InvalidArgument( + "The %d-th value of axis is expected to be greater ot " + "equal to 0 and less than the dimensions of x. But " + "recieved axis is {%s}, the dimension size of x is %d.", + i, + phi::make_ddim(axis), + x_dims.size())); + } + + out->set_dims(x_dims); + out->share_lod(x); + out->set_dtype(x.dtype()); + + if (mask != nullptr) { + std::vector mask_dims(x.dims().size(), 1); + + std::for_each( + axis.begin(), axis.end(), [&mask_dims, &x_dims](const int64_t& t) { + mask_dims[t] = x_dims[t]; + }); + + mask->set_dims(make_ddim(mask_dims)); + mask->set_dtype(DataType::UINT8); + } +} + void DotInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out) { auto x_dims = x.dims(); auto x_rank = static_cast(x_dims.size()); diff --git a/paddle/phi/infermeta/binary.h b/paddle/phi/infermeta/binary.h index 2cd34406fc2d2708af42284e86bceae5dbc23fb6..9709edf63ccc07cb36246e45e734d09c028434f0 100644 --- a/paddle/phi/infermeta/binary.h +++ b/paddle/phi/infermeta/binary.h @@ -145,6 +145,17 @@ void DropoutInferMeta(const MetaTensor& x, MetaTensor* out, MetaTensor* mask); +void DropoutNdInferMeta(const MetaTensor& x, + const MetaTensor& seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + const std::vector& axis, + MetaTensor* out, + MetaTensor* mask); + void ElementwiseInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out); diff --git a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc index db95656421884d1df29fa4be695ea2da2f7e025b..42b2834aaffc9ff45230c8a8f4d96b1b6a8eedb6 100644 --- a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc @@ -21,16 +21,17 @@ namespace phi { template -void DropoutGradRawKernel(const Context& dev_ctx, - const DenseTensor& mask, - const DenseTensor& out_grad, - float p, - bool is_test, - const std::string& mode, - DenseTensor* x_grad) { +void DropoutNdGradKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + const std::vector& axis, + DenseTensor* x_grad) { auto* grad_x = x_grad; auto* grad_y = &out_grad; - grad_x->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(grad_x); auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(*grad_y); @@ -44,19 +45,41 @@ void DropoutGradRawKernel(const Context& dev_ctx, dX.device(place) = dY * static_cast(1.0f - p); } } else { + std::vector out_dims = phi::vectorize(out_grad.dims()); auto M = EigenVector::Flatten(mask); if (dropout_implementation == "upscale_in_train") { if (p == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { - dX.device(place) = dY * M.cast() / static_cast(1.0f - p); + if (axis.empty()) { + dX.device(place) = dY * M.cast() / static_cast(1.0f - p); + } else { + dX.device(place) = + dY * M.broadcast(out_dims).cast() / static_cast(1.0f - p); + } } } else { - dX.device(place) = dY * M.cast(); + if (axis.empty()) { + dX.device(place) = dY * M.cast(); + } else { + dX.device(place) = dY * M.broadcast(out_dims).cast(); + } } } } +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad) { + DropoutNdGradKernel( + dev_ctx, mask, out_grad, p, is_test, mode, {}, x_grad); +} + } // namespace phi PD_REGISTER_KERNEL(dropout_grad, @@ -66,3 +89,7 @@ PD_REGISTER_KERNEL(dropout_grad, float, double, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + dropout_nd_grad, CPU, ALL_LAYOUT, phi::DropoutNdGradKernel, float, double) { +} diff --git a/paddle/phi/kernels/cpu/dropout_kernel.cc b/paddle/phi/kernels/cpu/dropout_kernel.cc index d9c02eff0106fecf932d29ed358a0d1c5ec80c4f..d3ca21cfe33b9d1006d86f93c67800c97ed393dd 100644 --- a/paddle/phi/kernels/cpu/dropout_kernel.cc +++ b/paddle/phi/kernels/cpu/dropout_kernel.cc @@ -17,10 +17,34 @@ #include "paddle/fluid/framework/generator.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/expand_kernel.h" #include "paddle/phi/kernels/funcs/eigen/common.h" namespace phi { +template +void ComputeDropoutInference(const Context& ctx, + const DenseTensor& x, + float dropout_prob, + bool upscale_in_train, + DenseTensor* y) { + if (upscale_in_train) { + const auto* X_data = x.data(); + T* Y_data = ctx.template Alloc(y); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < x.numel(); i++) { + Y_data[i] = X_data[i]; + } + } else { + auto X = EigenMatrix::Reshape(x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto& place = *ctx.eigen_device(); + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } +} + template void DropoutRawKernel(const Context& dev_ctx, const DenseTensor& x, @@ -34,13 +58,13 @@ void DropoutRawKernel(const Context& dev_ctx, DenseTensor* mask) { auto* y = out; const auto* x_data = x.data(); - auto* y_data = y->mutable_data(dev_ctx.GetPlace()); + T* y_data = dev_ctx.template Alloc(y); float dropout_prob = p; auto& dropout_implementation = mode; bool upscale_in_train = (dropout_implementation == "upscale_in_train"); if (!is_test) { - auto* mask_data = mask->mutable_data(dev_ctx.GetPlace()); + auto* mask_data = dev_ctx.template Alloc(mask); size_t size = phi::product(mask->dims()); // Special case when dropout_prob is 1.0 @@ -76,21 +100,92 @@ void DropoutRawKernel(const Context& dev_ctx, } } } else { - if (upscale_in_train) { - const auto* X_data = x.data(); - auto* Y_data = y->mutable_data(dev_ctx.GetPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < x.numel(); i++) { - Y_data[i] = X_data[i]; - } + ComputeDropoutInference( + dev_ctx, x, dropout_prob, upscale_in_train, y); + } +} + +template +void DropoutNdKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + const std::vector& axis, + DenseTensor* out, + DenseTensor* mask) { + auto* y = out; + const auto* x_data = x.data(); + T* y_data = dev_ctx.template Alloc(y); + float dropout_prob = p; + + auto& dropout_implementation = mode; + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); + if (!is_test) { + DenseTensor t_mask; + t_mask.Resize(mask->dims()); + T* t_mask_data = dev_ctx.template Alloc(&t_mask); + auto* mask_data = dev_ctx.template Alloc(mask); + size_t size = phi::product(mask->dims()); + + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT + std::memset(t_mask_data, 0, size * sizeof(*t_mask_data)); // NOLINT + std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT + return; + } + // std::minstd_rand engine; + // NOTE: fixed seed should only be used in unittest or for debug. + // Guarantee to use random seed in training. + int seed_data = 0; + if (seed_tensor.get_ptr() != nullptr) { + seed_data = *(seed_tensor->data()); } else { - auto X = EigenMatrix::Reshape(x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); - auto& place = *dev_ctx.eigen_device(); - Y.device(place) = X * static_cast(1.0f - dropout_prob); + seed_data = fix_seed ? seed : 0; } + auto engine = paddle::framework::GetCPURandomEngine(seed_data); + + std::uniform_real_distribution dist(0, 1); + + for (size_t i = 0; i < size; ++i) { + if (dist(*engine) < dropout_prob) { + t_mask_data[i] = 0; + mask_data[i] = 0; + } else { + t_mask_data[i] = 1; + mask_data[i] = 1; + } + } + auto& x_dims = x.dims(); + DenseTensor broadcast_mask; + broadcast_mask.Resize(x_dims); + T* broadcast_mask_data = dev_ctx.template Alloc(&broadcast_mask); + + std::vector mask_bst_dims_vec; + for (int i = 0; i < x_dims.size(); i++) { + mask_bst_dims_vec.emplace_back(x_dims[i]); + } + IntArray mask_bst_dims(mask_bst_dims_vec); + ExpandKernel(dev_ctx, t_mask, mask_bst_dims, &broadcast_mask); + + for (auto i = 0; i < x.numel(); i++) { + if (broadcast_mask_data[i] == static_cast(1)) { + if (upscale_in_train) { + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + y_data[i] = x_data[i]; + } + } else { + y_data[i] = 0; + } + } + } else { + ComputeDropoutInference( + dev_ctx, x, dropout_prob, upscale_in_train, y); } } @@ -103,3 +198,6 @@ PD_REGISTER_KERNEL(dropout, float, double, phi::dtype::bfloat16) {} + +PD_REGISTER_KERNEL( + dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {} diff --git a/paddle/phi/kernels/dropout_grad_kernel.h b/paddle/phi/kernels/dropout_grad_kernel.h index ae3f82056632ddde8968b7468eb16030f0c926f5..d8d5363ad59b7298d7b4216204dc3c433152e34a 100644 --- a/paddle/phi/kernels/dropout_grad_kernel.h +++ b/paddle/phi/kernels/dropout_grad_kernel.h @@ -28,4 +28,14 @@ void DropoutGradRawKernel(const Context& dev_ctx, const std::string& mode, DenseTensor* x_grad); +template +void DropoutNdGradKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + const std::vector& axis, + DenseTensor* x_grad); + } // namespace phi diff --git a/paddle/phi/kernels/dropout_kernel.h b/paddle/phi/kernels/dropout_kernel.h index 6febcd78e1107a34825b82b5843a5d7b10018a6d..cba8160058e9905fd759b8d83027edce97dc332b 100644 --- a/paddle/phi/kernels/dropout_kernel.h +++ b/paddle/phi/kernels/dropout_kernel.h @@ -14,6 +14,7 @@ #pragma once +#include "paddle/phi/common/int_array.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/dense_tensor.h" @@ -31,4 +32,17 @@ void DropoutRawKernel(const Context& dev_ctx, DenseTensor* out, DenseTensor* mask); +template +void DropoutNdKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + const std::vector& axis, + DenseTensor* out, + DenseTensor* mask); + } // namespace phi diff --git a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu index b27029fe863fad82fbafbc191a1f6efa0708c1a5..1eea13a5a226b481cfdc80c4cfa6b9bcde24784b 100644 --- a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu @@ -27,10 +27,25 @@ void DropoutGradRawKernel(const Context& dev_ctx, bool is_test, const std::string& mode, DenseTensor* x_grad) { + bool upscale_in_train = (mode == "upscale_in_train"); x_grad->mutable_data(dev_ctx.GetPlace()); - auto size = x_grad->numel(); paddle::operators::DropoutGradGPUKernelDriver( - dev_ctx, mode, p, out_grad, mask, size, x_grad, is_test); + dev_ctx, is_test, p, upscale_in_train, out_grad, mask, x_grad, false); +} + +template +void DropoutNdGradKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + const std::vector& axis, + DenseTensor* x_grad) { + bool upscale_in_train = (mode == "upscale_in_train"); + dev_ctx.template Alloc(x_grad); + paddle::operators::DropoutGradGPUKernelDriver( + dev_ctx, is_test, p, upscale_in_train, out_grad, mask, x_grad, true); } } // namespace phi @@ -43,3 +58,12 @@ PD_REGISTER_KERNEL(dropout_grad, double, phi::dtype::bfloat16, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(dropout_nd_grad, + GPU, + ALL_LAYOUT, + phi::DropoutNdGradKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu index 8ae3dd25cc8f6272f2d8c41d0445434f31660e53..3811440be75113906881775971709087cdeecc9b 100644 --- a/paddle/phi/kernels/gpu/dropout_kernel.cu +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -30,22 +30,48 @@ void DropoutRawKernel(const Context& dev_ctx, bool fix_seed, DenseTensor* out, DenseTensor* mask) { - out->mutable_data(dev_ctx.GetPlace()); - float dropout_prob = p; bool upscale_in_train = (mode == "upscale_in_train"); + out->mutable_data(dev_ctx.GetPlace()); mask->mutable_data(dev_ctx.GetPlace()); + paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, + is_test, + p, + upscale_in_train, + fix_seed, + seed, + x, + seed_tensor.get_ptr(), + mask, + out, + false); +} +template +void DropoutNdKernel(const Context& dev_ctx, + const DenseTensor& x, + const paddle::optional& seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + const std::vector& axis, + DenseTensor* out, + DenseTensor* mask) { + bool upscale_in_train = (mode == "upscale_in_train"); + dev_ctx.template Alloc(out); + dev_ctx.template Alloc(mask); paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, is_test, - mode, - dropout_prob, + p, upscale_in_train, fix_seed, seed, x, seed_tensor.get_ptr(), mask, - out); + out, + true); } } // namespace phi @@ -58,3 +84,12 @@ PD_REGISTER_KERNEL(dropout, double, phi::dtype::bfloat16, phi::dtype::float16) {} + +PD_REGISTER_KERNEL(dropout_nd, + GPU, + ALL_LAYOUT, + phi::DropoutNdKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/dropout_sig.cc b/paddle/phi/ops/compat/dropout_sig.cc index 712c5cbb0d634bef52deaa1a04105484d6d7aad7..403e752ca0e83b872513d3d3478c4c371c46c6a2 100644 --- a/paddle/phi/ops/compat/dropout_sig.cc +++ b/paddle/phi/ops/compat/dropout_sig.cc @@ -32,7 +32,31 @@ KernelSignature DropoutGradOpArgumentMapping( {"X@GRAD"}); } +KernelSignature DropoutNdOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature("dropout_nd", + {"X", "Seed"}, + {"dropout_prob", + "is_test", + "dropout_implementation", + "seed", + "fix_seed", + "axis"}, + {"Out", "Mask"}); +} + +KernelSignature DropoutNdGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature( + "dropout_nd_grad", + {"Mask", "Out@GRAD"}, + {"dropout_prob", "is_test", "dropout_implementation", "axis"}, + {"X@GRAD"}); +} + } // namespace phi PD_REGISTER_ARG_MAPPING_FN(dropout, phi::DropoutOpArgumentMapping); PD_REGISTER_ARG_MAPPING_FN(dropout_grad, phi::DropoutGradOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(dropout_nd, phi::DropoutNdOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(dropout_nd_grad, + phi::DropoutNdGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py b/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py new file mode 100644 index 0000000000000000000000000000000000000000..c696863c612b0010242eef9686f2970f5aa28f62 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_dropout_nd_op.py @@ -0,0 +1,130 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest, convert_float_to_uint16 +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +from paddle.fluid.layer_helper import LayerHelper +from paddle.fluid.framework import _non_static_mode +from paddle import _C_ops +from paddle.static import default_main_program + + +def dropout_nd(x, + p=0.5, + axis=None, + training=True, + mode="upscale_in_train", + name=None): + drop_axes = [axis] if isinstance(axis, int) else list(axis) + seed = None + mode = 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode #semantic transfer + if _non_static_mode(): + if default_main_program().random_seed != 0: + seed = default_main_program().random_seed + + out, mask = _C_ops.dropout_nd(x, 'dropout_prob', p, 'is_test', + not training, 'fix_seed', seed + is not None, 'seed', + seed if seed is not None else 0, + 'dropout_implementation', mode, 'axis', + drop_axes) + return out + + helper = LayerHelper('dropout_nd', **locals()) + check_variable_and_dtype(x, 'x', ['float16', 'float32', 'float64'], + 'dropout') + + out = helper.create_variable_for_type_inference(dtype=x.dtype) + mask = helper.create_variable_for_type_inference( + dtype=core.VarDesc.VarType.UINT8, stop_gradient=True) + + def get_attrs(prog, dropout_prob, is_test, seed): + if (seed is None or seed == 0) and prog.random_seed != 0: + seed = prog.random_seed + attrs = { + 'dropout_prob': dropout_prob, + 'is_test': is_test, + 'fix_seed': seed is not None, + 'seed': seed if seed is not None else 0, + 'dropout_implementation': mode, + 'axis': drop_axes + } + return attrs + + attrs = get_attrs(helper.main_program, p, not training, seed) + + helper.append_op(type='dropout_nd', + inputs={'X': [x]}, + outputs={ + 'Out': [out], + 'Mask': [mask] + }, + attrs=attrs) + return out + + +paddle.enable_static() + + +class TestDropoutNdOp(OpTest): + + def setUp(self): + self.op_type = "dropout_nd" + self.inputs = {'X': np.random.random((4, 32, 16)).astype("float64")} + self.attrs = { + 'dropout_prob': 0.0, + 'fix_seed': True, + 'is_test': False, + 'axis': [1] + } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((1, 32, 1)).astype('uint8') + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out') + + +class TestDropoutNdAPI(unittest.TestCase): + + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def test_dygraph(self): + paddle.disable_static() + for place in self.places: + with fluid.dygraph.guard(place): + in_np = np.random.random([4, 32, 16]).astype("float32") + input = paddle.to_tensor(in_np) + res1 = dropout_nd(x=input, p=0., axis=[0, 1]) + res2 = dropout_nd(x=input, p=0.5, axis=[0, 1]) + self.assertTrue(np.allclose(res1.numpy(), in_np)) + paddle.enable_static() + + +if __name__ == '__main__': + unittest.main()