提交 b2ee9190 编写于 作者: Y Yang Yang

add parallel_do test

上级 aea5ccca
...@@ -23,9 +23,9 @@ namespace operators { ...@@ -23,9 +23,9 @@ namespace operators {
constexpr char kInputs[] = "inputs"; constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters"; constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places"; constexpr char kPlaces[] = "places";
constexpr char kParallelBlock[] = "parallel_block"; constexpr char kParallelBlock[] = "sub_block";
constexpr char kOutputs[] = "outputs"; constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "sub_block"; constexpr char kParallelScopes[] = "sub_scopes";
// #define GRAD_SUFFIX "@GRAD" // #define GRAD_SUFFIX "@GRAD"
// constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX; // constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
// constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX; // constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
......
...@@ -5,12 +5,12 @@ from tensor import assign, fill_constant ...@@ -5,12 +5,12 @@ from tensor import assign, fill_constant
import contextlib import contextlib
__all__ = [ __all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard', 'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard',
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table', 'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While',
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor', 'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array',
'increment', 'array_write', 'create_array', 'less_than', 'array_read', 'array_to_lod_tensor', 'increment', 'array_write', 'create_array',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock', 'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse',
'StaticRNN' 'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'ParallelDo'
] ]
...@@ -67,29 +67,117 @@ class BlockGuard(object): ...@@ -67,29 +67,117 @@ class BlockGuard(object):
return True return True
class StaticRNNGuard(BlockGuard): class ParallelDo(object):
""" """
StaticRNNGuard class. ParallelDo class.
StaticRNNGuard class is used to create a StaticRNN block in a program. ParallelDo class is used to create a ParallelDo.
"""
def __init__(self, places, name=None):
self.helper = LayerHelper("parallel_do", name=name)
self.inputs = []
self.places = places
self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK
def do(self):
return BlockGuardWithCompletion(self)
def parent_block(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def read_input(self, var):
self.inputs.append(var)
def write_output(self, var):
self.outputs.append(var)
def get_parameters(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
local_inputs = set()
for op in current_block.ops:
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)
for var in self.inputs:
local_inputs.add(var.name)
params = list()
for op in current_block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)
return [parent_block.var(name) for name in params]
def complete_op(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
inputs = [parent_block.var(i.name) for i in self.inputs]
parent_block.append_op(
type='parallel_do',
inputs={
'inputs': inputs,
'parameters': self.get_parameters(),
'places': self.places
},
outputs={'outputs': self.outputs,
'step_scopes': [step_scope]},
attrs={'sub_block': current_block})
class BlockGuardWithCompletion(BlockGuard):
"""
BlockGuardWithCompletion class.
BlockGuardWithCompletion class is used to create an op with a block in a program.
""" """
def __init__(self, rnn): def __init__(self, rnn):
if not isinstance(rnn, StaticRNN): if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)):
raise TypeError("StaticRNNGuard takes a StaticRNN") raise TypeError(
super(StaticRNNGuard, self).__init__(rnn.helper.main_program) "BlockGuardWithCompletion takes a StaticRNN or ParallelDo")
super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program)
self.rnn = rnn self.rnn = rnn
def __enter__(self): def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__() return super(BlockGuardWithCompletion, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None: if exc_type is not None:
return False return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op() self.rnn.complete_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb) return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val,
exc_tb)
class StaticRNNMemoryLink(object): class StaticRNNMemoryLink(object):
...@@ -135,7 +223,7 @@ class StaticRNN(object): ...@@ -135,7 +223,7 @@ class StaticRNN(object):
self.seq_len = None self.seq_len = None
def step(self): def step(self):
return StaticRNNGuard(self) return BlockGuardWithCompletion(self)
def _assert_in_rnn_block_(self, method): def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK: if self.status != StaticRNN.IN_RNN_BLOCK:
...@@ -251,7 +339,7 @@ class StaticRNN(object): ...@@ -251,7 +339,7 @@ class StaticRNN(object):
else: else:
return self.outputs return self.outputs
def complete_rnn_op(self): def complete_op(self):
main_program = self.helper.main_program main_program = self.helper.main_program
rnn_block = main_program.current_block() rnn_block = main_program.current_block()
parent_block = self.parent_block() parent_block = self.parent_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册