未验证 提交 a08ee62a 编写于 作者: J JZ-LIANG 提交者: GitHub

Auto Parallel support conditional block (#39612)

* add subblock logic for context and partitioner

* partitioner support sub blocks

* revise typos

* fixed param init bug for while

* chmod 644

* add unitest

* mv forward parser

* update unitest

* update dist op ctx

* update dist op ctx

* fixed bug in dist op ctx

* fixed bug for recompute subblock
上级 ae8c811a
...@@ -55,6 +55,7 @@ class DistributedContext: ...@@ -55,6 +55,7 @@ class DistributedContext:
self._is_initialized_for_program = False self._is_initialized_for_program = False
self._dist_tensors_for_program = {} self._dist_tensors_for_program = {}
self._dist_ops_for_program = {} self._dist_ops_for_program = {}
self._block_state = BlockState()
# Graph related data members # Graph related data members
self._is_initialized_for_graph = False self._is_initialized_for_graph = False
self._serial_graph = None self._serial_graph = None
...@@ -102,6 +103,10 @@ class DistributedContext: ...@@ -102,6 +103,10 @@ class DistributedContext:
def dist_op_context(self): def dist_op_context(self):
return self._dist_op_context return self._dist_op_context
@property
def block_state(self):
return self._block_state
@property @property
def dist_main_programs(self): def dist_main_programs(self):
return self._dist_main_programs return self._dist_main_programs
...@@ -512,66 +517,83 @@ class DistributedOperatorContext: ...@@ -512,66 +517,83 @@ class DistributedOperatorContext:
def __init__(self): def __init__(self):
self._dst_main_program = None self._dst_main_program = None
self._main_block = None
self._dst_startup_program = None self._dst_startup_program = None
self._varname_mapping = None self._startup_block = None
self._rank_id = None
self._cur_src_op = None self._cur_src_op = None
self._cur_dist_attr = None self._cur_dist_attr = None
self.grad_op_id_to_op_id = {} self.grad_op_id_to_op_id = {}
self._work_block = None
self.already_init_sync_vars = set() self.already_init_sync_vars = set()
self.varname_mapping = None
self.rank_id = None
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
cls = self.__class__ cls = self.__class__
result = cls.__new__(cls) result = cls.__new__(cls)
memo[id(self)] = result memo[id(self)] = result
for k, v in self.__dict__.items(): for k, v in self.__dict__.items():
if k == "_dst_main_program" or k == "_dst_startup_program" or k == "_cur_src_op": if k in [
"_dst_main_program", "_dst_startup_program", "_cur_src_op",
"_work_block", "_main_block", "_startup_block"
]:
setattr(result, k, v) setattr(result, k, v)
else: else:
setattr(result, k, copy.deepcopy(v, memo)) setattr(result, k, copy.deepcopy(v, memo))
return result return result
def set_dst_main_program(self, prog): @property
self._dst_main_program = prog def dst_main_program(self):
def get_dst_main_program(self):
return self._dst_main_program return self._dst_main_program
def set_dst_startup_program(self, prog): @dst_main_program.setter
self._dst_startup_program = prog def dst_main_program(self, prog):
self._dst_main_program = prog
self._main_block = prog.blocks[0]
def get_dst_startup_program(self): @property
return self._dst_startup_program def main_block(self):
return self._main_block
def set_varname_mapping(self, mapping): @property
self._varname_mapping = mapping def dst_startup_program(self):
return self._dst_startup_program
def get_varname_mapping(self): @dst_startup_program.setter
return self._varname_mapping def dst_startup_program(self, prog):
self._dst_startup_program = prog
self._startup_block = prog.blocks[0]
def set_rank_id(self, rank_id): @property
self._rank_id = rank_id def startup_block(self):
return self._startup_block
def get_rank_id(self): @property
return self._rank_id def work_block(self):
assert self._work_block is not None
return self._work_block
def set_cur_src_op(self, cur_src_op): @work_block.setter
self._cur_src_op = cur_src_op def work_block(self, block):
assert block is not None
self._work_block = block
def get_cur_src_op(self): @property
def cur_src_op(self):
assert self._cur_src_op is not None
return self._cur_src_op return self._cur_src_op
def prepare_context(self, src_op): def prepare_context(self, src_op):
self.set_cur_src_op(src_op) self._cur_src_op = src_op
# build input varname mapping # build input varname mapping
kinputs = {} kinputs = {}
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
varnames = [] varnames = []
for varname in src_op.desc.input(input_name): for varname in src_op.desc.input(input_name):
assert varname in self._varname_mapping assert varname in self.varname_mapping
varnames.append(self._varname_mapping[varname]) varnames.append(self.varname_mapping[varname])
kinputs[input_name] = varnames kinputs[input_name] = varnames
# build output varname mapping # build output varname mapping
...@@ -579,8 +601,52 @@ class DistributedOperatorContext: ...@@ -579,8 +601,52 @@ class DistributedOperatorContext:
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
varnames = [] varnames = []
for varname in src_op.desc.output(output_name): for varname in src_op.desc.output(output_name):
assert varname in self._varname_mapping assert varname in self.varname_mapping
varnames.append(self._varname_mapping[varname]) varnames.append(self.varname_mapping[varname])
koutputs[output_name] = varnames koutputs[output_name] = varnames
return kinputs, koutputs return kinputs, koutputs
class BlockState(object):
def __init__(self):
self.nblock = 0
self.forward_indices = []
self.backward_indices = []
self.backward_to_forward_index_map = {}
def parse_forward_blocks(self, program):
while program.current_block_idx != 0:
program._rollback()
assert program.current_block_idx == 0
for idx, block in enumerate(program.blocks):
assert idx == block.idx, "index doesn't match"
assert block.forward_block_idx == -1, "forward_block_idx of forward block [{}] is not [{}]".format(
idx, block.forward_block_idx)
self.forward_indices.append(idx)
self.nblock += 1
assert self.nblock >= 1
def parse_backward_blocks(self, program):
assert 0 in self.forward_indices, "forward block idx are{}".format(
self.forward_indices)
self.backward_to_forward_index_map[0] = 0
for idx, block in enumerate(program.blocks):
if idx < len(self.forward_indices):
continue
assert idx == block.idx, "index doesn't match"
assert block.forward_block_idx in self.forward_indices
self.backward_indices.append(idx)
self.backward_to_forward_index_map[idx] = block.forward_block_idx
self.nblock += 1
assert self.nblock == len(program.blocks)
...@@ -76,9 +76,9 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -76,9 +76,9 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.main_block
backward_op = dist_op_context.get_cur_src_op() backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
......
...@@ -32,6 +32,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, ...@@ -32,6 +32,8 @@ from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY,
from ..process_group import new_process_group from ..process_group import new_process_group
from ..utils import _get_comm_group, _get_corresponding_rank from ..utils import _get_comm_group, _get_corresponding_rank
__op_not_need_param_init__ = ["while", "cond"]
class DistributedDefault(DistributedOperatorImplContainer): class DistributedDefault(DistributedOperatorImplContainer):
def __init__(self, op_type): def __init__(self, op_type):
...@@ -195,10 +197,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -195,10 +197,10 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
def forward(ctx, *args, **kwargs): def forward(ctx, *args, **kwargs):
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -227,6 +229,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -227,6 +229,9 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
main_block._sync_with_cpp() main_block._sync_with_cpp()
# param initialization sync # param initialization sync
if src_op.type in __op_not_need_param_init__:
return
for varname in dist_op_desc.input_arg_names(): for varname in dist_op_desc.input_arg_names():
if startup_block.has_var(varname) and startup_block.var( if startup_block.has_var(varname) and startup_block.var(
varname varname
...@@ -278,12 +283,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -278,12 +283,12 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
backward_op = dist_op_context.get_cur_src_op() backward_op = dist_op_context.cur_src_op
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
# check validation of inputs / outputs # check validation of inputs / outputs
for input_name in backward_op.desc.input_names(): for input_name in backward_op.desc.input_names():
......
...@@ -128,10 +128,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -128,10 +128,10 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -311,9 +311,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -311,9 +311,9 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
backward_op = dist_op_context.get_cur_src_op() backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
......
...@@ -223,9 +223,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -223,9 +223,9 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
# by now the backward function only insert the gradient allreduce for dist op itself # by now the backward function only insert the gradient allreduce for dist op itself
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
backward_op = dist_op_context.get_cur_src_op() backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
...@@ -257,7 +257,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -257,7 +257,7 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
kwargs['Y@GRAD']) kwargs['Y@GRAD'])
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Y_var = main_block.var(kwargs['Y'][0]) Y_var = main_block._var_recursive(kwargs['Y'][0])
Out_grad = main_block.var(kwargs['Out@GRAD'][0]) Out_grad = main_block.var(kwargs['Out@GRAD'][0])
Y_grad = main_block.var(kwargs['Y@GRAD'][0]) Y_grad = main_block.var(kwargs['Y@GRAD'][0])
...@@ -433,7 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -433,7 +433,8 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
assert Weight_var.name not in dist_op_context.already_init_sync_vars assert Weight_var.name not in dist_op_context.already_init_sync_vars, "{} is in {}.".format(
Weight_var.name, dist_op_context.already_init_sync_vars)
assert startup_block.has_var(Weight_var.name) assert startup_block.has_var(Weight_var.name)
dist_op_context.already_init_sync_vars.add(Weight_var.name) dist_op_context.already_init_sync_vars.add(Weight_var.name)
param = startup_block.var(Weight_var.name) param = startup_block.var(Weight_var.name)
...@@ -528,10 +529,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl): ...@@ -528,10 +529,10 @@ class DistributedMatmulImpl0(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -753,10 +754,10 @@ class DistributedMatmulImpl1(DistributedOperatorImpl): ...@@ -753,10 +754,10 @@ class DistributedMatmulImpl1(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -1042,10 +1043,10 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1042,10 +1043,10 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -1071,7 +1072,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl): ...@@ -1071,7 +1072,7 @@ class DistributedMatmulV2Impl0(DistributedOperatorImpl):
output_name) output_name)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
# TODO infer logic comm presentation # TODO infer logic comm presentation
...@@ -1261,10 +1262,10 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1261,10 +1262,10 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
startup_block = dist_op_context.get_dst_startup_program().global_block() startup_block = dist_op_context.startup_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -1290,7 +1291,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl): ...@@ -1290,7 +1291,7 @@ class DistributedMatmulV2Impl1(DistributedOperatorImpl):
output_name) output_name)
X_var = main_block.var(kwargs['X'][0]) X_var = main_block.var(kwargs['X'][0])
Weight_var = main_block.var(kwargs['Y'][0]) Weight_var = main_block._var_recursive(kwargs['Y'][0])
Out_var = main_block.var(kwargs['Out'][0]) Out_var = main_block.var(kwargs['Out'][0])
# TODO infer logic comm presentation # TODO infer logic comm presentation
......
...@@ -130,9 +130,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -130,9 +130,9 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
...@@ -287,9 +287,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -287,9 +287,9 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
""" """
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.work_block
src_op = dist_op_context.get_cur_src_op() src_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
op_dist_attr = ctx.get_op_dist_attr_for_program(src_op) op_dist_attr = ctx.get_op_dist_attr_for_program(src_op)
assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert op_dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(src_op)) str(src_op))
......
...@@ -65,9 +65,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -65,9 +65,9 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
# the backward function only filte the gradient with current rank id # the backward function only filte the gradient with current rank id
dist_op_context = ctx.dist_op_context dist_op_context = ctx.dist_op_context
main_block = dist_op_context.get_dst_main_program().global_block() main_block = dist_op_context.main_block
backward_op = dist_op_context.get_cur_src_op() backward_op = dist_op_context.cur_src_op
rank_id = dist_op_context.get_rank_id() rank_id = dist_op_context.rank_id
dist_attr = ctx.get_op_dist_attr_for_program(backward_op) dist_attr = ctx.get_op_dist_attr_for_program(backward_op)
assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format( assert dist_attr is not None, "backward op [{}] don't have dist attribute !".format(
str(backward_op)) str(backward_op))
......
...@@ -132,7 +132,7 @@ class AutoParallelizer: ...@@ -132,7 +132,7 @@ class AutoParallelizer:
distop_context=self._dist_context.dist_op_context) distop_context=self._dist_context.dist_op_context)
self._completer = Completer(self._dist_context) self._completer = Completer(self._dist_context)
self._completer.complete_backward_annotation(main_program) self._completer.complete_backward_annotation(main_program)
self._dist_context.block_state.parse_backward_blocks(main_program)
return params_grads return params_grads
def _apply_optimize(self, main_program, startup_program, params_grads): def _apply_optimize(self, main_program, startup_program, params_grads):
...@@ -174,6 +174,7 @@ class AutoParallelizer: ...@@ -174,6 +174,7 @@ class AutoParallelizer:
serial_main_program = self._main_program.clone() serial_main_program = self._main_program.clone()
serial_startup_program = self._startup_program.clone() serial_startup_program = self._startup_program.clone()
serial_loss = serial_main_program.global_block().var(self._loss.name) serial_loss = serial_main_program.global_block().var(self._loss.name)
# generating serial # generating serial
if dist_context is None: if dist_context is None:
# Annotation completion # Annotation completion
...@@ -186,6 +187,9 @@ class AutoParallelizer: ...@@ -186,6 +187,9 @@ class AutoParallelizer:
completed_main_program = serial_main_program completed_main_program = serial_main_program
self._dist_context = copy.deepcopy(dist_context) self._dist_context = copy.deepcopy(dist_context)
# parse forward sub block
self._dist_context.block_state.parse_forward_blocks(serial_main_program)
# serial backward pass # serial backward pass
params_grads = self._generate_backward( params_grads = self._generate_backward(
completed_main_program, serial_startup_program, serial_loss, completed_main_program, serial_startup_program, serial_loss,
......
...@@ -29,6 +29,9 @@ from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op ...@@ -29,6 +29,9 @@ from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue_0"] __varname_not_in_block__ = ["lod_tensor_blocking_queue_0"]
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
]
class Partitioner(object): class Partitioner(object):
...@@ -75,8 +78,8 @@ class Partitioner(object): ...@@ -75,8 +78,8 @@ class Partitioner(object):
# init distop helper # init distop helper
dist_op_context = self._dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
dist_op_context.set_varname_mapping(self._serial2dist_varname_mapping) dist_op_context.varname_mapping = self._serial2dist_varname_mapping
dist_op_context.set_rank_id(self._rank_id) dist_op_context.rank_id = self._rank_id
# partition startup program # partition startup program
if serial_startup_program == None: if serial_startup_program == None:
...@@ -84,7 +87,7 @@ class Partitioner(object): ...@@ -84,7 +87,7 @@ class Partitioner(object):
else: else:
partitioned_startup_prog = self.partition_startup_program( partitioned_startup_prog = self.partition_startup_program(
serial_main_program, serial_startup_program) serial_main_program, serial_startup_program)
dist_op_context.set_dst_startup_program(partitioned_startup_prog) dist_op_context.dst_startup_program = partitioned_startup_prog
# partition main program # partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program( partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
...@@ -157,15 +160,45 @@ class Partitioner(object): ...@@ -157,15 +160,45 @@ class Partitioner(object):
2. replace local op with corresponding dist op 2. replace local op with corresponding dist op
""" """
dist_op_context = self._dist_context.dist_op_context
partitioned_main_prog = fluid.Program() partitioned_main_prog = fluid.Program()
dist_op_context.set_dst_main_program(partitioned_main_prog) dist_op_context = self._dist_context.dist_op_context
target_block = partitioned_main_prog.global_block() dist_op_context.dst_main_program = partitioned_main_prog
ref_block = serial_main_program.global_block()
serial_ops = serial_main_program.global_block().ops for idx in range(self._dist_context.block_state.nblock):
ref_block = serial_main_program.blocks[idx]
if idx == 0:
target_block = partitioned_main_prog.blocks[0]
else:
target_block = partitioned_main_prog._create_block(
parent_idx=ref_block.parent_idx)
assert ref_block.idx == target_block.idx
target_block._set_forward_block_idx(ref_block.forward_block_idx)
dist_op_context.work_block = target_block
self.partition_block(ref_block, target_block)
partitioned_main_prog.current_block_idx = 0
partitioned_params_and_grads = []
for p, g in params_and_grads:
assert p.name in self._serial2dist_varname_mapping
dist_p = self._get_dist_var_by_serial_var(p, partitioned_main_prog)
if g is None:
dist_g = None
else:
assert g.name in self._serial2dist_varname_mapping
dist_g = self._get_dist_var_by_serial_var(g,
partitioned_main_prog)
partitioned_params_and_grads.append((dist_p, dist_g))
return partitioned_main_prog, partitioned_params_and_grads
def partition_block(self, ref_block, target_block):
dist_op_context = self._dist_context.dist_op_context
serial_ops = ref_block.ops
# init mapping # init mapping
first_backward_op_idx = -1
forward_op_id2forward_op = {} forward_op_id2forward_op = {}
for idx in range(len(serial_ops)): for idx in range(len(serial_ops)):
if is_forward_op(serial_ops[idx]): if is_forward_op(serial_ops[idx]):
...@@ -218,23 +251,6 @@ class Partitioner(object): ...@@ -218,23 +251,6 @@ class Partitioner(object):
"partitioner only support forward op and backward op, but got {}". "partitioner only support forward op and backward op, but got {}".
format(str(op))) format(str(op)))
partitioned_params_and_grads = []
for p, g in params_and_grads:
assert p.name in self._serial2dist_varname_mapping
dist_p_name = self._serial2dist_varname_mapping[p.name]
assert target_block.has_var(dist_p_name)
dist_p = target_block.var(dist_p_name)
if g is None:
dist_g = None
else:
assert g.name in self._serial2dist_varname_mapping
dist_g_name = self._serial2dist_varname_mapping[g.name]
assert target_block.has_var(dist_g_name)
dist_g = target_block.var(dist_g_name)
partitioned_params_and_grads.append((dist_p, dist_g))
return partitioned_main_prog, partitioned_params_and_grads
def _is_valid_annotated_program(self, program): def _is_valid_annotated_program(self, program):
# TODO (ZJ-LIANG) should check all block # TODO (ZJ-LIANG) should check all block
...@@ -245,7 +261,7 @@ class Partitioner(object): ...@@ -245,7 +261,7 @@ class Partitioner(object):
] ]
var_dist_attrs = [ var_dist_attrs = [
self._dist_context.get_tensor_dist_attr_for_program(var) self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_ for var in vars_ if (var.type not in __not_shape_var_type__)
] ]
all_ops_annotated = all(dist_attr is not None all_ops_annotated = all(dist_attr is not None
...@@ -255,6 +271,14 @@ class Partitioner(object): ...@@ -255,6 +271,14 @@ class Partitioner(object):
return all_ops_annotated and all_vars_annotated return all_ops_annotated and all_vars_annotated
def _get_dist_var_by_serial_var(self, serial_var, partitioned_main_prog):
block_idx = serial_var.block.idx
target_block = partitioned_main_prog.blocks[block_idx]
dist_var_name = self._serial2dist_varname_mapping[serial_var.name]
assert target_block.has_var(dist_var_name)
return target_block.var(dist_var_name)
def _get_dist_shape(var, dist_attr): def _get_dist_shape(var, dist_attr):
...@@ -341,7 +365,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -341,7 +365,7 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
""" """
src_var = src_block.var(src_varname) src_var = src_block.var(src_varname)
if src_var.type == core.VarDesc.VarType.READER: if src_var.type in __not_shape_var_type__:
dst_block.create_var( dst_block.create_var(
type=src_var.type, type=src_var.type,
name=dst_varname, name=dst_varname,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import paddle.nn as nn
import paddle.utils as utils
import paddle.fluid as fluid
import paddle.static as static
import paddle.nn.functional as F
import paddle.distributed.auto_parallel as auto
from paddle.distributed import fleet
from paddle.distributed.auto_parallel.partitioner import Partitioner
from paddle.distributed.auto_parallel.utils import make_data_unshard
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from paddle.distributed.auto_parallel.dist_context import DistributedContext, get_default_distributed_context
from paddle.distributed.auto_parallel.operators import find_best_compatible_distributed_operator_impl
paddle.enable_static()
batch_size = 4
epoch_num = 10
hidden_size = 1024
sequence_len = 512
_g_process_mesh = auto.ProcessMesh([0, 1])
def get_random_inputs_and_labels(input_shape, label_shape):
input = np.random.random(size=input_shape).astype('float32')
label = np.random.random(size=label_shape).astype('float32')
return input, label
def batch_generator_creator():
def __reader__():
for _ in range(batch_size):
batch_input, batch_label = get_random_inputs_and_labels(
[batch_size, sequence_len, hidden_size],
[batch_size, sequence_len, 1])
yield batch_input, batch_label
return __reader__
class MLPLayer(nn.Layer):
def __init__(self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02):
super(MLPLayer, self).__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
param_initializer = nn.initializer.Normal(
mean=0.0, std=initializer_range)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.linear0 = nn.Linear(
d_model,
dim_feedforward,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
self.linear1 = nn.Linear(
dim_feedforward,
d_model,
weight_attr=paddle.ParamAttr(initializer=param_initializer),
bias_attr=None)
def forward(self, input):
auto.shard_tensor(
self.norm.weight,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
auto.shard_tensor(
self.norm.bias,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
auto.shard_tensor(
self.linear0.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, 0]
})
auto.shard_tensor(
self.linear0.bias,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [0]})
auto.shard_tensor(
self.linear1.weight,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [0, -1]
})
auto.shard_tensor(
self.linear1.bias,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
out = self.norm(input)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
out = self.linear0(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
out = F.gelu(out, approximate=True)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, 0]
})
out = self.linear1(out)
auto.shard_tensor(
out,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
return out
def get_program():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
# fleet.init(is_collective=True, strategy=dist_strategy)
train_program = static.Program()
start_program = static.Program()
with fluid.program_guard(train_program, start_program):
# 循环计数器
i = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
auto.shard_tensor(
i,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
# 循环次数
loop_len = fluid.layers.fill_constant(
shape=[1], dtype='int64', value=epoch_num)
auto.shard_tensor(
loop_len,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
# input
input = static.data(
name="input",
shape=[batch_size, sequence_len, hidden_size],
dtype='float32')
label = static.data(
name="label", shape=[batch_size, sequence_len, 1], dtype='float32')
data_holder = [input, label]
# dataloader
dataloader = paddle.io.DataLoader.from_generator(
feed_list=data_holder, capacity=4 * batch_size, iterable=False)
dataloader.set_batch_generator(
batch_generator_creator(), places=paddle.static.cuda_places())
# data dist_attr
auto.shard_tensor(
input,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
auto.shard_tensor(
label,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
mlp_start = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_start(input)
input_array = fluid.layers.array_write(pred, i)
auto.shard_tensor(
input_array,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
cond = fluid.layers.less_than(x=i, y=loop_len)
auto.shard_tensor(
cond,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
while_op = fluid.layers.While(cond=cond)
with while_op.block():
pre_input = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(
pre_input,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
mlp_while = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
cur_pred = mlp_while(pre_input)
# 更新循环条件
i = fluid.layers.increment(x=i, value=1, in_place=True)
fluid.layers.array_write(cur_pred, array=input_array, i=i)
fluid.layers.less_than(x=i, y=loop_len, cond=cond)
end_pred = fluid.layers.array_read(array=input_array, i=i)
auto.shard_tensor(
end_pred,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
mlp_end = MLPLayer(
hidden_size=hidden_size,
intermediate_size=4 * hidden_size,
dropout_ratio=0.1,
initializer_range=0.02)
pred = mlp_end(end_pred)
error_cost = paddle.nn.functional.square_error_cost(pred, label)
auto.shard_tensor(
error_cost,
dist_attr={
"process_mesh": _g_process_mesh,
"dims_mapping": [-1, -1, -1]
})
loss = paddle.mean(error_cost)
auto.shard_tensor(
loss,
dist_attr={"process_mesh": _g_process_mesh,
"dims_mapping": [-1]})
return train_program, start_program, dataloader, i, loss
def completion(train_program, start_program, dist_context):
blocks = train_program.blocks
# completion tensors
for block in blocks:
for op in block.ops:
if op.type == "layer_norm":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "elementwise_sub":
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "matmul_v2":
col = False
for in_name in op.input_arg_names:
if ".w_" not in in_name:
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
assert tensor_dist_attr is not None
if tensor_dist_attr.dims_mapping == [-1, 0]:
col = True
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
tensor_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
if tensor_dist_attr:
continue
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
if col:
tensor_dist_attr.dims_mapping = [-1, -1, 0]
else:
tensor_dist_attr.dims_mapping = [-1, -1, -1]
dist_context.set_tensor_dist_attr_for_program(
out_var, tensor_dist_attr)
elif op.type == "while":
out_name = op.desc.output("StepScopes")[0]
out_var = block.vars[out_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = _g_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(out_var,
tensor_dist_attr)
# completion ops
for block in blocks:
for op in block.ops:
op_dist_attr = OperatorDistributedAttribute()
op_dist_attr.process_mesh = _g_process_mesh
if op.type == "create_by_read" or op.type == "create_double_buffer_reader":
for in_name in op.input_arg_names:
op_dist_attr.set_input_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
op_dist_attr.set_output_dims_mapping(out_name, [])
elif op.type == "read":
for in_name in op.input_arg_names:
op_dist_attr.set_output_dims_mapping(in_name, [])
for out_name in op.output_arg_names:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
elif op.type == "while":
for in_name in op.input_arg_names:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name == op.desc.output("StepScopes")[0]:
op_dist_attr.set_output_dims_mapping(out_name, [])
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name,
out_dist_attr)
else:
for in_name in op.input_arg_names:
if in_name == "lod_tensor_blocking_queue_0":
continue
if in_name not in block.vars:
in_var = blocks[0].vars[in_name]
else:
in_var = block.vars[in_name]
in_dist_attr = dist_context.get_tensor_dist_attr_for_program(
in_var)
op_dist_attr.set_input_dist_attr(in_name, in_dist_attr)
for out_name in op.output_arg_names:
if out_name not in block.vars:
out_var = blocks[0].vars[out_name]
else:
out_var = block.vars[out_name]
out_dist_attr = dist_context.get_tensor_dist_attr_for_program(
out_var)
op_dist_attr.set_output_dist_attr(out_name, out_dist_attr)
if op.type == "matmul_v2":
op_dist_attr.impl_type = "matmul_v2"
for in_name in op_dist_attr.inputs_dist_attrs.keys():
in_dist_attr = op_dist_attr.inputs_dist_attrs[in_name]
if ".w_" in in_name and in_dist_attr.dims_mapping[-1] == 0:
op_dist_attr.impl_idx = 0
else:
op_dist_attr.impl_idx = 1
else:
op_dist_attr.impl_type = "default"
op_dist_attr.impl_idx = 0
dist_context.set_op_dist_attr_for_program(op, op_dist_attr)
make_data_unshard(train_program, start_program, dist_context)
return train_program, start_program
def partition(train_program, start_program, dist_context):
# optimizer = paddle.optimizer.SGD(learning_rate=0.00001)
rank = paddle.distributed.get_rank()
partitioner = Partitioner(dist_context, rank)
dist_main_prog, dist_startup_prog, _ = partitioner.partition(
train_program, start_program, [])
return dist_main_prog, dist_startup_prog
class TestMLP(unittest.TestCase):
def test_partitioner(self):
train_program, start_program, dataloader, i, loss = get_program()
dist_context = get_default_distributed_context()
train_program, start_program = completion(train_program, start_program,
dist_context)
dist_context.block_state.parse_forward_blocks(train_program)
dist_main_prog, dist_startup_prog = partition(
train_program, start_program, dist_context)
global_block_ops = dist_main_prog.blocks[0].ops
global_block_ops = [op.type for op in global_block_ops]
sub_block_ops = dist_main_prog.blocks[1].ops
sub_block_ops = [op.type for op in sub_block_ops]
self.assertTrue("c_allreduce_sum" in global_block_ops)
self.assertTrue("c_allreduce_sum" in sub_block_ops)
if __name__ == "__main__":
unittest.main()
...@@ -158,6 +158,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -158,6 +158,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
......
...@@ -47,9 +47,7 @@ def get_dist_prog(train_program, ...@@ -47,9 +47,7 @@ def get_dist_prog(train_program,
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program train_program
) if complete_train_program is None else complete_train_program ) if complete_train_program is None else complete_train_program
dist_context.block_state.parse_forward_blocks(complete_train_program)
# parallelizer._apply_serial_forward_pass(complete_train_program,
# startup_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
...@@ -95,9 +93,9 @@ class TestDistributedTensor(unittest.TestCase): ...@@ -95,9 +93,9 @@ class TestDistributedTensor(unittest.TestCase):
rank_id = 1 rank_id = 1
train_program = paddle.static.Program() train_program = paddle.static.Program()
startup_program = paddle.static.Program() startup_program = paddle.static.Program()
dist_main_prog, dist_startup_prog, _ = get_dist_prog( dist_context = DistributedContext()
train_program, startup_program, dist_context, rank_id, dist_main_prog, dist_startup_prog, complete_train_program = get_dist_prog(
complete_train_program) train_program, startup_program, dist_context, rank_id, None)
dist_context.dist_main_programs[rank_id] = dist_main_prog dist_context.dist_main_programs[rank_id] = dist_main_prog
dist_context.dist_startup_programs[rank_id] = dist_startup_prog dist_context.dist_startup_programs[rank_id] = dist_startup_prog
name = "layer_norm_1.tmp_2" name = "layer_norm_1.tmp_2"
......
...@@ -486,7 +486,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -486,7 +486,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
startup_program, startup_program,
......
...@@ -53,6 +53,7 @@ def get_programs(annotated_func): ...@@ -53,6 +53,7 @@ def get_programs(annotated_func):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
rank_id = 3 rank_id = 3
dist_strategy = fleet.DistributedStrategy() dist_strategy = fleet.DistributedStrategy()
......
...@@ -885,6 +885,7 @@ class TestGPTPartitioner(unittest.TestCase): ...@@ -885,6 +885,7 @@ class TestGPTPartitioner(unittest.TestCase):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
# serial backward pass # serial backward pass
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
......
...@@ -160,7 +160,7 @@ def get_dist_prog(train_program, ...@@ -160,7 +160,7 @@ def get_dist_prog(train_program,
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
if change_process_mesh: if change_process_mesh:
global PP_MESH_1 global PP_MESH_1
dist_context.get_tensor_dist_attr_for_program( dist_context.get_tensor_dist_attr_for_program(
......
...@@ -120,7 +120,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -120,7 +120,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
startup_program, startup_program,
......
...@@ -136,7 +136,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id): ...@@ -136,7 +136,7 @@ def get_dist_prog(train_program, startup_program, dist_context, rank_id):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
params_grads = parallelizer._generate_backward( params_grads = parallelizer._generate_backward(
complete_train_program, complete_train_program,
startup_program, startup_program,
...@@ -269,6 +269,7 @@ class TestMLPReshard(unittest.TestCase): ...@@ -269,6 +269,7 @@ class TestMLPReshard(unittest.TestCase):
completer = Completer(dist_context) completer = Completer(dist_context)
complete_train_program = completer.complete_forward_annotation( complete_train_program = completer.complete_forward_annotation(
train_program) train_program)
dist_context.block_state.parse_forward_blocks(complete_train_program)
partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition( partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads = partitioner.partition(
complete_train_program, startup_program, []) complete_train_program, startup_program, [])
reshard(partitioned_main_prog, partitioned_startup_prog, rank_id, reshard(partitioned_main_prog, partitioned_startup_prog, rank_id,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册