/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #define EIGEN_USE_GPU #include #include #include #include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { template __global__ void RandomGenerator(const size_t n, const int seed, const AttrType dropout_prob, const T* src, T* mask_data, T* dst) { thrust::minstd_rand rng; rng.seed(seed); thrust::uniform_real_distribution dist(0, 1); int idx = blockDim.x * blockIdx.x + threadIdx.x; for (; idx < n; idx += blockDim.x * gridDim.x) { if (dist(rng) < dropout_prob) { mask_data[idx] = static_cast(0); } else { mask_data[idx] = static_cast(1); } dst[idx] = mask_data[idx] * src[idx]; } } // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); AttrType dropout_prob = context.Attr("dropout_prob")); auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = framework::product(mask->dims()); auto* x_data = x->data(); auto* y_data = y->mutable_data(context.GetPlace()); std::random_device rnd; int seed = context.Attr("fix_seed") ? context.Attr("seed") : rnd(); int threads = 512; int grid = (x->numel() + threads - 1) / threads; RandomGenerator<<>>( size, seed, dropout_prob, x_data, mask_data, y_data); } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL(dropout_grad, ops::DropoutGradKernel);