未验证 提交 81c13b86 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] update pipeline pass for while control_flow (#54224)

* [AutoParallel] update while control_flow with pipeline

* update process group instantiate

* fix micro_bsz for reshard

* update api for micro batch size

* add strategy for dp optimization
上级 b0e86d55
......@@ -216,8 +216,8 @@ bool MessageBus::SendInterRank(int64_t dst_rank,
brpc::Channel channel;
brpc::ChannelOptions options;
options.protocol = "baidu_std";
options.connect_timeout_ms = 1000;
options.timeout_ms = 1000;
options.connect_timeout_ms = 100000;
options.timeout_ms = 100000;
options.max_retry = 5;
PADDLE_ENFORCE_EQ(
channel.Init(dst_addr_for_brpc, &options),
......
......@@ -124,9 +124,9 @@ set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
set_field_default_config(QAT, "onnx_format", True)
# #########################################
#########################################
# auto tuning configuration
# #########################################
#########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "profile_start_step", 1)
......@@ -147,3 +147,12 @@ set_field_default_config(DATASET, "num_shards", 1)
FUSED_PASSES = "fused_passes"
set_field_default_config(FUSED_PASSES, "enable", False)
set_field_default_config(FUSED_PASSES, "fused_passes_list", [])
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
......@@ -58,6 +58,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset
self.feed_list = feed_list
......@@ -77,6 +78,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
......@@ -141,9 +143,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None
elif self.batch_size is None:
steps_per_epoch = len(self.dataset)
steps_per_epoch = len(self.dataset) // self.acc_steps
else:
steps_per_epoch = len(self.dataset) // self.batch_size
steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except:
raise ValueError(
"Please set `steps_per_epoch` or implement `__len__` method in dataset class."
......
......@@ -235,6 +235,11 @@ class Engine:
self._planned_mode = None
self._dygraph_mode = False
self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None
......@@ -400,7 +405,12 @@ class Engine:
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
......@@ -450,8 +460,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
else:
user_fetches = []
fetch_names = []
fetch_indices = []
......@@ -478,7 +486,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
......@@ -931,6 +939,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -940,7 +949,7 @@ class Engine:
dataset=train_data,
capacity=70,
iterable=False,
batch_size=batch_size,
batch_size=micro_batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn,
......@@ -951,7 +960,7 @@ class Engine:
cbks = config_callbacks(
callbacks,
engine=self,
batch_size=batch_size,
batch_size=micro_batch_size,
epochs=epochs,
steps=train_dataloader._steps,
log_freq=log_freq,
......@@ -959,7 +968,7 @@ class Engine:
save_dir=save_dir,
verbose=verbose,
metrics=self._metrics_name(),
acc_step=self._k_steps,
acc_step=self._acc_steps,
)
cbks.on_begin('train')
......@@ -977,7 +986,7 @@ class Engine:
)
except core.EOFException:
break
lr = auto_utils.get_lr(self._optimizer)
lr = auto_utils.get_lr(self.optimizer)
logs = self._prepare_logger(
outs,
epoch,
......@@ -1074,6 +1083,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1083,7 +1093,7 @@ class Engine:
dataset=valid_data,
capacity=70,
iterable=False,
batch_size=batch_size,
batch_size=micro_batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn,
)
......@@ -1093,7 +1103,7 @@ class Engine:
cbks = config_callbacks(
callbacks,
engine=self,
batch_size=batch_size,
batch_size=micro_batch_size,
log_freq=log_freq,
verbose=verbose,
metrics=self._metrics_name(),
......@@ -1180,6 +1190,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1189,7 +1200,7 @@ class Engine:
dataset=test_data,
capacity=70,
iterable=False,
batch_size=batch_size,
batch_size=micro_batch_size,
steps_per_epoch=steps,
collate_fn=collate_fn,
)
......@@ -1242,6 +1253,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1250,7 +1262,7 @@ class Engine:
dataloader = self._prepare_dataloader(
dataset,
return_list=False,
batch_size=batch_size,
batch_size=micro_batch_size,
shuffle=shuffle,
drop_last=drop_last,
collate_fn=collate_fn,
......@@ -1284,6 +1296,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
micro_batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1297,7 +1310,7 @@ class Engine:
return_list=False,
use_multiprocess=use_multiprocess,
drop_last=drop_last,
batch_size=batch_size,
batch_size=micro_batch_size,
epochs=epochs,
steps_per_epoch=steps_per_epoch,
collate_fn=collate_fn,
......@@ -1399,14 +1412,6 @@ class Engine:
steps_per_epoch=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1468,14 +1473,6 @@ class Engine:
collate_fn=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1515,6 +1512,9 @@ class Engine:
split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
)
self._prepare_reader(feed_list)
return dataloader
......@@ -1526,9 +1526,18 @@ class Engine:
)
self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs):
specs = auto_utils.to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec):
......@@ -1541,14 +1550,14 @@ class Engine:
i, spec
)
)
if self._k_steps > 1:
if self._acc_steps > 1:
shape = list(spec.shape)
assert (
shape[0] % self._k_steps == 0
shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps
spec.shape[0], self._acc_steps
)
shape[0] //= self._k_steps
shape[0] //= self._acc_steps
spec.shape = shape
return specs or []
......@@ -1579,7 +1588,6 @@ class Engine:
mode in self._dist_contexts
), f"{mode} model is not ready, please call `prepare()` first."
self.to_mode(mode)
self._optimizer = self._dist_contexts[mode]._serial_optimizer
def to_mode(self, mode):
assert mode in [
......
......@@ -499,21 +499,12 @@ class AutoParallelizer:
break
if is_pipeline:
with paddle.static.program_guard(dist_main_prog):
paddle.distributed.barrier(get_process_group(0))
paddle.distributed.barrier()
# Traverse different rank programs and traverse each op of them,
# instantiate communication by process_mapping.
all_process_groups = get_all_process_groups()
for process_group in all_process_groups:
if len(_g_process_group_map) > 0:
tmp = paddle.to_tensor([1], dtype="int32")
paddle.distributed.all_reduce(
tmp, sync_op=True, group=_g_process_group_map[0]
)
paddle.device.cuda.synchronize()
if rank not in process_group.ranks:
continue
process_group.instantiate()
# Copy distributed info to the default context
......
......@@ -105,13 +105,6 @@ class Parallelizer:
time.time() - time0, self._mode
)
)
# Do reshard process
time0 = time.time()
micro_bsz = (
1
if not self._strategy.pipeline.enable
else self._strategy.pipeline.micro_batch_size
)
set_grad_var_shape(dist_main_prog, self._dist_context)
resharder = Resharder(
dist_main_prog,
......@@ -119,7 +112,6 @@ class Parallelizer:
rank,
self._dist_context,
dist_params_grads,
micro_bsz,
)
resharder.reshard()
self._logger.debug(
......@@ -169,13 +161,19 @@ class Parallelizer:
)
)
time0 = time.time()
# Do reshard process
micro_bsz = (
1
if not self._strategy.pipeline.enable
else self._strategy.pipeline.micro_batch_size
)
resharder = Resharder(
dist_main_prog,
dist_startup_prog,
rank,
self._dist_context,
[],
1,
micro_bsz,
)
resharder.reshard()
self._logger.debug(
......@@ -305,11 +303,14 @@ class Parallelizer:
return
# data parallel optimization
config = {}
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable:
......
......@@ -49,6 +49,12 @@ def clear_all_process_groups():
_g_process_group_map[0] = ProcessGroup(0, [])
def remove_process_group(ring_id):
global _g_process_group_map
if ring_id in _g_process_group_map:
_g_process_group_map.pop(ring_id)
def new_process_group(ranks, group_id=None, force_new_group=False):
global _g_process_group_map
......
......@@ -132,6 +132,12 @@ class FusedPassesConfig(BaseConfig):
super().__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super().__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the parallelization and optimization behaviors.
......@@ -206,3 +212,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.FUSED_PASSES, None)
self.fused_passes = FusedPassesConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
......@@ -122,9 +122,9 @@ def all_reduce(
tensor, op, group, sync_op, use_calc_stream
)
else:
# assert (
# group is None
# ), "Group can not be used in static graph mode for now."
assert (
group is None
), "Group can not be used in static graph mode for now."
return _all_reduce_in_static_mode(
tensor, op, group, sync_op, use_calc_stream
)
......@@ -14,6 +14,9 @@
import os
from paddle.distributed.auto_parallel.static.process_group import (
remove_process_group,
)
from paddle.distributed.fleet.fleet_executor_utils import TaskNode
from paddle.fluid import core
from paddle.fluid.framework import Parameter, Program
......@@ -50,6 +53,14 @@ class PipelinePass(PassBase):
self._gen_bsz = self.get_attr("generation_batch_size")
self._program = main_program
self._cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
self._nrank = len(trainer_endpoints)
# compute current pp stage
self._pp_stages = len(self._dist_context.process_meshes)
self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B":
raise NotImplementedError("1F1B has not been implemented")
elif self._mode == "F-Then-B":
......@@ -70,6 +81,10 @@ class PipelinePass(PassBase):
send_vars = []
# insert sync ops
for index, op in enumerate(list(block.ops)):
# NOTE: pipeline might hang when dynamic_shape is True
if op.type in ['send_v2', 'recv_v2']:
op._set_attr("dynamic_shape", False)
# set send op on comm stream
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
......@@ -176,21 +191,13 @@ class PipelinePass(PassBase):
def _get_pp_stage(self, rank):
pp_idx = None
for idx, process_mesh in enumerate(self._dist_context.process_meshes):
if rank in process_mesh.processes:
if rank in process_mesh.process_ids:
pp_idx = idx
break
return pp_idx
def _task_stream(self):
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
nrank = len(trainer_endpoints)
num_of_functionality = 5
# compute current pp stage
pp_stages = len(self._dist_context.process_meshes)
cur_pp_stage = self._get_pp_stage(cur_rank)
start_prog = Program()
cond_prog = Program()
end_prog = Program()
......@@ -198,6 +205,7 @@ class PipelinePass(PassBase):
recv_prog = Program()
cond_var_name = None
# record the varnames related to the while cond vars and communicate by nccl
send_vars_name = set()
recv_vars_name = {}
for ib, src_block in enumerate(self._program.blocks):
......@@ -222,38 +230,25 @@ class PipelinePass(PassBase):
src_block, end_block, op, force_create=True
)
elif ib == 1:
# NOTE: for ernie generation
# The while block will be split to two separate blocks:
# while{transformer_layer(send_block), generation_and_broadcast(recv_block)}
# The send_block:
# include all ops about tansformer layers computation
# execlude the nccl op about the while cond var(the last pp stage).
# The recv_block:
# include all computation ops about generation and while cond var
# execlude the nccl op about the while cond var(the pp stages exclude the last one)
# the nccl op about the while cond var:
# put these varnames in the recv task node and do communication with brpc instead of nccl.
send_block = send_prog.block(0)
recv_block = recv_prog.block(0)
is_after_send_op = False
is_after_recv_op = False
for op in src_block.ops:
for i, op in enumerate(src_block.ops):
if op.type == "send_v2" and not is_after_send_op:
is_after_send_op = True
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
continue
if (
op.type not in ["recv_2", "assign"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
):
if (
len(op.desc.input_arg_names()) > 0
and "@RESHARD"
not in op.desc.input_arg_names()[0]
):
send_vars_name.add(
op.desc.input_arg_names()[0]
)
continue
if op.type == "send_v2":
continue
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if (
is_after_send_op
......@@ -261,45 +256,21 @@ class PipelinePass(PassBase):
and op.type == "recv_v2"
):
is_after_recv_op = True
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
old_var_name = var_name[:index]
else:
old_var_name = var_name
recv_vars_name[var_name] = old_var_name
if not src_block._find_var_recursive(old_var_name):
src_var = src_block._var_recursive(var_name)
recv_block.create_var(
type=src_var.type,
name=old_var_name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
continue
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
if not is_after_send_op or not is_after_recv_op:
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
if self._cur_pp_stage == self._pp_stages - 1:
# NOTE: the c_sync_calc_stream about c_allgather cannot be removed
if (
op.type == "c_sync_calc_stream"
and src_block.ops[i + 1].type == "send_v2"
):
continue
if op.type == "nop":
continue
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if (
op.type not in ["recv_2", "assign"]
op.type
not in ["recv_2", "assign", "c_allgather"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
......@@ -312,19 +283,27 @@ class PipelinePass(PassBase):
send_vars_name.add(
op.desc.input_arg_names()[0]
)
if op.type == "send_v2":
remove_process_group(op.attr("ring_id"))
continue
if op.type == "send_v2":
remove_process_group(op.attr("ring_id"))
continue
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if is_after_send_op and is_after_recv_op:
# HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
if op.type in ["send_v2", "recv_v2"]:
remove_process_group(op.attr("ring_id"))
# remove the suffix of "@RESHARD"
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
......@@ -356,6 +335,7 @@ class PipelinePass(PassBase):
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
else:
raise Exception("Only support generation condition.")
......@@ -397,52 +377,52 @@ class PipelinePass(PassBase):
vars_to_shape = recv_task_node_var_shape
start_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Start",
task_id=int(cur_rank * num_of_functionality + 0),
task_id=int(self._cur_rank * num_of_functionality + 0),
program=start_prog,
lazy_initialize=True,
)
cond_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Cond",
task_id=int(cur_rank * num_of_functionality + 1),
task_id=int(self._cur_rank * num_of_functionality + 1),
program=cond_prog,
cond_var_name=cond_var_name,
lazy_initialize=True,
)
send_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 2),
task_id=int(self._cur_rank * num_of_functionality + 2),
program=send_prog,
lazy_initialize=True,
)
recv_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 3),
task_id=int(self._cur_rank * num_of_functionality + 3),
program=recv_prog,
lazy_initialize=True,
vars_to_dtype=vars_to_dtype,
vars_to_shape=vars_to_shape,
)
end_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 4),
task_id=int(self._cur_rank * num_of_functionality + 4),
program=end_prog,
lazy_initialize=True,
)
# add dependencies for task nodes intra stage
inf = -1
pp_buff_size = int(pp_stages - cur_pp_stage)
pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
start_task_node.add_downstream_task(
cond_task_node.task_id(), self._gen_bsz
)
......@@ -551,12 +531,12 @@ class PipelinePass(PassBase):
# add dependencies for task nodes inter stage
# get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams
pp_upstream_ranks = up_down_streams.ups(cur_rank)
pp_downstream_ranks = up_down_streams.downs(cur_rank)
pp_upstream_ranks = up_down_streams.ups(self._cur_rank)
pp_downstream_ranks = up_down_streams.downs(self._cur_rank)
for upstream_rank in pp_upstream_ranks:
upstream_pp_stage = self._get_pp_stage(upstream_rank)
if upstream_pp_stage < pp_stages - 1:
if upstream_pp_stage < self._pp_stages - 1:
upstream_task_id = int(upstream_rank * num_of_functionality + 2)
send_task_node.add_upstream_task(upstream_task_id)
print(
......@@ -579,7 +559,7 @@ class PipelinePass(PassBase):
2,
)
for downstream_rank in pp_downstream_ranks:
if cur_pp_stage < pp_stages - 1:
if self._cur_pp_stage < self._pp_stages - 1:
downstream_task_id = int(
downstream_rank * num_of_functionality + 2
)
......@@ -607,7 +587,7 @@ class PipelinePass(PassBase):
)
task_id_to_rank = {}
for i in range(nrank):
for i in range(self._nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = {
......
......@@ -1462,12 +1462,6 @@ class Executor:
if "fleet_opt" in program._pipeline_opt:
# Move prepare here for port conflict with nccl in startup program
if self._fleet_executor is None:
# Temporary manual enable standalone executor for fleet executor,
# delete this code after the FLAGS is removed.
if 'tasks' in program._pipeline_opt["fleet_opt"]:
set_flags(
{"FLAGS_fleet_executor_with_standalone": True}
)
self._fleet_executor = _prepare_fleet_executor()
return self._run_using_fleet_executor(
program=program,
......
......@@ -58,6 +58,10 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_generation_pipeline MODULES
test_pass_generation_pipeline)
set_tests_properties(test_pass_generation_pipeline
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
# End of unittests WITH multi cards and timeout
# NOTE(zyl): unittests WITH multi cards and WITHOUT timeout
......
......@@ -98,14 +98,11 @@ class GEN(nn.Layer):
while cur_step < total_step:
out = self.mlp(model_kwargs['input'])
model_kwargs['res'] = out
paddle.increment(cur_step)
out_assign = auto.shard_op(paddle.assign, _g_mesh)(out)
auto.shard_op(paddle.assign, _g_mesh)(model_kwargs['input'], out)
output = F.gelu(model_kwargs['input'], approximate=True)
return output, cur_step
model_kwargs['output'] = paddle.assign(out_assign)
return model_kwargs['output'], cur_step
def get_model():
......@@ -125,35 +122,33 @@ class TestGenerationPipeline(unittest.TestCase):
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "stream"
pipeline.generation_batch_size = 4
pipeline.accumulate_steps = 4
pipeline.generation_batch_size = 2 # equal to the number of pp stages
pipeline.accumulate_steps = 20 # the number of all sample
engine = auto.Engine(model, strategy=strategy)
engine.prepare(
inputs_spec=paddle.static.InputSpec(
shape=[2, 1024], name='input', dtype='float32'
shape=[20, 1024], name='input', dtype='float32'
),
labels_spec=paddle.static.InputSpec(
shape=[2, 1024], name='label', dtype='float32'
shape=[20, 1024], name='label', dtype='float32'
),
mode="eval",
)
train_data = MyDataset(50 * 2)
train_data = MyDataset(20)
train_dataloader = engine._prepare_dataloader_from_generator(
dataset=train_data,
capacity=70,
capacity=20,
iterable=False,
batch_size=2,
batch_size=1, # micro_batch_size
epochs=1,
steps_per_epoch=100,
)
engine._prepare_reader()
fleet_opt = engine.main_program._pipeline_opt['fleet_opt']
assert len(fleet_opt['tasks']) == 5
assert fleet_opt['inference_generation']
assert fleet_opt['num_micro_batches'] == 4
assert fleet_opt['num_micro_batches'] == 20
num_task_in_rank = 5
for idx, (task_id, rank_id) in enumerate(
fleet_opt['task_id_to_rank'].items()
......@@ -170,7 +165,6 @@ class TestGenerationPipeline(unittest.TestCase):
except paddle.fluid.core.EOFException:
print("test done")
train_dataloader._inner_dataloader.reset()
train_dataloader._inner_dataloader.start()
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册