提交 c01bb26f 编写于 作者: Y yangyaming

Add reorder flag for DynamicRNN's memory function.

上级 cb6b468e
...@@ -1310,20 +1310,44 @@ class DynamicRNN(object): ...@@ -1310,20 +1310,44 @@ class DynamicRNN(object):
else: else:
return self.outputs return self.outputs
def memory(self, init=None, shape=None, value=0.0, dtype='float32'): def memory(self,
init=None,
shape=None,
value=0.0,
need_reorder=False,
dtype='float32'):
self._assert_in_rnn_block_('memory') self._assert_in_rnn_block_('memory')
if init is not None: if init is not None:
if not isinstance(init, Variable): if not isinstance(init, Variable):
raise TypeError( raise TypeError(
"The input arg `init` of memory() must be a Variable") "The input arg `init` of memory() must be a Variable")
parent_block = self._parent_block_() parent_block = self._parent_block_()
init_tensor = init
if need_reorder == True:
if self.lod_rank_table is None:
raise ValueError(
'If set need_reorder to True, make sure step_input be '
'invoked before '
'memory(init=init, need_reordered=True, ...).')
init_reordered = parent_block.create_var(
name=unique_name('dynamic_rnn_mem_init_reordered'),
type=core.VarDesc.VarType.LOD_TENSOR,
dtype=init.dtype)
parent_block.append_op(
type='reorder_lod_tensor_by_rank',
inputs={
'X': [init_tensor],
'RankTable': [self.lod_rank_table]
},
outputs={'Out': [init_reordered]})
init_tensor = init_reordered
mem_array = parent_block.create_var( mem_array = parent_block.create_var(
name=unique_name('dynamic_rnn_mem_array'), name=unique_name('dynamic_rnn_mem_array'),
type=core.VarDesc.VarType.LOD_TENSOR_ARRAY, type=core.VarDesc.VarType.LOD_TENSOR_ARRAY,
dtype=init.dtype) dtype=init.dtype)
parent_block.append_op( parent_block.append_op(
type='write_to_array', type='write_to_array',
inputs={'X': init, inputs={'X': init_tensor,
'I': self.zero_idx}, 'I': self.zero_idx},
outputs={'Out': mem_array}) outputs={'Out': mem_array})
retv = array_read(array=mem_array, i=self.step_idx) retv = array_read(array=mem_array, i=self.step_idx)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册