未验证 提交 07f33da9 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Accelerate procedure of partitioning and generating dist graphs (#44224)

* avoid sync with cpp in partition op

* delay eval & predict mode

* bugfix for gradient merge pass
上级 daa6cb92
...@@ -85,6 +85,11 @@ class Engine: ...@@ -85,6 +85,11 @@ class Engine:
self._feed_vars = {} self._feed_vars = {}
self._fetch_vars = {} self._fetch_vars = {}
self._planners = {} self._planners = {}
self._mode_init_states = {
"train": False,
"eval": False,
"predict": False
}
self._dygraph_mode = False self._dygraph_mode = False
def prepare(self, def prepare(self,
...@@ -101,6 +106,7 @@ class Engine: ...@@ -101,6 +106,7 @@ class Engine:
" or `paddle.fluid.optimizer.Optimizer`." " or `paddle.fluid.optimizer.Optimizer`."
) )
self._optimizer = optimizer self._optimizer = optimizer
self._all_ranks = all_ranks
if loss and not isinstance(loss, if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss): paddle.nn.Layer) and not callable(loss):
...@@ -116,22 +122,23 @@ class Engine: ...@@ -116,22 +122,23 @@ class Engine:
metric.__class__.__name__) metric.__class__.__name__)
self._metrics = to_list(metrics) self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale self._gradient_scale = gradient_scale
self._planned_mode = None self._planned_mode = None
self._modes = ['train', 'eval', 'predict'] self._prepare_single_mode("train")
# Build program and do auto parallel process def _prepare_single_mode(self, mode):
for mode in self._modes: self._modes = [mode]
# Build forward program self._build(self._modes[0])
self._build(mode) # Do auto parallel process
for mode in self._modes: for mode in self._modes:
# Do the planning process # Do the planning process
self._plan(mode) self._plan(mode)
for mode in self._modes: for mode in self._modes:
# Do the parallel process # Do the parallel process
self._parallel(mode, all_ranks) self._parallel(mode, self._all_ranks)
# Init comm and startup program # Init comm and startup program
self._initialize(mode) self._initialize(mode)
self._mode_init_states[mode] = True
def _build(self, mode): def _build(self, mode):
...@@ -432,6 +439,12 @@ class Engine: ...@@ -432,6 +439,12 @@ class Engine:
return_numpy=True): return_numpy=True):
# TODO: callbacks # TODO: callbacks
# TODO: evaluate after training # TODO: evaluate after training
if not self._mode_init_states['train']:
raise Exception(
"train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
)
self.mode = 'train' self.mode = 'train'
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first." "train model is not ready, please call `engine.prepare()` first."
...@@ -467,6 +480,9 @@ class Engine: ...@@ -467,6 +480,9 @@ class Engine:
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'eval' self.mode = 'eval'
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first." "eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size) eval_dataloader = self._create_dataloader(eval_data, batch_size)
...@@ -509,6 +525,9 @@ class Engine: ...@@ -509,6 +525,9 @@ class Engine:
use_program_cache=False, use_program_cache=False,
return_numpy=True): return_numpy=True):
self.mode = 'predict' self.mode = 'predict'
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
assert self.mode in self._dist_main_progs, \ assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first." "predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size) test_dataloader = self._create_dataloader(test_data, batch_size)
......
...@@ -113,12 +113,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -113,12 +113,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
filter_vars.append(varname) filter_vars.append(varname)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars) dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
# sync result # sync result
group = new_process_group(world_process_group.ranks) group = new_process_group(world_process_group.ranks)
...@@ -155,7 +154,6 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl): ...@@ -155,7 +154,6 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
"out_dtype": inf_var.dtype, "out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
main_block._sync_with_cpp()
for op in [cast_op1, allreduce_op, cast_op2]: for op in [cast_op1, allreduce_op, cast_op2]:
new_op_dist_attr = OperatorDistributedAttribute() new_op_dist_attr = OperatorDistributedAttribute()
......
...@@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name) output_name)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -371,8 +371,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -371,8 +371,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# data parallel synchronization for primtive operators # data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled from paddle.incubate.autograd import prim_enabled
if prim_enabled(): if prim_enabled():
...@@ -431,8 +429,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -431,8 +429,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_attr.set_input_dims_mapping(param.name, dims_mapping) op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr) ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -461,7 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -461,7 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name) output_name)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op # Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
...@@ -470,8 +466,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -470,8 +466,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in backward_op.desc.output_names(): for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# check if need gradient allreduce # check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited, # if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output # we need insert gradient allreduce for the gradient of parameter in its output
...@@ -552,8 +546,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl): ...@@ -552,8 +546,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping) dims_mapping)
ctx.set_op_dist_attr_for_program(op, op_attr) ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp()
register_distributed_operator_impl( register_distributed_operator_impl(
"default", DistributedDefaultImpl0("replicate_parallel")) "default", DistributedDefaultImpl0("replicate_parallel"))
...@@ -312,7 +312,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -312,7 +312,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
startup_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -412,8 +411,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -412,8 +411,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh, set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx) out_grad_dist_attr, ctx)
main_block._sync_with_cpp() c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
c_embedding_grad_op_desc = main_block.desc.append_op()
c_embedding_grad_op_desc.set_type("c_embedding_grad") c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name]) c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name]) c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
...@@ -422,7 +420,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl): ...@@ -422,7 +420,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name]) c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx) c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward) c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
main_block._sync_with_cpp()
c_embedding_grad_op = main_block.ops[-1] c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad" assert c_embedding_grad_op.type == "c_embedding_grad"
......
...@@ -118,7 +118,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl): ...@@ -118,7 +118,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis] shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]
op._set_attr("shape", shape_list) op._set_attr("shape", shape_list)
main_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
......
...@@ -38,7 +38,7 @@ from .dist_default import DistributedDefaultImpl0 ...@@ -38,7 +38,7 @@ from .dist_default import DistributedDefaultImpl0
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.desc.append_op() dist_op_desc = block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -48,7 +48,6 @@ def copy_op_with_new_input_output(ctx, block, src_op, **kwargs): ...@@ -48,7 +48,6 @@ def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
assert input_name in kwargs assert input_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
block._sync_with_cpp()
return dist_op_desc return dist_op_desc
...@@ -387,8 +386,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs): ...@@ -387,8 +386,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block, matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
backward_op, **kwargs) backward_op, **kwargs)
main_block._sync_with_cpp()
# check if need gradient allreduce # check if need gradient allreduce
need_gradient_allreduce = False need_gradient_allreduce = False
...@@ -468,7 +465,6 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id): ...@@ -468,7 +465,6 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
startup_block._sync_with_cpp()
class DistributedMatmul(DistributedOperatorImplContainer): class DistributedMatmul(DistributedOperatorImplContainer):
......
...@@ -248,7 +248,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -248,7 +248,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
# rename input # rename input
kwargs['X'] = [allgather_out.name] kwargs['X'] = [allgather_out.name]
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
...@@ -260,8 +260,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -260,8 +260,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
allgather_out.name, allgather_out_dist_attr.dims_mapping) allgather_out.name, allgather_out_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr) ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr)
main_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
...@@ -305,7 +303,7 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -305,7 +303,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var) new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var)
ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr) ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr)
# replicate op in dist program with new kwargs # replicate op in dist program with new kwargs
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op # Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
...@@ -319,7 +317,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -319,7 +317,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
op_dist_attr.set_output_dims_mapping(new_X_grad.name, op_dist_attr.set_output_dims_mapping(new_X_grad.name,
new_X_var_dist_attr.dims_mapping) new_X_var_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr) ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr)
main_block._sync_with_cpp()
# 2. insert slice op # 2. insert slice op
process_mesh_shape = op_dist_attr.process_mesh.topology process_mesh_shape = op_dist_attr.process_mesh.topology
...@@ -359,7 +356,6 @@ class DistributedPNormImpl(DistributedOperatorImpl): ...@@ -359,7 +356,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
slice_op_dist_attr.set_output_dims_mapping(X_grad_var.name, slice_op_dist_attr.set_output_dims_mapping(X_grad_var.name,
X_grad_var_dims_mapping) X_grad_var_dims_mapping)
ctx.set_op_dist_attr_for_program(slice_op, slice_op_dist_attr) ctx.set_op_dist_attr_for_program(slice_op, slice_op_dist_attr)
main_block._sync_with_cpp()
register_distributed_operator_impl("p_norm", register_distributed_operator_impl("p_norm",
......
...@@ -109,14 +109,13 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl): ...@@ -109,14 +109,13 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
output_name) output_name)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc) dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names(): for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name]) dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names(): for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name]) dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# batch dimension synchronization # batch dimension synchronization
var_name = src_op.output_arg_names[0] var_name = src_op.output_arg_names[0]
......
...@@ -177,7 +177,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -177,7 +177,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -187,8 +187,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl): ...@@ -187,8 +187,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs) DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
...@@ -335,7 +333,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -335,7 +333,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -345,8 +343,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl): ...@@ -345,8 +343,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs) DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
...@@ -486,7 +482,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -486,7 +482,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis] idx] = shape_list[idx] // process_mesh_shape[axis]
# create op # create op
new_op_desc = main_block.desc.append_op() new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc) new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx) set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list) new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
...@@ -496,8 +492,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl): ...@@ -496,8 +492,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name]) new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list) new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod @staticmethod
def backward(ctx, *args, **kwargs): def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs) DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
......
...@@ -127,12 +127,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl): ...@@ -127,12 +127,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
filter_vars.append(varname) filter_vars.append(varname)
# replicate op in dist program # replicate op in dist program
dist_op_desc = main_block.desc.append_op() dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc) dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx) set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars) dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars) dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
register_distributed_operator_impl( register_distributed_operator_impl(
......
...@@ -13,6 +13,8 @@ ...@@ -13,6 +13,8 @@
# limitations under the License. # limitations under the License.
import copy import copy
import time
import logging
from collections import defaultdict from collections import defaultdict
import paddle import paddle
...@@ -20,6 +22,7 @@ from paddle.fluid import program_guard ...@@ -20,6 +22,7 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode from paddle.fluid.framework import _non_static_mode
from paddle.distributed.passes import new_pass from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder from .reshard import Resharder
from .partitioner import Partitioner from .partitioner import Partitioner
...@@ -41,6 +44,7 @@ class Parallelizer: ...@@ -41,6 +44,7 @@ class Parallelizer:
assert self._dist_context._is_initialized assert self._dist_context._is_initialized
self._pass_context = self._dist_context.pass_context self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy self._strategy = self._dist_context.strategy
self._logger = get_logger(logging.INFO)
def parallel_all(self): def parallel_all(self):
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
...@@ -61,38 +65,65 @@ class Parallelizer: ...@@ -61,38 +65,65 @@ class Parallelizer:
serial_startup_program, serial_startup_program,
serial_loss) serial_loss)
# Apply pre optimization passes # Apply pre optimization passes
time0 = time.time()
self._apply_pre_optimization(serial_main_program, self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss, serial_startup_program, serial_loss,
serial_optimizer, params_grads) serial_optimizer, params_grads)
self._logger.info(
"within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
# Do logical partition # Do logical partition
time0 = time.time()
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads) serial_main_program, serial_startup_program, params_grads)
self._logger.info(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
# Generate optimizer # Generate optimizer
time0 = time.time()
self._generate_optimizer(dist_main_prog, dist_startup_prog, self._generate_optimizer(dist_main_prog, dist_startup_prog,
serial_optimizer, dist_params_grads) serial_optimizer, dist_params_grads)
self._logger.info(
"within parallel optimizer time: {}, mode {}".format(
time.time() - time0, self._mode))
# Do reshard process # Do reshard process
time0 = time.time()
set_grad_var_shape(dist_main_prog, self._dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads) self._dist_context, dist_params_grads)
resharder.reshard() resharder.reshard()
self._logger.info(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Apply post optimization passes # Apply post optimization passes
time0 = time.time()
self._apply_post_optimization(dist_main_prog, dist_startup_prog, self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads) rank, dist_params_grads)
self._logger.info(
"within parallel apply_post_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
else: else:
# Apply pre optimization passes # Apply pre optimization passes
# self._apply_pre_optimization(serial_main_program, # self._apply_pre_optimization(serial_main_program,
# serial_startup_program, None, None, # serial_startup_program, None, None,
# None) # None)
# Do logical partition # Do logical partition
time0 = time.time()
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, []) serial_main_program, serial_startup_program, [])
# Do reshard process # Do reshard process
self._logger.info(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
time0 = time.time()
resharder = Resharder(dist_main_prog, dist_startup_prog, rank, resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1) self._dist_context, [], 1)
resharder.reshard() resharder.reshard()
self._logger.info(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Clone program for test # Clone program for test
if self._mode != 'train': if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True) dist_main_prog = dist_main_prog.clone(for_test=True)
......
...@@ -58,7 +58,7 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -58,7 +58,7 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
def _remove_op_role_var(param, grad): def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker op_maker = core.op_proto_and_checker_maker
op = grad.op op = grad.op
if op.has_attr(op_maker.kOpRoleVarAttrName()): if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName()) op._remove_attr(op_maker.kOpRoleVarAttrName())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册