未验证 提交 cbdba509 编写于 作者: Y Yulong Ao 提交者: GitHub

[Auto Parallel] Move some changes or bug fixes from 2.4 to develop (#52721)

* [Auto Parallel] Speedup the completion process

* [Auto Parallel] Skip the property of dist_context when deepcopying

* [Auto Parallel] Remove the unnecessary print

* [Auto Parallel] Move some changes from 2.4 branch to develop

* Update engine.py

* [Auto Parallel] Fix a bug
上级 fd97d7d1
......@@ -102,6 +102,16 @@ set_field_default_config(GRADIENT_MERGE, "enable", False)
set_field_default_config(GRADIENT_MERGE, "k_steps", 1)
set_field_default_config(GRADIENT_MERGE, "avg", True)
#########################################
# pipeline configuration
#########################################
PIPELINE = "pipeline"
set_field_default_config(PIPELINE, "enable", False)
set_field_default_config(PIPELINE, "schedule_mode", "1F1B")
set_field_default_config(PIPELINE, "micro_batch_size", 1)
set_field_default_config(PIPELINE, "accumulate_steps", 1)
set_field_default_config(PIPELINE, "generation_batch_size", 1)
#########################################
# quantization configuration
#########################################
......
......@@ -606,8 +606,8 @@ def get_cost_from_engine(engine, mode):
)
serial_startup_prog = (
engine._serial_startup_progs[mode].clone()
if mode in engine._serial_startup_progs
engine._fwd_dist_contexts[mode]._original_serial_main_program.clone()
if mode in engine._fwd_dist_contexts
else engine._orig_startup_prog.clone()
)
losses = (
......
......@@ -130,6 +130,9 @@ class DistributedContext:
# A flag indicates whether the used parallelism is data parallel
self._data_parallel = False
# record upstream and downstream of cur rank
self._up_down_streams = UpDownStream()
self._json_config = json_config
@property
......@@ -218,6 +221,10 @@ class DistributedContext:
def data_parallel(self):
return self._data_parallel
@property
def up_down_streams(self):
return self._up_down_streams
@data_parallel.setter
def data_parallel(self, dp):
self._data_parallel = dp
......@@ -1220,3 +1227,45 @@ class BlockState:
self.nblock += 1
assert self.nblock == len(program.blocks)
class UpDownStream:
def __init__(self):
self._ups = {}
self._downs = {}
def add_up_stream(self, rank, up_stream):
ups = self._ups.get(rank, None)
if not ups:
self._ups[rank] = [up_stream]
elif up_stream != -1:
ups = list(filter(lambda a: a != -1, ups))
ups.append(up_stream)
self._ups[rank] = ups
def add_down_stream(self, rank, down_stream):
downs = self._downs.get(rank, None)
if not downs:
self._downs[rank] = [down_stream]
elif down_stream != -1:
downs = list(filter(lambda a: a != -1, downs))
downs.append(down_stream)
self._downs[rank] = downs
def add_pair_stream(self, up, down):
self.add_up_stream(up, -1)
self.add_up_stream(down, up)
self.add_down_stream(up, down)
self.add_down_stream(down, -1)
def ups(self, rank):
ups = self._ups.get(rank, None)
if not ups:
return None
return list(set(ups))
def downs(self, rank):
downs = self._downs.get(rank, None)
if not downs:
return None
return list(set(downs))
......@@ -29,8 +29,6 @@ class DistributedOperator:
def __init__(self, serial_op, dist_attr=None):
self._serial_op = serial_op
if dist_attr is not None and isinstance(dist_attr, OperatorDistAttr):
pass
# TODO: remove this deepcopy after we fix the issue
self._dist_attr = copy.deepcopy(dist_attr)
# self._dist_attr = dist_attr
......@@ -56,21 +54,6 @@ class DistributedOperator:
self._dist_attr = dist_attr
# TODO: Do we really need to write back to serial op?
self._serial_op.dist_attr = dist_attr
# if self._dist_attr is None:
# self._dist_attr = OperatorDistAttr()
# # Create new dist_attr related to current serial_op
# dist_attr = self._filter_dist_attr(dist_attr)
# # Append suffix to mark the inputs or outputs
# if isinstance(dist_attr, dict):
# # Copy the keys since we may add new ones
# for key in list(dist_attr.keys()):
# if isinstance(key, Variable):
# if key.name in self._serial_op.input_arg_names:
# dist_attr[append_op_input_suffix(key.name)] = True
# if key.name in self._serial_op.output_arg_names:
# dist_attr[append_op_output_suffix(key.name)] = True
# self._dist_attr.init(dist_attr)
# self._init_default_dist_attr()
def get_serial_input(self, name):
if self._serial_op.type == "create_py_reader":
......@@ -83,81 +66,6 @@ class DistributedOperator:
tensor = self._serial_op.block._var_recursive(name)
return tensor
# def _init_default_dist_attr(self):
# for tensor_name in self._serial_op.input_arg_names:
# if self._serial_op.type == "create_py_reader":
# tensor = None
# else:
# tensor = self._serial_op.block._var_recursive(tensor_name)
# self._serial_inputs[tensor_name] = tensor
# if tensor is None:
# tensor_shape = []
# else:
# if tensor.type in __no_shape_var_type__:
# tensor_shape = []
# else:
# tensor_shape = tensor.shape
# if self._dist_attr.get_input_dims_mapping(tensor_name) is None:
# tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
# self._dist_attr.set_input_dims_mapping(
# tensor_name, tensor_dims_mapping
# )
# for tensor_name in self._serial_op.output_arg_names:
# tensor = self._serial_op.block._var_recursive(tensor_name)
# if tensor.type in __no_shape_var_type__:
# tensor_shape = []
# else:
# tensor_shape = tensor.shape
# self._serial_outputs[tensor_name] = tensor
# if self._dist_attr.get_output_dims_mapping(tensor_name) is None:
# tensor_dims_mapping = [-1 for _ in range(len(tensor_shape))]
# self._dist_attr.set_output_dims_mapping(
# tensor_name, tensor_dims_mapping
# )
# if self._dist_attr.op_type is None:
# self._dist_attr.op_type = self.serial_op.type
# if self._dist_attr.impl_type is None:
# self._dist_attr.impl_type = "default"
# if self._dist_attr.impl_idx is None:
# self._dist_attr.impl_idx = 0
# if self._dist_attr.is_recompute is None:
# self._dist_attr.is_recompute = False
# def _filter_dist_attr(self, dist_attr):
# if dist_attr is None:
# return None
# new_dist_attr = None
# if isinstance(dist_attr, dict):
# new_dist_attr = {}
# for key, value in dist_attr.items():
# if isinstance(key, Variable):
# if (
# key.name in self._serial_op.input_arg_names
# or key.name in self._serial_op.output_arg_names
# ):
# new_dist_attr[key] = value
# else:
# new_dist_attr[key] = value
# elif isinstance(dist_attr, OperatorDistAttr):
# new_dist_attr = copy.deepcopy(dist_attr)
# new_dist_attr._inputs_dist_attrs.clear()
# new_dist_attr._outputs_dist_attrs.clear()
# for tensor_name in self._serial_op.input_arg_names:
# tensor_dist_attr = dist_attr.get_input_dist_attr(tensor_name)
# if tensor_dist_attr:
# new_dist_attr.set_input_dist_attr(
# tensor_name, tensor_dist_attr
# )
# for tensor_name in self._serial_op.output_arg_names:
# tensor_dist_attr = dist_attr.get_output_dist_attr(tensor_name)
# if tensor_dist_attr:
# new_dist_attr.set_output_dist_attr(
# tensor_name, tensor_dist_attr
# )
# else:
# assert False, "Cannot recognize the {} parameter.".format(dist_attr)
# return new_dist_attr
def validate_dist_attr(self):
if "read" in self.serial_op.type or "while" == self.serial_op.type:
return True
......@@ -402,5 +310,6 @@ class DistributedOperatorHelper:
if self._process_mesh is not None:
dist_op.dist_attr.mark_annotated("process_mesh")
default_dist_ctx.add_dist_op_for_program(dist_op)
default_dist_ctx.add_process_mesh(self._process_mesh)
return output
......@@ -192,17 +192,27 @@ class DistributedSaver:
used_inputs += op.input_arg_names
used_outputs += op.output_arg_names
for idx, var_name in enumerate(feed_vars_names):
if var_name not in used_inputs:
feed_vars_names.pop(idx)
for idx, var_name in enumerate(fetch_vars_names):
if var_name not in used_outputs:
fetch_vars_names.pop(idx)
# delete duplicated elements and keep order
feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars_names = [
var_name for var_name in feed_vars_names if var_name in used_inputs
]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in feed_vars_names])
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [global_block.vars[name] for name in fetch_vars_names]
dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names
]
dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename)
......
......@@ -17,7 +17,6 @@ import logging
import numbers
import os
import random
from collections import defaultdict
import numpy as np
......@@ -154,7 +153,6 @@ class Engine:
" or `paddle.static.Optimizer`."
)
self._optimizer = auto_utils.validate_opt(optimizer)
self._orig_optimizer = copy.deepcopy(self._optimizer)
metrics = metrics or []
for metric in auto_utils.to_list(metrics):
......@@ -185,6 +183,12 @@ class Engine:
)
fleet.init(is_collective=True)
# for compute cost
# TODO: remove _fwd_main_progs and _orig_optimizer
self._fwd_dist_contexts = {}
self._fwd_main_progs = {}
self._orig_optimizer = copy.deepcopy(self._optimizer)
self._executor = None
self._cur_rank = paddle.distributed.get_rank()
self._nranks = paddle.distributed.get_world_size()
......@@ -194,14 +198,6 @@ class Engine:
self._orig_startup_prog = static.default_startup_program()
self._orig_dist_context = get_default_distributed_context()
self._dist_contexts = {}
self._fwd_main_progs = {}
self._fwd_dist_contexts = {}
self._serial_main_progs = {}
self._serial_startup_progs = {}
self._dist_main_progs = defaultdict(dict) # dist main programs
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
self._has_prepared = {"train": False, "eval": False, "predict": False}
self._has_prepared_reader = {
......@@ -334,9 +330,9 @@ class Engine:
return inputs, labels
def _prepare_reader(self):
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
def _prepare_reader(self, feed_list=[]):
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: this list may be changed if Paddle changes the existing rules.
......@@ -357,10 +353,13 @@ class Engine:
if op.type in related_reader_ops:
reader_op_indices.append(idx)
# Step 2: insert the new reader ops to cpp
# record the read ops' desc to insert to program of forward task_node
read_ops_desc = []
new_reader_ops = []
for idx in reversed(reader_op_indices):
new_op_desc = dist_main_block.desc._prepend_op()
new_op_desc.copy_from(dist_main_block.ops[idx].desc)
read_ops_desc.append(new_op_desc)
new_op = Operator(
dist_main_block, new_op_desc, type=new_op_desc.type()
)
......@@ -379,6 +378,29 @@ class Engine:
dist_main_block._sync_with_cpp()
self._has_prepared_reader[self._mode] = True
# Insert read op to forward TaskNode if 1F1B pass is setted
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0]
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
for var in feed_list:
if var.name not in fwd_block.vars:
fwd_block._clone_variable(var)
for op_desc in read_ops_desc:
new_op_desc = fwd_block.desc._prepend_op()
new_op_desc.copy_from(op_desc)
new_op = Operator(
fwd_block, new_op_desc, type=new_op_desc.type()
)
fwd_block.ops.insert(0, new_op)
fwd_block._sync_with_cpp()
fwd_task.set_program(fwd_prog)
def _prepare_feed(self, data, user_feeds, mode):
feeds = {}
if data is not None:
......@@ -428,14 +450,16 @@ class Engine:
fetch_names.append([])
fetch_indices.append(group_indices)
dist_context = self._dist_contexts[mode]
fetch_vars = dist_context.serial_fetch_vars
if mode != "predict":
_process_fetch_group("loss", self._fetch_vars[mode]["loss"])
_process_fetch_group("loss", fetch_vars["loss"])
if mode != "predict":
metrics = self._fetch_vars[mode]["metrics"]
metrics = fetch_vars["metrics"]
for i, var_list in enumerate(metrics):
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", self._fetch_vars[mode]["outputs"])
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
......@@ -472,7 +496,8 @@ class Engine:
logs["loss"] = outs[idx][0]
group_idx += 1
# logging metrics
metric_vars = self._fetch_vars[mode]["metrics"]
dist_context = self._dist_contexts[mode]
metric_vars = dist_context.serial_fetch_vars["metrics"]
if metric_vars:
for metric in self._metrics:
metrics_indices = fetch_indices[group_idx]
......@@ -503,15 +528,18 @@ class Engine:
logs["fetches"] = logs_fetch
return logs
def _prepare_program(self, mode):
def _prepare_program(self, mode, init_parameters=True):
# Do the build process
self._build(mode)
# Do the planning process
self._plan(mode)
# Do the parallel process
self._parallel(mode)
# Init comm and startup program
self._initialize(mode)
# Init comm
self._init_comm()
if init_parameters:
# startup program
self._initialize(mode)
self._has_prepared[mode] = True
def _build(self, mode):
......@@ -543,9 +571,9 @@ class Engine:
paddle.enable_static()
else:
# build program in static graph mode
serial_main_prog = self._serial_main_progs.get(mode, None)
if serial_main_prog is not None:
# build program in static mode
dist_context = self._dist_contexts.get(mode, None)
if dist_context is not None:
return
outputs = []
......@@ -735,42 +763,23 @@ class Engine:
)
dist_context.set_op_dist_attr_for_program(op, ref_op_dist_attr)
def _initialize(self, mode):
# Get the current content from the distributed context
self._serial_main_progs[mode] = self._dist_contexts[
mode
].serial_main_program
self._serial_startup_progs[mode] = self._dist_contexts[
mode
].serial_startup_program
self._dist_main_progs[mode] = self._dist_contexts[
mode
].dist_main_programs
self._dist_startup_progs[mode] = self._dist_contexts[
mode
].dist_startup_programs
self._feed_vars[mode] = self._dist_contexts[mode].serial_feed_vars
self._fetch_vars[mode] = self._dist_contexts[mode].serial_fetch_vars
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def _init_comm(self):
if self._nranks > 1:
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
cur_rank = self._cur_rank
# NOTE: After the implementation of the unified dynamic and static communication group
# initialization mode in the future, the initialization logic of full mode
# will be removed because port occupation error may occur.
if self._strategy.auto_mode == "full":
auto_utils.initialize_pg_in_full_mode(
all_process_groups, cur_rank
all_process_groups, self._cur_rank
)
else:
for process_group in all_process_groups:
if cur_rank not in process_group.ranks:
if self._cur_rank not in process_group.ranks:
continue
process_group.instantiate()
def _initialize(self, mode):
self._place = _get_device()
if isinstance(self._place, paddle.framework.CUDAPlace):
self._place = paddle.framework.CUDAPlace(
......@@ -782,9 +791,9 @@ class Engine:
np.random.seed(self._strategy.seed + self._dp_ranks[0])
random.seed(self._strategy.seed + self._dp_ranks[0])
dist_context = self._dist_contexts[mode]
if self._dygraph_mode:
dist_context = self._dist_contexts[mode]
dist_main_program = self._dist_main_progs[mode][self._cur_rank]
dist_main_program = dist_context.dist_main_programs[self._cur_rank]
self.program_helper.init(
dist_main_program, self._place, dist_context
)
......@@ -792,7 +801,9 @@ class Engine:
if self._executor is None:
self._executor = paddle.static.Executor(self._place)
uninitialized = []
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
for var in dist_startup_prog.list_vars():
scope_var = global_scope().find_var(var.name)
if scope_var and scope_var.get_tensor()._is_initialized():
......@@ -809,7 +820,9 @@ class Engine:
if self._strategy.reinit:
self._logger.info("NOTE: parameters will be re-initialized.")
dist_startup_prog = self._dist_startup_progs[mode][self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[
self._cur_rank
]
self._executor.run(dist_startup_prog)
def fit(
......@@ -1282,6 +1295,7 @@ class Engine:
main_program=None,
startup_program=None,
mode=None,
init_parameters=True,
):
if mode is not None:
self.to_mode(mode)
......@@ -1324,7 +1338,7 @@ class Engine:
self._inputs_spec, self._labels_spec = inputs_spec, labels_spec
self._inputs, self._labels = inputs, labels
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
self._prepare_program(self._mode, init_parameters)
else:
self._switch_mode(self._mode)
......@@ -1375,16 +1389,17 @@ class Engine:
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"]
labels_var = self._feed_vars[self._mode]["labels"]
inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = dist_context.serial_feed_vars["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
......@@ -1443,16 +1458,17 @@ class Engine:
)
batch_size //= self._k_steps
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
dist_startup_prog = self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
dist_main_block = dist_main_prog.global_block()
# NOTE: Get feed_list, then insert dataloader op with sharded var shape.
# Cause predict_program does not contain labels var,
# then we will add labels var from serial_program to dist_program,
# that maintains the length of feed_list equal to the length of dataset's values.
inputs_var = self._feed_vars[self._mode]["inputs"]
labels_var = self._feed_vars[self._mode]["labels"]
inputs_var = dist_context.serial_feed_vars["inputs"]
labels_var = dist_context.serial_feed_vars["labels"]
feed_list = []
for var in inputs_var + labels_var:
if var.name in dist_main_block.vars:
......@@ -1482,7 +1498,7 @@ class Engine:
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
)
self._prepare_reader()
self._prepare_reader(feed_list)
return dataloader
def _tune(self, tune_data, tune_sample_split=None, batch_size=1):
......@@ -1542,7 +1558,7 @@ class Engine:
def _switch_mode(self, mode):
assert (
mode in self._dist_main_progs
mode in self._dist_contexts
), f"{mode} model is not ready, please call `prepare()` first."
self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
......@@ -1556,8 +1572,8 @@ class Engine:
self._mode = mode
def _set_state_dict(self, mode, strict, state_dict, dist_attr):
program = self._dist_main_progs[mode][self._cur_rank]
dist_context = self._dist_contexts[mode]
program = dist_context.dist_main_programs[self._cur_rank]
cur_dist_attr = auto_utils.get_dist_attr(program, dist_context)
converter = Converter(state_dict, dist_attr, cur_dist_attr)
state_dict = converter.convert(strict=strict)
......@@ -1622,10 +1638,10 @@ class Engine:
"""
if training:
assert self._mode in self._serial_main_progs
serial_program = self._serial_main_progs[self._mode]
dist_main_prog = self._dist_main_progs[self._mode][self._cur_rank]
assert self._mode in self._dist_contexts
dist_context = self._dist_contexts[self._mode]
serial_program = dist_context.serial_main_program
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
self._saver.save(
path,
serial_program=serial_program,
......@@ -1633,10 +1649,11 @@ class Engine:
dist_context=dist_context,
)
else:
assert "predict" in self._dist_main_progs
feed_vars = self._feed_vars["predict"]['inputs']
fetch_vars = self._fetch_vars["predict"]['outputs']
dist_main_prog = self._dist_main_progs["predict"][self._cur_rank]
assert "predict" in self._dist_contexts
dist_context = self._dist_contexts["predict"]
feed_vars = dist_context.serial_feed_vars['inputs']
fetch_vars = dist_context.serial_fetch_vars['outputs']
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
if self._strategy.qat.enable and self._strategy.qat.onnx_format:
from paddle.static.quantization import QuantWeightPass
......@@ -1776,11 +1793,13 @@ class Engine:
@property
def main_program(self):
return self._dist_main_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
return dist_context.dist_main_programs[self._cur_rank]
@property
def startup_program(self):
return self._dist_startup_progs[self._mode][self._cur_rank]
dist_context = self._dist_contexts[self._mode]
return dist_context.dist_startup_programs[self._cur_rank]
@property
def dist_context(self):
......@@ -1788,15 +1807,30 @@ class Engine:
@property
def serial_main_program(self):
return self._serial_main_progs[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_main_program
@property
def serial_startup_program(self):
return self._serial_startup_progs[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_startup_program
@property
def feed_vars(self):
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_feed_vars
@property
def fetch_vars(self):
return self._fetch_vars[self._mode]
dist_context = self._dist_contexts[self._mode]
return dist_context.serial_fetch_vars
@property
def optimizer(self):
dist_context = self._dist_contexts[self._mode]
if dist_context._serial_optimizer:
return dist_context._serial_optimizer
return self._optimizer
@property
def inputs(self):
......
......@@ -79,7 +79,15 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
assert isinstance(
shard_spec, list
), f"Argument shard_spec {shard_spec} is not an instance of list"
dist_tensor = DistributedTensor(x)
if isinstance(x, str):
x = (
paddle.static.default_main_program()
.global_block()
._var_recursive(x)
)
dist_tensor = DistributedTensor(x)
else:
dist_tensor = DistributedTensor(x)
serial_tensor = dist_tensor.serial_tensor
dist_tensor.dist_attr.process_mesh = process_mesh
if serial_tensor.type in __no_shape_var_type__:
......@@ -102,6 +110,7 @@ def shard_tensor(x, process_mesh=None, shard_spec=None):
default_dist_ctx = get_default_distributed_context()
default_dist_ctx.add_dist_tensor_for_program(dist_tensor)
dist_tensor = default_dist_ctx.get_dist_tensor_for_program(x)
default_dist_ctx.add_process_mesh(process_mesh)
return x
......
......@@ -499,12 +499,19 @@ class AutoParallelizer:
break
if is_pipeline:
with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier()
paddle.distributed.barrier(get_process_group(0))
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if len(_g_process_group_map) > 0:
tmp = paddle.to_tensor([1], dtype="int32")
paddle.distributed.all_reduce(
tmp, sync_op=True, group=_g_process_group_map[0]
)
paddle.device.cuda.synchronize()
if rank not in process_group.ranks:
continue
process_group.instantiate()
......
......@@ -177,10 +177,22 @@ class Parallelizer:
time.time() - time0, self._mode
)
)
# Apply post optimization passes
time0 = time.time()
self._apply_post_optimization(
dist_main_prog, dist_startup_prog, rank, dist_params_grads
)
self._logger.debug(
"within parallel apply_post_optimization time: {}, mode {}".format(
time.time() - time0, self._mode
)
)
# Clone program for test
if self._mode != 'train':
pipeline_opt = dist_main_prog._pipeline_opt
dist_main_prog = dist_main_prog.clone(for_test=True)
dist_startup_prog = dist_startup_prog.clone(for_test=True)
dist_main_prog._pipeline_opt = pipeline_opt
# Store the distributed programs for further usages
self._dist_context.dist_main_programs[rank] = dist_main_prog
......@@ -247,7 +259,7 @@ class Parallelizer:
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._strategy.qat.enable:
if self._mode == 'train' and self._strategy.qat.enable:
config = copy.deepcopy(self._strategy.qat.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
......@@ -307,8 +319,8 @@ class Parallelizer:
)
params_grads = self._pass_context.get_attr("params_grads")
# GradClip is train-only optimization
if self._mode == "train":
# GradClip is train-only optimization
config = copy.deepcopy(self._strategy.sharding.to_dict())
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
......@@ -330,6 +342,13 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context
)
if self._strategy.pipeline.enable:
self._strategy.gradient_merge.enable = True
self._strategy.gradient_merge.k_steps = (
self._strategy.pipeline.accumulate_steps
)
self._strategy.gradient_merge.avg = True
# gradient_merge is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge.enable:
config = copy.deepcopy(self._strategy.gradient_merge.to_dict())
......@@ -342,6 +361,16 @@ class Parallelizer:
[main_program], [startup_program], self._pass_context
)
if self._strategy.pipeline.enable:
config = copy.deepcopy(self._strategy.pipeline.to_dict())
config["dist_context"] = self._dist_context
auto_parallel_pipeline_pass = new_pass(
"auto_parallel_pipeline", config
)
auto_parallel_pipeline_pass.apply(
[main_program], [startup_program], self._pass_context
)
if self._mode == "train" and self._strategy.fused_passes.enable:
if len(self._strategy.fused_passes.fused_passes_list) > 0:
new_pass_list = []
......
......@@ -52,9 +52,9 @@ def new_process_group(ranks, group_id=None, force_new_group=False):
global _g_process_group_map
if not force_new_group:
# A key constructed from ranks is used for avoiding duplication
new_key = ''.join(map(str, sorted(ranks)))
new_key = ''.join(map(str, ranks))
for pg_id, pg in _g_process_group_map.items():
cur_key = ''.join(map(str, sorted(pg.ranks)))
cur_key = ''.join(map(str, pg.ranks))
if pg_id != 0 and new_key == cur_key:
return pg
# If not matching the existing one, construct a new process group
......@@ -82,7 +82,7 @@ class ProcessGroup:
group_id != 0
), "Process group id 0 is reserved for all ranks."
self._group_id = group_id
self._ranks = sorted(ranks)
self._ranks = ranks
# Add the current ranks into group 0
if group_id != 0:
global _g_process_group_map
......@@ -109,7 +109,7 @@ class ProcessGroup:
not self.is_instantiate()
), "Cannot add new ranks after instantiating the process group"
self._ranks.extend(new_ranks)
self._ranks = sorted(set(self.ranks))
self._ranks = list(set(self.ranks))
def local_rank(self, global_rank):
if global_rank in self.ranks:
......
......@@ -848,7 +848,8 @@ class Remover:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
block._remove_op(idx)
block._remove_op(idx, sync=False)
block._sync_with_cpp()
@staticmethod
def remove_no_need_vars(
......@@ -1000,7 +1001,8 @@ class Remover:
if is_no_need_op:
remove_op_idx.append(idx)
for idx in remove_op_idx[::-1]:
startup_block._remove_op(idx)
startup_block._remove_op(idx, sync=False)
startup_block._sync_with_cpp()
class Resharder:
......@@ -1441,6 +1443,8 @@ class Resharder:
target_process_group = target_process_mesh.process_ids
target_process_shape = target_process_mesh.shape
op_role = dist_attr[2]
if source_tensor.shape[0] < 0:
assert source_tensor.shape[0] == -1
new_shape = list(source_tensor.shape)
......@@ -1583,6 +1587,10 @@ class Resharder:
Resharder.concat_partitions(
partition_index_list, source_partition_index
)
if int(op_role) == int(OpRole.Forward):
self.dist_context.up_down_streams.add_pair_stream(
to_send_process, target_process
)
# append concat op desc
op_desc_seq[target_process].append(
......@@ -2037,13 +2045,6 @@ class Resharder:
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
# if (
# old_name
# in op_dist_attr._inputs_dist_attrs
# ):
# op_dist_attr.del_input_dist_attr(
# old_name
# )
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
......@@ -2067,7 +2068,6 @@ class Resharder:
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
# op_dist_attr.del_input_dist_attr(old_name)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
......@@ -2095,7 +2095,6 @@ class Resharder:
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
# op_dist_attr.del_input_dist_attr(old_name)
op_dist_attr.set_input_dims_mapping(
new_name, dims_mapping
)
......@@ -2135,7 +2134,13 @@ class Resharder:
has_exist = True
break
if not has_exist:
input_attrs.append([process_mesh, input_dims_mapping])
input_attrs.append(
[
process_mesh,
input_dims_mapping,
op.attr('op_role'),
]
)
return input_attrs
def _get_subblock_output_attrs(self, op, var_name):
......@@ -2165,7 +2170,13 @@ class Resharder:
has_exist = True
break
if not has_exist:
output_attrs.append([process_mesh, output_dims_mapping])
output_attrs.append(
[
process_mesh,
output_dims_mapping,
op.attr('op_role'),
]
)
return output_attrs
def _get_common_op_input_attrs(self, op, var_name):
......@@ -2188,7 +2199,9 @@ class Resharder:
input_dims_mapping = dist_attr.get_input_dims_mapping(var_name)
input_attrs = []
for process_mesh in process_meshes:
input_attrs.append([process_mesh, input_dims_mapping])
input_attrs.append(
[process_mesh, input_dims_mapping, op.attr('op_role')]
)
return input_attrs
......@@ -2207,7 +2220,7 @@ class Resharder:
assert (
op_input_attrs
), "The input '{}' of op '{}' has no distibution attributes in subblock".format(
), "The input '{}' of op '{}' has no distributed attributes in subblock".format(
op.name, var_name
)
......@@ -2215,30 +2228,24 @@ class Resharder:
def _remove_global_process_mesh(self):
"""Remove global process mesh from dist_context.process_meshes"""
processes = set()
process_ids = set()
process_mesh_count = len(self.dist_context.process_meshes)
if process_mesh_count > 1:
global_process_mesh_idx = None
global_process_mesh_idx = []
has_sub_process_mesh = False
for process_mesh in self.dist_context.process_meshes:
for process in process_mesh.process_ids:
processes.add(process)
for process_id in process_mesh.process_ids:
process_ids.add(process_id)
for idx, process_mesh in enumerate(
self.dist_context.process_meshes
):
if len(set(process_mesh.process_ids)) == len(processes):
global_process_mesh_idx = idx
break
if len(set(process_mesh.process_ids)) == len(process_ids):
global_process_mesh_idx.append(idx)
elif set(process_mesh.process_ids) < process_ids:
has_sub_process_mesh = True
if global_process_mesh_idx is not None:
is_removed = False
global_mesh = self.dist_context.process_meshes[idx]
for i, mesh in enumerate(self.dist_context.process_meshes):
if i == idx:
continue
if set(mesh.process_ids) < set(global_mesh.process_ids):
is_removed = True
if is_removed:
if has_sub_process_mesh:
for idx in reversed(global_process_mesh_idx):
self.dist_context.process_meshes.pop(idx)
def _change_subblock_op_input_and_output(self, block_idx, block):
......@@ -2278,7 +2285,6 @@ class Resharder:
op_dist_attr.set_input_dist_attr(
new_name, op_input_dist_attr
)
# op_dist_attr.del_input_dist_attr(old_name)
# the outputs also need to be renamed when the output name is the same with input name in inplace op
for var_name in op.output_arg_names:
......@@ -2302,7 +2308,6 @@ class Resharder:
op_dist_attr.set_output_dist_attr(
new_name, op_output_dist_attr
)
# op_dist_attr.del_output_dist_attr(old_name)
def _reshard_input(self, block):
idx = 0
......@@ -2450,7 +2455,7 @@ class Resharder:
assert set_lod is True
# cast int64 to bool
block._insert_op(
cast_op = block._insert_op(
idx + 2,
type='cast',
inputs={
......@@ -2465,6 +2470,7 @@ class Resharder:
'op_role': op.attr('op_role'),
},
)
cast_op._set_attr('op_namescope', "/auto_parallel/reshard")
else:
if var.lod_level != 0:
recv_out = block.create_var(
......@@ -2612,6 +2618,10 @@ class Resharder:
]
if recv_rank == item:
continue
if var.shape[0] == -1:
new_shape = list(var.shape)
new_shape[0] = self.batch_size
var.desc.set_shape(new_shape)
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
......@@ -2640,6 +2650,10 @@ class Resharder:
item = output_attr[0].process_ids[index]
if recv_rank == item:
continue
if var.shape[0] == -1:
new_shape = list(var.shape)
new_shape[0] = self.batch_size
var.desc.set_shape(new_shape)
if self.rank_id == item:
# if send bool data, cast then send
self._handle_send(
......@@ -2714,7 +2728,11 @@ class Resharder:
tensor.name
)
process_mesh = dist_op.dist_attr.process_mesh
dist_attr = [process_mesh, dims_mapping]
dist_attr = [
process_mesh,
dims_mapping,
dist_op.serial_op.attr('op_role'),
]
if dist_tensor is not None and self.need_reshard(
dist_tensor, dist_attr
):
......
......@@ -102,6 +102,12 @@ class GradientMergeConfig(BaseConfig):
super().__init__(category, config_dict)
class PipelineConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.PIPELINE
super().__init__(category, config_dict)
class QATConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.QAT
......@@ -186,6 +192,9 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.GRADIENT_MERGE, None)
self.gradient_merge = GradientMergeConfig(config_dict)
config_dict = self._config_dict.get(constants.PIPELINE, None)
self.pipeline = PipelineConfig(config_dict)
config_dict = self._config_dict.get(constants.QAT, None)
self.qat = QATConfig(config_dict)
......
......@@ -91,7 +91,7 @@ def init_process_groups(group_map, rank):
# TODO should instantiate global group first
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if rank not in process_group.ranks:
if process_group.id == 0 or rank not in process_group.ranks:
continue
print(process_group)
process_group.instantiate()
......
......@@ -122,9 +122,9 @@ def all_reduce(
tensor, op, group, sync_op, use_calc_stream
)
else:
assert (
group is None
), "Group can not be used in static graph mode for now."
# assert (
# group is None
# ), "Group can not be used in static graph mode for now."
return _all_reduce_in_static_mode(
tensor, op, group, sync_op, use_calc_stream
)
......@@ -23,6 +23,7 @@ from .auto_parallel_quantization import * # noqa: F403
from .auto_parallel_data_parallel_optimization import * # noqa: F403
from .auto_parallel_grad_clip import * # noqa: F403
from .auto_parallel_supplement_explicit_dependencies import * # noqa: F403
from .auto_parallel_pipeline import * # noqa: F403
from .cpp_pass import * # noqa: F403
from .ps_trainer_pass import * # noqa: F403
from .ps_server_pass import * # noqa: F403
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
import paddle.nn.functional as F
from paddle import nn
from paddle.distributed.fleet import auto
_g_mesh = auto.ProcessMesh([0, 1])
PP_MESH_0 = auto.ProcessMesh([0])
PP_MESH_1 = auto.ProcessMesh([1])
image_size = 1024
class_num = 10
class MyDataset(paddle.io.Dataset):
def __init__(self, num_samples):
super().__init__()
self.num_samples = num_samples
def __getitem__(self, index):
input = np.random.uniform(size=image_size).astype("float32")
input = np.random.uniform(size=image_size).astype("float32")
return input, input
def __len__(self):
return self.num_samples
class MLPLayer(nn.Layer):
def __init__(
self,
hidden_size=1024,
intermediate_size=4 * 1024,
dropout_ratio=0.1,
initializer_range=0.02,
):
super().__init__()
d_model = hidden_size
dim_feedforward = intermediate_size
weight_attr = paddle.ParamAttr(
initializer=nn.initializer.Normal(mean=0.0, std=initializer_range)
)
bias_attr = None
self.linear0 = nn.Linear(
d_model, dim_feedforward, weight_attr, bias_attr=bias_attr
)
self.linear1 = nn.Linear(
dim_feedforward, d_model, weight_attr, bias_attr=bias_attr
)
self.linear2 = nn.Linear(d_model, 1, weight_attr, bias_attr=bias_attr)
self.norm = nn.LayerNorm(d_model, epsilon=1e-5)
self.dropout = nn.Dropout(dropout_ratio, mode="upscale_in_train")
def forward(self, input):
out = auto.shard_op(self.norm, PP_MESH_0)(input)
out = self.linear0(out)
out = F.gelu(out, approximate=True)
out = auto.shard_op(self.linear1, PP_MESH_1)(out)
out = self.dropout(out)
out = self.linear2(out)
return out
class GEN(nn.Layer):
def __init__(self, mlp):
super().__init__()
self.mlp = mlp
def forward(self, input):
model_kwargs = {}
output = self.mlp(input)
cur_step = paddle.full([1], 0, dtype='int64')
total_step = paddle.full([1], 10, dtype='int64')
model_kwargs['input'] = input
model_kwargs['output'] = output
while cur_step < total_step:
out = self.mlp(model_kwargs['input'])
model_kwargs['res'] = out
paddle.increment(cur_step)
auto.shard_op(paddle.assign, _g_mesh)(model_kwargs['input'], out)
output = F.gelu(model_kwargs['input'], approximate=True)
return output, cur_step
def get_model():
with paddle.LazyGuard():
mlp = MLPLayer()
gen = GEN(mlp)
return gen
class TestGenerationPipeline(unittest.TestCase):
def test_pp2(self):
model = get_model()
strategy = auto.Strategy()
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "stream"
pipeline.generation_batch_size = 4
pipeline.accumulate_steps = 4
engine = auto.Engine(model, strategy=strategy)
engine.prepare(
inputs_spec=paddle.static.InputSpec(
shape=[2, 1024], name='input', dtype='float32'
),
labels_spec=paddle.static.InputSpec(
shape=[2, 1024], name='label', dtype='float32'
),
mode="eval",
)
train_data = MyDataset(50 * 2)
train_dataloader = engine._prepare_dataloader_from_generator(
dataset=train_data,
capacity=70,
iterable=False,
batch_size=2,
epochs=1,
steps_per_epoch=100,
)
engine._prepare_reader()
fleet_opt = engine.main_program._pipeline_opt['fleet_opt']
assert len(fleet_opt['tasks']) == 5
assert fleet_opt['inference_generation']
assert fleet_opt['num_micro_batches'] == 4
num_task_in_rank = 5
for idx, (task_id, rank_id) in enumerate(
fleet_opt['task_id_to_rank'].items()
):
assert (
task_id == rank_id * num_task_in_rank + idx % num_task_in_rank
)
train_dataloader._inner_dataloader.start()
try:
engine._executor.run(
engine.main_program, use_program_cache=False, return_numpy=False
)
except paddle.fluid.core.EOFException:
print("test done")
train_dataloader._inner_dataloader.reset()
train_dataloader._inner_dataloader.start()
if __name__ == "__main__":
unittest.main()
......@@ -247,6 +247,7 @@ class TestDistributedContext(unittest.TestCase):
"_backup_serial_main_program_stack",
"_backup_serial_startup_program_stack",
"_pass_context",
"_tensor_nodes_with_same_name",
]
for i in range(len(copy_list)):
......
......@@ -203,7 +203,7 @@ class TestBF16Pass(unittest.TestCase):
bf16_o1_engine.prepare(
inputs_spec=inputs_spec, labels_spec=labels_spec, mode="train"
)
self.check_program(bf16_o1_engine._dist_main_progs["train"][0])
self.check_program(bf16_o1_engine.main_program)
print("BF16!check program successfully!")
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import subprocess
import sys
import tempfile
import unittest
class TestGenerationPipeline(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(
file_dir, "generation_pipeline_pass_unittest.py"
)
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
......@@ -180,6 +180,9 @@ def check_send_recv_result(dist_main_prog, rank_id):
return send_result and recv_result
@unittest.skipIf(
not paddle.is_compiled_with_cuda(), "core is not compiled with CUDA"
)
class TestMLPReshard(unittest.TestCase):
def test_mlp_serial(self):
global _global_parallel_strategy
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册