diff --git a/paddle/fluid/operators/dropout_op.cc b/paddle/fluid/operators/dropout_op.cc index 597d446b00a61f3b9fc60535d98efa900d860074..c494d27d41b520be187081e5e44410ea14241df5 100644 --- a/paddle/fluid/operators/dropout_op.cc +++ b/paddle/fluid/operators/dropout_op.cc @@ -35,12 +35,23 @@ class DropoutOp : public framework::OperatorWithKernel { } ctx->ShareLoD("X", /*->*/ "Out"); } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), ctx.GetPlace()); + } }; class DropoutOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", "The input of dropout op."); + AddInput("Seed", + "The seed of dropout op, it has higher priority than the attr " + "fix_seed and seed") + .AsDispensable(); AddOutput("Out", "The output of dropout op."); AddOutput("Mask", "The random sampled dropout mask.").AsIntermediate(); diff --git a/paddle/fluid/operators/dropout_op.cu b/paddle/fluid/operators/dropout_op.cu index 3e0cb76d0435c0b2bc22ec68eb08581183da0672..d9b5572d95c5f913ad35b31cf1b6da31c81beaf8 100644 --- a/paddle/fluid/operators/dropout_op.cu +++ b/paddle/fluid/operators/dropout_op.cu @@ -67,6 +67,8 @@ class GPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); + auto* seed = + context.HasInput("Seed") ? context.Input("Seed") : nullptr; auto* y = context.Output("Out"); y->mutable_data(context.GetPlace()); float dropout_prob = context.Attr("dropout_prob"); @@ -84,6 +86,20 @@ class GPUDropoutKernel : public framework::OpKernel { auto* mask_data = mask->mutable_data(context.GetPlace()); size_t size = framework::product(mask->dims()); auto* x_data = x->data(); + int seed_data; + std::random_device rnd; + if (seed) { + if (platform::is_gpu_place(seed->place())) { + framework::Tensor temp; + TensorCopySync(*seed, platform::CPUPlace(), &temp); + seed_data = *(temp.data()); + } else { + seed_data = *(seed->data()); + } + } else { + seed_data = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + } auto* y_data = y->mutable_data(context.GetPlace()); if (dropout_prob == 1.0f) { PADDLE_ENFORCE_CUDA_SUCCESS( @@ -93,14 +109,10 @@ class GPUDropoutKernel : public framework::OpKernel { return; } - std::random_device rnd; - int seed = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - int threads = 512; int grid = (x_numel + threads - 1) / threads; RandomGenerator<<>>( - size, seed, dropout_prob, x_data, mask_data, y_data, + size, seed_data, dropout_prob, x_data, mask_data, y_data, upscale_in_train); } else { auto X = EigenMatrix::Reshape(*x, 1); diff --git a/paddle/fluid/operators/dropout_op.h b/paddle/fluid/operators/dropout_op.h index 20742f9a453c5ad3c3702fd939e28312263323f5..b2bfbc1f82619939bfdee188adf4a000610acf14 100644 --- a/paddle/fluid/operators/dropout_op.h +++ b/paddle/fluid/operators/dropout_op.h @@ -33,6 +33,8 @@ class CPUDropoutKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto* x = context.Input("X"); + auto* seed = + context.HasInput("Seed") ? context.Input("Seed") : nullptr; auto* y = context.Output("Out"); const auto* x_data = x->data(); auto* y_data = y->mutable_data(context.GetPlace()); @@ -57,9 +59,14 @@ class CPUDropoutKernel : public framework::OpKernel { // Guarantee to use random seed in training. std::random_device rnd; std::minstd_rand engine; - int seed = - context.Attr("fix_seed") ? context.Attr("seed") : rnd(); - engine.seed(seed); + int seed_data; + if (seed) { + seed_data = *(seed->data()); + } else { + seed_data = + context.Attr("fix_seed") ? context.Attr("seed") : rnd(); + } + engine.seed(seed_data); std::uniform_real_distribution dist(0, 1); diff --git a/paddle/fluid/operators/seed_op.cc b/paddle/fluid/operators/seed_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..86c551f4c7426bbbeaef157d6d7c66094ecf11f2 --- /dev/null +++ b/paddle/fluid/operators/seed_op.cc @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/operators/seed_op.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; +class SeedOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + ctx->SetOutputDim("Out", {1}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::INT32, + platform::CPUPlace()); + } +}; + +class SeedOpMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("Out", "The output of seed op."); + AddAttr("seed", "Dropout random seed.").SetDefault(0); + AddComment(R"DOC( +Seed Operator. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + seed, ops::SeedOp, ops::SeedOpMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); +REGISTER_OP_CPU_KERNEL( + seed, ops::CPUSeedKernel); diff --git a/paddle/fluid/operators/seed_op.h b/paddle/fluid/operators/seed_op.h new file mode 100644 index 0000000000000000000000000000000000000000..f8b513fca4824c3c8e242326f99e6c840520e7a3 --- /dev/null +++ b/paddle/fluid/operators/seed_op.h @@ -0,0 +1,44 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { +using Tensor = framework::Tensor; + +template +class CPUSeedKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& context) const override { + auto* out = context.Output("Out"); + auto* out_data = out->mutable_data(context.GetPlace()); + int user_seed = context.Attr("seed"); + + // NOTE: fixed seed should only be used in unittest or for debug. + // Guarantee to use random seed in training. + std::random_device rnd; + int seed; + if (user_seed != 0) { + seed = user_seed; + } else { + seed = rnd(); + } + out_data[0] = seed; + } +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 0d2c7e31fbbe663ac6072ad947dc2a56d3835a97..5aceb13613c2d15d47bd62bfc185123af2ea38a9 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -56,7 +56,7 @@ class ProgramStats(object): def get_reserved_vars(self): var_name = [] for op in self.ops: - if op.desc.type() == "dropout": + if op.desc.type() == "seed": var_name.extend(op.desc.output_arg_names()) return var_name @@ -136,6 +136,42 @@ class ProgramStats(object): sorted_checkpoints = sorted(sorted_checkpoints, key=lambda x: x[1]) return [x[0] for x in sorted_checkpoints] + def modify_forward_desc_for_recompute(self): + op_types = [op.desc.type() for op in self.ops] + if "dropout" not in op_types: + return + + op_idx = 0 + while (op_idx < len(self.ops)): + op = self.ops[op_idx] + if op.desc.type() != "dropout": + op_idx += 1 + continue + # add a seed op so that the two dropout op can generate same output + op_unique_name = unique_name.generate("seed") + var_unique_name = unique_name.generate_with_ignorable_key(".".join( + [op_unique_name, 'tmp'])) + added_var = self.block.create_var( + name=var_unique_name, + dtype='int32', + type=core.VarDesc.VarType.LOD_TENSOR, + persistable=False, + stop_gradient=False) + seed = 0 if op.attr("fix_seed") is False else int(op.attr("seed")) + added_op = self.block._insert_op( + index=op.idx, + type='seed', + inputs={}, + outputs={'Out': [added_var]}, + attrs={'seed': seed}) + self.ops.insert(op_idx, added_op) + # modify dropout op desc so that it accept a seed var as input + op.desc.set_input("Seed", [var_unique_name]) + op.desc.remove_attr("fix_seed") + op.desc.remove_attr("seed") + self.block._sync_with_cpp() + op_idx += 2 + def _pretty_op_desc_(op_desc, prefix): out_s = "%s\tname:[%s]\n%s \tinputs:[%s]\n%s \toutputs:[%s]" % \ @@ -589,6 +625,7 @@ def _append_backward_ops_with_checkpoints_( checkpoints: variables that a user defined as checkpoint for forward recomputation Algorithms: + 0) deal with forward recomputing program descs 1) find ops between checkpoints, i.e. recompute_segments 2) go through all forward ops and induct all variables that will be hold in memory a. variables that are used across segments will be held in memory @@ -609,10 +646,12 @@ def _append_backward_ops_with_checkpoints_( checkpoints_name = list(set(checkpoints_name)) local_block = block.program._create_block() buffer_block = block.program._create_block() - - # 1) find ops between checkpoints, i.e. recompute_segments + # 0) deal with forward recomputing program descs program_stat = ProgramStats(block, ops) + program_stat.modify_forward_desc_for_recompute() program_stat.build_stats() + + # 1) find ops between checkpoints, i.e. recompute_segments checkpoints_name = program_stat.sort_checkpoints(checkpoints_name) segments = [] diff --git a/python/paddle/fluid/tests/unittests/test_dropout_op.py b/python/paddle/fluid/tests/unittests/test_dropout_op.py index 6a71da5dad41cf4751f4504d2077dfc63982f6a9..7047e5a2d876ad97254cd3ba747c117d0de12ec1 100644 --- a/python/paddle/fluid/tests/unittests/test_dropout_op.py +++ b/python/paddle/fluid/tests/unittests/test_dropout_op.py @@ -150,6 +150,27 @@ class TestDropoutOp9(OpTest): self.check_output() +class TestDropoutOpWithSeed(OpTest): + def setUp(self): + self.op_type = "dropout" + self.inputs = { + "X": np.random.random((32, 64)).astype("float32"), + "Seed": np.asarray( + [125], dtype="int32") + } + self.attrs = {'dropout_prob': 0.0, } + self.outputs = { + 'Out': self.inputs['X'], + 'Mask': np.ones((32, 64)).astype('uint8') + } + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X'], 'Out', max_relative_error=0.05) + + class TestFP16DropoutOp(OpTest): def setUp(self): self.op_type = "dropout" diff --git a/python/paddle/fluid/tests/unittests/test_optimizer.py b/python/paddle/fluid/tests/unittests/test_optimizer.py index a10a5e362287dd66a657c9a182cab3cb522e1b67..678b52c87568ad1a3a6f997fdb1f59e7d80506aa 100644 --- a/python/paddle/fluid/tests/unittests/test_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_optimizer.py @@ -614,7 +614,7 @@ class TestLookaheadOptimizer(unittest.TestCase): class TestRecomputeOptimizer(unittest.TestCase): - def net(self, return_input=False): + def net(self, return_input=False, with_dropout=False): program = framework.Program() block = program.global_block() mul_x = block.create_parameter( @@ -623,6 +623,14 @@ class TestRecomputeOptimizer(unittest.TestCase): dtype="float32", shape=[10, 8], lod_level=0, name="mul.y") mul_out = block.create_var( dtype="float32", shape=[5, 8], lod_level=0, name="mul.out") + if with_dropout == True: + mul_out_drop = block.create_var( + dtype="float32", + shape=[5, 8], + lod_level=0, + name="mul.out.dropout") + mul_out_mask = block.create_var( + dtype="uint8", shape=[5, 8], lod_level=0, name="mul.out.mask") b1 = block.create_parameter( dtype="float32", shape=[5, 8], lod_level=0, name="b1") b1_out = block.create_var( @@ -639,11 +647,24 @@ class TestRecomputeOptimizer(unittest.TestCase): "Y": mul_y}, outputs={"Out": mul_out}, attrs={"x_num_col_dims": 1}) - block.append_op( - type="elementwise_add", - inputs={"X": mul_out, - "Y": b1}, - outputs={"Out": b1_out}) + if with_dropout == True: + block.append_op( + type='dropout', + inputs={'X': [mul_out]}, + outputs={'Out': [mul_out_drop], + 'Mask': [mul_out_mask]}, + attrs={'dropout_prob': 0.5, }) + block.append_op( + type="elementwise_add", + inputs={"X": mul_out_drop, + "Y": b1}, + outputs={"Out": b1_out}) + else: + block.append_op( + type="elementwise_add", + inputs={"X": mul_out, + "Y": b1}, + outputs={"Out": b1_out}) block.append_op( type="elementwise_add", inputs={"X": b1_out, @@ -799,6 +820,29 @@ class TestRecomputeOptimizer(unittest.TestCase): "load function is not supported by Recompute Optimizer for now", cpt.get_exception_message(e)) + def test_dropout(self): + """ + If there are dropout layers in the forward nets, we should add a + seed op + """ + mul_out, b1_out, b2_out, mean_out = self.net(with_dropout=True) + self.assertEqual(len(mean_out.block.ops), 5) + self.assertEqual( + [op.type for op in mean_out.block.ops], + ["mul", "dropout", "elementwise_add", "elementwise_add", "mean"]) + sgd_optimizer = optimizer.SGD(learning_rate=1.0) + recompute_optimizer = optimizer.RecomputeOptimizer(sgd_optimizer) + recompute_optimizer._set_checkpoints([b1_out]) + opts, params_grads = recompute_optimizer.minimize(mean_out) + + self.assertEqual(len(mean_out.block.ops), 17) + self.assertEqual([op.type for op in mean_out.block.ops], [ + "mul", "seed", "dropout", "elementwise_add", "elementwise_add", + "mean", "fill_constant", "mean_grad", "elementwise_add_grad", "mul", + "dropout", "elementwise_add_grad", "dropout_grad", "mul_grad", + "sgd", "sgd", "sgd" + ]) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_seed_op.py b/python/paddle/fluid/tests/unittests/test_seed_op.py new file mode 100644 index 0000000000000000000000000000000000000000..7d6705f72569b60df0d4dc15f7c00556edaa9d1d --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_seed_op.py @@ -0,0 +1,46 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +from op_test import OpTest +import paddle.fluid as fluid + + +class TestSeedOpFixSeed(OpTest): + def setUp(self): + self.op_type = "seed" + self.inputs = {} + self.attrs = {"seed": 123} + self.outputs = {"Out": np.asarray((123)).astype('int32')} + + def test_check_output(self): + self.check_output() + + +class TestSeedOpDiffSeed(OpTest): + def setUp(self): + self.op_type = "seed" + self.inputs = {} + self.attrs = {"seed": 0} + self.outputs = {"Out": np.asarray((123)).astype('int32')} + + def test_check_output(self): + self.check_output(no_check_set=["Out"]) + + +if __name__ == '__main__': + unittest.main()