From fd8d83e68a20d056340b7de35f19538c5831d492 Mon Sep 17 00:00:00 2001 From: chengduo Date: Wed, 19 Sep 2018 13:33:40 +0800 Subject: [PATCH] Fix the nested dyn_rnn (#13417) * add unit test for nested drnn * add nested dyn_rnn * refine while_op * fix bug --- paddle/fluid/operators/while_op.cc | 90 ++++++++---- .../fluid/tests/unittests/test_dyn_rnn.py | 136 ++++++++++++++++++ 2 files changed, 199 insertions(+), 27 deletions(-) diff --git a/paddle/fluid/operators/while_op.cc b/paddle/fluid/operators/while_op.cc index 791138a8c0..16eac1ec24 100644 --- a/paddle/fluid/operators/while_op.cc +++ b/paddle/fluid/operators/while_op.cc @@ -1,16 +1,16 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. #include #include "paddle/fluid/framework/executor.h" @@ -138,6 +138,10 @@ class WhileGradOp : public framework::OperatorBase { auto inside_og_name = inside_og_names[i]; VLOG(8) << "Linking outside " << outside_og_name << " --> inside " << inside_og_name; + if (scope.FindVar(outside_og_name) == nullptr) { + continue; + } + auto &og_outside = detail::Ref(scope.FindVar(outside_og_name), "Cannot find Outside Gradient %s", outside_og_name); @@ -167,20 +171,46 @@ class WhileGradOp : public framework::OperatorBase { PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0); } } + } else { + PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray."); } } executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true, true); - auto &pg_names = Outputs(kXGRAD); + // The Outputs(kXGRAD) contains the names of the gradient of parameters + // and inputs. + auto &pg_ig_names = Outputs(kXGRAD); auto &p_names = Inputs(kX); - PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); - for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { - if (pg_names[param_id] == framework::kEmptyVarName) { + PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size()); + for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) { + if (pg_ig_names[param_id] == framework::kEmptyVarName) { continue; // parameter doesn't have gradient } auto inside_grad_name = framework::GradVarName(p_names[param_id]); + // for some grad_op, their input doesn't have gradient, + // for example lookup_table_grad_op, the input(Idx) doesn't have + // gradient. + auto pg_ig_var = cur_scope.FindVar(inside_grad_name); + PADDLE_ENFORCE(pg_ig_var != nullptr); + if (pg_ig_var->IsType()) { + auto pg_ig_lod_t_arr = + pg_ig_var->GetMutable(); + bool empty = true; + for (auto &each : *pg_ig_lod_t_arr) { + if (each.numel() != 0) { + empty = false; + break; + } + } + if (empty) { + LOG(WARNING) << pg_ig_names[param_id] + << " is not found in cur_scope."; + continue; + } + } + // // TODO(tonyyang-svail): Not sure we need the following // // If does not compute gradient of that variable inside rnn, // just @@ -194,6 +224,11 @@ class WhileGradOp : public framework::OperatorBase { if (cur_scope_iter == step_scopes->rbegin()) { auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name); + PADDLE_ENFORCE(var->IsType() || + var->IsType(), + "Currently the type of var only can be LoDTensorArray " + "or LoDTensor."); + if (var->IsType()) { auto &inside_tensor = var->Get(); framework::AttributeMap attrs; @@ -201,7 +236,7 @@ class WhileGradOp : public framework::OperatorBase { attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; - auto var_name = pg_names[param_id]; + auto var_name = pg_ig_names[param_id]; auto zero_op = framework::OpRegistry::CreateOp( "fill_constant", framework::VariableNameMap{}, {{"Out", {var_name}}}, attrs); @@ -213,8 +248,8 @@ class WhileGradOp : public framework::OperatorBase { } auto new_inside_name = cur_scope.Rename(inside_grad_name); auto sum_op = framework::OpRegistry::CreateOp( - "sum", {{"X", {pg_names[param_id], new_inside_name}}}, - {{"Out", {pg_names[param_id]}}}, + "sum", {{"X", {pg_ig_names[param_id], new_inside_name}}}, + {{"Out", {pg_ig_names[param_id]}}}, framework::AttributeMap{{"use_mkldnn", {false}}}); sum_op->Run(cur_scope, dev_place); cur_scope.Rename(new_inside_name, inside_grad_name); @@ -281,6 +316,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { parent_block->FindVarRecursive(input_name) != nullptr)) { continue; } + output_grads.insert(input_name); } for (auto &output_name : op->OutputArgumentNames()) { @@ -309,13 +345,13 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference { void operator()(const framework::OpDesc &op_desc, framework::BlockDesc *block) const override { auto p_names = op_desc.Input(kX); - auto pg_names = op_desc.Output(framework::GradVarName(kX)); + auto pg_ig_names = op_desc.Output(framework::GradVarName(kX)); for (size_t i = 0; i < p_names.size(); ++i) { auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i])); - auto *g_var = block->FindVarRecursive(pg_names[i]); + auto *g_var = block->FindVarRecursive(pg_ig_names[i]); if (g_var != nullptr) { // Gradient could be @EMPTY@ - VLOG(5) << "Setting " << pg_names[i] << " following " << p_names[i] + VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i] << " type: " << p_var.GetType(); g_var->SetType(p_var.GetType()); g_var->SetDataType(p_var.GetDataType()); @@ -333,21 +369,21 @@ class WhileGradOpShapeInference : public framework::InferShapeBase { ctx->HasInputs(framework::GradVarName(kOutputs)); auto p_names = ctx->Inputs(kX); - auto pg_names = ctx->Outputs(kXGRAD); + auto pg_ig_names = ctx->Outputs(kXGRAD); auto var_types = ctx->GetInputsVarType(kX); std::vector names_to_set; std::vector dims_to_set; for (size_t i = 0; i < p_names.size(); ++i) { - if (pg_names[i] == framework::kEmptyVarName) { + if (pg_ig_names[i] == framework::kEmptyVarName) { continue; } auto dims = ctx->GetInputsElementDim(kX, i); if (var_types[i] == framework::proto::VarType::LOD_TENSOR) { - names_to_set.push_back(pg_names[i]); + names_to_set.push_back(pg_ig_names[i]); dims_to_set.push_back(dims); } else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) { // not sure how to set the dim of LOD_TENSOR_ARRAY - names_to_set.push_back(pg_names[i]); + names_to_set.push_back(pg_ig_names[i]); dims_to_set.push_back(dims); } } diff --git a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py index d84dab1499..3191eb94d7 100644 --- a/python/paddle/fluid/tests/unittests/test_dyn_rnn.py +++ b/python/paddle/fluid/tests/unittests/test_dyn_rnn.py @@ -144,6 +144,142 @@ class TestDynRNN(unittest.TestCase): # loss should be small after 100 mini-batch self.assertLess(val[0], loss_0[0]) + # this unit test is just used to the two layer nested dyn_rnn. + def test_train_nested_dyn_rnn(self): + word_dict = [i for i in range(30)] + + def fake_reader(): + seq_len, label = [[2, 2]], [0, 1] + data = [] + for ele in seq_len: + for j in ele: + data.append([numpy.random.randint(30) \ + for _ in range(j)]) + + while True: + yield data, label + + train_data = paddle.batch(fake_reader, batch_size=2) + + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + sentence = fluid.layers.data( + name='word', shape=[1], dtype='int64', lod_level=2) + label = fluid.layers.data( + name='label', shape=[1], dtype='float32', lod_level=1) + + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + in_ = rnn.step_input(sentence) + sent_emb = fluid.layers.embedding( + input=in_, size=[len(word_dict), 32], dtype='float32') + out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh') + + rnn1 = fluid.layers.DynamicRNN() + with rnn1.block(): + in_1 = rnn1.step_input(out_) + out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh') + rnn1.output(out_1) + + last = fluid.layers.sequence_last_step(input=rnn1()) + rnn.output(last) + + last = rnn() + logits = fluid.layers.fc(input=last, size=1, act=None) + loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=logits, label=label) + loss = fluid.layers.mean(loss) + sgd = fluid.optimizer.SGD(1e-3) + #sgd = fluid.optimizer.Adam(1e-3) + sgd.minimize(loss=loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + exe.run(startup_program) + feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu) + data = next(train_data()) + val = exe.run(main_program, feed=feeder.feed(data), + fetch_list=[loss])[0] + + for _ in range(100): + val = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss])[0] + print(val) + + # this unit test is just used to the two layer nested dyn_rnn. + def test_train_nested_dyn_rnn2(self): + word_dict = [i for i in range(30)] + + def fake_reader(): + seq_len, label = [[2, 2]], [0, 1] + data = [] + for ele in seq_len: + for j in ele: + data.append([numpy.random.randint(30) \ + for _ in range(j)]) + + while True: + yield data, label + + train_data = paddle.batch(fake_reader, batch_size=2) + hidden_size = 32 + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + sentence = fluid.layers.data( + name='word', shape=[1], dtype='int64', lod_level=2) + label = fluid.layers.data( + name='label', shape=[1], dtype='float32', lod_level=1) + + rnn = fluid.layers.DynamicRNN() + with rnn.block(): + in_ = rnn.step_input(sentence) + sent_emb = fluid.layers.embedding( + input=in_, + size=[len(word_dict), hidden_size], + dtype='float32') + input_forward_proj = fluid.layers.fc(input=sent_emb, + size=hidden_size * 4, + act=None, + bias_attr=False) + forward, _ = fluid.layers.dynamic_lstm( + input=input_forward_proj, + size=hidden_size * 4, + use_peepholes=False) + + rnn1 = fluid.layers.DynamicRNN() + with rnn1.block(): + in_1 = rnn1.step_input(forward) + out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh') + rnn1.output(out_1) + + last = fluid.layers.sequence_last_step(input=rnn1()) + rnn.output(last) + + last = rnn() + logits = fluid.layers.fc(input=last, size=1, act=None) + loss = fluid.layers.sigmoid_cross_entropy_with_logits( + x=logits, label=label) + loss = fluid.layers.mean(loss) + sgd = fluid.optimizer.SGD(1e-3) + #sgd = fluid.optimizer.Adam(1e-3) + sgd.minimize(loss=loss) + + cpu = fluid.CPUPlace() + exe = fluid.Executor(cpu) + exe.run(startup_program) + feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu) + data = next(train_data()) + val = exe.run(main_program, feed=feeder.feed(data), + fetch_list=[loss])[0] + + for _ in range(100): + val = exe.run(main_program, + feed=feeder.feed(data), + fetch_list=[loss])[0] + if __name__ == '__main__': unittest.main() -- GitLab