From 52dad013ce914f27a7bc296dbd0435090f23d9b2 Mon Sep 17 00:00:00 2001 From: yangyaming Date: Wed, 10 Jan 2018 20:23:02 +0800 Subject: [PATCH] Add static_input. --- python/paddle/v2/fluid/layers/control_flow.py | 20 +++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index 9ad021fa992..f134e56cda6 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -1210,6 +1210,26 @@ class DynamicRNN(object): outputs={'Out': input_array}) return array_read(array=input_array, i=self.step_idx) + def static_input(self, x): + self._assert_in_rnn_block_("static_input") + if not isinstance(x, Variable): + raise TypeError( + "static_input() can only take a Variable as its input") + if self.lod_rank_table is None: + raise RuntimeError( + "static_input() must be called after step_input().") + parent_block = self._parent_block_() + x_reordered = parent_block.create_var( + name=unique_name("dynamic_rnn_static_input_reordered"), + type=core.VarDesc.VarType.LOD_TENSOR, + dtype=x.dtype) + parent_block.append_op( + type='reorder_lod_tensor_by_rank', + inputs={'X': [x], + 'RankTable': [self.lod_rank_table]}, + outputs={'Out': [x_reordered]}) + return shrink_memory(x_reordered, self.step_idx, self.lod_rank_table) + @contextlib.contextmanager def block(self): if self.status != DynamicRNN.BEFORE_RNN: -- GitLab