From 9a44f3d6dabb676aad0c63854c115aa75247bf84 Mon Sep 17 00:00:00 2001 From: Xinghai Sun Date: Sat, 2 Sep 2017 18:33:58 +0800 Subject: [PATCH] Add dropout operator. --- paddle/operators/dropout_op.cc | 81 ++++++++++++++++++++++++++++++++++ paddle/operators/dropout_op.cu | 22 +++++++++ paddle/operators/dropout_op.h | 70 +++++++++++++++++++++++++++++ paddle/pybind/pybind.cc | 1 + 4 files changed, 174 insertions(+) create mode 100644 paddle/operators/dropout_op.cc create mode 100644 paddle/operators/dropout_op.cu create mode 100644 paddle/operators/dropout_op.h diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc new file mode 100644 index 00000000000..a9950a48e0f --- /dev/null +++ b/paddle/operators/dropout_op.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class DropoutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + auto dims = ctx.Input("X")->dims(); + ctx.Output("Out")->Resize(dims); + ctx.Output("Mask")->Resize(dims); + } +}; + +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DropoutOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of dropout op."); + AddOutput("Out", "The output of dropout op."); + AddOutput("Mask", "The dropout mask.").AsIntermediate(); + + AddComment(R"DOC(Dropout Operator.)DOC"); + } +}; + +class DropoutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + auto x_dims = ctx.Input("X")->dims(); + auto mask_dims = ctx.Input("Mask")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(x_dims, out_dims, + "Dimensions of Input(X) and Out must be the same."); + PADDLE_ENFORCE_EQ(x_dims, mask_dims, + "Dimensions of Input(X) and Mask must be the same."); + + auto *x_grad = ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); +REGISTER_OP_CPU_KERNEL(dropout, + ops::DropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu new file mode 100644 index 00000000000..9e9efaa3b1a --- /dev/null +++ b/paddle/operators/dropout_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/dropout_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(dropout, + ops::DropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h new file mode 100644 index 00000000000..d5d32df74b7 --- /dev/null +++ b/paddle/operators/dropout_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class DropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + auto* mask = context.Output("Mask"); + mask->mutable_data(context.GetPlace()); + y->mutable_data(context.GetPlace()); + + auto dims = x->dims(); + auto X = EigenMatrix::From(*x); + auto Y = EigenMatrix::From(*y); + auto M = EigenMatrix::From(*mask); + + auto place = context.GetEigenDevice(); + M.device(place).setRandom(); + float dropout_prob = context.op_.GetAttr("dropout_prob"); + M.device(place) = (M > dropout_prob).cast(); + Y.device(place) = X * Y; + } +}; + +template +class DropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + + auto dims = grad_x->dims(); + auto M = EigenMatrix::From(*mask); + auto dX = EigenMatrix::From(*grad_x); + auto dY = EigenMatrix::From(*grad_y); + + auto place = context.GetEigenDevice(); + dX.device(place) = dY * M; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3bc150ccb7a..42fce51024e 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -46,6 +46,7 @@ USE_OP(lookup_table); USE_OP(scale); USE_OP_ITSELF(identity); USE_OP(minus); +USE_OP(dropout); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter); -- GitLab