未验证 提交 81c13b86 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] update pipeline pass for while control_flow (#54224)

* [AutoParallel] update while control_flow with pipeline

* update process group instantiate

* fix micro_bsz for reshard

* update api for micro batch size

* add strategy for dp optimization
上级 b0e86d55
...@@ -216,8 +216,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank, ...@@ -216,8 +216,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
brpc::Channel channel; brpc::Channel channel;
brpc::ChannelOptions options; brpc::ChannelOptions options;
options.protocol = "baidu_std"; options.protocol = "baidu_std";
options.connect_timeout_ms = 1000; options.connect_timeout_ms = 100000;
options.timeout_ms = 1000; options.timeout_ms = 100000;
options.max_retry = 5; options.max_retry = 5;
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options), channel.Init(dst_addr_for_brpc, &options),
......
...@@ -124,9 +124,9 @@ set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) ...@@ -124,9 +124,9 @@ set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "algo", None)
set_field_default_config(QAT, "onnx_format", True) set_field_default_config(QAT, "onnx_format", True)
# ######################################### #########################################
# auto tuning configuration # auto tuning configuration
# ######################################### #########################################
TUNING = "tuning" TUNING = "tuning"
set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "profile_start_step", 1) set_field_default_config(TUNING, "profile_start_step", 1)
...@@ -147,3 +147,12 @@ set_field_default_config(DATASET, "num_shards", 1) ...@@ -147,3 +147,12 @@ set_field_default_config(DATASET, "num_shards", 1)
FUSED_PASSES = "fused_passes" FUSED_PASSES = "fused_passes"
set_field_default_config(FUSED_PASSES, "enable", False) set_field_default_config(FUSED_PASSES, "enable", False)
set_field_default_config(FUSED_PASSES, "fused_passes_list", []) set_field_default_config(FUSED_PASSES, "fused_passes_list", [])
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
...@@ -58,6 +58,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -58,6 +58,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[], data_parallel_rank=[],
acc_steps=1,
): ):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
...@@ -77,6 +78,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -77,6 +78,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER self.dataset_kind = _DatasetKind.ITER
...@@ -141,9 +143,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -141,9 +143,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None steps_per_epoch = None
elif self.batch_size is None: elif self.batch_size is None:
steps_per_epoch = len(self.dataset) steps_per_epoch = len(self.dataset) // self.acc_steps
else: else:
steps_per_epoch = len(self.dataset) // self.batch_size steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except: except:
raise ValueError( raise ValueError(
"Please set `steps_per_epoch` or implement `__len__` method in dataset class." "Please set `steps_per_epoch` or implement `__len__` method in dataset class."
......
...@@ -235,6 +235,11 @@ class Engine: ...@@ -235,6 +235,11 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None self.history = None
...@@ -400,7 +405,12 @@ class Engine: ...@@ -400,7 +405,12 @@ class Engine:
if self.main_program._pipeline_opt: if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0] fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program() fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block() fwd_block = fwd_prog.global_block()
...@@ -450,8 +460,6 @@ class Engine: ...@@ -450,8 +460,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -478,7 +486,7 @@ class Engine: ...@@ -478,7 +486,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches: for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch) var_name = _to_name_str(usr_fetch)
fetch(var_name) fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
...@@ -931,6 +939,7 @@ class Engine: ...@@ -931,6 +939,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size train_data, train_sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -940,7 +949,7 @@ class Engine: ...@@ -940,7 +949,7 @@ class Engine:
dataset=train_data, dataset=train_data,
capacity=70, capacity=70,
iterable=False, iterable=False,
batch_size=batch_size, batch_size=micro_batch_size,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn, collate_fn=collate_fn,
...@@ -951,7 +960,7 @@ class Engine: ...@@ -951,7 +960,7 @@ class Engine:
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
engine=self, engine=self,
batch_size=batch_size, batch_size=micro_batch_size,
epochs=epochs, epochs=epochs,
steps=train_dataloader._steps, steps=train_dataloader._steps,
log_freq=log_freq, log_freq=log_freq,
...@@ -959,7 +968,7 @@ class Engine: ...@@ -959,7 +968,7 @@ class Engine:
save_dir=save_dir, save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), metrics=self._metrics_name(),
acc_step=self._k_steps, acc_step=self._acc_steps,
) )
cbks.on_begin('train') cbks.on_begin('train')
...@@ -977,7 +986,7 @@ class Engine: ...@@ -977,7 +986,7 @@ class Engine:
) )
except core.EOFException: except core.EOFException:
break break
lr = auto_utils.get_lr(self._optimizer) lr = auto_utils.get_lr(self.optimizer)
logs = self._prepare_logger( logs = self._prepare_logger(
outs, outs,
epoch, epoch,
...@@ -1074,6 +1083,7 @@ class Engine: ...@@ -1074,6 +1083,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size valid_data, valid_sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1083,7 +1093,7 @@ class Engine: ...@@ -1083,7 +1093,7 @@ class Engine:
dataset=valid_data, dataset=valid_data,
capacity=70, capacity=70,
iterable=False, iterable=False,
batch_size=batch_size, batch_size=micro_batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
...@@ -1093,7 +1103,7 @@ class Engine: ...@@ -1093,7 +1103,7 @@ class Engine:
cbks = config_callbacks( cbks = config_callbacks(
callbacks, callbacks,
engine=self, engine=self,
batch_size=batch_size, batch_size=micro_batch_size,
log_freq=log_freq, log_freq=log_freq,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), metrics=self._metrics_name(),
...@@ -1180,6 +1190,7 @@ class Engine: ...@@ -1180,6 +1190,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size test_data, test_sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1189,7 +1200,7 @@ class Engine: ...@@ -1189,7 +1200,7 @@ class Engine:
dataset=test_data, dataset=test_data,
capacity=70, capacity=70,
iterable=False, iterable=False,
batch_size=batch_size, batch_size=micro_batch_size,
steps_per_epoch=steps, steps_per_epoch=steps,
collate_fn=collate_fn, collate_fn=collate_fn,
) )
...@@ -1242,6 +1253,7 @@ class Engine: ...@@ -1242,6 +1253,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1250,7 +1262,7 @@ class Engine: ...@@ -1250,7 +1262,7 @@ class Engine:
dataloader = self._prepare_dataloader( dataloader = self._prepare_dataloader(
dataset, dataset,
return_list=False, return_list=False,
batch_size=batch_size, batch_size=micro_batch_size,
shuffle=shuffle, shuffle=shuffle,
drop_last=drop_last, drop_last=drop_last,
collate_fn=collate_fn, collate_fn=collate_fn,
...@@ -1284,6 +1296,7 @@ class Engine: ...@@ -1284,6 +1296,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1297,7 +1310,7 @@ class Engine: ...@@ -1297,7 +1310,7 @@ class Engine:
return_list=False, return_list=False,
use_multiprocess=use_multiprocess, use_multiprocess=use_multiprocess,
drop_last=drop_last, drop_last=drop_last,
batch_size=batch_size, batch_size=micro_batch_size,
epochs=epochs, epochs=epochs,
steps_per_epoch=steps_per_epoch, steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn, collate_fn=collate_fn,
...@@ -1399,14 +1412,6 @@ class Engine: ...@@ -1399,14 +1412,6 @@ class Engine:
steps_per_epoch=None, steps_per_epoch=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1468,14 +1473,6 @@ class Engine: ...@@ -1468,14 +1473,6 @@ class Engine:
collate_fn=None, collate_fn=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1515,6 +1512,9 @@ class Engine: ...@@ -1515,6 +1512,9 @@ class Engine:
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks, data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
) )
self._prepare_reader(feed_list) self._prepare_reader(feed_list)
return dataloader return dataloader
...@@ -1526,9 +1526,18 @@ class Engine: ...@@ -1526,9 +1526,18 @@ class Engine:
) )
self._optimization_tuning(self._mode, tune_data, batch_size) self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = auto_utils.to_list(specs) specs = auto_utils.to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec): if not isinstance(spec, InputSpec):
...@@ -1541,14 +1550,14 @@ class Engine: ...@@ -1541,14 +1550,14 @@ class Engine:
i, spec i, spec
) )
) )
if self._k_steps > 1: if self._acc_steps > 1:
shape = list(spec.shape) shape = list(spec.shape)
assert ( assert (
shape[0] % self._k_steps == 0 shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps spec.shape[0], self._acc_steps
) )
shape[0] //= self._k_steps shape[0] //= self._acc_steps
spec.shape = shape spec.shape = shape
return specs or [] return specs or []
...@@ -1579,7 +1588,6 @@ class Engine: ...@@ -1579,7 +1588,6 @@ class Engine:
mode in self._dist_contexts mode in self._dist_contexts
), f"{mode} model is not ready, please call `prepare()` first." ), f"{mode} model is not ready, please call `prepare()` first."
self.to_mode(mode) self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode): def to_mode(self, mode):
assert mode in [ assert mode in [
......
...@@ -499,21 +499,12 @@ class AutoParallelizer: ...@@ -499,21 +499,12 @@ class AutoParallelizer:
break break
if is_pipeline: if is_pipeline:
with paddle.static.program_guard(dist_main_prog): with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier(get_process_group(0)) paddle.distributed.barrier()
# Traverse different rank programs and traverse each op of them, # Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping. # instantiate communication by process_mapping.
all_process_groups = get_all_process_groups() all_process_groups = get_all_process_groups()
for process_group in all_process_groups: for process_group in all_process_groups:
if len(_g_process_group_map) > 0:
tmp = paddle.to_tensor([1], dtype="int32")
paddle.distributed.all_reduce(
tmp, sync_op=True, group=_g_process_group_map[0]
)
paddle.device.cuda.synchronize()
if rank not in process_group.ranks:
continue
process_group.instantiate() process_group.instantiate()
# Copy distributed info to the default context # Copy distributed info to the default context
......
...@@ -105,13 +105,6 @@ class Parallelizer: ...@@ -105,13 +105,6 @@ class Parallelizer:
time.time() - time0, self._mode time.time() - time0, self._mode
) )
) )
# Do reshard process
time0 = time.time()
micro_bsz = (
1
if not self._strategy.pipeline.enable
else self._strategy.pipeline.micro_batch_size
)
set_grad_var_shape(dist_main_prog, self._dist_context) set_grad_var_shape(dist_main_prog, self._dist_context)
resharder = Resharder( resharder = Resharder(
dist_main_prog, dist_main_prog,
...@@ -119,7 +112,6 @@ class Parallelizer: ...@@ -119,7 +112,6 @@ class Parallelizer:
rank, rank,
self._dist_context, self._dist_context,
dist_params_grads, dist_params_grads,
micro_bsz,
) )
resharder.reshard() resharder.reshard()
self._logger.debug( self._logger.debug(
...@@ -169,13 +161,19 @@ class Parallelizer: ...@@ -169,13 +161,19 @@ class Parallelizer:
) )
) )
time0 = time.time() time0 = time.time()
# Do reshard process
micro_bsz = (
1
if not self._strategy.pipeline.enable
else self._strategy.pipeline.micro_batch_size
)
resharder = Resharder( resharder = Resharder(
dist_main_prog, dist_main_prog,
dist_startup_prog, dist_startup_prog,
rank, rank,
self._dist_context, self._dist_context,
[], [],
1, micro_bsz,
) )
resharder.reshard() resharder.reshard()
self._logger.debug( self._logger.debug(
...@@ -305,11 +303,14 @@ class Parallelizer: ...@@ -305,11 +303,14 @@ class Parallelizer:
return return
# data parallel optimization # data parallel optimization
config = {} if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context) dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable: if self._strategy.sharding.enable:
......
...@@ -49,6 +49,12 @@ def clear_all_process_groups(): ...@@ -49,6 +49,12 @@ def clear_all_process_groups():
_g_process_group_map[0] = ProcessGroup(0, []) _g_process_group_map[0] = ProcessGroup(0, [])
def remove_process_group(ring_id):
global _g_process_group_map
if ring_id in _g_process_group_map:
_g_process_group_map.pop(ring_id)
def new_process_group(ranks, group_id=None, force_new_group=False): def new_process_group(ranks, group_id=None, force_new_group=False):
global _g_process_group_map global _g_process_group_map
......
...@@ -132,6 +132,12 @@ class FusedPassesConfig(BaseConfig): ...@@ -132,6 +132,12 @@ class FusedPassesConfig(BaseConfig):
super().__init__(category, config_dict) super().__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super().__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the parallelization and optimization behaviors. The `Strategy` object is used to configure the parallelization and optimization behaviors.
...@@ -206,3 +212,6 @@ class Strategy(BaseConfig): ...@@ -206,3 +212,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.FUSED_PASSES, None) config_dict = self._config_dict.get(constants.FUSED_PASSES, None)
self.fused_passes = FusedPassesConfig(config_dict) self.fused_passes = FusedPassesConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
...@@ -122,9 +122,9 @@ def all_reduce( ...@@ -122,9 +122,9 @@ def all_reduce(
tensor, op, group, sync_op, use_calc_stream tensor, op, group, sync_op, use_calc_stream
) )
else: else:
# assert ( assert (
# group is None group is None
# ), "Group can not be used in static graph mode for now." ), "Group can not be used in static graph mode for now."
return _all_reduce_in_static_mode( return _all_reduce_in_static_mode(
tensor, op, group, sync_op, use_calc_stream tensor, op, group, sync_op, use_calc_stream
) )
...@@ -14,6 +14,9 @@ ...@@ -14,6 +14,9 @@
import os import os
from paddle.distributed.auto_parallel.static.process_group import (
remove_process_group,
)
from paddle.distributed.fleet.fleet_executor_utils import TaskNode from paddle.distributed.fleet.fleet_executor_utils import TaskNode
from paddle.fluid import core from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program from paddle.fluid.framework import Parameter, Program
...@@ -50,6 +53,14 @@ class PipelinePass(PassBase): ...@@ -50,6 +53,14 @@ class PipelinePass(PassBase):
self._gen_bsz = self.get_attr("generation_batch_size") self._gen_bsz = self.get_attr("generation_batch_size")
self._program = main_program self._program = main_program
self._cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
self._nrank = len(trainer_endpoints)
# compute current pp stage
self._pp_stages = len(self._dist_context.process_meshes)
self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B": if self._mode == "1F1B":
raise NotImplementedError("1F1B has not been implemented") raise NotImplementedError("1F1B has not been implemented")
elif self._mode == "F-Then-B": elif self._mode == "F-Then-B":
...@@ -70,6 +81,10 @@ class PipelinePass(PassBase): ...@@ -70,6 +81,10 @@ class PipelinePass(PassBase):
send_vars = [] send_vars = []
# insert sync ops # insert sync ops
for index, op in enumerate(list(block.ops)): for index, op in enumerate(list(block.ops)):
# NOTE: pipeline might hang when dynamic_shape is True
if op.type in ['send_v2', 'recv_v2']:
op._set_attr("dynamic_shape", False)
# set send op on comm stream
if op.type == 'send_v2': if op.type == 'send_v2':
# step1: set 'use_calc_stream' False # step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False) op._set_attr("use_calc_stream", False)
...@@ -176,21 +191,13 @@ class PipelinePass(PassBase): ...@@ -176,21 +191,13 @@ class PipelinePass(PassBase):
def _get_pp_stage(self, rank): def _get_pp_stage(self, rank):
pp_idx = None pp_idx = None
for idx, process_mesh in enumerate(self._dist_context.process_meshes): for idx, process_mesh in enumerate(self._dist_context.process_meshes):
if rank in process_mesh.processes: if rank in process_mesh.process_ids:
pp_idx = idx pp_idx = idx
break break
return pp_idx return pp_idx
def _task_stream(self): def _task_stream(self):
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
nrank = len(trainer_endpoints)
num_of_functionality = 5 num_of_functionality = 5
# compute current pp stage
pp_stages = len(self._dist_context.process_meshes)
cur_pp_stage = self._get_pp_stage(cur_rank)
start_prog = Program() start_prog = Program()
cond_prog = Program() cond_prog = Program()
end_prog = Program() end_prog = Program()
...@@ -198,6 +205,7 @@ class PipelinePass(PassBase): ...@@ -198,6 +205,7 @@ class PipelinePass(PassBase):
recv_prog = Program() recv_prog = Program()
cond_var_name = None cond_var_name = None
# record the varnames related to the while cond vars and communicate by nccl
send_vars_name = set() send_vars_name = set()
recv_vars_name = {} recv_vars_name = {}
for ib, src_block in enumerate(self._program.blocks): for ib, src_block in enumerate(self._program.blocks):
...@@ -222,38 +230,25 @@ class PipelinePass(PassBase): ...@@ -222,38 +230,25 @@ class PipelinePass(PassBase):
src_block, end_block, op, force_create=True src_block, end_block, op, force_create=True
) )
elif ib == 1: elif ib == 1:
# NOTE: for ernie generation
# The while block will be split to two separate blocks:
# while{transformer_layer(send_block), generation_and_broadcast(recv_block)}
# The send_block:
# include all ops about tansformer layers computation
# execlude the nccl op about the while cond var(the last pp stage).
# The recv_block:
# include all computation ops about generation and while cond var
# execlude the nccl op about the while cond var(the pp stages exclude the last one)
# the nccl op about the while cond var:
# put these varnames in the recv task node and do communication with brpc instead of nccl.
send_block = send_prog.block(0) send_block = send_prog.block(0)
recv_block = recv_prog.block(0) recv_block = recv_prog.block(0)
is_after_send_op = False is_after_send_op = False
is_after_recv_op = False is_after_recv_op = False
for op in src_block.ops: for i, op in enumerate(src_block.ops):
if op.type == "send_v2" and not is_after_send_op: if op.type == "send_v2" and not is_after_send_op:
is_after_send_op = True is_after_send_op = True
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
continue
if (
op.type not in ["recv_2", "assign"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
):
if (
len(op.desc.input_arg_names()) > 0
and "@RESHARD"
not in op.desc.input_arg_names()[0]
):
send_vars_name.add(
op.desc.input_arg_names()[0]
)
continue
if op.type == "send_v2":
continue
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if ( if (
is_after_send_op is_after_send_op
...@@ -261,45 +256,21 @@ class PipelinePass(PassBase): ...@@ -261,45 +256,21 @@ class PipelinePass(PassBase):
and op.type == "recv_v2" and op.type == "recv_v2"
): ):
is_after_recv_op = True is_after_recv_op = True
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
old_var_name = var_name[:index]
else:
old_var_name = var_name
recv_vars_name[var_name] = old_var_name
if not src_block._find_var_recursive(old_var_name):
src_var = src_block._var_recursive(var_name)
recv_block.create_var(
type=src_var.type,
name=old_var_name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
continue
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
if not is_after_send_op or not is_after_recv_op: if not is_after_send_op or not is_after_recv_op:
if cur_pp_stage == pp_stages - 1: if self._cur_pp_stage == self._pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]: # NOTE: the c_sync_calc_stream about c_allgather cannot be removed
if (
op.type == "c_sync_calc_stream"
and src_block.ops[i + 1].type == "send_v2"
):
continue continue
if op.type == "nop":
continue
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if ( if (
op.type not in ["recv_2", "assign"] op.type
not in ["recv_2", "assign", "c_allgather"]
and op.has_attr('op_namescope') and op.has_attr('op_namescope')
and "/auto_parallel/reshard" and "/auto_parallel/reshard"
in op.attr('op_namescope') in op.attr('op_namescope')
...@@ -312,19 +283,27 @@ class PipelinePass(PassBase): ...@@ -312,19 +283,27 @@ class PipelinePass(PassBase):
send_vars_name.add( send_vars_name.add(
op.desc.input_arg_names()[0] op.desc.input_arg_names()[0]
) )
if op.type == "send_v2":
remove_process_group(op.attr("ring_id"))
continue continue
if op.type == "send_v2": if op.type == "send_v2":
remove_process_group(op.attr("ring_id"))
continue continue
self._create_program( self._create_program(
src_block, send_block, op, force_create=True src_block, send_block, op, force_create=True
) )
continue
if is_after_send_op and is_after_recv_op: if is_after_send_op and is_after_recv_op:
# HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm
if op.has_attr( if op.has_attr(
'op_namescope' 'op_namescope'
) and "/auto_parallel/reshard" in op.attr( ) and "/auto_parallel/reshard" in op.attr(
'op_namescope' 'op_namescope'
): ):
if op.type in ["send_v2", "recv_v2"]:
remove_process_group(op.attr("ring_id"))
# remove the suffix of "@RESHARD"
var_name = op.desc.output_arg_names()[0] var_name = op.desc.output_arg_names()[0]
index = var_name.find("@") index = var_name.find("@")
if index > 0: if index > 0:
...@@ -356,6 +335,7 @@ class PipelinePass(PassBase): ...@@ -356,6 +335,7 @@ class PipelinePass(PassBase):
self._create_program( self._create_program(
src_block, recv_block, op, force_create=True src_block, recv_block, op, force_create=True
) )
continue
else: else:
raise Exception("Only support generation condition.") raise Exception("Only support generation condition.")
...@@ -397,52 +377,52 @@ class PipelinePass(PassBase): ...@@ -397,52 +377,52 @@ class PipelinePass(PassBase):
vars_to_shape = recv_task_node_var_shape vars_to_shape = recv_task_node_var_shape
start_task_node = TaskNode( start_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Start", node_type="Start",
task_id=int(cur_rank * num_of_functionality + 0), task_id=int(self._cur_rank * num_of_functionality + 0),
program=start_prog, program=start_prog,
lazy_initialize=True, lazy_initialize=True,
) )
cond_task_node = TaskNode( cond_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Cond", node_type="Cond",
task_id=int(cur_rank * num_of_functionality + 1), task_id=int(self._cur_rank * num_of_functionality + 1),
program=cond_prog, program=cond_prog,
cond_var_name=cond_var_name, cond_var_name=cond_var_name,
lazy_initialize=True, lazy_initialize=True,
) )
send_task_node = TaskNode( send_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 2), task_id=int(self._cur_rank * num_of_functionality + 2),
program=send_prog, program=send_prog,
lazy_initialize=True, lazy_initialize=True,
) )
recv_task_node = TaskNode( recv_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 3), task_id=int(self._cur_rank * num_of_functionality + 3),
program=recv_prog, program=recv_prog,
lazy_initialize=True, lazy_initialize=True,
vars_to_dtype=vars_to_dtype, vars_to_dtype=vars_to_dtype,
vars_to_shape=vars_to_shape, vars_to_shape=vars_to_shape,
) )
end_task_node = TaskNode( end_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 4), task_id=int(self._cur_rank * num_of_functionality + 4),
program=end_prog, program=end_prog,
lazy_initialize=True, lazy_initialize=True,
) )
# add dependencies for task nodes intra stage # add dependencies for task nodes intra stage
inf = -1 inf = -1
pp_buff_size = int(pp_stages - cur_pp_stage) pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
start_task_node.add_downstream_task( start_task_node.add_downstream_task(
cond_task_node.task_id(), self._gen_bsz cond_task_node.task_id(), self._gen_bsz
) )
...@@ -551,12 +531,12 @@ class PipelinePass(PassBase): ...@@ -551,12 +531,12 @@ class PipelinePass(PassBase):
# add dependencies for task nodes inter stage # add dependencies for task nodes inter stage
# get upstream ranks and downstream ranks of cur_rank # get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams up_down_streams = self._dist_context.up_down_streams
pp_upstream_ranks = up_down_streams.ups(cur_rank) pp_upstream_ranks = up_down_streams.ups(self._cur_rank)
pp_downstream_ranks = up_down_streams.downs(cur_rank) pp_downstream_ranks = up_down_streams.downs(self._cur_rank)
for upstream_rank in pp_upstream_ranks: for upstream_rank in pp_upstream_ranks:
upstream_pp_stage = self._get_pp_stage(upstream_rank) upstream_pp_stage = self._get_pp_stage(upstream_rank)
if upstream_pp_stage < pp_stages - 1: if upstream_pp_stage < self._pp_stages - 1:
upstream_task_id = int(upstream_rank * num_of_functionality + 2) upstream_task_id = int(upstream_rank * num_of_functionality + 2)
send_task_node.add_upstream_task(upstream_task_id) send_task_node.add_upstream_task(upstream_task_id)
print( print(
...@@ -579,7 +559,7 @@ class PipelinePass(PassBase): ...@@ -579,7 +559,7 @@ class PipelinePass(PassBase):
2, 2,
) )
for downstream_rank in pp_downstream_ranks: for downstream_rank in pp_downstream_ranks:
if cur_pp_stage < pp_stages - 1: if self._cur_pp_stage < self._pp_stages - 1:
downstream_task_id = int( downstream_task_id = int(
downstream_rank * num_of_functionality + 2 downstream_rank * num_of_functionality + 2
) )
...@@ -607,7 +587,7 @@ class PipelinePass(PassBase): ...@@ -607,7 +587,7 @@ class PipelinePass(PassBase):
) )
task_id_to_rank = {} task_id_to_rank = {}
for i in range(nrank): for i in range(self._nrank):
for j in range(num_of_functionality): for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = { self._program._pipeline_opt = {
......
...@@ -1462,12 +1462,6 @@ class Executor: ...@@ -1462,12 +1462,6 @@ class Executor:
if "fleet_opt" in program._pipeline_opt: if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program # Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None: if self._fleet_executor is None:
# Temporary manual enable standalone executor for fleet executor,
# delete this code after the FLAGS is removed.
if 'tasks' in program._pipeline_opt["fleet_opt"]:
set_flags(
{"FLAGS_fleet_executor_with_standalone": True}
)
self._fleet_executor = _prepare_fleet_executor() self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor( return self._run_using_fleet_executor(
program=program, program=program,
......
...@@ -58,6 +58,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -58,6 +58,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_generation_pipeline MODULES
test_pass_generation_pipeline)
set_tests_properties(test_pass_generation_pipeline
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
# End of unittests WITH multi cards and timeout # End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout # NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
...@@ -98,14 +98,11 @@ class GEN(nn.Layer): ...@@ -98,14 +98,11 @@ class GEN(nn.Layer):
while cur_step < total_step: while cur_step < total_step:
out = self.mlp(model_kwargs['input']) out = self.mlp(model_kwargs['input'])
model_kwargs['res'] = out
paddle.increment(cur_step) paddle.increment(cur_step)
out_assign = auto.shard_op(paddle.assign, _g_mesh)(out)
auto.shard_op(paddle.assign, _g_mesh)(model_kwargs['input'], out) model_kwargs['output'] = paddle.assign(out_assign)
return model_kwargs['output'], cur_step
output = F.gelu(model_kwargs['input'], approximate=True)
return output, cur_step
def get_model(): def get_model():
...@@ -125,35 +122,33 @@ class TestGenerationPipeline(unittest.TestCase): ...@@ -125,35 +122,33 @@ class TestGenerationPipeline(unittest.TestCase):
pipeline = strategy.pipeline pipeline = strategy.pipeline
pipeline.enable = True pipeline.enable = True
pipeline.schedule_mode = "stream" pipeline.schedule_mode = "stream"
pipeline.generation_batch_size = 4 pipeline.generation_batch_size = 2 # equal to the number of pp stages
pipeline.accumulate_steps = 4 pipeline.accumulate_steps = 20 # the number of all sample
engine = auto.Engine(model, strategy=strategy) engine = auto.Engine(model, strategy=strategy)
engine.prepare( engine.prepare(
inputs_spec=paddle.static.InputSpec( inputs_spec=paddle.static.InputSpec(
shape=[2, 1024], name='input', dtype='float32' shape=[20, 1024], name='input', dtype='float32'
), ),
labels_spec=paddle.static.InputSpec( labels_spec=paddle.static.InputSpec(
shape=[2, 1024], name='label', dtype='float32' shape=[20, 1024], name='label', dtype='float32'
), ),
mode="eval", mode="eval",
) )
train_data = MyDataset(50 * 2) train_data = MyDataset(20)
train_dataloader = engine._prepare_dataloader_from_generator( train_dataloader = engine._prepare_dataloader_from_generator(
dataset=train_data, dataset=train_data,
capacity=70, capacity=20,
iterable=False, iterable=False,
batch_size=2, batch_size=1, # micro_batch_size
epochs=1, epochs=1,
steps_per_epoch=100,
) )
engine._prepare_reader()
fleet_opt = engine.main_program._pipeline_opt['fleet_opt'] fleet_opt = engine.main_program._pipeline_opt['fleet_opt']
assert len(fleet_opt['tasks']) == 5 assert len(fleet_opt['tasks']) == 5
assert fleet_opt['inference_generation'] assert fleet_opt['inference_generation']
assert fleet_opt['num_micro_batches'] == 4 assert fleet_opt['num_micro_batches'] == 20
num_task_in_rank = 5 num_task_in_rank = 5
for idx, (task_id, rank_id) in enumerate( for idx, (task_id, rank_id) in enumerate(
fleet_opt['task_id_to_rank'].items() fleet_opt['task_id_to_rank'].items()
...@@ -170,7 +165,6 @@ class TestGenerationPipeline(unittest.TestCase): ...@@ -170,7 +165,6 @@ class TestGenerationPipeline(unittest.TestCase):
except paddle.fluid.core.EOFException: except paddle.fluid.core.EOFException:
print("test done") print("test done")
train_dataloader._inner_dataloader.reset() train_dataloader._inner_dataloader.reset()
train_dataloader._inner_dataloader.start()
if __name__ == "__main__": if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册