提交 52dad013 编写于 作者: Y yangyaming

Add static_input.

上级 377424bf
...@@ -1210,6 +1210,26 @@ class DynamicRNN(object): ...@@ -1210,6 +1210,26 @@ class DynamicRNN(object):
outputs={'Out': input_array}) outputs={'Out': input_array})
return array_read(array=input_array, i=self.step_idx) return array_read(array=input_array, i=self.step_idx)
def static_input(self, x):
self._assert_in_rnn_block_("static_input")
if not isinstance(x, Variable):
raise TypeError(
"static_input() can only take a Variable as its input")
if self.lod_rank_table is None:
raise RuntimeError(
"static_input() must be called after step_input().")
parent_block = self._parent_block_()
x_reordered = parent_block.create_var(
name=unique_name("dynamic_rnn_static_input_reordered"),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=x.dtype)
parent_block.append_op(
type='reorder_lod_tensor_by_rank',
inputs={'X': [x],
'RankTable': [self.lod_rank_table]},
outputs={'Out': [x_reordered]})
return shrink_memory(x_reordered, self.step_idx, self.lod_rank_table)
@contextlib.contextmanager @contextlib.contextmanager
def block(self): def block(self):
if self.status != DynamicRNN.BEFORE_RNN: if self.status != DynamicRNN.BEFORE_RNN:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册