未验证 提交 69536892 编写于 作者: H hong 提交者: GitHub

change staticRNN to while (#48213)

* change staticRNN to while

* update code

* fix rnn bug

* update

* fix _find_op_path_ bugs in append_backward.

* polish code

* revert op proto

* update

* udpate while

* format

* revert test while loop op

* fix create array

* fix windows error

* fix bug

* update

* fix array write bug
Co-authored-by: Nxiongkun <xiongkun03@baidu.com>
上级 c4f30c51
...@@ -84,3 +84,4 @@ paddle/fluid/pybind/tmp_eager_op_function_impl.h ...@@ -84,3 +84,4 @@ paddle/fluid/pybind/tmp_eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/eager_op_function_impl.h paddle/fluid/pybind/eager_op_function_impl.h
paddle/fluid/pybind/op_function_impl.h paddle/fluid/pybind/op_function_impl.h
paddle/fluid/pybind/*final_state_op_function_impl.h
...@@ -305,6 +305,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -305,6 +305,7 @@ class WhileOp : public framework::OperatorBase {
cond_data = GetCondData( cond_data = GetCondData(
scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>()); scope.FindVar(Input(kCondition))->Get<phi::DenseTensor>());
} }
scope.DeleteScope(&current_scope); scope.DeleteScope(&current_scope);
} }
} }
...@@ -367,6 +368,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -367,6 +368,7 @@ class WhileGradOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto *parent_block = block->ParentBlock();
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
...@@ -428,6 +430,35 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -428,6 +430,35 @@ class WhileGradOp : public framework::OperatorBase {
continue; continue;
} }
if (cur_scope_iter == step_scopes->rbegin()) {
auto &og_outside = *scope.FindVar(outside_og_name);
if (og_outside.IsType<phi::DenseTensor>() &&
!og_outside.GetMutable<phi::DenseTensor>()->IsInitialized()) {
auto *var_desc = parent_block->FindVarRecursive(outside_og_name);
PADDLE_ENFORCE_NOT_NULL(var_desc,
platform::errors::PreconditionNotMet(
"Var `%s` is not found in parent "
"block, can't fill constant.",
outside_og_name));
auto shape = var_desc->GetShape();
VLOG(8) << "Found uninitialized tensor " << outside_og_name
<< " in step 0, fill it with 0.0f. dims="
<< phi::make_ddim(shape);
framework::AttributeMap attrs;
attrs["dtype"] = var_desc->GetDataType();
attrs["shape"] = phi::vectorize<int>(phi::make_ddim(shape));
attrs["value"] = 0.0f;
auto var_name = outside_og_name;
auto zero_op =
framework::OpRegistry::CreateOp("fill_constant",
framework::VariableNameMap{},
{{"Out", {var_name}}},
attrs);
zero_op->Run(scope, dev_place);
}
}
auto &og_outside = *scope.FindVar(outside_og_name); auto &og_outside = *scope.FindVar(outside_og_name);
auto &og_inside = *cur_scope.Var(inside_og_name); auto &og_inside = *cur_scope.Var(inside_og_name);
if (og_outside.IsType<phi::DenseTensor>()) { if (og_outside.IsType<phi::DenseTensor>()) {
...@@ -534,9 +565,10 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -534,9 +565,10 @@ class WhileGradOp : public framework::OperatorBase {
// continue; // continue;
// } // }
auto var_iter = std::find(outside_og_names.begin(), auto is_var_input_and_output =
outside_og_names.end(), std::find(outside_og_names.begin(),
pg_ig_names[param_id]); outside_og_names.end(),
pg_ig_names[param_id]) != outside_og_names.end();
// zero gradient variable in step 0 // zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) { if (cur_scope_iter == step_scopes->rbegin()) {
...@@ -555,8 +587,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -555,8 +587,7 @@ class WhileGradOp : public framework::OperatorBase {
inside_grad_name, inside_grad_name,
framework::ToTypeName(var->Type()))); framework::ToTypeName(var->Type())));
if ((var_iter == outside_og_names.end()) && if (!is_var_input_and_output && var->IsType<phi::DenseTensor>()) {
var->IsType<phi::DenseTensor>()) {
auto &inside_tensor = var->Get<phi::DenseTensor>(); auto &inside_tensor = var->Get<phi::DenseTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = attrs["dtype"] =
...@@ -575,10 +606,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -575,10 +606,7 @@ class WhileGradOp : public framework::OperatorBase {
inside_tensor.lod()); inside_tensor.lod());
} }
} }
auto var_outside = scope.FindVar(pg_ig_names[param_id]); if (!is_var_input_and_output) {
if ((var_iter == outside_og_names.end()) ||
((var_iter != outside_og_names.end()) &&
var_outside->IsType<framework::LoDTensorArray>())) {
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", "sum",
...@@ -587,6 +615,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -587,6 +615,8 @@ class WhileGradOp : public framework::OperatorBase {
framework::AttributeMap{{"use_mkldnn", {false}}}); framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} else {
ShareVariable(cur_scope, scope, pg_ig_names[param_id]);
} }
} }
dev_ctx.Wait(); dev_ctx.Wait();
...@@ -595,6 +625,29 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -595,6 +625,29 @@ class WhileGradOp : public framework::OperatorBase {
step_scopes->clear(); step_scopes->clear();
} }
void ShareVariable(const framework::Scope &source,
const framework::Scope &dest,
std::string name) const {
auto from_var = source.FindVar(name);
auto to_var = dest.FindVar(name);
if (from_var->IsType<phi::DenseTensor>()) {
if (from_var->Get<phi::DenseTensor>().IsInitialized()) {
to_var->GetMutable<phi::DenseTensor>()->ShareDataWith(
from_var->Get<phi::DenseTensor>());
}
} else if (from_var->IsType<framework::LoDTensorArray>()) {
auto from_arr = from_var->GetMutable<framework::LoDTensorArray>();
auto to_arr = to_var->GetMutable<framework::LoDTensorArray>();
to_arr->clear();
to_arr->resize(from_arr->size());
for (size_t i = 0; i < to_arr->size(); ++i) {
if (from_arr->at(i).IsInitialized()) {
to_arr->at(i).ShareDataWith(from_arr->at(i));
}
}
}
}
private: private:
mutable std::shared_ptr<framework::Executor> executor_{nullptr}; mutable std::shared_ptr<framework::Executor> executor_{nullptr};
mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr}; mutable std::unique_ptr<framework::ExecutorPrepareContext> ctx_{nullptr};
...@@ -646,6 +699,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -646,6 +699,7 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
block_ins.insert(o); block_ins.insert(o);
} }
std::unordered_set<std::string> output_grads; std::unordered_set<std::string> output_grads;
for (const auto *op : grad_block->AllOps()) { for (const auto *op : grad_block->AllOps()) {
for (auto &input_name : op->InputArgumentNames()) { for (auto &input_name : op->InputArgumentNames()) {
// If the input of Op has been recorded or is generated by the forward // If the input of Op has been recorded or is generated by the forward
...@@ -658,7 +712,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -658,7 +712,6 @@ class WhileGradOpMaker : public framework::SingleGradOpMaker<T> {
parent_block->FindVarRecursive(input_name) != nullptr)) { parent_block->FindVarRecursive(input_name) != nullptr)) {
continue; continue;
} }
output_grads.insert(input_name); output_grads.insert(input_name);
} }
for (auto &output_name : op->OutputArgumentNames()) { for (auto &output_name : op->OutputArgumentNames()) {
......
...@@ -2220,6 +2220,10 @@ def _find_op_path_( ...@@ -2220,6 +2220,10 @@ def _find_op_path_(
op.desc.output_arg_names(), output_names op.desc.output_arg_names(), output_names
): ):
relevant_op_flags[i] = True relevant_op_flags[i] = True
if core.has_non_empty_grad_op_maker(op.type):
for name in op.desc.input_arg_names():
if name not in no_grad_set:
output_names.add(name)
op_path = [ op_path = [
block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i] block.ops[i] for i in range(len(block.ops)) if relevant_op_flags[i]
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
paddle.set_default_dtype("float64")
import unittest
import numpy as np
from paddle import fluid
from paddle.fluid import framework
bidirectional_list = ["bidirectional", "bidirect"]
class TestSimpleRNN(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super().__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place
self.batch_size = 4
self.input_size = 16
self.hidden_size = 16
self.seq_len = 12
self.seed = 1234
def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(self.place)
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
cell_dy = paddle.nn.SimpleRNNCell(self.input_size, self.hidden_size)
self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major)
paddle.enable_static()
with paddle.fluid.unique_name.guard():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(
main_program=main_program, startup_program=startup_program
):
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
self.exe = fluid.Executor(
fluid.CPUPlace()
if self.place == "cpu"
else fluid.CUDAPlace(0)
)
rnn_in_data = paddle.static.data(
"x",
[None, self.batch_size, self.hidden_size],
dtype="float64",
)
pre_h_data = paddle.static.data(
"pre_h",
[self.batch_size, self.hidden_size],
dtype="float64",
)
seq_len_data = paddle.static.data(
"seq_len", [self.batch_size], dtype="int64"
)
cell_st = paddle.nn.SimpleRNNCell(
self.input_size, self.hidden_size
)
self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major)
st_out, st_last_h = self.rnn_st(
rnn_in_data, pre_h_data, sequence_length=seq_len_data
)
self.fetch_list = [st_out, st_last_h]
self.exe.run(framework.default_startup_program())
self.main_program = framework.default_main_program()
paddle.disable_static(self.place)
def test_base(self, test_seq_len=False):
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(4, 16)
paddle.disable_static(self.place)
if test_seq_len:
seq_len = np.array([9, 10, 8, 12], "int64")
else:
seq_len = np.array([12, 12, 12, 12], "int64")
y1, h1 = self.rnn_net(
paddle.to_tensor(x),
paddle.to_tensor(prev_h),
sequence_length=paddle.to_tensor(seq_len),
)
paddle.enable_static()
out = self.exe.run(
self.main_program,
feed={"x": x, "pre_h": prev_h, "seq_len": seq_len},
fetch_list=[self.fetch_list],
)
y2, h2 = out
np.testing.assert_allclose(y1.numpy(), y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1.numpy(), h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_base()
self.test_base(True)
class TestGRU(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super().__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place
self.batch_size = 4
self.input_size = 16
self.hidden_size = 16
self.seq_len = 12
self.seed = 1234
def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(self.place)
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
cell_dy = paddle.nn.GRUCell(self.input_size, self.hidden_size)
self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major)
paddle.enable_static()
with paddle.fluid.unique_name.guard():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(
main_program=main_program, startup_program=startup_program
):
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
self.exe = fluid.Executor(
fluid.CPUPlace()
if self.place == "cpu"
else fluid.CUDAPlace(0)
)
rnn_in_data = paddle.static.data(
"x",
[None, self.batch_size, self.hidden_size],
dtype="float64",
)
pre_h_data = paddle.static.data(
"pre_h",
[self.batch_size, self.hidden_size],
dtype="float64",
)
seq_len_data = paddle.static.data(
"seq_len", [self.batch_size], dtype="int64"
)
cell_st = paddle.nn.GRUCell(self.input_size, self.hidden_size)
self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major)
st_out, st_last_h = self.rnn_st(
rnn_in_data, pre_h_data, sequence_length=seq_len_data
)
self.fetch_list = [st_out, st_last_h]
self.exe.run(framework.default_startup_program())
self.main_program = framework.default_main_program()
paddle.disable_static(self.place)
def test_base(self, test_seq_len=False):
x = np.random.randn(12, 4, 16)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(4, 16)
paddle.disable_static(self.place)
if test_seq_len:
seq_len = np.array([9, 10, 8, 12], "int64")
else:
seq_len = np.array([12, 12, 12, 12], "int64")
y1, h1 = self.rnn_net(
paddle.to_tensor(x),
paddle.to_tensor(prev_h),
sequence_length=paddle.to_tensor(seq_len),
)
paddle.enable_static()
out = self.exe.run(
self.main_program,
feed={"x": x, "pre_h": prev_h, "seq_len": seq_len},
fetch_list=[self.fetch_list],
)
y2, h2 = out
np.testing.assert_allclose(y1.numpy(), y2, atol=1e-8, rtol=1e-5)
np.testing.assert_allclose(h1.numpy(), h2, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_base()
self.test_base(True)
class TestGRUBackward(unittest.TestCase):
def __init__(self, time_major=True, direction="forward", place="cpu"):
super().__init__("runTest")
self.time_major = time_major
self.direction = direction
self.num_directions = 2 if direction in bidirectional_list else 1
self.place = place
self.batch_size = 4
self.input_size = 4
self.hidden_size = 4
self.seq_len = 12
self.seed = 1234
def setUp(self):
# Since `set_device` is global, set `set_device` in `setUp` rather than
# `__init__` to avoid using an error device set by another test case.
place = paddle.set_device(self.place)
paddle.disable_static(self.place)
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
cell_dy = paddle.nn.SimpleRNNCell(self.input_size, self.hidden_size)
self.rnn_net = paddle.nn.RNN(cell_dy, time_major=self.time_major)
paddle.enable_static()
with paddle.fluid.unique_name.guard():
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(
main_program=main_program, startup_program=startup_program
):
paddle.seed(self.seed)
paddle.framework.random._manual_program_seed(self.seed)
self.exe = paddle.fluid.Executor(
fluid.CPUPlace()
if self.place == "cpu"
else fluid.CUDAPlace(0)
)
rnn_in_data = paddle.static.data(
"x",
[None, self.batch_size, self.hidden_size],
dtype="float64",
)
pre_h_data = paddle.static.data(
"pre_h",
[self.batch_size, self.hidden_size],
dtype="float64",
)
seq_len_data = paddle.static.data(
"seq_len", [self.batch_size], dtype="int64"
)
pre_h_data.stop_gradient = False
rnn_in_data.stop_gradient = False
cell_st = paddle.nn.SimpleRNNCell(
self.input_size, self.hidden_size
)
self.rnn_st = paddle.nn.RNN(cell_st, time_major=self.time_major)
st_out, st_last_h = self.rnn_st(
rnn_in_data, pre_h_data, sequence_length=seq_len_data
)
loss = paddle.sum(st_out)
sgd = paddle.optimizer.SGD(0.0)
sgd.minimize(loss)
self.fetch_list = [st_out, st_last_h, "pre_h@GRAD", "x@GRAD"]
self.exe.run(framework.default_startup_program())
self.main_program = framework.default_main_program()
paddle.disable_static(self.place)
def test_base(self, test_seq_len=False):
x = np.random.randn(12, 4, self.hidden_size)
if not self.time_major:
x = np.transpose(x, [1, 0, 2])
prev_h = np.random.randn(4, self.hidden_size)
paddle.disable_static(self.place)
if test_seq_len:
seq_len = np.array([9, 10, 8, 12], "int64")
else:
seq_len = np.array([12, 12, 12, 12], "int64")
x_in = paddle.to_tensor(x)
h_in = paddle.to_tensor(prev_h)
x_in.stop_gradient = False
h_in.stop_gradient = False
y1, h1 = self.rnn_net(
x_in,
h_in,
sequence_length=paddle.to_tensor(seq_len),
)
loss = y1.sum()
loss.backward()
h1_grad = h_in.gradient()
paddle.enable_static()
out = self.exe.run(
self.main_program,
feed={"x": x, "pre_h": prev_h, "seq_len": seq_len},
fetch_list=[self.fetch_list],
)
y2, h2, g1, g2 = out
np.testing.assert_allclose(h1_grad, g1, atol=1e-8, rtol=1e-5)
def runTest(self):
self.test_base(True)
self.test_base()
self.test_base()
self.test_base(True)
if __name__ == "__main__":
paddle.enable_static()
unittest.main()
...@@ -248,8 +248,8 @@ def _rnn_static_graph( ...@@ -248,8 +248,8 @@ def _rnn_static_graph(
if not time_major: if not time_major:
inputs = map_structure(_transpose_batch_time, inputs) inputs = map_structure(_transpose_batch_time, inputs)
max_seq_len = paddle.shape(flatten(inputs)[0])[0]
if sequence_length: if sequence_length:
max_seq_len = paddle.shape(flatten(inputs)[0])[0]
mask = sequence_lod.sequence_mask( mask = sequence_lod.sequence_mask(
sequence_length, sequence_length,
maxlen=max_seq_len, maxlen=max_seq_len,
...@@ -260,30 +260,77 @@ def _rnn_static_graph( ...@@ -260,30 +260,77 @@ def _rnn_static_graph(
inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs) inputs = map_structure(lambda x: paddle.reverse(x, axis=[0]), inputs)
mask = paddle.reverse(mask, axis=[0]) if sequence_length else None mask = paddle.reverse(mask, axis=[0]) if sequence_length else None
# StaticRNN with paddle.fluid.framework.device_guard("cpu"):
rnn = control_flow.StaticRNN() start_i = paddle.zeros([1], dtype="int64")
with rnn.step(): end = max_seq_len
inputs = map_structure(rnn.step_input, inputs)
states = map_structure(rnn.memory, initial_states) end = paddle.cast(end, "int64")
copy_states = map_structure(lambda x: x, states) cond = start_i < end
outputs, new_states = cell(inputs, copy_states, **kwargs) while_op = control_flow.While(cond)
utils.assert_same_structure(states, new_states)
out_array = paddle.tensor.create_array(dtype=flatten(inputs)[0].dtype)
init_array = map_structure(
lambda x: paddle.tensor.create_array(dtype=x.dtype), initial_states
)
map_structure(
lambda x, y: paddle.tensor.array_write(x, start_i, y),
initial_states,
init_array,
)
with while_op.block():
step_in = inputs[start_i]
# step_in = paddle.fluid.layers.Print( step_in, message="step in")
pre_state = map_structure(
lambda x: paddle.tensor.array_read(x, start_i), init_array
)
# pre_state = paddle.fluid.layers.Print( pre_state, message="pre")
outputs, new_states = cell(step_in, pre_state, **kwargs)
assert isinstance(outputs, paddle.fluid.framework.Variable)
utils.assert_same_structure(new_states, pre_state)
if sequence_length: if sequence_length:
step_mask = rnn.step_input(mask) step_mask = paddle.unsqueeze(mask[start_i], 1)
# paddle.fluid.layers.Print( step_mask, message="mask")
# new_states = map_structure(
# partial(_maybe_copy, step_mask=step_mask),
# pre_state, new_states
# )
new_states = map_structure( new_states = map_structure(
partial(_maybe_copy, step_mask=step_mask), states, new_states lambda x, y: (x * step_mask + y * (1.0 - step_mask)),
new_states,
pre_state,
) )
map_structure(rnn.update_memory, states, new_states) paddle.tensor.array_write(outputs, start_i, out_array)
flat_outputs = flatten(outputs)
map_structure(rnn.step_output, outputs) with paddle.fluid.framework.device_guard("cpu"):
map_structure(rnn.step_output, new_states)
start_i = paddle.tensor.increment(x=start_i, value=1)
map_structure(
lambda x, y: paddle.tensor.array_write(x, start_i, y),
new_states,
init_array,
)
with paddle.fluid.framework.device_guard("cpu"):
new_cond = paddle.tensor.less_than(start_i, end)
paddle.fluid.layers.assign(new_cond, cond)
out, _ = paddle.fluid.layers.tensor_array_to_tensor(
out_array, axis=0, use_stack=True
)
rnn_out = rnn() all_state = map_structure(
final_outputs = rnn_out[: len(flat_outputs)] lambda x: paddle.fluid.layers.tensor_array_to_tensor(
final_outputs = utils.pack_sequence_as(outputs, final_outputs) x, axis=0, use_stack=True
final_states = map_structure(lambda x: x[-1], rnn_out[len(flat_outputs) :]) )[0],
final_states = utils.pack_sequence_as(new_states, final_states) init_array,
)
final_outputs = out
final_states = map_structure(lambda x: x[-1], all_state)
if is_reverse: if is_reverse:
final_outputs = map_structure( final_outputs = map_structure(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册