From b2ee91903daedcd7a2c0eda9096da588811dacae Mon Sep 17 00:00:00 2001 From: Yang Yang Date: Tue, 19 Dec 2017 08:56:03 +0000 Subject: [PATCH] add parallel_do test --- paddle/operators/parallel_do_op.cc | 4 +- python/paddle/v2/fluid/layers/control_flow.py | 122 +++++++++++++++--- 2 files changed, 107 insertions(+), 19 deletions(-) diff --git a/paddle/operators/parallel_do_op.cc b/paddle/operators/parallel_do_op.cc index 3ab4bd3df28..4c026c2239e 100644 --- a/paddle/operators/parallel_do_op.cc +++ b/paddle/operators/parallel_do_op.cc @@ -23,9 +23,9 @@ namespace operators { constexpr char kInputs[] = "inputs"; constexpr char kParameters[] = "parameters"; constexpr char kPlaces[] = "places"; -constexpr char kParallelBlock[] = "parallel_block"; +constexpr char kParallelBlock[] = "sub_block"; constexpr char kOutputs[] = "outputs"; -constexpr char kParallelScopes[] = "sub_block"; +constexpr char kParallelScopes[] = "sub_scopes"; // #define GRAD_SUFFIX "@GRAD" // constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX; // constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX; diff --git a/python/paddle/v2/fluid/layers/control_flow.py b/python/paddle/v2/fluid/layers/control_flow.py index dc6c0e7f518..4791d749700 100644 --- a/python/paddle/v2/fluid/layers/control_flow.py +++ b/python/paddle/v2/fluid/layers/control_flow.py @@ -5,12 +5,12 @@ from tensor import assign, fill_constant import contextlib __all__ = [ - 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard', - 'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table', - 'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor', - 'increment', 'array_write', 'create_array', 'less_than', 'array_read', - 'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock', - 'StaticRNN' + 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', + 'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While', + 'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array', + 'array_to_lod_tensor', 'increment', 'array_write', 'create_array', + 'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse', + 'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'ParallelDo' ] @@ -67,29 +67,117 @@ class BlockGuard(object): return True -class StaticRNNGuard(BlockGuard): +class ParallelDo(object): """ - StaticRNNGuard class. + ParallelDo class. - StaticRNNGuard class is used to create a StaticRNN block in a program. + ParallelDo class is used to create a ParallelDo. + """ + + def __init__(self, places, name=None): + self.helper = LayerHelper("parallel_do", name=name) + self.inputs = [] + self.places = places + self.outputs = [] + self.status = StaticRNN.BEFORE_RNN_BLOCK + + def do(self): + return BlockGuardWithCompletion(self) + + def parent_block(self): + prog = self.helper.main_program + parent_idx = prog.current_block().parent_idx + assert parent_idx >= 0 + parent_block = prog.block(parent_idx) + return parent_block + + def __call__(self, *args, **kwargs): + if self.status != StaticRNN.AFTER_RNN_BLOCK: + raise ValueError("RNN output can only be retrieved after rnn block") + if len(self.outputs) == 0: + raise ValueError("RNN has no output") + elif len(self.outputs) == 1: + return self.outputs[0] + else: + return self.outputs + + def read_input(self, var): + self.inputs.append(var) + + def write_output(self, var): + self.outputs.append(var) + + def get_parameters(self): + main_program = self.helper.main_program + current_block = main_program.current_block() + parent_block = self.parent_block() + + local_inputs = set() + + for op in current_block.ops: + for oname in op.output_names: + for out_var_name in op.output(oname): + local_inputs.add(out_var_name) + + for var in self.inputs: + local_inputs.add(var.name) + + params = list() + for op in current_block.ops: + for iname in op.input_names: + for in_var_name in op.input(iname): + if in_var_name not in local_inputs: + params.append(in_var_name) + + return [parent_block.var(name) for name in params] + + def complete_op(self): + main_program = self.helper.main_program + current_block = main_program.current_block() + parent_block = self.parent_block() + + step_scope = parent_block.create_var( + type=core.VarDesc.VarType.STEP_SCOPES) + + inputs = [parent_block.var(i.name) for i in self.inputs] + + parent_block.append_op( + type='parallel_do', + inputs={ + 'inputs': inputs, + 'parameters': self.get_parameters(), + 'places': self.places + }, + outputs={'outputs': self.outputs, + 'step_scopes': [step_scope]}, + attrs={'sub_block': current_block}) + + +class BlockGuardWithCompletion(BlockGuard): + """ + BlockGuardWithCompletion class. + + BlockGuardWithCompletion class is used to create an op with a block in a program. """ def __init__(self, rnn): - if not isinstance(rnn, StaticRNN): - raise TypeError("StaticRNNGuard takes a StaticRNN") - super(StaticRNNGuard, self).__init__(rnn.helper.main_program) + if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)): + raise TypeError( + "BlockGuardWithCompletion takes a StaticRNN or ParallelDo") + super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program) self.rnn = rnn def __enter__(self): self.rnn.status = StaticRNN.IN_RNN_BLOCK - return super(StaticRNNGuard, self).__enter__() + return super(BlockGuardWithCompletion, self).__enter__() def __exit__(self, exc_type, exc_val, exc_tb): if exc_type is not None: return False self.rnn.status = StaticRNN.AFTER_RNN_BLOCK - self.rnn.complete_rnn_op() - return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb) + self.rnn.complete_op() + return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val, + exc_tb) class StaticRNNMemoryLink(object): @@ -135,7 +223,7 @@ class StaticRNN(object): self.seq_len = None def step(self): - return StaticRNNGuard(self) + return BlockGuardWithCompletion(self) def _assert_in_rnn_block_(self, method): if self.status != StaticRNN.IN_RNN_BLOCK: @@ -251,7 +339,7 @@ class StaticRNN(object): else: return self.outputs - def complete_rnn_op(self): + def complete_op(self): main_program = self.helper.main_program rnn_block = main_program.current_block() parent_block = self.parent_block() -- GitLab