diff --git a/paddle/operators/dropout_op.cc b/paddle/operators/dropout_op.cc index 34d5e762cb0d3ef3d7efc5f42f25908584e9b8de..74e72cf1168c4712fbcc54c7c2a495346db1e8b3 100644 --- a/paddle/operators/dropout_op.cc +++ b/paddle/operators/dropout_op.cc @@ -30,6 +30,10 @@ class DropoutOp : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("X"), "Input(X) must not be null."); PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); + // resize auto dims = ctx.Input("X")->dims(); ctx.Output("Out")->Resize(dims); @@ -37,13 +41,16 @@ class DropoutOp : public framework::OperatorWithKernel { } }; +template class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: DropoutOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { - AddAttr("dropout_prob", "Probability for dropping out units.") + AddAttr("dropout_prob", "Probability of setting units to zero.") .SetDefault(.5f); + // TODO(xinghai-sun): use bool for is_training after bool is supported. + AddAttr("is_training", "Whether in training phase.").SetDefault(1); AddAttr("seed", "Dropout random seed.").SetDefault(0); AddInput("X", "The input of dropout op."); AddOutput("Out", "The output of dropout op."); @@ -61,6 +68,7 @@ being set to their inputs. } }; +template class DropoutOpGrad : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -72,8 +80,11 @@ class DropoutOpGrad : public framework::OperatorWithKernel { PADDLE_ENFORCE_NOT_NULL(ctx.InputVar("Mask"), "Mask must not be null."); PADDLE_ENFORCE_NOT_NULL(ctx.InputVar(framework::GradVarName("Out")), "Input(Out@GRAD) must not be null."); - PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); - PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + PADDLE_ENFORCE_GE(ctx.Attr("dropout_prob"), 0); + PADDLE_ENFORCE_LE(ctx.Attr("dropout_prob"), 1); + // TODO(xinghai-sun): remove this check after swtiching to bool + PADDLE_ENFORCE(ctx.Attr("is_training") == 0 || + ctx.Attr("is_training") == 1); auto x_dims = ctx.Input("X")->dims(); auto mask_dims = ctx.Input("Mask")->dims(); auto out_dims = ctx.Input(framework::GradVarName("Out"))->dims(); @@ -91,9 +102,9 @@ class DropoutOpGrad : public framework::OperatorWithKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, - ops::DropoutOpGrad); +REGISTER_OP(dropout, ops::DropoutOp, ops::DropoutOpMaker, dropout_grad, + ops::DropoutOpGrad); REGISTER_OP_CPU_KERNEL( - dropout, ops::CPUDropoutKernel); + dropout, ops::CPUDropoutKernel); REGISTER_OP_CPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.cu b/paddle/operators/dropout_op.cu index ccee7cfa7ac74694515234847ec99237d3a7c8fd..f5fbad5ca0af6d265dc218e08d6552018c5319bd 100644 --- a/paddle/operators/dropout_op.cu +++ b/paddle/operators/dropout_op.cu @@ -22,18 +22,18 @@ namespace paddle { namespace operators { -template +template struct MaskGenerator { - float dropout_prob; + AttrType dropout_prob; int seed; - __host__ __device__ MaskGenerator(float dropout_prob, int seed) + __host__ __device__ MaskGenerator(AttrType dropout_prob, int seed) : dropout_prob(dropout_prob), seed(seed) {} __host__ __device__ T operator()(const unsigned int n) const { thrust::minstd_rand rng; rng.seed(seed); - thrust::uniform_real_distribution dist(0, 1); + thrust::uniform_real_distribution dist(0, 1); rng.discard(n); if (dist(rng) < dropout_prob) { return static_cast(0); @@ -46,33 +46,35 @@ struct MaskGenerator { // It seems that Eigen::Tensor::setRandom in GPU will SEGFAULT. // Use std::random and thrust::random(thrust is a std library in CUDA) to // implement uniform random. -template +template class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); - auto* mask = context.Output("Mask"); y->mutable_data(context.GetPlace()); + auto* mask = context.Output("Mask"); + auto* mask_data = mask->mutable_data(context.GetPlace()); - float dropout_prob = context.Attr("dropout_prob"); - int seed = context.Attr("seed"); - thrust::counting_iterator index_sequence_begin(0); - int size = framework::product(mask->dims()); - T* mask_data = mask->mutable_data(context.GetPlace()); - thrust::transform(index_sequence_begin, index_sequence_begin + size, - thrust::device_ptr(mask_data), - MaskGenerator(dropout_prob, seed)); + AttrType dropout_prob = context.Attr("dropout_prob"); - auto dims = x->dims(); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto X = EigenMatrix::From(*x, new_dims); - auto Y = EigenMatrix::From(*y, new_dims); - auto M = EigenMatrix::From(*mask, new_dims); + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto M = EigenMatrix::Reshape(*mask, 1); auto place = context.GetEigenDevice(); - Y.device(place) = X * M; - // TODO(xinghai-sun): add test time logits. + int size = framework::product(mask->dims()); + if (context.Attr("is_training") == 1) { + int seed = context.Attr("seed"); + thrust::counting_iterator index_sequence_begin(0); + thrust::transform(index_sequence_begin, index_sequence_begin + size, + thrust::device_ptr(mask_data), + MaskGenerator(dropout_prob, seed)); + Y.device(place) = X * M; + } else { + cudaMemset(mask_data, 0, sizeof(T) * size); + Y.device(place) = X * dropout_prob; + } } }; @@ -81,6 +83,6 @@ class GPUDropoutKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_GPU_KERNEL( - dropout, ops::GPUDropoutKernel); + dropout, ops::GPUDropoutKernel); REGISTER_OP_GPU_KERNEL( dropout_grad, ops::DropoutGradKernel); diff --git a/paddle/operators/dropout_op.h b/paddle/operators/dropout_op.h index c9e45fa22038c0050431ccfda35cb63261c2978c..00fdfb4c5f13d6abcceec682b09ddb2f211f5f41 100644 --- a/paddle/operators/dropout_op.h +++ b/paddle/operators/dropout_op.h @@ -25,34 +25,42 @@ template using EigenMatrix = framework::EigenMatrix; -template +template class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); auto* y = context.Output("Out"); auto* mask = context.Output("Mask"); - T* mask_data = mask->mutable_data(context.GetPlace()); - T* y_data = y->mutable_data(context.GetPlace()); - const T* x_data = x->data(); + auto* mask_data = mask->mutable_data(context.GetPlace()); + auto* y_data = y->mutable_data(context.GetPlace()); + const auto* x_data = x->data(); - float dropout_prob = context.Attr("dropout_prob"); - int seed = context.Attr("seed"); + AttrType dropout_prob = context.Attr("dropout_prob"); - std::minstd_rand engine; - engine.seed(seed); - std::uniform_real_distribution dist(0, 1); - size_t size = framework::product(mask->dims()); - for (size_t i = 0; i < size; ++i) { - if (dist(engine) < dropout_prob) { - mask_data[i] = 0; - y_data[i] = 0; - } else { - mask_data[i] = 1; - y_data[i] = x_data[i]; + if (context.Attr("is_training") == 1) { + int seed = context.Attr("seed"); + std::minstd_rand engine; + engine.seed(seed); + std::uniform_real_distribution dist(0, 1); + size_t size = framework::product(mask->dims()); + for (size_t i = 0; i < size; ++i) { + if (dist(engine) < dropout_prob) { + mask_data[i] = 0; + y_data[i] = 0; + } else { + mask_data[i] = 1; + y_data[i] = x_data[i]; + } } + } else { + size_t size = framework::product(mask->dims()); + memset(mask_data, 0, sizeof(T) * size); + auto X = EigenMatrix::Reshape(*x, 1); + auto Y = EigenMatrix::Reshape(*y, 1); + auto place = context.GetEigenDevice(); + Y.device(place) = X * dropout_prob; } - // TODO: add test phase logits. } }; @@ -60,21 +68,19 @@ template class DropoutGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { + PADDLE_ENFORCE_EQ(context.Attr("is_training"), 1, + "Only callable when is_training is true"); auto* grad_x = context.Output(framework::GradVarName("X")); auto* grad_y = context.Input(framework::GradVarName("Out")); auto* mask = context.Input("Mask"); grad_x->mutable_data(context.GetPlace()); - auto dims = grad_x->dims(); - int size = static_cast(framework::product(dims)); - auto new_dims = framework::make_ddim({dims[0], size / dims[0]}); - auto M = EigenMatrix::From(*mask, new_dims); - auto dX = EigenMatrix::From(*grad_x, new_dims); - auto dY = EigenMatrix::From(*grad_y, new_dims); + auto M = EigenMatrix::Reshape(*mask, 1); + auto dX = EigenMatrix::Reshape(*grad_x, 1); + auto dY = EigenMatrix::Reshape(*grad_y, 1); auto place = context.GetEigenDevice(); dX.device(place) = dY * M; - // TODO: add test time logits. } }; diff --git a/python/paddle/v2/framework/tests/test_dropout_op.py b/python/paddle/v2/framework/tests/test_dropout_op.py index 1387b87dc7b7857fbd47b5139a98a6f1754589d4..d49952492947f53cd5a416f6999640023cbdf44e 100644 --- a/python/paddle/v2/framework/tests/test_dropout_op.py +++ b/python/paddle/v2/framework/tests/test_dropout_op.py @@ -7,7 +7,7 @@ class TestDropoutOp(OpTest): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 0.0} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64))} def test_check_output(self): @@ -21,7 +21,7 @@ class TestDropoutOp2(TestDropoutOp): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64)).astype("float32")} - self.attrs = {'dropout_prob': 1.0} + self.attrs = {'dropout_prob': 1.0, 'is_training': 1} self.outputs = {'Out': np.zeros((32, 64)), 'Mask': np.zeros((32, 64))} @@ -29,9 +29,37 @@ class TestDropoutOp3(TestDropoutOp): def setUp(self): self.op_type = "dropout" self.inputs = {'X': np.random.random((32, 64, 2)).astype("float32")} - self.attrs = {'dropout_prob': 0.0} + self.attrs = {'dropout_prob': 0.0, 'is_training': 1} self.outputs = {'Out': self.inputs['X'], 'Mask': np.ones((32, 64, 2))} +class TestDropoutOp4(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64)).astype("float32")} + self.attrs = {'dropout_prob': 0.35, 'is_training': 0} + self.outputs = { + 'Out': self.inputs['X'] * self.attrs['dropout_prob'], + 'Mask': np.zeros((32, 64)) + } + + def test_check_output(self): + self.check_output() + + +class TestDropoutOp5(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = {'X': np.random.random((32, 64, 3)).astype("float32")} + self.attrs = {'dropout_prob': 0.75, 'is_training': 0} + self.outputs = { + 'Out': self.inputs['X'] * self.attrs['dropout_prob'], + 'Mask': np.zeros((32, 64, 3)) + } + + def test_check_output(self): + self.check_output() + + if __name__ == '__main__': unittest.main()