/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #pragma once #include #include #include #include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/generator.h" #include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { using Tensor = framework::Tensor; template using EigenMatrix = framework::EigenMatrix; template using EigenVector = framework::EigenVector; template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* seed = context.HasInput("Seed") ? context.Input("Seed") : nullptr; auto* y = context.Output("Out"); const auto* x_data = x->data(); auto* y_data = y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); auto& dropout_implementation = context.Attr("dropout_implementation"); bool upscale_in_train = (dropout_implementation == "upscale_in_train"); if (!context.Attr("is_test")) { auto* mask = context.Output("Mask"); auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = phi::product(mask->dims()); // Special case when dropout_prob is 1.0 if (dropout_prob == 1.0f) { std::memset(y_data, 0, size * sizeof(*y_data)); // NOLINT std::memset(mask_data, 0, size * sizeof(*mask_data)); // NOLINT return; } // std::minstd_rand engine; // NOTE: fixed seed should only be used in unittest or for debug. // Guarantee to use random seed in training. int seed_data = 0; if (seed) { seed_data = *(seed->data()); } else { seed_data = context.Attr("fix_seed") ? context.Attr("seed") : 0; } auto engine = framework::GetCPURandomEngine(seed_data); std::uniform_real_distribution dist(0, 1); for (size_t i = 0; i < size; ++i) { if (dist(*engine) < dropout_prob) { mask_data[i] = 0; y_data[i] = 0; } else { mask_data[i] = 1; if (upscale_in_train) { y_data[i] = x_data[i] / static_cast(1.0f - dropout_prob); } else { y_data[i] = x_data[i]; } } } } else { if (upscale_in_train) { const auto* X_data = x->data(); auto* Y_data = y->mutable_data(context.GetPlace()); #ifdef PADDLE_WITH_MKLML #pragma omp parallel for #endif for (int i = 0; i < x->numel(); i++) { Y_data[i] = X_data[i]; } } else { auto X = EigenMatrix::Reshape(*x, 1); auto Y = EigenMatrix::Reshape(*y, 1); auto& place = *context.template device_context().eigen_device(); Y.device(place) = X * static_cast(1.0f - dropout_prob); } } } }; template class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); auto dX = EigenVector::Flatten(*grad_x); auto dY = EigenVector::Flatten(*grad_y); auto& place = *context.template device_context().eigen_device(); auto& dropout_implementation = context.Attr("dropout_implementation"); if (context.Attr("is_test") == true) { if (dropout_implementation == "upscale_in_train") { dX.device(place) = static_cast(1) * dY; } else { float dropout_prob = context.Attr("dropout_prob"); dX.device(place) = dY * static_cast(1.0f - dropout_prob); } } else { auto M = EigenVector::Flatten(*mask); if (dropout_implementation == "upscale_in_train") { float dropout_prob = context.Attr("dropout_prob"); if (dropout_prob == 1.0f) { dX.device(place) = static_cast(0) * dY; } else { dX.device(place) = dY * M.cast() / static_cast(1.0f - dropout_prob); } } else { dX.device(place) = dY * M.cast(); } } } }; } // namespace operators } // namespace paddle