提交 50cf103e 编写于 作者: Q qijun

make memory optimization module compatible with parallel_do

上级 bf9ed4a9
...@@ -29,6 +29,8 @@ dtype_to_size = { ...@@ -29,6 +29,8 @@ dtype_to_size = {
core.VarDesc.VarType.BOOL: 1 core.VarDesc.VarType.BOOL: 1
} }
sub_block_ops = ["while", "while_grad", "parallel_do", "parallel_do_grad"]
class ControlFlowGraph(object): class ControlFlowGraph(object):
def __init__(self, Program, ops, forward_num, skip_opt): def __init__(self, Program, ops, forward_num, skip_opt):
...@@ -141,7 +143,7 @@ class ControlFlowGraph(object): ...@@ -141,7 +143,7 @@ class ControlFlowGraph(object):
self.pool = [] self.pool = []
for i in range(self.op_size): for i in range(self.op_size):
op = self._ops[i] op = self._ops[i]
if op.type() == "while" or op.type() == "while_grad": if op.type() in sub_block_ops:
continue continue
block_desc = op.block() block_desc = op.block()
is_forward = i < self._forward_num is_forward = i < self._forward_num
...@@ -198,67 +200,75 @@ class ControlFlowGraph(object): ...@@ -198,67 +200,75 @@ class ControlFlowGraph(object):
block_desc, var_name, is_forward).shape())) block_desc, var_name, is_forward).shape()))
def get_cfgs(input_program): def _process_sub_block_pair(pdesc, sub_block_pair):
ops_list = [] ops_list = []
pdesc = input_program.get_desc()
block_desc = pdesc.block(0) block_desc = pdesc.block(0)
op_size = block_desc.op_size() op_size = block_desc.op_size()
# Get global block ops for fwd_op, bwd_op in sub_block_pair:
ops_list.append( sub_block_ids = []
([block_desc.op(i) for i in range(op_size)], op_size, set())) grad_sub_block_ids = []
sub_block_id_pair = []
while_sub_block_ids = [] sub_op_dict = {}
while_grad_sub_block_ids = []
while_block_id_pair = []
while_op_dict = {}
for i in range(op_size): for i in range(op_size):
op = block_desc.op(i) op = block_desc.op(i)
if op.type() == "while": if op.type() == fwd_op:
while_sub_block_ids.append(op.attr("sub_block").id) sub_block_ids.append(op.attr("sub_block").id)
while_op_dict[op.attr("sub_block").id] = op sub_op_dict[op.attr("sub_block").id] = op
elif op.type() == "while_grad": elif op.type() == bwd_op:
while_grad_sub_block_ids.append(op.attr("sub_block").id) grad_sub_block_ids.append(op.attr("sub_block").id)
while_op_dict[op.attr("sub_block").id] = op sub_op_dict[op.attr("sub_block").id] = op
# Find while/while_grad block pair # Find fwd_op/bwd_op block pair
for grad_id in while_grad_sub_block_ids: for grad_id in grad_sub_block_ids:
parent_id = pdesc.block(grad_id).parent parent_id = pdesc.block(grad_id).parent
if parent_id in while_sub_block_ids: if parent_id in sub_block_ids:
while_block_id_pair.append((parent_id, grad_id)) sub_block_id_pair.append((parent_id, grad_id))
while_sub_block_ids.remove(parent_id) sub_block_ids.remove(parent_id)
# Get fwd_op/bwd_op block ops
for parent_id, grad_id in sub_block_id_pair:
sub_block_ops = []
sub_block = pdesc.block(parent_id)
block_op_size = sub_block.op_size()
for i in range(block_op_size):
sub_block_ops.append(sub_block.op(i))
# Get while/while_grad block ops grad_sub_block = pdesc.block(grad_id)
for parent_id, grad_id in while_block_id_pair: grad_sub_block_op_size = grad_sub_block.op_size()
while_block_ops = [] for i in range(grad_sub_block_op_size):
while_block = pdesc.block(parent_id) sub_block_ops.append(grad_sub_block.op(i))
while_block_op_size = while_block.op_size()
for i in range(while_block_op_size):
while_block_ops.append(while_block.op(i))
while_grad_block = pdesc.block(grad_id) sub_op_output = set()
while_grad_block_op_size = while_grad_block.op_size() sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
for i in range(while_grad_block_op_size): sub_op_output.update(sub_op_dict[grad_id].output_arg_names())
while_block_ops.append(while_grad_block.op(i)) ops_list.append((sub_block_ops, block_op_size, sub_op_output))
while_op_output = set() # Process rest fwd_op block ops
while_op_output.update(while_op_dict[parent_id].output_arg_names()) for parent_id in sub_block_ids:
while_op_output.update(while_op_dict[grad_id].output_arg_names()) sub_block_ops = []
sub_block = pdesc.block(parent_id)
sub_block_op_size = sub_block.op_size()
for i in range(sub_block_op_size):
sub_block_ops.append(sub_block.op(i))
sub_op_output = set()
sub_op_output.update(sub_op_dict[parent_id].output_arg_names())
ops_list.append((sub_block_ops, sub_block_op_size, sub_op_output))
return ops_list
ops_list.append((while_block_ops, while_block_op_size, while_op_output))
# Process rest while block ops def _get_cfgs(input_program):
for parent_id in while_sub_block_ids: ops_list = []
while_block_ops = [] pdesc = input_program.get_desc()
while_block = pdesc.block(parent_id) block_desc = pdesc.block(0)
while_block_op_size = while_block.op_size() op_size = block_desc.op_size()
for i in range(while_block_op_size): # Get global block ops
while_block_ops.append(while_block.op(i)) ops_list.append(
([block_desc.op(i) for i in range(op_size)], op_size, set()))
while_op_output = set() sub_block_pair = [("while", "while_grad"), ("parallel_do",
while_op_output.update(while_op_dict[parent_id].output_arg_names()) "parallel_do_grad")]
ops_list.append((while_block_ops, while_block_op_size, while_op_output)) ops_list.extend(_process_sub_block_pair(pdesc, sub_block_pair))
cfgs = [ cfgs = [
ControlFlowGraph(input_program, ops, forward_num, skip_opt) ControlFlowGraph(input_program, ops, forward_num, skip_opt)
...@@ -268,6 +278,6 @@ def get_cfgs(input_program): ...@@ -268,6 +278,6 @@ def get_cfgs(input_program):
def memory_optimize(input_program): def memory_optimize(input_program):
cfgs = get_cfgs(input_program) cfgs = _get_cfgs(input_program)
for cfg in cfgs: for cfg in cfgs:
cfg.memory_optimize() cfg.memory_optimize()
...@@ -24,15 +24,21 @@ import sys ...@@ -24,15 +24,21 @@ import sys
fluid.default_startup_program().random_seed = 111 fluid.default_startup_program().random_seed = 111
x = fluid.layers.data(name='x', shape=[13], dtype='float32') x = fluid.layers.data(name='x', shape=[13], dtype='float32')
y_predict = fluid.layers.fc(input=x, size=1, act=None)
y = fluid.layers.data(name='y', shape=[1], dtype='float32') y = fluid.layers.data(name='y', shape=[1], dtype='float32')
cost = fluid.layers.square_error_cost(input=y_predict, label=y) places = fluid.layers.get_places(device_count=2, device_type='CPU')
avg_cost = fluid.layers.mean(x=cost) pd = fluid.layers.ParallelDo(places)
with pd.do():
x_ = pd.read_input(x)
y_ = pd.read_input(y)
y_predict = fluid.layers.fc(input=x_, size=1, act=None)
cost = fluid.layers.square_error_cost(input=y_predict, label=y_)
avg_cost = fluid.layers.mean(x=cost)
pd.write_output(avg_cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.1) cost = pd()
avg_cost = fluid.layers.mean(x=cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.01)
sgd_optimizer.minimize(avg_cost) sgd_optimizer.minimize(avg_cost)
fluid.memory_optimize(fluid.default_main_program()) fluid.memory_optimize(fluid.default_main_program())
...@@ -65,6 +71,7 @@ for pass_id in range(PASS_NUM): ...@@ -65,6 +71,7 @@ for pass_id in range(PASS_NUM):
if avg_loss_value[0] < 10.0: if avg_loss_value[0] < 10.0:
exit(0) # if avg cost less than 10.0, we think our code is good. exit(0) # if avg cost less than 10.0, we think our code is good.
print avg_loss_value[0]
if math.isnan(float(avg_loss_value)): if math.isnan(float(avg_loss_value)):
sys.exit("got NaN loss, training failed.") sys.exit("got NaN loss, training failed.")
exit(1) exit(1)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册