未验证 提交 c47853f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)

上级 6f3c9643
......@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished.
// The local vars will be free while running op in sub block.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
......@@ -129,27 +132,12 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
}
return while_block_vars;
}
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
return while_block_vars;
}
} // namespace
......@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars =
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
......@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release();
}
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
......@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
......
......@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
#########################################
# auto tuning configuration
# #########################################
#########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
......@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
......@@ -17,12 +17,18 @@ import numpy as np
import paddle
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
from paddle.fluid.dataloader.batch_sampler import (
_InfiniteIterableSampler,
DistributedBatchSampler,
)
from paddle.fluid.dataloader.dataloader_iter import (
_DatasetKind,
default_collate_fn,
default_convert_fn,
)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __iter__(self):
raise NotImplementedError
......@@ -43,8 +49,8 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(self,
def __init__(
self,
dataset,
feed_list=None,
capacity=None,
......@@ -60,7 +66,9 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset
self.feed_list = feed_list
self.capacity = capacity
......@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
......@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
dataset, batch_size
)
else:
self.batch_sampler = BatchSampler(dataset,
self.batch_sampler = BatchSampler(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
drop_last=drop_last,
)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
......@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_last)
self.dataset_kind,
self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
......@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None
elif self.batch_size is None:
steps_per_epoch = len(self.dataset)
steps_per_epoch = len(self.dataset) // self.acc_steps
else:
steps_per_epoch = len(self.dataset) // self.batch_size
steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except:
raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
......@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self):
def data_generator():
while True:
try:
indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices)
if batch is None: break
if batch is None:
break
except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_last)
self.dataset_kind,
self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
break
partial_data = []
......@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue
batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
assert (
batch_size % self.dp_world_sizes[i] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self.dp_world_sizes[i])
)
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
np.split(array, self.dp_world_sizes[i])[
self.dp_ranks[i]
]
)
yield partial_data
......@@ -194,15 +220,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
drop_last=self.drop_last,
)
dataloader.set_batch_generator(data_generator, self.places)
return dataloader
class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(self,
def __init__(
self,
dataset,
feed_list=None,
places=None,
......@@ -220,7 +247,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
data_parallel_rank=[],
):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
......@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self.dataset,
self.batch_size,
self.dp_world_sizes[0],
self.dp_ranks[0],
self.shuffle,
self.drop_last,
)
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
......@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
worker_init_fn=self.worker_init_fn,
)
self.data = (x for x in dataloader)
return dataloader
......@@ -18,6 +18,7 @@ import errno
import pickle
import warnings
import logging
import collections
import numpy as np
import paddle
......@@ -53,16 +54,13 @@ def _process_path(path):
class DistributedSaver:
def __init__(self):
self._logger = get_logger(logging.INFO)
def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
k: np.array(v) for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
......@@ -108,8 +106,9 @@ class DistributedSaver:
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
if check_filename(
'{}(.*)_dist(.*).{}'.format(filename, suffix), file
):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
......@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname,
"pdopt") if load_optimizer else {}
opt_state_dict = (
_load_state(filename, dirname, "pdopt") if load_optimizer else {}
)
state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr
dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list))
"Load distributed attribute file: {}".format(dist_attr_file_list)
)
dist_attr = {}
for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f:
......@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs += op.input_arg_names
used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs))
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs))
# delete duplicated elements and keep order
feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars = [
global_block.vars[name] for name in dist_feed_vars_names
dist_feed_vars_names = [
var_name for var_name in feed_vars_names if var_name in used_inputs
]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names
]
......@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(dist_path,
paddle.static.save_inference_model(
dist_path,
dist_feed_vars,
dist_fetch_vars,
exe,
program=dist_main_prog)
program=dist_main_prog,
)
def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv')
......
......@@ -225,6 +225,11 @@ class Engine:
self._planned_mode = None
self._dygraph_mode = False
self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None
......@@ -388,7 +393,12 @@ class Engine:
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
......@@ -438,8 +448,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
else:
user_fetches = []
fetch_names = []
fetch_indices = []
......@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
......@@ -903,6 +911,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -931,7 +940,7 @@ class Engine:
save_dir=save_dir,
verbose=verbose,
metrics=self._metrics_name(),
acc_step=self._k_steps,
acc_step=self._acc_steps,
)
cbks.on_begin('train')
......@@ -965,7 +974,7 @@ class Engine:
val_logs = self.evaluate(
valid_data,
valid_sample_split,
batch_size,
batch_size * self._acc_steps,
valid_steps,
log_freq,
collate_fn,
......@@ -1046,6 +1055,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1152,6 +1162,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1214,6 +1225,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1256,6 +1268,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1440,14 +1445,6 @@ class Engine:
collate_fn=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1487,6 +1484,9 @@ class Engine:
split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
)
self._prepare_reader(feed_list)
return dataloader
......@@ -1498,9 +1498,18 @@ class Engine:
)
self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs):
specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec):
......@@ -1513,14 +1522,14 @@ class Engine:
i, spec
)
)
if self._k_steps > 1:
if self._acc_steps > 1:
shape = list(spec.shape)
assert (
shape[0] % self._k_steps == 0
shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps
spec.shape[0], self._acc_steps
)
shape[0] //= self._k_steps
shape[0] //= self._acc_steps
spec.shape = shape
return specs or []
......
......@@ -297,12 +297,14 @@ class Parallelizer:
if self._strategy is None:
return
# data parallel optimization
config = {}
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable:
......
......@@ -422,11 +422,11 @@ class Inserter:
)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
slice_op = block._insert_op(
attrs = {"in_place": False, "op_role": op_role}
assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
# use split once
......@@ -1217,6 +1217,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1]
):
overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped
def is_unshard(self, dims_mapping):
......@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else:
op_output_dims_mapping = dist_attr[1]
if all(
......@@ -1585,10 +1595,10 @@ class Resharder:
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
min_comm_group = copy.deepcopy(group)
all_partition_index_list_copied = copy.deepcopy(
all_partition_index_list
)
target_partition_index = Resharder.compute_partition_index(
process,
complete_shape,
......@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape,
target_process_group,
)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
for _process in group:
source_partition_index = (
Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc(
starts=slice_starts,
ends=slice_ends,
......@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
group=min_comm_group,
shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
partition_index_list=all_partition_index_list_copied
),
slice_op_desc,
]
if len(group) > 1
if len(min_comm_group) > 1
else [slice_op_desc]
)
......
......@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super(DatasetConfig, self).__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super(DPOptimizationConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
......@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
......@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad",
"relu_grad",
"exp_grad",
]
forward_list = [
"reshape2",
......@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
"exp",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......@@ -1320,6 +1322,11 @@ def is_forward_op(op):
)
def is_fillconst_op_for_micro_batch(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched))
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
......
......@@ -18,15 +18,31 @@ import numpy as np
import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
is_backward_op,
ring_id_to_process_group,
find_higher_order_backward_op,
)
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum',
'merge_momentum'
'lars_momentum',
'sparse_momentum',
'dgc_momentum',
'momentum',
'merge_momentum',
]
# a heuristic number
......@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self.set_attr("dist_context", None)
self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False)
self.set_attr("fuse_all_reduce_ops", False)
self.set_attr("fuse_grad_size_in_MB", 32)
self.set_attr("overlap_comm_cacl", False)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict()
......@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if (not isinstance(self.get_attr("global_rank"),
int)) or self.get_attr("global_rank") < 0:
if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
"global_rank"
) < 0:
return False
return True
......@@ -80,12 +100,17 @@ class DataParallelOptimizationPass(PassBase):
self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding")
overlap_comm_cacl = self.get_attr("overlap_comm_cacl")
fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops")
with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()
if self.is_data_parallel_applied():
if overlap_comm_cacl:
self._prune_grad_scaling()
self._calc_comm_overlap()
if fuse_all_reduce_ops:
grad_group = self._fuse_allreduce()
# self.summary(grad_group)
......@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op))
assert (
group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
self._grad_name_to_group_map[grad_name] = group
......@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
elif (
is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
self._support_rescale_grad = True
not_synchronized_grads = []
for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name)
assert len(
assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads)
)
def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
......@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self):
return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree())
self._support_rescale_grad or self._all_dp_groups_same_degree()
)
def _all_dp_groups_same_degree(self):
return len(
set([
return (
len(
set(
[
len(group.ranks)
for group in self._group_to_grad_name_map.keys()
])) == 1
]
)
)
== 1
)
def _scale_backward_initial_grad(self):
......@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
assert op.type == 'fill_constant', \
"loss_grad_op must be fill_constant op, " \
assert op.type == 'fill_constant', (
"loss_grad_op must be fill_constant op, "
"but this op is {}".format(op.type)
)
assert op.has_attr('value')
loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree
......@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
if (
is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr(
'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op))
assert len(
op.input("Grad")
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op))
str(op)
)
assert (
len(op.input("Grad")) == 1
), "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op)
)
grad_name = op.input("Grad")[0]
dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks))
list(self._grad_name_to_group_map[grad_name].ranks)
)
scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys(
)), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads)
assert scaled_grads == set(
self._grad_name_to_group_map.keys()
), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream
......@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx,
block._insert_op_without_sync(
idx,
type='c_wait_compute',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
)
block._sync_with_cpp()
......@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad
else:
for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items(
):
for (
ring_id,
unsync_grad_names,
) in ring_id_to_un_sync_grad_map.items():
if input_var_name in unsync_grad_names:
# need to sync before op_i
if i in op_idx_to_sync_ring_id_map:
......@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i,
block._insert_op_without_sync(
i,
type='c_wait_comm',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
)
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
......@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
group.coalesce_var = block.create_var(
name=unique_name.generate('coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
stop_gradient=True,
)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
assert (
scale_op.type == 'scale'
), "should found scale op but found {}".format(str(scale_op))
scale_op._rename_input(
scale_op.input_arg_names[0], group.coalesce_var.name
)
scale_op._rename_output(
scale_op.output_arg_names[0], group.coalesce_var.name
)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
assert (
allreduce_op.type == 'c_allreduce_sum'
), "should found c_allreduce_sum op but found {}".format(
str(allreduce_op)
)
allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
remove_op_indices = (
group.remove_wait_op_indices
+ group.remove_allreduce_op_indices
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
assert (
block.ops[idx].type in remove_op_types
), "Unexception: try to remove op {}".format(
str(block.ops[idx].type())
)
block._remove_op(idx)
# insert coalecse op
......@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
block._insert_op_without_sync(
group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_shapes": concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
OP_ROLE_KEY: OpRole.Backward,
},
)
block._sync_with_cpp()
# TODO update dist attr
......@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
......@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
"origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
len(self._grad_name_to_group_map.keys()), len(grad_groups)
)
)
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
i, [grad.name for grad in group.gradients]
)
)
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
individual_grads = set(self._grad_name_to_group_map.keys()) - set(
fused_grads
)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
......@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
assert (
grad_var.name in grad_op.output_arg_names
), "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op)
)
self.coalesce_op_idx = grad_op_idx
def finalize(self):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable == False
# pp2 1f1b merge training
engine_1f1b = self.get_engine(True)
history = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable == True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history.history["loss"])
losses_1f1b = np.array(history.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
......@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
......
......@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp"
elif strategy == "dp":
modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else:
raise ValueError("Only support serial, mp2 and dp2.")
......@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
......@@ -19,7 +19,7 @@ import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
......@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
......@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history = dp_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert dp_engine._strategy.gradient_merge.enable == False
dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training
......@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history = gm_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert gm_engine._strategy.gradient_merge.enable == True
gm_losses = np.array(history.history["loss"])
# avg_loss = 0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册