/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include #include "paddle/framework/eigen.h" #include "paddle/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); auto* mask = context.Output("Mask"); T* mask_data = mask->mutable_data(context.GetPlace()); T* y_data = y->mutable_data(context.GetPlace()); const T* x_data = x->data(); float dropout_prob = context.op_.GetAttr("dropout_prob"); int seed = context.op_.GetAttr("seed"); std::minstd_rand engine; engine.seed(seed); std::uniform_real_distribution dist(0, 1); size_t size = framework::product(mask->dims()); for (size_t i = 0; i < size; ++i) { if (dist(engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { mask_data[i] = 1; y_data[i] = (1 - dropout_prob) * x_data[i]; } } } }; template struct MaskGenerator { float dropout_prob_; int seed_; __host__ __device__ MaskGenerator(float dropout_prob, int seed) : dropout_prob_(dropout_prob), seed_(seed) {} __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed_); thrust::uniform_real_distribution dist(0, 1); rng.discard(n); if (dist(rng) < dropout_prob_) { return static_cast(0); } else { return static_cast(1); } } }; // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); auto* mask = context.Output("Mask"); y->mutable_data(context.GetPlace()); float dropout_prob = context.op_.GetAttr("dropout_prob"); int seed = context.op_.GetAttr("seed"); thrust::counting_iterator index_sequence_begin(0); int size = framework::product(mask->dims()); T* mask_data = mask->mutable_data(context.GetPlace()); thrust::transform(index_sequence_begin, index_sequence_begin + size, thrust::device_ptr(mask_data), MaskGenerator(dropout_prob, seed)); auto dims = x->dims(); auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); auto X = EigenMatrix::From(*x, new_dims); auto Y = EigenMatrix::From(*y, new_dims); auto M = EigenMatrix::From(*mask, new_dims); auto place = context.GetEigenDevice(); Y.device(place) = X * M * (1 - dropout_prob); } }; template class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); auto dims = grad_x->dims(); int size = static_cast(framework::product(dims)); auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); auto M = EigenMatrix::From(*mask, new_dims); auto dX = EigenMatrix::From(*grad_x, new_dims); auto dY = EigenMatrix::From(*grad_y, new_dims); auto place = context.GetEigenDevice(); float dropout_prob = context.op_.GetAttr("dropout_prob"); dX.device(place) = dY * M * (1 - dropout_prob); } }; } // namespace operators } // namespace paddle