提交 b2ee9190 编写于 作者: Y Yang Yang

add parallel_do test

上级 aea5ccca
......@@ -23,9 +23,9 @@ namespace operators {
constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places";
constexpr char kParallelBlock[] = "parallel_block";
constexpr char kParallelBlock[] = "sub_block";
constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "sub_block";
constexpr char kParallelScopes[] = "sub_scopes";
// #define GRAD_SUFFIX "@GRAD"
// constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
// constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
......
......@@ -5,12 +5,12 @@ from tensor import assign, fill_constant
import contextlib
__all__ = [
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard', 'StaticRNNGuard',
'StaticRNNMemoryLink', 'WhileGuard', 'While', 'lod_rank_table',
'max_sequence_len', 'topk', 'lod_tensor_to_array', 'array_to_lod_tensor',
'increment', 'array_write', 'create_array', 'less_than', 'array_read',
'shrink_memory', 'array_length', 'IfElse', 'DynamicRNN', 'ConditionalBlock',
'StaticRNN'
'split_lod_tensor', 'merge_lod_tensor', 'BlockGuard',
'BlockGuardWithCompletion', 'StaticRNNMemoryLink', 'WhileGuard', 'While',
'lod_rank_table', 'max_sequence_len', 'topk', 'lod_tensor_to_array',
'array_to_lod_tensor', 'increment', 'array_write', 'create_array',
'less_than', 'array_read', 'shrink_memory', 'array_length', 'IfElse',
'DynamicRNN', 'ConditionalBlock', 'StaticRNN', 'ParallelDo'
]
......@@ -67,29 +67,117 @@ class BlockGuard(object):
return True
class StaticRNNGuard(BlockGuard):
class ParallelDo(object):
"""
StaticRNNGuard class.
ParallelDo class.
StaticRNNGuard class is used to create a StaticRNN block in a program.
ParallelDo class is used to create a ParallelDo.
"""
def __init__(self, places, name=None):
self.helper = LayerHelper("parallel_do", name=name)
self.inputs = []
self.places = places
self.outputs = []
self.status = StaticRNN.BEFORE_RNN_BLOCK
def do(self):
return BlockGuardWithCompletion(self)
def parent_block(self):
prog = self.helper.main_program
parent_idx = prog.current_block().parent_idx
assert parent_idx >= 0
parent_block = prog.block(parent_idx)
return parent_block
def __call__(self, *args, **kwargs):
if self.status != StaticRNN.AFTER_RNN_BLOCK:
raise ValueError("RNN output can only be retrieved after rnn block")
if len(self.outputs) == 0:
raise ValueError("RNN has no output")
elif len(self.outputs) == 1:
return self.outputs[0]
else:
return self.outputs
def read_input(self, var):
self.inputs.append(var)
def write_output(self, var):
self.outputs.append(var)
def get_parameters(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
local_inputs = set()
for op in current_block.ops:
for oname in op.output_names:
for out_var_name in op.output(oname):
local_inputs.add(out_var_name)
for var in self.inputs:
local_inputs.add(var.name)
params = list()
for op in current_block.ops:
for iname in op.input_names:
for in_var_name in op.input(iname):
if in_var_name not in local_inputs:
params.append(in_var_name)
return [parent_block.var(name) for name in params]
def complete_op(self):
main_program = self.helper.main_program
current_block = main_program.current_block()
parent_block = self.parent_block()
step_scope = parent_block.create_var(
type=core.VarDesc.VarType.STEP_SCOPES)
inputs = [parent_block.var(i.name) for i in self.inputs]
parent_block.append_op(
type='parallel_do',
inputs={
'inputs': inputs,
'parameters': self.get_parameters(),
'places': self.places
},
outputs={'outputs': self.outputs,
'step_scopes': [step_scope]},
attrs={'sub_block': current_block})
class BlockGuardWithCompletion(BlockGuard):
"""
BlockGuardWithCompletion class.
BlockGuardWithCompletion class is used to create an op with a block in a program.
"""
def __init__(self, rnn):
if not isinstance(rnn, StaticRNN):
raise TypeError("StaticRNNGuard takes a StaticRNN")
super(StaticRNNGuard, self).__init__(rnn.helper.main_program)
if not (isinstance(rnn, StaticRNN) or isinstance(rnn, ParallelDo)):
raise TypeError(
"BlockGuardWithCompletion takes a StaticRNN or ParallelDo")
super(BlockGuardWithCompletion, self).__init__(rnn.helper.main_program)
self.rnn = rnn
def __enter__(self):
self.rnn.status = StaticRNN.IN_RNN_BLOCK
return super(StaticRNNGuard, self).__enter__()
return super(BlockGuardWithCompletion, self).__enter__()
def __exit__(self, exc_type, exc_val, exc_tb):
if exc_type is not None:
return False
self.rnn.status = StaticRNN.AFTER_RNN_BLOCK
self.rnn.complete_rnn_op()
return super(StaticRNNGuard, self).__exit__(exc_type, exc_val, exc_tb)
self.rnn.complete_op()
return super(BlockGuardWithCompletion, self).__exit__(exc_type, exc_val,
exc_tb)
class StaticRNNMemoryLink(object):
......@@ -135,7 +223,7 @@ class StaticRNN(object):
self.seq_len = None
def step(self):
return StaticRNNGuard(self)
return BlockGuardWithCompletion(self)
def _assert_in_rnn_block_(self, method):
if self.status != StaticRNN.IN_RNN_BLOCK:
......@@ -251,7 +339,7 @@ class StaticRNN(object):
else:
return self.outputs
def complete_rnn_op(self):
def complete_op(self):
main_program = self.helper.main_program
rnn_block = main_program.current_block()
parent_block = self.parent_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册