未验证 提交 fd8d83e6 编写于 作者: C chengduo 提交者: GitHub

Fix the nested dyn_rnn (#13417)

* add unit test for nested drnn

* add nested dyn_rnn

* refine while_op

* fix bug
上级 cf128231
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/fluid/framework/executor.h"
......@@ -138,6 +138,10 @@ class WhileGradOp : public framework::OperatorBase {
auto inside_og_name = inside_og_names[i];
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
if (scope.FindVar(outside_og_name) == nullptr) {
continue;
}
auto &og_outside =
detail::Ref(scope.FindVar(outside_og_name),
"Cannot find Outside Gradient %s", outside_og_name);
......@@ -167,20 +171,46 @@ class WhileGradOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(inside_array[j].numel(), 0);
}
}
} else {
PADDLE_THROW("Currently only support LoDTensor and LoDTensorArray.");
}
}
executor.RunPreparedContext(ctx.get(), *cur_scope_iter, false, true,
true);
auto &pg_names = Outputs(kXGRAD);
// The Outputs(kXGRAD) contains the names of the gradient of parameters
// and inputs.
auto &pg_ig_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
if (pg_names[param_id] == framework::kEmptyVarName) {
PADDLE_ENFORCE_EQ(pg_ig_names.size(), p_names.size());
for (size_t param_id = 0; param_id < pg_ig_names.size(); ++param_id) {
if (pg_ig_names[param_id] == framework::kEmptyVarName) {
continue; // parameter doesn't have gradient
}
auto inside_grad_name = framework::GradVarName(p_names[param_id]);
// for some grad_op, their input doesn't have gradient,
// for example lookup_table_grad_op, the input(Idx) doesn't have
// gradient.
auto pg_ig_var = cur_scope.FindVar(inside_grad_name);
PADDLE_ENFORCE(pg_ig_var != nullptr);
if (pg_ig_var->IsType<framework::LoDTensorArray>()) {
auto pg_ig_lod_t_arr =
pg_ig_var->GetMutable<framework::LoDTensorArray>();
bool empty = true;
for (auto &each : *pg_ig_lod_t_arr) {
if (each.numel() != 0) {
empty = false;
break;
}
}
if (empty) {
LOG(WARNING) << pg_ig_names[param_id]
<< " is not found in cur_scope.";
continue;
}
}
// // TODO(tonyyang-svail): Not sure we need the following
// // If does not compute gradient of that variable inside rnn,
// just
......@@ -194,6 +224,11 @@ class WhileGradOp : public framework::OperatorBase {
if (cur_scope_iter == step_scopes->rbegin()) {
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
PADDLE_ENFORCE_NOT_NULL(var, "Can not find var %s", inside_grad_name);
PADDLE_ENFORCE(var->IsType<framework::LoDTensorArray>() ||
var->IsType<LoDTensor>(),
"Currently the type of var only can be LoDTensorArray "
"or LoDTensor.");
if (var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
......@@ -201,7 +236,7 @@ class WhileGradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
auto var_name = pg_names[param_id];
auto var_name = pg_ig_names[param_id];
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {var_name}}}, attrs);
......@@ -213,8 +248,8 @@ class WhileGradOp : public framework::OperatorBase {
}
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_names[param_id], new_inside_name}}},
{{"Out", {pg_names[param_id]}}},
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
{{"Out", {pg_ig_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
......@@ -281,6 +316,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
parent_block->FindVarRecursive(input_name) != nullptr)) {
continue;
}
output_grads.insert(input_name);
}
for (auto &output_name : op->OutputArgumentNames()) {
......@@ -309,13 +345,13 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kX);
auto pg_names = op_desc.Output(framework::GradVarName(kX));
auto pg_ig_names = op_desc.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
auto *g_var = block->FindVarRecursive(pg_names[i]);
auto *g_var = block->FindVarRecursive(pg_ig_names[i]);
if (g_var != nullptr) { // Gradient could be @EMPTY@
VLOG(5) << "Setting " << pg_names[i] << " following " << p_names[i]
VLOG(5) << "Setting " << pg_ig_names[i] << " following " << p_names[i]
<< " type: " << p_var.GetType();
g_var->SetType(p_var.GetType());
g_var->SetDataType(p_var.GetDataType());
......@@ -333,21 +369,21 @@ class WhileGradOpShapeInference : public framework::InferShapeBase {
ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kX);
auto pg_names = ctx->Outputs(kXGRAD);
auto pg_ig_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kX);
std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set;
for (size_t i = 0; i < p_names.size(); ++i) {
if (pg_names[i] == framework::kEmptyVarName) {
if (pg_ig_names[i] == framework::kEmptyVarName) {
continue;
}
auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarType::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
names_to_set.push_back(pg_ig_names[i]);
dims_to_set.push_back(dims);
} else if (var_types[i] == framework::proto::VarType::LOD_TENSOR_ARRAY) {
// not sure how to set the dim of LOD_TENSOR_ARRAY
names_to_set.push_back(pg_names[i]);
names_to_set.push_back(pg_ig_names[i]);
dims_to_set.push_back(dims);
}
}
......
......@@ -144,6 +144,142 @@ class TestDynRNN(unittest.TestCase):
# loss should be small after 100 mini-batch
self.assertLess(val[0], loss_0[0])
# this unit test is just used to the two layer nested dyn_rnn.
def test_train_nested_dyn_rnn(self):
word_dict = [i for i in range(30)]
def fake_reader():
seq_len, label = [[2, 2]], [0, 1]
data = []
for ele in seq_len:
for j in ele:
data.append([numpy.random.randint(30) \
for _ in range(j)])
while True:
yield data, label
train_data = paddle.batch(fake_reader, batch_size=2)
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
sentence = fluid.layers.data(
name='word', shape=[1], dtype='int64', lod_level=2)
label = fluid.layers.data(
name='label', shape=[1], dtype='float32', lod_level=1)
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sentence)
sent_emb = fluid.layers.embedding(
input=in_, size=[len(word_dict), 32], dtype='float32')
out_ = fluid.layers.fc(input=sent_emb, size=100, act='tanh')
rnn1 = fluid.layers.DynamicRNN()
with rnn1.block():
in_1 = rnn1.step_input(out_)
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1)
last = fluid.layers.sequence_last_step(input=rnn1())
rnn.output(last)
last = rnn()
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(1e-3)
#sgd = fluid.optimizer.Adam(1e-3)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(train_data())
val = exe.run(main_program, feed=feeder.feed(data),
fetch_list=[loss])[0]
for _ in range(100):
val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
print(val)
# this unit test is just used to the two layer nested dyn_rnn.
def test_train_nested_dyn_rnn2(self):
word_dict = [i for i in range(30)]
def fake_reader():
seq_len, label = [[2, 2]], [0, 1]
data = []
for ele in seq_len:
for j in ele:
data.append([numpy.random.randint(30) \
for _ in range(j)])
while True:
yield data, label
train_data = paddle.batch(fake_reader, batch_size=2)
hidden_size = 32
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
sentence = fluid.layers.data(
name='word', shape=[1], dtype='int64', lod_level=2)
label = fluid.layers.data(
name='label', shape=[1], dtype='float32', lod_level=1)
rnn = fluid.layers.DynamicRNN()
with rnn.block():
in_ = rnn.step_input(sentence)
sent_emb = fluid.layers.embedding(
input=in_,
size=[len(word_dict), hidden_size],
dtype='float32')
input_forward_proj = fluid.layers.fc(input=sent_emb,
size=hidden_size * 4,
act=None,
bias_attr=False)
forward, _ = fluid.layers.dynamic_lstm(
input=input_forward_proj,
size=hidden_size * 4,
use_peepholes=False)
rnn1 = fluid.layers.DynamicRNN()
with rnn1.block():
in_1 = rnn1.step_input(forward)
out_1 = fluid.layers.fc(input=[in_1], size=100, act='tanh')
rnn1.output(out_1)
last = fluid.layers.sequence_last_step(input=rnn1())
rnn.output(last)
last = rnn()
logits = fluid.layers.fc(input=last, size=1, act=None)
loss = fluid.layers.sigmoid_cross_entropy_with_logits(
x=logits, label=label)
loss = fluid.layers.mean(loss)
sgd = fluid.optimizer.SGD(1e-3)
#sgd = fluid.optimizer.Adam(1e-3)
sgd.minimize(loss=loss)
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(startup_program)
feeder = fluid.DataFeeder(feed_list=[sentence, label], place=cpu)
data = next(train_data())
val = exe.run(main_program, feed=feeder.feed(data),
fetch_list=[loss])[0]
for _ in range(100):
val = exe.run(main_program,
feed=feeder.feed(data),
fetch_list=[loss])[0]
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册