未验证 提交 c47853f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)

上级 6f3c9643
...@@ -110,12 +110,15 @@ void PreventVarsDelete( ...@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile( std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) { const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze // NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the // the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op. // whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished. // The vars in while block should not be free until the while op is finished.
// The local vars will be free while running op in sub block. // In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor. // The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
...@@ -129,27 +132,12 @@ std::vector<std::string> GetUnusedVarsAfterWhile( ...@@ -129,27 +132,12 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) { for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name); while_block_vars.emplace_back(var_name);
} }
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
} }
} }
return while_block_vars;
}
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
} }
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {}); return while_block_vars;
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
} }
} // namespace } // namespace
...@@ -174,13 +162,8 @@ void FleetExecutor::Init( ...@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) { for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") { if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = while_block_vars = GetUnusedVarsAfterWhile(
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); program_desc, task_node, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
VLOG(3) << "Vars will be gced after while op"; VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) { for (auto var : while_block_vars) {
VLOG(3) << var; VLOG(3) << var;
...@@ -210,9 +193,6 @@ void FleetExecutor::Init( ...@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release(); unique_op.release();
} }
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -223,8 +203,6 @@ void FleetExecutor::Init( ...@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars); task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
......
...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8) ...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "algo", None)
# ######################################### #########################################
# auto tuning configuration # auto tuning configuration
# ######################################### #########################################
TUNING = "tuning" TUNING = "tuning"
set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1) set_field_default_config(TUNING, "batch_size", 1)
...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True) ...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET = "dataset" DATASET = "dataset"
set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1) set_field_default_config(DATASET, "num_shards", 1)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
...@@ -17,12 +17,18 @@ import numpy as np ...@@ -17,12 +17,18 @@ import numpy as np
import paddle import paddle
from paddle.io import BatchSampler, IterableDataset from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler from paddle.fluid.dataloader.batch_sampler import (
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn _InfiniteIterableSampler,
DistributedBatchSampler,
)
from paddle.fluid.dataloader.dataloader_iter import (
_DatasetKind,
default_collate_fn,
default_convert_fn,
)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta): class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
...@@ -43,8 +49,8 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta): ...@@ -43,8 +49,8 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
capacity=None, capacity=None,
...@@ -60,7 +66,9 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -60,7 +66,9 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
collate_fn=None, collate_fn=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.capacity = capacity self.capacity = capacity
...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER self.dataset_kind = _DatasetKind.ITER
...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else: else:
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler( self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size) dataset, batch_size
)
else: else:
self.batch_sampler = BatchSampler(dataset, self.batch_sampler = BatchSampler(
dataset,
batch_size=batch_size, batch_size=batch_size,
shuffle=False, shuffle=False,
drop_last=drop_last) drop_last=drop_last,
)
self.auto_collate_batch = self.batch_sampler is not None self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler) self.sampler_iter = iter(self.index_sampler)
...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind,
self.collate_fn, self.drop_last) self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None steps_per_epoch = None
elif self.batch_size is None: elif self.batch_size is None:
steps_per_epoch = len(self.dataset) steps_per_epoch = len(self.dataset) // self.acc_steps
else: else:
steps_per_epoch = len(self.dataset) // self.batch_size steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except: except:
raise ValueError( raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class." "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return _InfiniteIterableSampler(self.dataset, 1) return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def data_generator():
while True: while True:
try: try:
indices = next(self.sampler_iter) indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices) batch = self.dataset_fetcher.fetch(indices)
if batch is None: break if batch is None:
break
except StopIteration: except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind,
self.auto_collate_batch, self.collate_fn, self.dataset,
self.drop_last) self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
break break
partial_data = [] partial_data = []
...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue continue
batch_size = array.shape[0] batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \ assert (
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i])) batch_size % self.dp_world_sizes[i] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self.dp_world_sizes[i])
)
partial_data.append( partial_data.append(
np.split(array, np.split(array, self.dp_world_sizes[i])[
self.dp_world_sizes[i])[self.dp_ranks[i]]) self.dp_ranks[i]
]
)
yield partial_data yield partial_data
...@@ -194,15 +220,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -194,15 +220,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable=False, iterable=False,
return_list=self.return_list, return_list=self.return_list,
use_multiprocess=self.use_multiprocess, use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last) drop_last=self.drop_last,
)
dataloader.set_batch_generator(data_generator, self.places) dataloader.set_batch_generator(data_generator, self.places)
return dataloader return dataloader
class DistributedDataLoader(DistributedDataLoaderBase): class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
places=None, places=None,
...@@ -220,7 +247,8 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -220,7 +247,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
steps_per_epoch=None, steps_per_epoch=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.return_list = return_list self.return_list = return_list
...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self.split_data = split_data self.split_data = split_data
# TODO: rank info # TODO: rank info
self.batch_sampler = DistributedBatchSampler( self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0], self.dataset,
self.dp_ranks[0], self.shuffle, self.drop_last) self.batch_size,
self.dp_world_sizes[0],
self.dp_ranks[0],
self.shuffle,
self.drop_last,
)
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self): def __iter__(self):
...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader=self.use_buffer_reader, use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory, use_shared_memory=self.use_shared_memory,
timeout=self.timeout, timeout=self.timeout,
worker_init_fn=self.worker_init_fn) worker_init_fn=self.worker_init_fn,
)
self.data = (x for x in dataloader) self.data = (x for x in dataloader)
return dataloader return dataloader
...@@ -18,6 +18,7 @@ import errno ...@@ -18,6 +18,7 @@ import errno
import pickle import pickle
import warnings import warnings
import logging import logging
import collections
import numpy as np import numpy as np
import paddle import paddle
...@@ -53,16 +54,13 @@ def _process_path(path): ...@@ -53,16 +54,13 @@ def _process_path(path):
class DistributedSaver: class DistributedSaver:
def __init__(self): def __init__(self):
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
def save(self, path, serial_program, dist_main_program, dist_context): def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"): def _save_state(program, path, mode="param"):
state = { state = {
k: np.array(v) k: np.array(v) for k, v in program.state_dict(mode).items()
for k, v in program.state_dict(mode).items()
} }
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(state, f) pickle.dump(state, f)
...@@ -108,8 +106,9 @@ class DistributedSaver: ...@@ -108,8 +106,9 @@ class DistributedSaver:
def _load_file(filename, dirname, suffix="pdparams"): def _load_file(filename, dirname, suffix="pdparams"):
file_list = [] file_list = []
for file in os.listdir(dirname): for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix), if check_filename(
file): '{}(.*)_dist(.*).{}'.format(filename, suffix), file
):
file_list.append(os.path.join(dirname, file)) file_list.append(os.path.join(dirname, file))
file_list.sort() file_list.sort()
return file_list return file_list
...@@ -137,14 +136,16 @@ class DistributedSaver: ...@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt # load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname) param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname, opt_state_dict = (
"pdopt") if load_optimizer else {} _load_state(filename, dirname, "pdopt") if load_optimizer else {}
)
state_dict = dict(param_state_dict, **opt_state_dict) state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr # load path.pdattr
dist_attr_file_list = _load_file(filename, dirname, "pdattr") dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info( self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list)
)
dist_attr = {} dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
...@@ -196,12 +197,24 @@ class DistributedSaver: ...@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs += op.input_arg_names used_inputs += op.input_arg_names
used_outputs += op.output_arg_names used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs)) # delete duplicated elements and keep order
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs)) feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars = [ dist_feed_vars_names = [
global_block.vars[name] for name in dist_feed_vars_names var_name for var_name in feed_vars_names if var_name in used_inputs
] ]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [ dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names global_block.vars[name] for name in dist_fetch_vars_names
] ]
...@@ -209,11 +222,13 @@ class DistributedSaver: ...@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock. # NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id) dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename) dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(dist_path, paddle.static.save_inference_model(
dist_path,
dist_feed_vars, dist_feed_vars,
dist_fetch_vars, dist_fetch_vars,
exe, exe,
program=dist_main_prog) program=dist_main_prog,
)
def _save_rank_mapping(self, dirname): def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv') path = os.path.join(dirname, 'rank_mapping.csv')
......
...@@ -225,6 +225,11 @@ class Engine: ...@@ -225,6 +225,11 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None self.history = None
...@@ -388,7 +393,12 @@ class Engine: ...@@ -388,7 +393,12 @@ class Engine:
if self.main_program._pipeline_opt: if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0] fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program() fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block() fwd_block = fwd_prog.global_block()
...@@ -438,8 +448,6 @@ class Engine: ...@@ -438,8 +448,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -466,7 +474,7 @@ class Engine: ...@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches: for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch) var_name = _to_name_str(usr_fetch)
fetch(var_name) fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
...@@ -903,6 +911,7 @@ class Engine: ...@@ -903,6 +911,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size train_data, train_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -931,7 +940,7 @@ class Engine: ...@@ -931,7 +940,7 @@ class Engine:
save_dir=save_dir, save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), metrics=self._metrics_name(),
acc_step=self._k_steps, acc_step=self._acc_steps,
) )
cbks.on_begin('train') cbks.on_begin('train')
...@@ -965,7 +974,7 @@ class Engine: ...@@ -965,7 +974,7 @@ class Engine:
val_logs = self.evaluate( val_logs = self.evaluate(
valid_data, valid_data,
valid_sample_split, valid_sample_split,
batch_size, batch_size * self._acc_steps,
valid_steps, valid_steps,
log_freq, log_freq,
collate_fn, collate_fn,
...@@ -1046,6 +1055,7 @@ class Engine: ...@@ -1046,6 +1055,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size valid_data, valid_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1152,6 +1162,7 @@ class Engine: ...@@ -1152,6 +1162,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size test_data, test_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1214,6 +1225,7 @@ class Engine: ...@@ -1214,6 +1225,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1256,6 +1268,7 @@ class Engine: ...@@ -1256,6 +1268,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1371,14 +1384,6 @@ class Engine: ...@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch=None, steps_per_epoch=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1440,14 +1445,6 @@ class Engine: ...@@ -1440,14 +1445,6 @@ class Engine:
collate_fn=None, collate_fn=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1487,6 +1484,9 @@ class Engine: ...@@ -1487,6 +1484,9 @@ class Engine:
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks, data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
) )
self._prepare_reader(feed_list) self._prepare_reader(feed_list)
return dataloader return dataloader
...@@ -1498,9 +1498,18 @@ class Engine: ...@@ -1498,9 +1498,18 @@ class Engine:
) )
self._optimization_tuning(self._mode, tune_data, batch_size) self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = to_list(specs) specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec): if not isinstance(spec, InputSpec):
...@@ -1513,14 +1522,14 @@ class Engine: ...@@ -1513,14 +1522,14 @@ class Engine:
i, spec i, spec
) )
) )
if self._k_steps > 1: if self._acc_steps > 1:
shape = list(spec.shape) shape = list(spec.shape)
assert ( assert (
shape[0] % self._k_steps == 0 shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps spec.shape[0], self._acc_steps
) )
shape[0] //= self._k_steps shape[0] //= self._acc_steps
spec.shape = shape spec.shape = shape
return specs or [] return specs or []
......
...@@ -297,12 +297,14 @@ class Parallelizer: ...@@ -297,12 +297,14 @@ class Parallelizer:
if self._strategy is None: if self._strategy is None:
return return
# data parallel optimization if self._strategy.dp_optimization.enable:
config = {} config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context) dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable: if self._strategy.sharding.enable:
......
...@@ -422,11 +422,11 @@ class Inserter: ...@@ -422,11 +422,11 @@ class Inserter:
) )
inputs = {'X': [tensor]} inputs = {'X': [tensor]}
outputs = {"Out": [out]} outputs = {"Out": [out]}
attrs = {"in_place": False} attrs = {"in_place": False, "op_role": op_role}
slice_op = block._insert_op( assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
) )
slice_op._set_attr('op_namescope', "/auto_parallel/reshard") assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use split once # use split once
...@@ -1217,6 +1217,8 @@ class Resharder: ...@@ -1217,6 +1217,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1] shape_x[0] <= shape_y[0] < shape_x[1]
): ):
overlapped = True overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped return overlapped
def is_unshard(self, dims_mapping): def is_unshard(self, dims_mapping):
...@@ -1304,6 +1306,14 @@ class Resharder: ...@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh # judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh: if tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else: else:
op_output_dims_mapping = dist_attr[1] op_output_dims_mapping = dist_attr[1]
if all( if all(
...@@ -1585,10 +1595,10 @@ class Resharder: ...@@ -1585,10 +1595,10 @@ class Resharder:
if i == 0: if i == 0:
all_partition_index_list.append(process_index[j][1]) all_partition_index_list.append(process_index[j][1])
for process in group: for process in group:
# append slice op desc min_comm_group = copy.deepcopy(group)
slice_starts = [] all_partition_index_list_copied = copy.deepcopy(
slice_ends = [] all_partition_index_list
slices_axes = [] )
target_partition_index = Resharder.compute_partition_index( target_partition_index = Resharder.compute_partition_index(
process, process,
complete_shape, complete_shape,
...@@ -1596,12 +1606,56 @@ class Resharder: ...@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape, target_process_shape,
target_process_group, target_process_group,
) )
for idx, item in enumerate(target_partition_index): for _process in group:
slice_starts.append(item[0]) source_partition_index = (
slice_ends.append(item[1]) Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx) slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc( slice_op_desc = SliceOpDesc(
starts=slice_starts, starts=slice_starts,
ends=slice_ends, ends=slice_ends,
...@@ -1616,16 +1670,16 @@ class Resharder: ...@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq[process] = ( op_desc_seq[process] = (
[ [
AllGatherOpDesc( AllGatherOpDesc(
group=group, group=min_comm_group,
shape=allgather_shape, shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool), is_bool=(source_tensor.dtype == paddle.bool),
), ),
ConcatOpDesc( ConcatOpDesc(
partition_index_list=all_partition_index_list partition_index_list=all_partition_index_list_copied
), ),
slice_op_desc, slice_op_desc,
] ]
if len(group) > 1 if len(min_comm_group) > 1
else [slice_op_desc] else [slice_op_desc]
) )
......
...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig): ...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super(DatasetConfig, self).__init__(category, config_dict) super(DatasetConfig, self).__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super(DPOptimizationConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
...@@ -194,3 +200,6 @@ class Strategy(BaseConfig): ...@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.DATASET, None) config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict) self.dataset = DatasetConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad", "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "flatten_contiguous_range_grad",
"relu_grad", "relu_grad",
"exp_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle", "fused_softmax_mask_upper_triangle",
"flatten_contiguous_range", "flatten_contiguous_range",
"relu", "relu",
"exp",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
...@@ -1320,6 +1322,11 @@ def is_forward_op(op): ...@@ -1320,6 +1322,11 @@ def is_forward_op(op):
) )
def is_fillconst_op_for_micro_batch(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched))
def is_backward_op(op): def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and int( return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY] op.all_attrs()[OP_ROLE_KEY]
......
...@@ -18,15 +18,31 @@ import numpy as np ...@@ -18,15 +18,31 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import (
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op OpRole,
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
is_backward_op,
ring_id_to_process_group,
find_higher_order_backward_op,
)
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [ __rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum', 'lars_momentum',
'merge_momentum' 'sparse_momentum',
'dgc_momentum',
'momentum',
'merge_momentum',
] ]
# a heuristic number # a heuristic number
...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False) self.set_attr("use_sharding", False)
self.set_attr("fuse_all_reduce_ops", False)
self.set_attr("fuse_grad_size_in_MB", 32)
self.set_attr("overlap_comm_cacl", False)
# {grad1: group1, grad2: group1, grad3: group2} # {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory # record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict() self._grad_name_to_group_map = OrderedDict()
...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
return False return False
if (not isinstance(self.get_attr("global_rank"), if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
int)) or self.get_attr("global_rank") < 0: "global_rank"
) < 0:
return False return False
return True return True
...@@ -80,12 +100,17 @@ class DataParallelOptimizationPass(PassBase): ...@@ -80,12 +100,17 @@ class DataParallelOptimizationPass(PassBase):
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding") self.use_sharding = self.get_attr("use_sharding")
overlap_comm_cacl = self.get_attr("overlap_comm_cacl")
fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops")
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
if self.is_data_parallel_applied(): if self.is_data_parallel_applied():
if overlap_comm_cacl:
self._prune_grad_scaling() self._prune_grad_scaling()
self._calc_comm_overlap() self._calc_comm_overlap()
if fuse_all_reduce_ops:
grad_group = self._fuse_allreduce() grad_group = self._fuse_allreduce()
# self.summary(grad_group) # self.summary(grad_group)
...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase): ...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) ), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id")) group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format( assert (
grad_name, str(op)) group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
self._grad_name_to_group_map[grad_name] = group self._grad_name_to_group_map[grad_name] = group
...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future. # TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network. # here we assume that the optimizer is unique in network.
elif is_optimize_op( elif (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
self._support_rescale_grad = True self._support_rescale_grad = True
not_synchronized_grads = [] not_synchronized_grads = []
for grad_name in scaled_grads: for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map: if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name) not_synchronized_grads.append(grad_name)
assert len( assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( )
not_synchronized_grads)
def is_data_parallel_applied(self): def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0 return len(self._group_to_grad_name_map) > 0
...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context.gradient_scale and ( return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree()) self._support_rescale_grad or self._all_dp_groups_same_degree()
)
def _all_dp_groups_same_degree(self): def _all_dp_groups_same_degree(self):
return len( return (
set([ len(
set(
[
len(group.ranks) len(group.ranks)
for group in self._group_to_grad_name_map.keys() for group in self._group_to_grad_name_map.keys()
])) == 1 ]
)
)
== 1
)
def _scale_backward_initial_grad(self): def _scale_backward_initial_grad(self):
...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op): if is_loss_grad_op(op):
assert op.type == 'fill_constant', \ assert op.type == 'fill_constant', (
"loss_grad_op must be fill_constant op, " \ "loss_grad_op must be fill_constant op, "
"but this op is {}".format(op.type) "but this op is {}".format(op.type)
)
assert op.has_attr('value') assert op.has_attr('value')
loss_scale = float(op.attr('value')) loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree loss_scale = loss_scale / dp_degree
...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase): ...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = set() scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op( if (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr( assert op.has_attr(
'rescale_grad' 'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op)) str(op)
assert len( )
op.input("Grad") assert (
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format( len(op.input("Grad")) == 1
str(op)) ), "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op)
)
grad_name = op.input("Grad")[0] grad_name = op.input("Grad")[0]
dp_degree = len( dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks)) list(self._grad_name_to_group_map[grad_name].ranks)
)
scaled_grads.add(grad_name) scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad) op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys( assert scaled_grads == set(
)), "Unexception: gradients [{}] are unscaled.".format( self._grad_name_to_group_map.keys()
set(self._grad_name_to_group_map.keys()) - scaled_grads) ), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
def _could_be_overlap(self): def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream # NOTE current different nccl comm will use different cuda stream
...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False) op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx, block._insert_op_without_sync(
idx,
type='c_wait_compute', type='c_wait_compute',
inputs={'X': []}, inputs={'X': []},
outputs={'Out': []}, outputs={'Out': []},
attrs={ attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'op_role': OpRole.Backward, )
'ring_id': ring_id
})
block._sync_with_cpp() block._sync_with_cpp()
...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad # other ops that might use communicating grad
else: else:
for input_var_name in op.input_arg_names: for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items( for (
): ring_id,
unsync_grad_names,
) in ring_id_to_un_sync_grad_map.items():
if input_var_name in unsync_grad_names: if input_var_name in unsync_grad_names:
# need to sync before op_i # need to sync before op_i
if i in op_idx_to_sync_ring_id_map: if i in op_idx_to_sync_ring_id_map:
...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for i in sorted(indices, reverse=True): for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]: for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i, block._insert_op_without_sync(
i,
type='c_wait_comm', type='c_wait_comm',
inputs={'X': []}, inputs={'X': []},
outputs={'Out': []}, outputs={'Out': []},
attrs={ attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'op_role': OpRole.Backward, )
'ring_id': ring_id
})
def _could_be_fuse(self): def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient. # TODO support gradient fuse higher order gradient.
...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase): ...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]): for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor # create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate( group.coalesce_var = block.create_var(
'coalecse_grad_{}'.format(i)), name=unique_name.generate('coalecse_grad_{}'.format(i)),
dtype=group.dtype, dtype=group.dtype,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
# update allreduce & scale op # update allreduce & scale op
if group.scale_op_idx != -1: if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx] scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format( assert (
str(scale_op)) scale_op.type == 'scale'
scale_op._rename_input(scale_op.input_arg_names[0], ), "should found scale op but found {}".format(str(scale_op))
group.coalesce_var.name) scale_op._rename_input(
scale_op._rename_output(scale_op.output_arg_names[0], scale_op.input_arg_names[0], group.coalesce_var.name
group.coalesce_var.name) )
scale_op._rename_output(
scale_op.output_arg_names[0], group.coalesce_var.name
)
allreduce_op = block.ops[group.allreduce_op_idx] allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format( assert (
str(allreduce_op)) allreduce_op.type == 'c_allreduce_sum'
allreduce_op._rename_input(allreduce_op.input_arg_names[0], ), "should found c_allreduce_sum op but found {}".format(
group.coalesce_var.name) str(allreduce_op)
allreduce_op._rename_output(allreduce_op.output_arg_names[0], )
group.coalesce_var.name) allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
)
# remvoe un-used op # remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices remove_op_indices = (
group.remove_wait_op_indices
+ group.remove_allreduce_op_indices
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True): for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[ assert (
idx].type in remove_op_types, "Unexception: try to remove op {}".format( block.ops[idx].type in remove_op_types
str(op)) ), "Unexception: try to remove op {}".format(
str(block.ops[idx].type())
)
block._remove_op(idx) block._remove_op(idx)
# insert coalecse op # insert coalecse op
...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase): ...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks.append(len(shape)) concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients] grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx, block._insert_op_without_sync(
group.coalesce_op_idx,
type="coalesce_tensor", type="coalesce_tensor",
inputs={"Input": grad_names}, inputs={"Input": grad_names},
outputs={ outputs={
"Output": grad_names, "Output": grad_names,
"FusedOutput": group.coalesce_var "FusedOutput": group.coalesce_var,
}, },
attrs={ attrs={
"copy_data": False, "copy_data": False,
"use_align": True, "use_align": True,
"dtype": group.dtype, "dtype": group.dtype,
"concated_shapes": "concated_shapes": concated_shapes,
concated_shapes,
"concated_ranks": concated_ranks, "concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
block._sync_with_cpp() block._sync_with_cpp()
# TODO update dist attr # TODO update dist attr
...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def summary(self, grad_groups=[]): def summary(self, grad_groups=[]):
# TODO: add logger module # TODO: add logger module
import logging import logging
self._logger = logging.getLogger() self._logger = logging.getLogger()
self._logger.propagate = False self._logger.propagate = False
if not self._logger.handlers: if not self._logger.handlers:
...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase): ...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if len(grad_groups) > 0: if len(grad_groups) > 0:
self._logger.info( self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops." "origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
.format(len(self._grad_name_to_group_map.keys()), len(self._grad_name_to_group_map.keys()), len(grad_groups)
len(grad_groups))) )
)
self._logger.info("gradient fusing group are following: ") self._logger.info("gradient fusing group are following: ")
fused_grads = set() fused_grads = set()
for i, group in enumerate(grad_groups): for i, group in enumerate(grad_groups):
self._logger.info( self._logger.info(
"coalecse gradient [{}] is composed by: {}".format( "coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients])) i, [grad.name for grad in group.gradients]
)
)
fused_grads.update([grad.name for grad in group.gradients]) fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set( individual_grads = set(self._grad_name_to_group_map.keys()) - set(
self._grad_name_to_group_map.keys()) - set(fused_grads) fused_grads
)
self._logger.info( self._logger.info(
"the following [{}] gradients are not fused: ".format( "the following [{}] gradients are not fused: ".format(
len(individual_grads))) len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads)) self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object): class GradientsGroup(object):
def __init__(self, ops, max_group_size): def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size self.max_group_size = max_group_size
self.ops = ops self.ops = ops
...@@ -575,8 +643,11 @@ class GradientsGroup(object): ...@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx -= 1 grad_op_idx -= 1
grad_op = self.ops[grad_op_idx] grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( assert (
grad_var.name, str(grad_op)) grad_var.name in grad_op.output_arg_names
), "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op)
)
self.coalesce_op_idx = grad_op_idx self.coalesce_op_idx = grad_op_idx
def finalize(self): def finalize(self):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable == False
# pp2 1f1b merge training
engine_1f1b = self.get_engine(True)
history = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable == True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history.history["loss"])
losses_1f1b = np.array(history.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS}) ${dist_ENVS})
......
...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
elif strategy == "dp": elif strategy == "dp":
modeling._global_parallel_strategy = "dp" modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7, eos_token_id=7,
bos_token_id=0, bos_token_id=0,
eol_token_id=3, eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
) )
model = GPTForPretraining( model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset from get_gpt_model import generate_model, FakeDataset
paddle.enable_static() paddle.enable_static()
...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False): ...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True strategy.reinit = True
if use_gradient_merge: if use_gradient_merge:
gradient_merge = strategy.gradient_merge gradient_merge = strategy.gradient_merge
gradient_merge.enable = True gradient_merge.enable = True
gradient_merge.k_steps = 4 gradient_merge.k_steps = 4
gradient_merge.avg = True gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy return strategy
...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history = dp_engine.fit( history = dp_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert dp_engine._strategy.gradient_merge.enable == False
dp_losses = np.array(history.history["loss"]) dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training # dp2 gradient merge training
...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history = gm_engine.fit( history = gm_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert gm_engine._strategy.gradient_merge.enable == True
gm_losses = np.array(history.history["loss"]) gm_losses = np.array(history.history["loss"])
# avg_loss = 0 # avg_loss = 0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册