diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..b111b9fccb2310bd5fb92bda878a497c51f62ce0 --- /dev/null +++ b/paddle/operators/dropout_op.cc @@ -0,0 +1,113 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +using framework::Tensor; +using framework::LoDTensor; + +class DropoutOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + + auto dims = ctx.Input("X")->dims(); + ctx.Output("Out")->Resize(dims); + if (ctx.Attr("is_training") == 1) { + ctx.Output("Mask")->Resize(dims); + } + } +}; + +template +class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { + public: + DropoutOpMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddAttr("dropout_prob", "Probability of setting units to zero.") + .SetDefault(.5f); + // TODO(xinghai-sun): use bool for is_training after bool is supported. + AddAttr("is_training", "Whether in training phase.").SetDefault(1); + AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddInput("X", "The input of dropout op."); + AddOutput("Out", "The output of dropout op."); + AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); + + AddComment(R"DOC( +Dropout Operator. + +"Dropout" refers to randomly dropping out units in a nerual network. It is a +regularization technique for reducing overfitting by preventing neuron +co-adaption during training. The dropout operator randomly set (according to +the given dropout probability) the outputs of some units to zero, while others +being set to their inputs. +)DOC"); + } +}; + +template +class DropoutOpGrad : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + protected: + void InferShape(const framework::InferShapeContext &ctx) const override { + PADDLE_ENFORCE_EQ(ctx.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); + PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), + "Input(Out@GRAD) must not be null."); + + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + auto x_dims = ctx.Input("X")->dims(); + auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); + PADDLE_ENFORCE_EQ(x_dims, out_dims, + "Dimensions of Input(X) and Out@Grad must be the same."); + auto mask_dims = ctx.Input("Mask")->dims(); + PADDLE_ENFORCE_EQ(x_dims, mask_dims, + "Dimensions of Input(X) and Mask must be the same."); + + auto *x_grad = ctx.Output(framework::GradVarName("X")); + x_grad->Resize(x_dims); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); +REGISTER_OP_CPU_KERNEL( + dropout, ops::CPUDropoutKernel); +REGISTER_OP_CPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu new file mode 100644 index 0000000000000000000000000000000000000000..186237fb238add37f32403309a0f7e8a9846d335 --- /dev/null +++ b/paddle/operators/dropout_op.cu @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#define EIGEN_USE_GPU +#include +#include +#include +#include +#include "paddle/operators/dropout_op.h" + +namespace paddle { +namespace operators { + +template +struct MaskGenerator { + AttrType dropout_prob; + int seed; + + __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) + : dropout_prob(dropout_prob), seed(seed) {} + + __host__ __device__ T operator()(const unsigned int n) const { + thrust::minstd_rand rng; + rng.seed(seed); + thrust::uniform_real_distribution dist(0, 1); + rng.discard(n); + if (dist(rng) < dropout_prob) { + return static_cast(0); + } else { + return static_cast(1); + } + } +}; + +// It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. +// Use std::random and thrust::random(thrust is a std library in CUDA) to +// implement uniform random. +template +class GPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + + auto place = context.GetEigenDevice(); + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int size = framework::product(mask->dims()); + int seed = context.Attr("seed"); + thrust::counting_iterator index_sequence_begin(0); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + auto M = EigenMatrix::Reshape(*mask, 1); + Y.device(place) = X * M; + } else { + Y.device(place) = X * dropout_prob; + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OP_GPU_KERNEL( + dropout, ops::GPUDropoutKernel); +REGISTER_OP_GPU_KERNEL( + dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h new file mode 100644 index 0000000000000000000000000000000000000000..82eafee0e0e7db7b4b4ae5405f37146d061aefd5 --- /dev/null +++ b/paddle/operators/dropout_op.h @@ -0,0 +1,86 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include +#include "paddle/framework/eigen.h" +#include "paddle/framework/op_registry.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +template +using EigenMatrix = framework::EigenMatrix; + +template +class CPUDropoutKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* x = context.Input("X"); + auto* y = context.Output("Out"); + const auto* x_data = x->data(); + auto* y_data = y->mutable_data(context.GetPlace()); + AttrType dropout_prob = context.Attr("dropout_prob"); + + if (context.Attr("is_training") == 1) { + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); + int seed = context.Attr("seed"); + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } + } + } else { + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto place = context.GetEigenDevice(); + Y.device(place) = X * dropout_prob; + } + } +}; + +template +class DropoutGradKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ(context.Attr("is_training"), 1, + "GradOp is only callable when is_training is true"); + + auto* grad_x = context.Output(framework::GradVarName("X")); + auto* grad_y = context.Input(framework::GradVarName("Out")); + auto* mask = context.Input("Mask"); + grad_x->mutable_data(context.GetPlace()); + + auto M = EigenMatrix::Reshape(*mask, 1); + auto dX = EigenMatrix::Reshape(*grad_x, 1); + auto dY = EigenMatrix::Reshape(*grad_y, 1); + + auto place = context.GetEigenDevice(); + dX.device(place) = dY * M; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py new file mode 100644 index 0000000000000000000000000000000000000000..3638fee1a1c26195791bc1f5a46dd749da0aee95 --- /dev/null +++ b/python/paddle/v2/framework/tests/test_dropout_op.py @@ -0,0 +1,59 @@ +import unittest +import numpy as np +from op_test import OpTest + + +class TestDropoutOp(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.05) + + +class TestDropoutOp2(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 1.0, 'is_training': 1} + self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} + + +class TestDropoutOp3(TestDropoutOp): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} + self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} + + +class TestDropoutOp4(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.35, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp5(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = {'dropout_prob': 0.75, 'is_training': 0} + self.outputs = {'Out': self.inputs['X'] * self.attrs['dropout_prob']} + + def test_check_output(self): + self.check_output() + + +if __name__ == '__main__': + unittest.main()