未验证 提交 18911b6e 编写于 作者: W whs 提交者: GitHub

[enhence] Make step_input of dynamic_rnn support custom lod level. (#15972)

* Make step_input support custom lod level.
test=develop

* Fix API.spec
test=develop

* Fix API.spec.
test=develop

* Fix API.spec
test=develop

* Add default value in document of step_input.
test=develop

* Fix document.
test=develop

* Fix API.spec
test=develop
上级 d3acf680
...@@ -277,7 +277,7 @@ paddle.fluid.layers.DynamicRNN.block (ArgSpec(args=['self'], varargs=None, keywo ...@@ -277,7 +277,7 @@ paddle.fluid.layers.DynamicRNN.block (ArgSpec(args=['self'], varargs=None, keywo
paddle.fluid.layers.DynamicRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'value', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, False, 'float32')), ('document', 'b9174d4e91505b0c8ecc193eb51e248d')) paddle.fluid.layers.DynamicRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'value', 'need_reorder', 'dtype'], varargs=None, keywords=None, defaults=(None, None, 0.0, False, 'float32')), ('document', 'b9174d4e91505b0c8ecc193eb51e248d'))
paddle.fluid.layers.DynamicRNN.output (ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None), ('document', 'b439a176a3328de8a75bdc5c08eece4a')) paddle.fluid.layers.DynamicRNN.output (ArgSpec(args=['self'], varargs='outputs', keywords=None, defaults=None), ('document', 'b439a176a3328de8a75bdc5c08eece4a'))
paddle.fluid.layers.DynamicRNN.static_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', 'f29ad2478b6b2ad4f413d2936a331ea0')) paddle.fluid.layers.DynamicRNN.static_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', 'f29ad2478b6b2ad4f413d2936a331ea0'))
paddle.fluid.layers.DynamicRNN.step_input (ArgSpec(args=['self', 'x'], varargs=None, keywords=None, defaults=None), ('document', '169d694d2224f62b4f3afdc3dbc19e95')) paddle.fluid.layers.DynamicRNN.step_input (ArgSpec(args=['self', 'x', 'level'], varargs=None, keywords=None, defaults=(0,)), ('document', '7568c5ac7622a10288d3307a94134655'))
paddle.fluid.layers.DynamicRNN.update_memory (ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None), ('document', '5d83987da13b98363d6a807a52d8024f')) paddle.fluid.layers.DynamicRNN.update_memory (ArgSpec(args=['self', 'ex_mem', 'new_mem'], varargs=None, keywords=None, defaults=None), ('document', '5d83987da13b98363d6a807a52d8024f'))
paddle.fluid.layers.StaticRNN.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754')) paddle.fluid.layers.StaticRNN.__init__ (ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=(None,)), ('document', '6adf97f83acf6453d4a6a4b1070f3754'))
paddle.fluid.layers.StaticRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)), ('document', 'c24e368e23afac1ed91a78a639d7a9c7')) paddle.fluid.layers.StaticRNN.memory (ArgSpec(args=['self', 'init', 'shape', 'batch_ref', 'init_value', 'init_batch_dim_idx', 'ref_batch_dim_idx'], varargs=None, keywords=None, defaults=(None, None, None, 0.0, 0, 1)), ('document', 'c24e368e23afac1ed91a78a639d7a9c7'))
......
...@@ -1448,12 +1448,13 @@ class DynamicRNN(object): ...@@ -1448,12 +1448,13 @@ class DynamicRNN(object):
self.input_array = [] self.input_array = []
self.mem_link = [] self.mem_link = []
def step_input(self, x): def step_input(self, x, level=0):
""" """
Mark a sequence as a dynamic RNN input. Mark a sequence as a dynamic RNN input.
Args: Args:
x(Variable): The input sequence. x(Variable): The input sequence.
level(int): The level of lod used to split steps. Default: 0.
Returns: Returns:
The current timestep in the input sequence. The current timestep in the input sequence.
...@@ -1471,7 +1472,8 @@ class DynamicRNN(object): ...@@ -1471,7 +1472,8 @@ class DynamicRNN(object):
parent_block.append_op( parent_block.append_op(
type='lod_rank_table', type='lod_rank_table',
inputs={"X": x}, inputs={"X": x},
outputs={"Out": self.lod_rank_table}) outputs={"Out": self.lod_rank_table},
attrs={"level": level})
self.max_seq_len = parent_block.create_var( self.max_seq_len = parent_block.create_var(
name=unique_name.generate('dynamic_rnn_max_seq_len'), name=unique_name.generate('dynamic_rnn_max_seq_len'),
dtype='int64') dtype='int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册