diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a9950a48e0f42ce46ffd27ecc972c35d82a98fdb --- /dev/null +++ b/paddle/operators/dropout_op.cc @@ -0,0 +1,81 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; + +class DropoutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + auto dims = ctx.Input("X")->dims(); + ctx.Output("Out")->Resize(dims); + ctx.Output("Mask")->Resize(dims); + } +}; + +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DropoutOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddInput("X", "The input of dropout op."); + AddOutput("Out", "The output of dropout op."); + AddOutput("Mask", "The dropout mask.").AsIntermediate(); + + AddComment(R"DOC(Dropout Operator.)DOC"); + } +}; + +class DropoutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + auto x_dims = ctx.Input("X")->dims(); + auto mask_dims = ctx.Input("Mask")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(x_dims, out_dims, + "Dimensions of Input(X) and Out must be the same."); + PADDLE_ENFORCE_EQ(x_dims, mask_dims, + "Dimensions of Input(X) and Mask must be the same."); + + auto *x_grad = ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); +REGISTER_OP_CPU_KERNEL(dropout, + ops::DropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..9e9efaa3b1af7e6905ea9ed565c7b40433010db7 --- /dev/null +++ b/paddle/operators/dropout_op.cu @@ -0,0 +1,22 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include "paddle/operators/dropout_op.h" + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL(dropout, + ops::DropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..d5d32df74b70c9d81b094d4f4ea6d6681b9573ba --- /dev/null +++ b/paddle/operators/dropout_op.h @@ -0,0 +1,70 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class DropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + auto* mask = context.Output("Mask"); + mask->mutable_data(context.GetPlace()); + y->mutable_data(context.GetPlace()); + + auto dims = x->dims(); + auto X = EigenMatrix::From(*x); + auto Y = EigenMatrix::From(*y); + auto M = EigenMatrix::From(*mask); + + auto place = context.GetEigenDevice(); + M.device(place).setRandom(); + float dropout_prob = context.op_.GetAttr("dropout_prob"); + M.device(place) = (M > dropout_prob).cast(); + Y.device(place) = X * Y; + } +}; + +template +class DropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + + auto dims = grad_x->dims(); + auto M = EigenMatrix::From(*mask); + auto dX = EigenMatrix::From(*grad_x); + auto dY = EigenMatrix::From(*grad_y); + + auto place = context.GetEigenDevice(); + dX.device(place) = dY * M; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index 3bc150ccb7af2885439cc2344aa0db9ba3b1ca03..42fce51024e1a52a22de1c6f60e5466018f895ad 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -46,6 +46,7 @@ USE_OP(lookup_table); USE_OP(scale); USE_OP_ITSELF(identity); USE_OP(minus); +USE_OP(dropout); USE_CPU_ONLY_OP(gather); USE_CPU_ONLY_OP(scatter);