未验证 提交 1511a049 编写于 作者: Y Yang yaming 提交者: GitHub

Merge pull request #7540 from pkuyym/fix-7533

Add reorder flag for DynamicRNN's memory function.
......@@ -1343,20 +1343,44 @@ class DynamicRNN(object):
else:
return self.outputs
def memory(self, init=None, shape=None, value=0.0, dtype='float32'):
def memory(self,
init=None,
shape=None,
value=0.0,
need_reorder=False,
dtype='float32'):
self._assert_in_rnn_block_('memory')
if init is not None:
if not isinstance(init, Variable):
raise TypeError(
"The input arg `init` of memory() must be a Variable")
parent_block = self._parent_block_()
init_tensor = init
if need_reorder == True:
if self.lod_rank_table is None:
raise ValueError(
'If set need_reorder to True, make sure step_input be '
'invoked before '
'memory(init=init, need_reordered=True, ...).')
init_reordered = parent_block.create_var(
name=unique_name('dynamic_rnn_mem_init_reordered'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=init.dtype)
parent_block.append_op(
type='reorder_lod_tensor_by_rank',
inputs={
'X': [init_tensor],
'RankTable': [self.lod_rank_table]
},
outputs={'Out': [init_reordered]})
init_tensor = init_reordered
mem_array = parent_block.create_var(
name=unique_name('dynamic_rnn_mem_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=init.dtype)
parent_block.append_op(
type='write_to_array',
inputs={'X': init,
inputs={'X': init_tensor,
'I': self.zero_idx},
outputs={'Out': mem_array})
retv = array_read(array=mem_array, i=self.step_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册