diff --git a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc index 8c61200f7f57cdf57b372c37c8f7cea40c4a8d4c..b69292827aa136fd1d8a1f66d80823e6344a6174 100644 --- a/paddle/fluid/inference/tensorrt/convert/dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/dropout_op.cc @@ -89,5 +89,5 @@ class DropoutOpConverter : public OpConverter { } // namespace inference } // namespace paddle -USE_OP(dropout); +USE_OP_ITSELF(dropout); REGISTER_TRT_OP_CONVERTER(dropout, DropoutOpConverter); diff --git a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc index 474fd92071fb0795b868f0cd86591061cf8b6581..cf377396087637f115523ddc60a468e2a23d57d4 100644 --- a/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc +++ b/paddle/fluid/inference/tensorrt/convert/test_dropout_op.cc @@ -57,4 +57,4 @@ TEST(DropoutOpConverter, main) { } // namespace inference } // namespace paddle -USE_OP(dropout); +USE_OP_ITSELF(dropout); diff --git a/paddle/fluid/operators/assign_op_npu_test.cc b/paddle/fluid/operators/assign_op_npu_test.cc index 72488a932d9c33cbfeddc9f35818e42ebe0137fa..b452dea8536dd98d6d4060d5224e39daf9137c50 100644 --- a/paddle/fluid/operators/assign_op_npu_test.cc +++ b/paddle/fluid/operators/assign_op_npu_test.cc @@ -23,7 +23,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc b/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc index c0968581acda9950aaa8ee2b8f3af15e1db59a67..7206dd01bcaa3e588cc275c2fdf25e70aacc1663 100644 --- a/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allgather_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc b/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc index 31b00a93f1396564907a7872e919ba6c96f666d8..0946ad8aca65e28835ea1d139fb94c309ce840a1 100644 --- a/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allreduce_max_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc index 9c11704704ed420b14a6ccd9873e0bfbe143b4fe..61e5f27903477972ef10465ccfd6f8de8ce8fba6 100644 --- a/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_allreduce_sum_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc b/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc index 5787090e6a52f2f37bd504a904108cd1d24caf5f..cf4d6a28744b368212fe8bcb0924001aa53b5a4e 100644 --- a/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_broadcast_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc index c79b2f92b69a1e6cc5c6f1cf17fa402c671a1997..c4e410d04da5fb5e9b6bfe4d7d5c263084889f54 100644 --- a/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_reduce_sum_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc index d9a7a4abb08fc883b9b9210fcdefd56af127263a..8b498787c69db0f978acaa68ba63883270e11eb4 100644 --- a/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_reducescatter_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc index b8abf458c1c6d395fef08238abaa114ff5dc6e9e..133085ad3f3b0ffd00dbf4d026687b0311116951 100644 --- a/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc +++ b/paddle/fluid/operators/collective/c_sync_comm_stream_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/checknumeric_npu_test.cc b/paddle/fluid/operators/collective/checknumeric_npu_test.cc index bb78971734bf05e94f7b0ebc1f1540b254f98067..36c6f4fadd0fcc9b06c61d5c45ce6829f2d3d977 100644 --- a/paddle/fluid/operators/collective/checknumeric_npu_test.cc +++ b/paddle/fluid/operators/collective/checknumeric_npu_test.cc @@ -27,7 +27,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc b/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc index 8f7b8c4a9040be3a2b4540c693c128e92c06a180..6e02d362156970cdee7257c7d00b70cef0519757 100644 --- a/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc +++ b/paddle/fluid/operators/collective/recv_v2_op_npu_test.cc @@ -26,7 +26,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/collective/send_v2_op_npu_test.cc b/paddle/fluid/operators/collective/send_v2_op_npu_test.cc index c40b2c3e76a02ce6e5e754b2dc4280d6917145e7..57e3dd53cc7748fa0fb66e7e934a1c9cd764a15f 100644 --- a/paddle/fluid/operators/collective/send_v2_op_npu_test.cc +++ b/paddle/fluid/operators/collective/send_v2_op_npu_test.cc @@ -25,7 +25,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/dropout_impl.cu.h b/paddle/fluid/operators/dropout_impl.cu.h index 17665ad67e40e8b73e63f37147c62f8566ab68f0..144198367d538e178a745c22902bb77a65f45fe4 100644 --- a/paddle/fluid/operators/dropout_impl.cu.h +++ b/paddle/fluid/operators/dropout_impl.cu.h @@ -32,10 +32,9 @@ limitations under the License. */ #include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/fluid/operators/dropout_impl_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" -#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/phi/kernels/funcs/aligned_vector.h" +#include "paddle/fluid/platform/aligned_vector.h" +#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/kernels/funcs/functors.h" namespace paddle { @@ -177,12 +176,13 @@ __global__ void DropoutGradCUDAKernel( } template -void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, - bool is_test, +void DropoutFwGPUKernelDriver(const phi::GPUContext& dev_ctx, bool is_test, const std::string dropout_implementation, float dropout_prob, bool upscale_in_train, - bool is_fix_seed, int seed_val, const Tensor& x, - const Tensor* seed, Tensor* mask, Tensor* y) { + bool is_fix_seed, int seed_val, + const framework::Tensor& x, + const framework::Tensor* seed, + framework::Tensor* mask, framework::Tensor* y) { auto& place = *dev_ctx.eigen_device(); int64_t x_numel = x.numel(); auto stream = dev_ctx.stream(); @@ -220,7 +220,8 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, // VectorizedRandomGenerator use curand_uniform4, so we only support // vec_size is 4; int vec_size = (phi::GetVectorizedSize(x_data) == 4) ? 4 : 1; - auto gpu_config = GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); + auto gpu_config = + phi::backends::gpu::GetGpuLaunchConfig1D(dev_ctx, x_numel, vec_size); auto offset = ((x_numel - 1) / (gpu_config.GetThreadNum() * vec_size) + 1) * vec_size; @@ -278,11 +279,13 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, } template -void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx, +void DropoutGradGPUKernelDriver(const phi::GPUContext& dev_ctx, const std::string dropout_implementation, - float dropout_prob, const Tensor& grad_y, - const Tensor& mask, int64_t size, - Tensor* grad_x, bool is_test = false) { + float dropout_prob, + const framework::Tensor& grad_y, + const framework::Tensor& mask, int64_t size, + framework::Tensor* grad_x, + bool is_test = false) { using MT = typename details::MPTypeTrait::Type; auto stream = dev_ctx.stream(); MT factor; diff --git a/paddle/fluid/operators/dropout_impl_util.h b/paddle/fluid/operators/dropout_impl_util.h index d7db7dddce3887ca25ea1df34048f15663b2e987..c62d45570ba291dc60120c393d21842cc6548c61 100644 --- a/paddle/fluid/operators/dropout_impl_util.h +++ b/paddle/fluid/operators/dropout_impl_util.h @@ -20,7 +20,7 @@ limitations under the License. */ namespace paddle { namespace operators { -inline void GetSeedDataAndIncrement(const platform::CUDADeviceContext& dev_ctx, +inline void GetSeedDataAndIncrement(const phi::GPUContext& dev_ctx, const framework::Tensor* seed, const bool is_fix_seed, const int seed_val, const int offset, uint64_t* seed_data, diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 7613b04bccfdc2084decc0b383eec199f7e10991..6d52ce45c4c10099dbeb4d4fadbf91f8c390ef46 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -12,9 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/dropout_op.h" #include #include +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -177,14 +177,3 @@ REGISTER_OPERATOR(dropout, ops::DropoutOp, ops::DropoutOpMaker, ops::DropoutGradOpMaker, ops::DropoutGradOpMaker); REGISTER_OPERATOR(dropout_grad, ops::DropoutOpGrad); -REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel, - ops::CPUDropoutKernel, - ops::CPUDropoutKernel); -REGISTER_OP_CPU_KERNEL( - dropout_grad, - ops::DropoutGradKernel, - ops::DropoutGradKernel, - ops::DropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu deleted file mode 100644 index f6ddff1d0327d3c7961781f875da69f89df1edec..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/dropout_op.cu +++ /dev/null @@ -1,94 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/dropout_impl.cu.h" -#include "paddle/fluid/operators/dropout_op.h" -#include "paddle/fluid/platform/bfloat16.h" -#include "paddle/fluid/platform/float16.h" - -namespace paddle { -namespace operators { - -// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. -// Use std::random and thrust::random(thrust is a std library in CUDA) to -// implement uniform random. -template -class GPUDropoutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* seed = - context.HasInput("Seed") ? context.Input("Seed") : nullptr; - auto* y = context.Output("Out"); - y->mutable_data(context.GetPlace()); - float dropout_prob = context.Attr("dropout_prob"); - - auto& dropout_implementation = - context.Attr("dropout_implementation"); - bool upscale_in_train = (dropout_implementation == "upscale_in_train"); - - bool is_test = context.Attr("is_test"); - - auto& dev_ctx = context.cuda_device_context(); - auto* mask = context.Output("Mask"); - mask->mutable_data(context.GetPlace()); - - bool is_fix_seed = context.Attr("fix_seed"); - int seed_val = context.Attr("seed"); - DropoutFwGPUKernelDriver(dev_ctx, is_test, dropout_implementation, - dropout_prob, upscale_in_train, is_fix_seed, - seed_val, *x, seed, mask, y); - } -}; - -template -class GPUDropoutGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = context.Output(framework::GradVarName("X")); - auto* grad_y = context.Input(framework::GradVarName("Out")); - auto* mask = context.Input("Mask"); - grad_x->mutable_data(context.GetPlace()); - auto size = grad_x->numel(); - auto& dropout_implementation = - context.Attr("dropout_implementation"); - float dropout_prob = context.Attr("dropout_prob"); - - bool is_test = context.Attr("is_test"); - - auto& dev_ctx = - context.template device_context(); - DropoutGradGPUKernelDriver(dev_ctx, dropout_implementation, dropout_prob, - *grad_y, *mask, size, grad_x, is_test); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -namespace plat = paddle::platform; -REGISTER_OP_CUDA_KERNEL( - dropout, ops::GPUDropoutKernel, - ops::GPUDropoutKernel, - ops::GPUDropoutKernel, - ops::GPUDropoutKernel); -REGISTER_OP_CUDA_KERNEL( - dropout_grad, ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel, - ops::GPUDropoutGradKernel); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h deleted file mode 100644 index ea6ed0e61947470c22f18e47acce2fca4cb9c41f..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/dropout_op.h +++ /dev/null @@ -1,151 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#pragma once - -#include -#include -#include - -#include -#include "paddle/fluid/framework/eigen.h" -#include "paddle/fluid/framework/generator.h" -#include "paddle/fluid/framework/op_registry.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; -template -using EigenMatrix = framework::EigenMatrix; - -template -using EigenVector = framework::EigenVector; - -template -class CPUDropoutKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* x = context.Input("X"); - auto* seed = - context.HasInput("Seed") ? context.Input("Seed") : nullptr; - auto* y = context.Output("Out"); - const auto* x_data = x->data(); - auto* y_data = y->mutable_data(context.GetPlace()); - float dropout_prob = context.Attr("dropout_prob"); - - auto& dropout_implementation = - context.Attr("dropout_implementation"); - bool upscale_in_train = (dropout_implementation == "upscale_in_train"); - if (!context.Attr("is_test")) { - auto* mask = context.Output("Mask"); - auto* mask_data = mask->mutable_data(context.GetPlace()); - size_t size = phi::product(mask->dims()); - - // Special case when dropout_prob is 1.0 - if (dropout_prob == 1.0f) { - std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT - std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT - return; - } - // std::minstd_rand engine; - // NOTE: fixed seed should only be used in unittest or for debug. - // Guarantee to use random seed in training. - int seed_data = 0; - if (seed) { - seed_data = *(seed->data()); - } else { - seed_data = - context.Attr("fix_seed") ? context.Attr("seed") : 0; - } - auto engine = framework::GetCPURandomEngine(seed_data); - - std::uniform_real_distribution dist(0, 1); - - for (size_t i = 0; i < size; ++i) { - if (dist(*engine) < dropout_prob) { - mask_data[i] = 0; - y_data[i] = 0; - } else { - mask_data[i] = 1; - if (upscale_in_train) { - y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); - } else { - y_data[i] = x_data[i]; - } - } - } - } else { - if (upscale_in_train) { - const auto* X_data = x->data(); - auto* Y_data = y->mutable_data(context.GetPlace()); -#ifdef PADDLE_WITH_MKLML -#pragma omp parallel for -#endif - for (int i = 0; i < x->numel(); i++) { - Y_data[i] = X_data[i]; - } - } else { - auto X = EigenMatrix::Reshape(*x, 1); - auto Y = EigenMatrix::Reshape(*y, 1); - auto& place = - *context.template device_context().eigen_device(); - Y.device(place) = X * static_cast(1.0f - dropout_prob); - } - } - } -}; -template -class DropoutGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& context) const override { - auto* grad_x = context.Output(framework::GradVarName("X")); - auto* grad_y = context.Input(framework::GradVarName("Out")); - auto* mask = context.Input("Mask"); - grad_x->mutable_data(context.GetPlace()); - - auto dX = EigenVector::Flatten(*grad_x); - auto dY = EigenVector::Flatten(*grad_y); - - auto& place = - *context.template device_context().eigen_device(); - auto& dropout_implementation = - context.Attr("dropout_implementation"); - if (context.Attr("is_test") == true) { - if (dropout_implementation == "upscale_in_train") { - dX.device(place) = static_cast(1) * dY; - } else { - float dropout_prob = context.Attr("dropout_prob"); - dX.device(place) = dY * static_cast(1.0f - dropout_prob); - } - } else { - auto M = EigenVector::Flatten(*mask); - if (dropout_implementation == "upscale_in_train") { - float dropout_prob = context.Attr("dropout_prob"); - if (dropout_prob == 1.0f) { - dX.device(place) = static_cast(0) * dY; - } else { - dX.device(place) = - dY * M.cast() / static_cast(1.0f - dropout_prob); - } - } else { - dX.device(place) = dY * M.cast(); - } - } - } -}; - -} // namespace operators -} // namespace paddle diff --git a/paddle/fluid/operators/dropout_op_npu.cc b/paddle/fluid/operators/dropout_op_npu.cc index 6aae566760623c666f3ce82a890a119e3e173390..07b3b5381162575cbfc03dd8cc10d0c88a2d21e8 100644 --- a/paddle/fluid/operators/dropout_op_npu.cc +++ b/paddle/fluid/operators/dropout_op_npu.cc @@ -15,8 +15,8 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/phi/core/ddim.h" diff --git a/paddle/fluid/operators/dropout_op_test.cc b/paddle/fluid/operators/dropout_op_test.cc index 206d9a6c5e9c9869216f0a6c137accc931aa2a77..bdf08646f1d8b94d6d8d141d8a9fa9864cdc937b 100644 --- a/paddle/fluid/operators/dropout_op_test.cc +++ b/paddle/fluid/operators/dropout_op_test.cc @@ -24,14 +24,13 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace f = paddle::framework; namespace p = paddle::platform; -USE_OP(dropout); +USE_OP_ITSELF(dropout); void Compare(f::Scope* scope, const p::DeviceContext& ctx) { // init diff --git a/paddle/fluid/operators/dropout_op_xpu.cc b/paddle/fluid/operators/dropout_op_xpu.cc index 07b7e2cc7c09b09d6640f49fce438d58d0cc9cf2..7d8660f238abc8446b2988aad24a64c565e01ef9 100644 --- a/paddle/fluid/operators/dropout_op_xpu.cc +++ b/paddle/fluid/operators/dropout_op_xpu.cc @@ -8,15 +8,17 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/dropout_op.h" + #include #include +#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/device/device_wrapper.h" namespace paddle { namespace operators { #ifdef PADDLE_WITH_XPU +using Tensor = framework::Tensor; template class DropoutXPUKernel : public framework::OpKernel { using XPUTyp = typename XPUTypeTrait::Type; diff --git a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc index fc128a88f2096a26141ff7922b1d9166b8302ded..3e9263fe93acd93638ff9e496203b7ea432cea86 100644 --- a/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc +++ b/paddle/fluid/operators/elementwise/elementwise_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/expand_op_npu_test.cc b/paddle/fluid/operators/expand_op_npu_test.cc index cdd4e1dbaae6a6a74bb11be44589877234021764..df00ae54c1036b1b0f0899eb0a949d58c398aa48 100644 --- a/paddle/fluid/operators/expand_op_npu_test.cc +++ b/paddle/fluid/operators/expand_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/fused/fmha_ref.h b/paddle/fluid/operators/fused/fmha_ref.h index 020277675797358bf87a58ac108e6eaaddb26ccc..3c9e16785eac814ffba34455d635d798042cdf43 100644 --- a/paddle/fluid/operators/fused/fmha_ref.h +++ b/paddle/fluid/operators/fused/fmha_ref.h @@ -140,9 +140,9 @@ class FMHARef { if (dropout_param_.dropout_prob_) { DropoutFwGPUKernelDriver( - dev_ctx_, dropout_param_.is_test_, - static_cast( - dropout_param_.dropout_implementation_), + static_cast(dev_ctx_), + dropout_param_.is_test_, static_cast( + dropout_param_.dropout_implementation_), dropout_param_.dropout_prob_, dropout_param_.is_upscale_in_train_, dropout_param_.is_fix_seed_, dropout_param_.seed_val_, static_cast(*softmax_out_tensor), dropout_param_.seed_, @@ -242,8 +242,9 @@ class FMHARef { // dropout bw if (dropout_param_.dropout_prob_) { DropoutGradGPUKernelDriver( - dev_ctx_, static_cast( - dropout_param_.dropout_implementation_), + static_cast(dev_ctx_), + static_cast( + dropout_param_.dropout_implementation_), dropout_param_.dropout_prob_, static_cast(*dropout_out_grad_tensor), dropout_mask_out_tensor, softmax_out_grad_tensor->numel(), diff --git a/paddle/fluid/operators/fused/fused_dropout_test.h b/paddle/fluid/operators/fused/fused_dropout_test.h index d7952df470d81566c3833e79e8cfa31a7d2dc68c..18c7187fc8e64c9fed8a86a984954b5420c1e5b5 100644 --- a/paddle/fluid/operators/fused/fused_dropout_test.h +++ b/paddle/fluid/operators/fused/fused_dropout_test.h @@ -31,7 +31,7 @@ namespace framework = paddle::framework; namespace platform = paddle::platform; namespace memory = paddle::memory; -USE_OP(dropout); +USE_OP_ITSELF(dropout); USE_OP(layer_norm); template diff --git a/paddle/fluid/operators/gelu_op_npu_test.cc b/paddle/fluid/operators/gelu_op_npu_test.cc index 00ff7ad2166dcf99d7b60ec45adfe70b478dedf8..f3ac53138328dbfad12c6d530a6517f40c658677 100644 --- a/paddle/fluid/operators/gelu_op_npu_test.cc +++ b/paddle/fluid/operators/gelu_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/increment_op_npu_test.cc b/paddle/fluid/operators/increment_op_npu_test.cc index 09f4e63943ad3784a598524273831bf875ed9213..8324a6215bca8145ba36dabb3d8108006a57e829 100644 --- a/paddle/fluid/operators/increment_op_npu_test.cc +++ b/paddle/fluid/operators/increment_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/range_op_npu_test.cc b/paddle/fluid/operators/range_op_npu_test.cc index 24741efe426b18b7cecae9332c522d67aee98d63..c7e91ba35dee1356ddd71ade0fe9892f8032c77b 100644 --- a/paddle/fluid/operators/range_op_npu_test.cc +++ b/paddle/fluid/operators/range_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/rnn_op.h b/paddle/fluid/operators/rnn_op.h index b636184ae457edf5c8028fecfb92a3ea96f5a0d9..a473b54c1f855945a5f3f0ac8d0826b15494ba1a 100644 --- a/paddle/fluid/operators/rnn_op.h +++ b/paddle/fluid/operators/rnn_op.h @@ -16,9 +16,9 @@ limitations under the License. */ #include #include +#include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/activation_op.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/math/concat_and_split.h" #include "paddle/fluid/operators/math/fc.h" #include "paddle/fluid/operators/unique_op.h" @@ -36,6 +36,14 @@ using LoDTensor = framework::LoDTensor; using Tensor = framework::Tensor; using TensorList = std::vector; +template +using EigenMatrix = framework::EigenMatrix; + +template +using EigenVector = framework::EigenVector; + #define DEFINE_MODE_DETECTOR(MODE_NAME, MODE_STR) \ inline bool is_##MODE_NAME(const framework::ExecutionContext& ctx) { \ const std::string& mode = ctx.Attr("mode"); \ diff --git a/paddle/fluid/operators/softmax_op_npu_test.cc b/paddle/fluid/operators/softmax_op_npu_test.cc index 3bc55fafd81e18d0a986268ff4692129c6515edc..3148b31a8322e2bab39ad7f723ee59a6db64c204 100644 --- a/paddle/fluid/operators/softmax_op_npu_test.cc +++ b/paddle/fluid/operators/softmax_op_npu_test.cc @@ -22,7 +22,6 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/tensor_util.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/squeeze_op_npu_test.cc b/paddle/fluid/operators/squeeze_op_npu_test.cc index 956544c53609eb29326dc5cf295d978d767ac176..d61f5aa3f634cd2aee1e5c2f34f4467b1697e455 100644 --- a/paddle/fluid/operators/squeeze_op_npu_test.cc +++ b/paddle/fluid/operators/squeeze_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/transpose_op_npu_test.cc b/paddle/fluid/operators/transpose_op_npu_test.cc index 5617d728a51dc1c5e21053a2af05d062ecc1a22b..fb39034c8e92c1ac39aa1ca6e57d5a08ca1ca9d6 100644 --- a/paddle/fluid/operators/transpose_op_npu_test.cc +++ b/paddle/fluid/operators/transpose_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/fluid/operators/unsqueeze_op_npu_test.cc b/paddle/fluid/operators/unsqueeze_op_npu_test.cc index 3e11c952d15f3460f987f6fa2cb28970f97cc96b..a8ced783744a961eb8ce64983de7e9615763c1b6 100644 --- a/paddle/fluid/operators/unsqueeze_op_npu_test.cc +++ b/paddle/fluid/operators/unsqueeze_op_npu_test.cc @@ -24,7 +24,6 @@ limitations under the License. */ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" -#include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/string/printf.h" #include "paddle/phi/kernels/funcs/math_function.h" diff --git a/paddle/phi/kernels/cpu/dropout_grad_kernel.cc b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b77a6c55b14716e2747a2cb76d4b1bda380a2d02 --- /dev/null +++ b/paddle/phi/kernels/cpu/dropout_grad_kernel.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dropout_grad_kernel.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad) { + auto* grad_x = x_grad; + auto* grad_y = &out_grad; + grad_x->mutable_data(dev_ctx.GetPlace()); + + auto dX = EigenVector::Flatten(*grad_x); + auto dY = EigenVector::Flatten(*grad_y); + + auto& place = *dev_ctx.eigen_device(); + auto& dropout_implementation = mode; + if (is_test == true) { + if (dropout_implementation == "upscale_in_train") { + dX.device(place) = static_cast(1) * dY; + } else { + dX.device(place) = dY * static_cast(1.0f - p); + } + } else { + auto M = EigenVector::Flatten(mask); + if (dropout_implementation == "upscale_in_train") { + if (p == 1.0f) { + dX.device(place) = static_cast(0) * dY; + } else { + dX.device(place) = dY * M.cast() / static_cast(1.0f - p); + } + } else { + dX.device(place) = dY * M.cast(); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout_grad, + CPU, + ALL_LAYOUT, + phi::DropoutGradRawKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/cpu/dropout_kernel.cc b/paddle/phi/kernels/cpu/dropout_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..c00aedef8c67d5b88efb29c746c70f5f7859507a --- /dev/null +++ b/paddle/phi/kernels/cpu/dropout_kernel.cc @@ -0,0 +1,104 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/dropout_kernel.h" +#include "paddle/fluid/framework/generator.h" +#include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/eigen/common.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask) { + auto* y = out; + const auto* x_data = x.data(); + auto* y_data = y->mutable_data(dev_ctx.GetPlace()); + float dropout_prob = p; + + auto& dropout_implementation = mode; + bool upscale_in_train = (dropout_implementation == "upscale_in_train"); + if (!is_test) { + auto* mask_data = mask->mutable_data(dev_ctx.GetPlace()); + size_t size = phi::product(mask->dims()); + + // Special case when dropout_prob is 1.0 + if (dropout_prob == 1.0f) { + std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT + std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT + return; + } + // std::minstd_rand engine; + // NOTE: fixed seed should only be used in unittest or for debug. + // Guarantee to use random seed in training. + int seed_data = 0; + if (seed_tensor.get_ptr() != nullptr) { + seed_data = *(seed_tensor->data()); + } else { + seed_data = fix_seed ? seed : 0; + } + auto engine = paddle::framework::GetCPURandomEngine(seed_data); + + std::uniform_real_distribution dist(0, 1); + + for (size_t i = 0; i < size; ++i) { + if (dist(*engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + if (upscale_in_train) { + y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); + } else { + y_data[i] = x_data[i]; + } + } + } + } else { + if (upscale_in_train) { + const auto* X_data = x.data(); + auto* Y_data = y->mutable_data(dev_ctx.GetPlace()); +#ifdef PADDLE_WITH_MKLML +#pragma omp parallel for +#endif + for (int i = 0; i < x.numel(); i++) { + Y_data[i] = X_data[i]; + } + } else { + auto X = EigenMatrix::Reshape(x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto& place = *dev_ctx.eigen_device(); + Y.device(place) = X * static_cast(1.0f - dropout_prob); + } + } +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout, + CPU, + ALL_LAYOUT, + phi::DropoutRawKernel, + float, + double, + phi::dtype::bfloat16) {} diff --git a/paddle/phi/kernels/dropout_grad_kernel.h b/paddle/phi/kernels/dropout_grad_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..ae3f82056632ddde8968b7468eb16030f0c926f5 --- /dev/null +++ b/paddle/phi/kernels/dropout_grad_kernel.h @@ -0,0 +1,31 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad); + +} // namespace phi diff --git a/paddle/phi/kernels/dropout_kernel.h b/paddle/phi/kernels/dropout_kernel.h new file mode 100644 index 0000000000000000000000000000000000000000..dc9f89e08e17ac6c8113f28be6604d638cd6cd3a --- /dev/null +++ b/paddle/phi/kernels/dropout_kernel.h @@ -0,0 +1,34 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/phi/common/scalar.h" +#include "paddle/phi/core/dense_tensor.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask); + +} // namespace phi diff --git a/paddle/phi/kernels/gpu/dropout_grad_kernel.cu b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..94d4942a41878f06fbe7a0ffa6e8cc2c3f42f159 --- /dev/null +++ b/paddle/phi/kernels/gpu/dropout_grad_kernel.cu @@ -0,0 +1,46 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/phi/kernels/dropout_grad_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DropoutGradRawKernel(const Context& dev_ctx, + const DenseTensor& mask, + const DenseTensor& out_grad, + float p, + bool is_test, + const std::string& mode, + DenseTensor* x_grad) { + x_grad->mutable_data(dev_ctx.GetPlace()); + auto size = x_grad->numel(); + paddle::operators::DropoutGradGPUKernelDriver( + dev_ctx, mode, p, out_grad, mask, size, x_grad, is_test); +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout_grad, + GPU, + ALL_LAYOUT, + phi::DropoutGradRawKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/gpu/dropout_kernel.cu b/paddle/phi/kernels/gpu/dropout_kernel.cu new file mode 100644 index 0000000000000000000000000000000000000000..bd1683ad0c7d8c9a91721a15ae0c24046aab414e --- /dev/null +++ b/paddle/phi/kernels/gpu/dropout_kernel.cu @@ -0,0 +1,61 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/dropout_impl.cu.h" +#include "paddle/phi/kernels/dropout_kernel.h" + +#include "paddle/phi/backends/gpu/gpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void DropoutRawKernel(const Context& dev_ctx, + const DenseTensor& x, + paddle::optional seed_tensor, + float p, + bool is_test, + const std::string& mode, + int seed, + bool fix_seed, + DenseTensor* out, + DenseTensor* mask) { + out->mutable_data(dev_ctx.GetPlace()); + float dropout_prob = p; + bool upscale_in_train = (mode == "upscale_in_train"); + mask->mutable_data(dev_ctx.GetPlace()); + + paddle::operators::DropoutFwGPUKernelDriver(dev_ctx, + is_test, + mode, + dropout_prob, + upscale_in_train, + fix_seed, + seed, + x, + seed_tensor.get_ptr(), + mask, + out); +} + +} // namespace phi + +PD_REGISTER_KERNEL(dropout, + GPU, + ALL_LAYOUT, + phi::DropoutRawKernel, + float, + double, + phi::dtype::bfloat16, + phi::dtype::float16) {} diff --git a/paddle/phi/ops/compat/dropout_sig.cc b/paddle/phi/ops/compat/dropout_sig.cc new file mode 100644 index 0000000000000000000000000000000000000000..6bf229c98bd07face5a5ba7778318cf1662f29a9 --- /dev/null +++ b/paddle/phi/ops/compat/dropout_sig.cc @@ -0,0 +1,38 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/core/compat/op_utils.h" + +namespace phi { + +KernelSignature DropoutOpArgumentMapping(const ArgumentMappingContext& ctx) { + return KernelSignature( + "dropout", + {"X", "Seed"}, + {"dropout_prob", "is_test", "dropout_implementation", "seed", "fix_seed"}, + {"Out", "Mask"}); +} + +KernelSignature DropoutGradOpArgumentMapping( + const ArgumentMappingContext& ctx) { + return KernelSignature("dropout_grad", + {"Mask", GradVarName("Out")}, + {"dropout_prob", "is_test", "dropout_implementation"}, + {GradVarName("X")}); +} + +} // namespace phi + +PD_REGISTER_ARG_MAPPING_FN(dropout, phi::DropoutOpArgumentMapping); +PD_REGISTER_ARG_MAPPING_FN(dropout_grad, phi::DropoutGradOpArgumentMapping); diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index f670f7c38097b39723e22d15d152590fa19c607d..fd2f642b770d646e74168800bbe8820534581354 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -933,5 +933,65 @@ class TestDropoutWithDeterminateSeedGenerator(unittest.TestCase): self.check_static_result(place=place) +class TestDropoutBackward(unittest.TestCase): + def setUp(self): + np.random.seed(123) + self.places = [fluid.CPUPlace()] + if core.is_compiled_with_cuda(): + self.places.append(fluid.CUDAPlace(0)) + + def cal_grad_upscale_train(self, mask, prob): + return mask.astype("float32") / (1 - prob) + + def cal_grad_downscale_in_infer(self, mask): + return mask.astype("float32") + + def test_backward_downscale_in_infer(self): + for place in self.places: + with fluid.dygraph.guard(place): + + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', 0.5) + out.backward() + + self.assertTrue( + np.array_equal(input.gradient( + ), self.cal_grad_downscale_in_infer(mask.numpy()))) + + def test_backward_upscale_train(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.5 + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + def test_backward_upscale_train_2(self): + for place in self.places: + with fluid.dygraph.guard(place): + + prob = 0.3 + input = paddle.uniform([40, 40], dtype="float32") + input.stop_gradient = False + out, mask = core.ops.dropout(input, 'dropout_prob', prob, + "dropout_implementation", + "upscale_in_train") + out.backward() + + self.assertTrue( + np.allclose(input.gradient( + ), self.cal_grad_upscale_train(mask.numpy(), prob))) + + if __name__ == '__main__': + paddle.enable_static() unittest.main()