未验证 提交 07f33da9 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Accelerate procedure of partitioning and generating dist graphs (#44224)

* avoid sync with cpp in partition op

* delay eval & predict mode

* bugfix for gradient merge pass
上级 daa6cb92
......@@ -85,6 +85,11 @@ class Engine:
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
self._mode_init_states = {
"train": False,
"eval": False,
"predict": False
}
self._dygraph_mode = False
def prepare(self,
......@@ -101,6 +106,7 @@ class Engine:
" or `paddle.fluid.optimizer.Optimizer`."
)
self._optimizer = optimizer
self._all_ranks = all_ranks
if loss and not isinstance(loss,
paddle.nn.Layer) and not callable(loss):
......@@ -116,22 +122,23 @@ class Engine:
metric.__class__.__name__)
self._metrics = to_list(metrics)
self._gradient_scale = gradient_scale
self._planned_mode = None
self._modes = ['train', 'eval', 'predict']
self._prepare_single_mode("train")
# Build program and do auto parallel process
for mode in self._modes:
# Build forward program
self._build(mode)
def _prepare_single_mode(self, mode):
self._modes = [mode]
self._build(self._modes[0])
# Do auto parallel process
for mode in self._modes:
# Do the planning process
self._plan(mode)
for mode in self._modes:
# Do the parallel process
self._parallel(mode, all_ranks)
self._parallel(mode, self._all_ranks)
# Init comm and startup program
self._initialize(mode)
self._mode_init_states[mode] = True
def _build(self, mode):
......@@ -432,6 +439,12 @@ class Engine:
return_numpy=True):
# TODO: callbacks
# TODO: evaluate after training
if not self._mode_init_states['train']:
raise Exception(
"train program is not initialized yet, please call engine.prepare() before calling fit() funtion."
)
self.mode = 'train'
assert self.mode in self._dist_main_progs, \
"train model is not ready, please call `engine.prepare()` first."
......@@ -467,6 +480,9 @@ class Engine:
use_program_cache=False,
return_numpy=True):
self.mode = 'eval'
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
assert self.mode in self._dist_main_progs, \
"eval model is not ready, please call `engine.prepare()` first."
eval_dataloader = self._create_dataloader(eval_data, batch_size)
......@@ -509,6 +525,9 @@ class Engine:
use_program_cache=False,
return_numpy=True):
self.mode = 'predict'
if not self._mode_init_states[self.mode]:
self._prepare_single_mode(self.mode)
assert self.mode in self._dist_main_progs, \
"predict model is not ready, please call `engine.prepare()` first."
test_dataloader = self._create_dataloader(test_data, batch_size)
......
......@@ -113,12 +113,11 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
filter_vars.append(varname)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
# sync result
group = new_process_group(world_process_group.ranks)
......@@ -155,7 +154,6 @@ class DistributedCheckFiniteAndUnscaleImpl(DistributedOperatorImpl):
"out_dtype": inf_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
main_block._sync_with_cpp()
for op in [cast_op1, allreduce_op, cast_op2]:
new_op_dist_attr = OperatorDistributedAttribute()
......
......@@ -363,7 +363,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
......@@ -371,8 +371,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# data parallel synchronization for primtive operators
from paddle.incubate.autograd import prim_enabled
if prim_enabled():
......@@ -431,8 +429,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
op_attr.set_input_dims_mapping(param.name, dims_mapping)
ctx.set_op_dist_attr_for_program(new_op, op_attr)
startup_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -461,7 +457,7 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
......@@ -470,8 +466,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
for output_name in backward_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# check if need gradient allreduce
# if there is a non-gradient & non-parameter input and its batch dimension is splited,
# we need insert gradient allreduce for the gradient of parameter in its output
......@@ -552,8 +546,6 @@ class DistributedDefaultImpl0(DistributedOperatorImpl):
dims_mapping)
ctx.set_op_dist_attr_for_program(op, op_attr)
main_block._sync_with_cpp()
register_distributed_operator_impl(
"default", DistributedDefaultImpl0("replicate_parallel"))
......@@ -312,7 +312,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -412,8 +411,7 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
set_comm_op_dist_attr_for_program(c_identity_op, dist_attr.process_mesh,
out_grad_dist_attr, ctx)
main_block._sync_with_cpp()
c_embedding_grad_op_desc = main_block.desc.append_op()
c_embedding_grad_op_desc = main_block.append_op(type='nop').desc
c_embedding_grad_op_desc.set_type("c_embedding_grad")
c_embedding_grad_op_desc.set_input('Ids', [Ids_var.name])
c_embedding_grad_op_desc.set_input('W', [Weight_var.name])
......@@ -422,7 +420,6 @@ class DistributedEmbeddingImpl(DistributedOperatorImpl):
c_embedding_grad_op_desc.set_output('W@GRAD', [Weight_grad.name])
c_embedding_grad_op_desc._set_attr('start_index', relative_idx)
c_embedding_grad_op_desc._set_attr(OP_ROLE_KEY, OpRole.Backward)
main_block._sync_with_cpp()
c_embedding_grad_op = main_block.ops[-1]
assert c_embedding_grad_op.type == "c_embedding_grad"
......
......@@ -118,7 +118,6 @@ class DistributedFillConstantBatchSizeLikeImpl0(DistributedOperatorImpl):
shape_list[idx] = shape_list[idx] // process_mesh_shape[axis]
op._set_attr("shape", shape_list)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
......
......@@ -38,7 +38,7 @@ from .dist_default import DistributedDefaultImpl0
def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
dist_op_desc = block.desc.append_op()
dist_op_desc = block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
......@@ -48,7 +48,6 @@ def copy_op_with_new_input_output(ctx, block, src_op, **kwargs):
assert input_name in kwargs
dist_op_desc.set_output(output_name, kwargs[output_name])
block._sync_with_cpp()
return dist_op_desc
......@@ -387,8 +386,6 @@ def _right_operand_parameter_matmul_backward(ctx, *args, **kwargs):
matmul_op_desc = copy_op_with_new_input_output(ctx, main_block,
backward_op, **kwargs)
main_block._sync_with_cpp()
# check if need gradient allreduce
need_gradient_allreduce = False
......@@ -468,7 +465,6 @@ def _init_param_sync(Weight_var, dist_op_context, startup_block, ctx, rank_id):
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
startup_block._sync_with_cpp()
class DistributedMatmul(DistributedOperatorImplContainer):
......
......@@ -248,7 +248,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
# rename input
kwargs['X'] = [allgather_out.name]
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
......@@ -260,8 +260,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
allgather_out.name, allgather_out_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(pnorm_op, op_dist_attr)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
......@@ -305,7 +303,7 @@ class DistributedPNormImpl(DistributedOperatorImpl):
new_X_var_dist_attr = ctx.get_tensor_dist_attr_for_program(new_X_var)
ctx.set_tensor_dist_attr_for_program(new_X_grad, new_X_var_dist_attr)
# replicate op in dist program with new kwargs
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
# Refer to the related dist op
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
......@@ -319,7 +317,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
op_dist_attr.set_output_dims_mapping(new_X_grad.name,
new_X_var_dist_attr.dims_mapping)
ctx.set_op_dist_attr_for_program(p_norm_grad_op, op_dist_attr)
main_block._sync_with_cpp()
# 2. insert slice op
process_mesh_shape = op_dist_attr.process_mesh.topology
......@@ -359,7 +356,6 @@ class DistributedPNormImpl(DistributedOperatorImpl):
slice_op_dist_attr.set_output_dims_mapping(X_grad_var.name,
X_grad_var_dims_mapping)
ctx.set_op_dist_attr_for_program(slice_op, slice_op_dist_attr)
main_block._sync_with_cpp()
register_distributed_operator_impl("p_norm",
......
......@@ -109,14 +109,13 @@ class DistributedReducePrimtiveImpl0(DistributedOperatorImpl):
output_name)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(dist_op_desc, src_op.desc, ctx)
for input_name in src_op.desc.input_names():
dist_op_desc.set_input(input_name, kwargs[input_name])
for output_name in src_op.desc.output_names():
dist_op_desc.set_output(output_name, kwargs[output_name])
main_block._sync_with_cpp()
# batch dimension synchronization
var_name = src_op.output_arg_names[0]
......
......@@ -177,7 +177,7 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis]
# create op
new_op_desc = main_block.desc.append_op()
new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
......@@ -187,8 +187,6 @@ class DistributedReshapeImpl0(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
......@@ -335,7 +333,7 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis]
# create op
new_op_desc = main_block.desc.append_op()
new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
......@@ -345,8 +343,6 @@ class DistributedReshapeImpl1(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
......@@ -486,7 +482,7 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
idx] = shape_list[idx] // process_mesh_shape[axis]
# create op
new_op_desc = main_block.desc.append_op()
new_op_desc = main_block.append_op(type='nop').desc
new_op_desc.copy_from(src_op.desc)
set_dist_op_desc_original_id(new_op_desc, src_op.desc, ctx)
new_op_desc.set_input('ShapeTensor', ShapeTensor_var_list)
......@@ -496,8 +492,6 @@ class DistributedReshapeImpl2(DistributedOperatorImpl):
new_op_desc.set_output('Out', [Out_var.name])
new_op_desc._set_attr('shape', shape_list)
main_block._sync_with_cpp()
@staticmethod
def backward(ctx, *args, **kwargs):
DistributedDefaultImpl0.backward(ctx, *args, **kwargs)
......
......@@ -127,12 +127,11 @@ class DistributedUpdateLossScalingImpl(DistributedOperatorImpl):
filter_vars.append(varname)
# replicate op in dist program
dist_op_desc = main_block.desc.append_op()
dist_op_desc = main_block.append_op(type='nop').desc
dist_op_desc.copy_from(backward_op.desc)
set_dist_op_desc_original_id(dist_op_desc, backward_op.desc, ctx)
dist_op_desc.set_input('X', filter_vars)
dist_op_desc.set_output('Out', filter_vars)
main_block._sync_with_cpp()
register_distributed_operator_impl(
......
......@@ -13,6 +13,8 @@
# limitations under the License.
import copy
import time
import logging
from collections import defaultdict
import paddle
......@@ -20,6 +22,7 @@ from paddle.fluid import program_guard
from paddle.fluid.backward import append_backward
from paddle.fluid.framework import _non_static_mode
from paddle.distributed.passes import new_pass
from paddle.distributed.utils import get_logger
from .reshard import Resharder
from .partitioner import Partitioner
......@@ -41,6 +44,7 @@ class Parallelizer:
assert self._dist_context._is_initialized
self._pass_context = self._dist_context.pass_context
self._strategy = self._dist_context.strategy
self._logger = get_logger(logging.INFO)
def parallel_all(self):
world_process_group = get_world_process_group()
......@@ -61,38 +65,65 @@ class Parallelizer:
serial_startup_program,
serial_loss)
# Apply pre optimization passes
time0 = time.time()
self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss,
serial_optimizer, params_grads)
self._logger.info(
"within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
# Do logical partition
time0 = time.time()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, params_grads)
self._logger.info(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
# Generate optimizer
time0 = time.time()
self._generate_optimizer(dist_main_prog, dist_startup_prog,
serial_optimizer, dist_params_grads)
self._logger.info(
"within parallel optimizer time: {}, mode {}".format(
time.time() - time0, self._mode))
# Do reshard process
time0 = time.time()
set_grad_var_shape(dist_main_prog, self._dist_context)
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, dist_params_grads)
resharder.reshard()
self._logger.info(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Apply post optimization passes
time0 = time.time()
self._apply_post_optimization(dist_main_prog, dist_startup_prog,
rank, dist_params_grads)
self._logger.info(
"within parallel apply_post_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
else:
# Apply pre optimization passes
# self._apply_pre_optimization(serial_main_program,
# serial_startup_program, None, None,
# None)
# Do logical partition
time0 = time.time()
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
serial_main_program, serial_startup_program, [])
# Do reshard process
self._logger.info(
"within parallel partitioner time: {}, mode {}".format(
time.time() - time0, self._mode))
time0 = time.time()
resharder = Resharder(dist_main_prog, dist_startup_prog, rank,
self._dist_context, [], 1)
resharder.reshard()
self._logger.info(
"within parallel reshard time: {}, mode {}".format(
time.time() - time0, self._mode))
# Clone program for test
if self._mode != 'train':
dist_main_prog = dist_main_prog.clone(for_test=True)
......
......@@ -58,7 +58,7 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
def _remove_op_role_var(param, grad):
op_maker = core.op_proto_and_checker_maker
op = grad.op
if op.has_attr(op_maker.kOpRoleVarAttrName()):
if op and op.has_attr(op_maker.kOpRoleVarAttrName()):
op._remove_attr(op_maker.kOpRoleVarAttrName())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册