From 99fc1b084dbca6ad4f1c0137548ca8a308f1d819 Mon Sep 17 00:00:00 2001 From: hong <43953930+phlrain@users.noreply.github.com> Date: Thu, 10 Mar 2022 09:55:41 +0800 Subject: [PATCH] Move dropout to phi (#40148) * move dropout to phi; test=develop * fix xpu, npu compile error; test=develop --- .../inference/tensorrt/convert/dropout_op.cc | 2 +- .../tensorrt/convert/test_dropout_op.cc | 2 +- paddle/fluid/operators/assign_op_npu_test.cc | 1 - .../collective/c_allgather_op_npu_test.cc | 1 - .../collective/c_allreduce_max_op_npu_test.cc | 1 - .../collective/c_allreduce_sum_op_npu_test.cc | 1 - .../collective/c_broadcast_op_npu_test.cc | 1 - .../collective/c_reduce_sum_op_npu_test.cc | 1 - .../collective/c_reducescatter_op_npu_test.cc | 1 - .../c_sync_comm_stream_op_npu_test.cc | 1 - .../collective/checknumeric_npu_test.cc | 1 - .../collective/recv_v2_op_npu_test.cc | 1 - .../collective/send_v2_op_npu_test.cc | 1 - paddle/fluid/operators/dropout_impl.cu.h | 27 ++-- paddle/fluid/operators/dropout_impl_util.h | 2 +- paddle/fluid/operators/dropout_op.cc | 13 +- paddle/fluid/operators/dropout_op.cu | 94 ----------- paddle/fluid/operators/dropout_op.h | 151 ------------------ paddle/fluid/operators/dropout_op_npu.cc | 2 +- paddle/fluid/operators/dropout_op_test.cc | 3 +- paddle/fluid/operators/dropout_op_xpu.cc | 4 +- .../elementwise/elementwise_op_npu_test.cc | 1 - paddle/fluid/operators/expand_op_npu_test.cc | 1 - paddle/fluid/operators/fused/fmha_ref.h | 11 +- .../operators/fused/fused_dropout_test.h | 2 +- paddle/fluid/operators/gelu_op_npu_test.cc | 1 - .../fluid/operators/increment_op_npu_test.cc | 1 - paddle/fluid/operators/range_op_npu_test.cc | 1 - paddle/fluid/operators/rnn_op.h | 10 +- paddle/fluid/operators/softmax_op_npu_test.cc | 1 - paddle/fluid/operators/squeeze_op_npu_test.cc | 1 - .../fluid/operators/transpose_op_npu_test.cc | 1 - .../fluid/operators/unsqueeze_op_npu_test.cc | 1 - paddle/phi/kernels/cpu/dropout_grad_kernel.cc | 67 ++++++++ paddle/phi/kernels/cpu/dropout_kernel.cc | 104 ++++++++++++ paddle/phi/kernels/dropout_grad_kernel.h | 31 ++++ paddle/phi/kernels/dropout_kernel.h | 34 ++++ paddle/phi/kernels/gpu/dropout_grad_kernel.cu | 46 ++++++ paddle/phi/kernels/gpu/dropout_kernel.cu | 61 +++++++ paddle/phi/ops/compat/dropout_sig.cc | 38 +++++ .../fluid/tests/unittests/test_dropout_op.py | 60 +++++++ 41 files changed, 481 insertions(+), 303 deletions(-) delete mode 100644 paddle/fluid/operators/dropout_op.cu delete mode 100644 paddle/fluid/operators/dropout_op.h create mode 100644 paddle/phi/kernels/cpu/dropout_grad_kernel.cc create mode 100644 paddle/phi/kernels/cpu/dropout_kernel.cc create mode 100644 paddle/phi/kernels/dropout_grad_kernel.h create mode 100644 paddle/phi/kernels/dropout_kernel.h create mode 100644 paddle/phi/kernels/gpu/dropout_grad_kernel.cu create mode 100644 paddle/phi/kernels/gpu/dropout_kernel.cu create mode 100644 paddle/phi/ops/compat/dropout_sig.cc diff --git a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc index 8c61200f7f5..b69292827aa 100644 --- a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc @@ -89,5 +89,5 @@ class DropoutOpConverter : public OpConverter { } // namespace inference } // namespace paddle -USE_OP(dropout); +USE_OP_ITSELF(dropout); REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc index 474fd92071f..cf377396087 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc @@ -57,4 +57,4 @@ TEST(DropoutOpConverter, main) { } // namespace inference } // namespace paddle -USE_OP(dropout); +USE_OP_ITSELF(dropout); diff --git a/paddle/fluid/operators/assign_op_npu_test.cc b/paddle/fluid/operators/assign_op_npu_test.cc index 72488a932d9..b452dea8536 100644 --- a/paddle/fluid/operators/assign_op_npu_test.cc +++ b/paddle/fluid/operators/assign_op_npu_test.cc @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc b/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc index c0968581acd..7206dd01bca 100644 --- a/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc b/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc index 31b00a93f13..0946ad8aca6 100644 --- a/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc index 9c11704704e..61e5f279034 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc b/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc index 5787090e6a5..cf4d6a28744 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc index c79b2f92b69..c4e410d04da 100644 --- a/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc index d9a7a4abb08..8b498787c69 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc index b8abf458c1c..133085ad3f3 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/checknumeric_npu_test.cc b/paddle/fluid/operators/collective/checknumeric_npu_test.cc index bb78971734b..36c6f4fadd0 100644 --- a/paddle/fluid/operators/collective/checknumeric_npu_test.cc +++ b/paddle/fluid/operators/collective/checknumeric_npu_test.cc @@ -27,7 +27,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc b/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc index 8f7b8c4a904..6e02d362156 100644 --- a/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc +++ b/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/send_v2_op_npu_test.cc b/paddle/fluid/operators/collective/send_v2_op_npu_test.cc index c40b2c3e76a..57e3dd53cc7 100644 --- a/paddle/fluid/operators/collective/send_v2_op_npu_test.cc +++ b/paddle/fluid/operators/collective/send_v2_op_npu_test.cc @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 17665ad67e4..144198367d5 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -32,10 +32,9 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/dropout_impl_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/functors.h" namespace paddle { @@ -177,12 +176,13 @@ __global__ void DropoutGradCUDAKernel( } template -void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, - bool is_test, +void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, const std::string dropout_implementation, float dropout_prob, bool upscale_in_train, - bool is_fix_seed, int seed_val, const Tensor& x, - const Tensor* seed, Tensor* mask, Tensor* y) { + bool is_fix_seed, int seed_val, + const framework::Tensor& x, + const framework::Tensor* seed, + framework::Tensor* mask, framework::Tensor* y) { auto& place = *dev_ctx.eigen_device(); int64_t x_numel = x.numel(); auto stream = dev_ctx.stream(); @@ -220,7 +220,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, // VectorizedRandomGenerator use curand_uniform4, so we only support // vec_size is 4; int vec_size = (phi::GetVectorizedSize(x_data) == 4) ? 4 : 1; - auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); + auto gpu_config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); auto offset = ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size; @@ -278,11 +279,13 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, } template -void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, +void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, const std::string dropout_implementation, - float dropout_prob, const Tensor& grad_y, - const Tensor& mask, int64_t size, - Tensor* grad_x, bool is_test = false) { + float dropout_prob, + const framework::Tensor& grad_y, + const framework::Tensor& mask, int64_t size, + framework::Tensor* grad_x, + bool is_test = false) { using MT = typename details::MPTypeTrait::Type; auto stream = dev_ctx.stream(); MT factor; diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index d7db7dddce3..c62d45570ba 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, +inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, const framework::Tensor* seed, const bool is_fix_seed, const int seed_val, const int offset, uint64_t* seed_data, diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 7613b04bccf..6d52ce45c4c 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/dropout_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -177,14 +177,3 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ops::DropoutGradOpMaker, ops::DropoutGradOpMaker); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); -REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel, - ops::CPUDropoutKernel, - ops::CPUDropoutKernel); -REGISTER_OP_CPU_KERNEL( - dropout_grad, - ops::DropoutGradKernel, - ops::DropoutGradKernel, - ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu deleted file mode 100644 index f6ddff1d032..00000000000 --- a/paddle/fluid/operators/dropout_op.cu +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/dropout_impl.cu.h" -#include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class GPUDropoutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* seed = - context.HasInput("Seed") ? context.Input("Seed") : nullptr; - auto* y = context.Output("Out"); - y->mutable_data(context.GetPlace()); - float dropout_prob = context.Attr("dropout_prob"); - - auto& dropout_implementation = - context.Attr("dropout_implementation"); - bool upscale_in_train = (dropout_implementation == "upscale_in_train"); - - bool is_test = context.Attr("is_test"); - - auto& dev_ctx = context.cuda_device_context(); - auto* mask = context.Output("Mask"); - mask->mutable_data(context.GetPlace()); - - bool is_fix_seed = context.Attr("fix_seed"); - int seed_val = context.Attr("seed"); - DropoutFwGPUKernelDriver(dev_ctx, is_test, dropout_implementation, - dropout_prob, upscale_in_train, is_fix_seed, - seed_val, *x, seed, mask, y); - } -}; - -template -class GPUDropoutGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = context.Output(framework::GradVarName("X")); - auto* grad_y = context.Input(framework::GradVarName("Out")); - auto* mask = context.Input("Mask"); - grad_x->mutable_data(context.GetPlace()); - auto size = grad_x->numel(); - auto& dropout_implementation = - context.Attr("dropout_implementation"); - float dropout_prob = context.Attr("dropout_prob"); - - bool is_test = context.Attr("is_test"); - - auto& dev_ctx = - context.template device_context(); - DropoutGradGPUKernelDriver(dev_ctx, dropout_implementation, dropout_prob, - *grad_y, *mask, size, grad_x, is_test); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel, - ops::GPUDropoutKernel, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL( - dropout_grad, ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h deleted file mode 100644 index ea6ed0e6194..00000000000 --- a/paddle/fluid/operators/dropout_op.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include -#include -#include - -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; - -template -using EigenVector = framework::EigenVector; - -template -class CPUDropoutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* seed = - context.HasInput("Seed") ? context.Input("Seed") : nullptr; - auto* y = context.Output("Out"); - const auto* x_data = x->data(); - auto* y_data = y->mutable_data(context.GetPlace()); - float dropout_prob = context.Attr("dropout_prob"); - - auto& dropout_implementation = - context.Attr("dropout_implementation"); - bool upscale_in_train = (dropout_implementation == "upscale_in_train"); - if (!context.Attr("is_test")) { - auto* mask = context.Output("Mask"); - auto* mask_data = mask->mutable_data(context.GetPlace()); - size_t size = phi::product(mask->dims()); - - // Special case when dropout_prob is 1.0 - if (dropout_prob == 1.0f) { - std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT - std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT - return; - } - // std::minstd_rand engine; - // NOTE: fixed seed should only be used in unittest or for debug. - // Guarantee to use random seed in training. - int seed_data = 0; - if (seed) { - seed_data = *(seed->data()); - } else { - seed_data = - context.Attr("fix_seed") ? context.Attr("seed") : 0; - } - auto engine = framework::GetCPURandomEngine(seed_data); - - std::uniform_real_distribution dist(0, 1); - - for (size_t i = 0; i < size; ++i) { - if (dist(*engine) < dropout_prob) { - mask_data[i] = 0; - y_data[i] = 0; - } else { - mask_data[i] = 1; - if (upscale_in_train) { - y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); - } else { - y_data[i] = x_data[i]; - } - } - } - } else { - if (upscale_in_train) { - const auto* X_data = x->data(); - auto* Y_data = y->mutable_data(context.GetPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < x->numel(); i++) { - Y_data[i] = X_data[i]; - } - } else { - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); - auto& place = - *context.template device_context().eigen_device(); - Y.device(place) = X * static_cast(1.0f - dropout_prob); - } - } - } -}; -template -class DropoutGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = context.Output(framework::GradVarName("X")); - auto* grad_y = context.Input(framework::GradVarName("Out")); - auto* mask = context.Input("Mask"); - grad_x->mutable_data(context.GetPlace()); - - auto dX = EigenVector::Flatten(*grad_x); - auto dY = EigenVector::Flatten(*grad_y); - - auto& place = - *context.template device_context().eigen_device(); - auto& dropout_implementation = - context.Attr("dropout_implementation"); - if (context.Attr("is_test") == true) { - if (dropout_implementation == "upscale_in_train") { - dX.device(place) = static_cast(1) * dY; - } else { - float dropout_prob = context.Attr("dropout_prob"); - dX.device(place) = dY * static_cast(1.0f - dropout_prob); - } - } else { - auto M = EigenVector::Flatten(*mask); - if (dropout_implementation == "upscale_in_train") { - float dropout_prob = context.Attr("dropout_prob"); - if (dropout_prob == 1.0f) { - dX.device(place) = static_cast(0) * dY; - } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); - } - } else { - dX.device(place) = dY * M.cast(); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/dropout_op_npu.cc b/paddle/fluid/operators/dropout_op_npu.cc index 6aae5667606..07b3b538116 100644 --- a/paddle/fluid/operators/dropout_op_npu.cc +++ b/paddle/fluid/operators/dropout_op_npu.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc index 206d9a6c5e9..bdf08646f1d 100644 --- a/paddle/fluid/operators/dropout_op_test.cc +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -24,14 +24,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(dropout); +USE_OP_ITSELF(dropout); void Compare(f::Scope* scope, const p::DeviceContext& ctx) { // init diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 07b7e2cc7c0..7d8660f238a 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -8,15 +8,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/dropout_op.h" + #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { #ifdef PADDLE_WITH_XPU +using Tensor = framework::Tensor; template class DropoutXPUKernel : public framework::OpKernel { using XPUTyp = typename XPUTypeTrait::Type; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc index fc128a88f20..3e9263fe93a 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc +++ b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/expand_op_npu_test.cc b/paddle/fluid/operators/expand_op_npu_test.cc index cdd4e1dbaae..df00ae54c10 100644 --- a/paddle/fluid/operators/expand_op_npu_test.cc +++ b/paddle/fluid/operators/expand_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 02027767579..3c9e16785ea 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -140,9 +140,9 @@ class FMHARef { if (dropout_param_.dropout_prob_) { DropoutFwGPUKernelDriver( - dev_ctx_, dropout_param_.is_test_, - static_cast( - dropout_param_.dropout_implementation_), + static_cast(dev_ctx_), + dropout_param_.is_test_, static_cast( + dropout_param_.dropout_implementation_), dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_, dropout_param_.seed_val_, static_cast(*softmax_out_tensor), dropout_param_.seed_, @@ -242,8 +242,9 @@ class FMHARef { // dropout bw if (dropout_param_.dropout_prob_) { DropoutGradGPUKernelDriver( - dev_ctx_, static_cast( - dropout_param_.dropout_implementation_), + static_cast(dev_ctx_), + static_cast( + dropout_param_.dropout_implementation_), dropout_param_.dropout_prob_, static_cast(*dropout_out_grad_tensor), dropout_mask_out_tensor, softmax_out_grad_tensor->numel(), diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index d7952df470d..18c7187fc8e 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -31,7 +31,7 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace memory = paddle::memory; -USE_OP(dropout); +USE_OP_ITSELF(dropout); USE_OP(layer_norm); template diff --git a/paddle/fluid/operators/gelu_op_npu_test.cc b/paddle/fluid/operators/gelu_op_npu_test.cc index 00ff7ad2166..f3ac5313832 100644 --- a/paddle/fluid/operators/gelu_op_npu_test.cc +++ b/paddle/fluid/operators/gelu_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/increment_op_npu_test.cc b/paddle/fluid/operators/increment_op_npu_test.cc index 09f4e63943a..8324a6215bc 100644 --- a/paddle/fluid/operators/increment_op_npu_test.cc +++ b/paddle/fluid/operators/increment_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/range_op_npu_test.cc b/paddle/fluid/operators/range_op_npu_test.cc index 24741efe426..c7e91ba35de 100644 --- a/paddle/fluid/operators/range_op_npu_test.cc +++ b/paddle/fluid/operators/range_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index b636184ae45..a473b54c1f8 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -16,9 +16,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/unique_op.h" @@ -36,6 +36,14 @@ using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; using TensorList = std::vector; +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + #define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \ inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \ const std::string& mode = ctx.Attr("mode"); \ diff --git a/paddle/fluid/operators/softmax_op_npu_test.cc b/paddle/fluid/operators/softmax_op_npu_test.cc index 3bc55fafd81..3148b31a832 100644 --- a/paddle/fluid/operators/softmax_op_npu_test.cc +++ b/paddle/fluid/operators/softmax_op_npu_test.cc @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/squeeze_op_npu_test.cc b/paddle/fluid/operators/squeeze_op_npu_test.cc index 956544c5360..d61f5aa3f63 100644 --- a/paddle/fluid/operators/squeeze_op_npu_test.cc +++ b/paddle/fluid/operators/squeeze_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/transpose_op_npu_test.cc b/paddle/fluid/operators/transpose_op_npu_test.cc index 5617d728a51..fb39034c8e9 100644 --- a/paddle/fluid/operators/transpose_op_npu_test.cc +++ b/paddle/fluid/operators/transpose_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/unsqueeze_op_npu_test.cc b/paddle/fluid/operators/unsqueeze_op_npu_test.cc index 3e11c952d15..a8ced783744 100644 --- a/paddle/fluid/operators/unsqueeze_op_npu_test.cc +++ b/paddle/fluid/operators/unsqueeze_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc new file mode 100644 index 00000000000..b77a6c55b14 --- /dev/null +++ b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dropout_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad) { + auto* grad_x = x_grad; + auto* grad_y = &out_grad; + grad_x->mutable_data(dev_ctx.GetPlace()); + + auto dX = EigenVector::Flatten(*grad_x); + auto dY = EigenVector::Flatten(*grad_y); + + auto& place = *dev_ctx.eigen_device(); + auto& dropout_implementation = mode; + if (is_test == true) { + if (dropout_implementation == "upscale_in_train") { + dX.device(place) = static_cast(1) * dY; + } else { + dX.device(place) = dY * static_cast(1.0f - p); + } + } else { + auto M = EigenVector::Flatten(mask); + if (dropout_implementation == "upscale_in_train") { + if (p == 1.0f) { + dX.device(place) = static_cast(0) * dY; + } else { + dX.device(place) = dY * M.cast() / static_cast(1.0f - p); + } + } else { + dX.device(place) = dY * M.cast(); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout_grad, + CPU, + ALL_LAYOUT, + phi::DropoutGradRawKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/dropout_kernel.cc b/paddle/phi/kernels/cpu/dropout_kernel.cc new file mode 100644 index 00000000000..c00aedef8c6 --- /dev/null +++ b/paddle/phi/kernels/cpu/dropout_kernel.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dropout_kernel.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask) { + auto* y = out; + const auto* x_data = x.data(); + auto* y_data = y->mutable_data(dev_ctx.GetPlace()); + float dropout_prob = p; + + auto& dropout_implementation = mode; + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); + if (!is_test) { + auto* mask_data = mask->mutable_data(dev_ctx.GetPlace()); + size_t size = phi::product(mask->dims()); + + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT + std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT + return; + } + // std::minstd_rand engine; + // NOTE: fixed seed should only be used in unittest or for debug. + // Guarantee to use random seed in training. + int seed_data = 0; + if (seed_tensor.get_ptr() != nullptr) { + seed_data = *(seed_tensor->data()); + } else { + seed_data = fix_seed ? seed : 0; + } + auto engine = paddle::framework::GetCPURandomEngine(seed_data); + + std::uniform_real_distribution dist(0, 1); + + for (size_t i = 0; i < size; ++i) { + if (dist(*engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + if (upscale_in_train) { + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + y_data[i] = x_data[i]; + } + } + } + } else { + if (upscale_in_train) { + const auto* X_data = x.data(); + auto* Y_data = y->mutable_data(dev_ctx.GetPlace()); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < x.numel(); i++) { + Y_data[i] = X_data[i]; + } + } else { + auto X = EigenMatrix::Reshape(x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto& place = *dev_ctx.eigen_device(); + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout, + CPU, + ALL_LAYOUT, + phi::DropoutRawKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/dropout_grad_kernel.h b/paddle/phi/kernels/dropout_grad_kernel.h new file mode 100644 index 00000000000..ae3f8205663 --- /dev/null +++ b/paddle/phi/kernels/dropout_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/dropout_kernel.h b/paddle/phi/kernels/dropout_kernel.h new file mode 100644 index 00000000000..dc9f89e08e1 --- /dev/null +++ b/paddle/phi/kernels/dropout_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu new file mode 100644 index 00000000000..94d4942a418 --- /dev/null +++ b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/phi/kernels/dropout_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad) { + x_grad->mutable_data(dev_ctx.GetPlace()); + auto size = x_grad->numel(); + paddle::operators::DropoutGradGPUKernelDriver( + dev_ctx, mode, p, out_grad, mask, size, x_grad, is_test); +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout_grad, + GPU, + ALL_LAYOUT, + phi::DropoutGradRawKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu new file mode 100644 index 00000000000..bd1683ad0c7 --- /dev/null +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/phi/kernels/dropout_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask) { + out->mutable_data(dev_ctx.GetPlace()); + float dropout_prob = p; + bool upscale_in_train = (mode == "upscale_in_train"); + mask->mutable_data(dev_ctx.GetPlace()); + + paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, + is_test, + mode, + dropout_prob, + upscale_in_train, + fix_seed, + seed, + x, + seed_tensor.get_ptr(), + mask, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout, + GPU, + ALL_LAYOUT, + phi::DropoutRawKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/dropout_sig.cc b/paddle/phi/ops/compat/dropout_sig.cc new file mode 100644 index 00000000000..6bf229c98bd --- /dev/null +++ b/paddle/phi/ops/compat/dropout_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "dropout", + {"X", "Seed"}, + {"dropout_prob", "is_test", "dropout_implementation", "seed", "fix_seed"}, + {"Out", "Mask"}); +} + +KernelSignature DropoutGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("dropout_grad", + {"Mask", GradVarName("Out")}, + {"dropout_prob", "is_test", "dropout_implementation"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(dropout, phi::DropoutOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(dropout_grad, phi::DropoutGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index f670f7c3809..fd2f642b770 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -933,5 +933,65 @@ class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase): self.check_static_result(place=place) +class TestDropoutBackward(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def cal_grad_upscale_train(self, mask, prob): + return mask.astype("float32") / (1 - prob) + + def cal_grad_downscale_in_infer(self, mask): + return mask.astype("float32") + + def test_backward_downscale_in_infer(self): + for place in self.places: + with fluid.dygraph.guard(place): + + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', 0.5) + out.backward() + + self.assertTrue( + np.array_equal(input.gradient( + ), self.cal_grad_downscale_in_infer(mask.numpy()))) + + def test_backward_upscale_train(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.5 + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + def test_backward_upscale_train_2(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.3 + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + if __name__ == '__main__': + paddle.enable_static() unittest.main() -- GitLab