/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include #include #include #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/platform/dynload/curand.h" #include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { template __global__ void RandomGenerator(const size_t n, const int seed, const float dropout_prob, const T* src, MaskType* mask_data, T* dst, bool is_upscale_in_train) { curandStatePhilox4_32_10_t state; int idx = blockDim.x * blockIdx.x + threadIdx.x; int step_size = 0; MaskType mask; T dest; for (; idx < n; idx += blockDim.x * gridDim.x) { T s = src[idx]; if (step_size == 0) { curand_init(seed, idx, idx, &state); step_size = blockDim.x * gridDim.x; } else { curand_init(seed, idx, step_size, &state); } if (curand_uniform(&state) < dropout_prob) { mask = 0; dest = 0; } else { mask = 1; if (is_upscale_in_train) { dest = s / static_cast(1.0f - dropout_prob); } else { dest = s; } } mask_data[idx] = mask; dst[idx] = dest; } } // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* seed = context.HasInput("Seed") ? context.Input("Seed") : nullptr; auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); auto& dropout_implementation = context.Attr("dropout_implementation"); bool upscale_in_train = (dropout_implementation == "upscale_in_train"); auto& place = *context.template device_context().eigen_device(); if (!context.Attr("is_test")) { int64_t x_numel = x->numel(); auto stream = context.cuda_device_context().stream(); auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = framework::product(mask->dims()); auto* x_data = x->data(); int seed_data; std::random_device rnd; if (seed) { if (platform::is_gpu_place(seed->place())) { framework::Tensor temp; TensorCopySync(*seed, platform::CPUPlace(), &temp); seed_data = *(temp.data()); } else { seed_data = *(seed->data()); } } else { seed_data = context.Attr("fix_seed") ? context.Attr("seed") : rnd(); } auto* y_data = y->mutable_data(context.GetPlace()); if (dropout_prob == 1.0f) { PADDLE_ENFORCE_CUDA_SUCCESS( cudaMemsetAsync(y_data, 0, x_numel * sizeof(T), stream)); PADDLE_ENFORCE_CUDA_SUCCESS(cudaMemsetAsync( mask_data, 0, x_numel * sizeof(*mask_data), stream)); return; } int threads = 512; int grid = (x_numel + threads - 1) / threads; RandomGenerator<<>>( size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train); } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); if (upscale_in_train) { Y.device(place) = X; } else { Y.device(place) = X * static_cast(1.0f - dropout_prob); } } } }; } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( dropout, ops::GPUDropoutKernel, ops::GPUDropoutKernel, ops::GPUDropoutKernel); REGISTER_OP_CUDA_KERNEL( dropout_grad, ops::DropoutGradKernel, ops::DropoutGradKernel, ops::DropoutGradKernel);