未验证 提交 c47853f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)

上级 6f3c9643
...@@ -110,12 +110,15 @@ void PreventVarsDelete( ...@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile( std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) { const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze // NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the // the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op. // whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished. // The vars in while block should not be free until the while op is finished.
// The local vars will be free while running op in sub block. // In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor. // The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile( ...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) { for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name); while_block_vars.emplace_back(var_name);
} }
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
} }
} }
return while_block_vars; return while_block_vars;
} }
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
}
} // namespace } // namespace
void FleetExecutor::Init( void FleetExecutor::Init(
...@@ -174,13 +162,8 @@ void FleetExecutor::Init( ...@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) { for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") { if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = while_block_vars = GetUnusedVarsAfterWhile(
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); program_desc, task_node, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
VLOG(3) << "Vars will be gced after while op"; VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) { for (auto var : while_block_vars) {
VLOG(3) << var; VLOG(3) << var;
...@@ -210,9 +193,6 @@ void FleetExecutor::Init( ...@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release(); unique_op.release();
} }
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -223,8 +203,6 @@ void FleetExecutor::Init( ...@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars); task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
......
...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8) ...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "algo", None)
# ######################################### #########################################
# auto tuning configuration # auto tuning configuration
# ######################################### #########################################
TUNING = "tuning" TUNING = "tuning"
set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1) set_field_default_config(TUNING, "batch_size", 1)
...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True) ...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET = "dataset" DATASET = "dataset"
set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1) set_field_default_config(DATASET, "num_shards", 1)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
...@@ -17,12 +17,18 @@ import numpy as np ...@@ -17,12 +17,18 @@ import numpy as np
import paddle import paddle
from paddle.io import BatchSampler, IterableDataset from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler from paddle.fluid.dataloader.batch_sampler import (
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn _InfiniteIterableSampler,
DistributedBatchSampler,
)
from paddle.fluid.dataloader.dataloader_iter import (
_DatasetKind,
default_collate_fn,
default_convert_fn,
)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta): class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
...@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta): ...@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
capacity=None, capacity=None,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False, return_list=False,
use_multiprocess=False, use_multiprocess=False,
drop_last=True, drop_last=True,
places=None, places=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None, collate_fn=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.capacity = capacity self.capacity = capacity
...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER self.dataset_kind = _DatasetKind.ITER
...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else: else:
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler( self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size) dataset, batch_size
)
else: else:
self.batch_sampler = BatchSampler(dataset, self.batch_sampler = BatchSampler(
batch_size=batch_size, dataset,
shuffle=False, batch_size=batch_size,
drop_last=drop_last) shuffle=False,
drop_last=drop_last,
)
self.auto_collate_batch = self.batch_sampler is not None self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler) self.sampler_iter = iter(self.index_sampler)
...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind,
self.collate_fn, self.drop_last) self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None steps_per_epoch = None
elif self.batch_size is None: elif self.batch_size is None:
steps_per_epoch = len(self.dataset) steps_per_epoch = len(self.dataset) // self.acc_steps
else: else:
steps_per_epoch = len(self.dataset) // self.batch_size steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except: except:
raise ValueError( raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class." "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return _InfiniteIterableSampler(self.dataset, 1) return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def data_generator():
while True: while True:
try: try:
indices = next(self.sampler_iter) indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices) batch = self.dataset_fetcher.fetch(indices)
if batch is None: break if batch is None:
break
except StopIteration: except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind,
self.auto_collate_batch, self.collate_fn, self.dataset,
self.drop_last) self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
break break
partial_data = [] partial_data = []
...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue continue
batch_size = array.shape[0] batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \ assert (
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i])) batch_size % self.dp_world_sizes[i] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self.dp_world_sizes[i])
)
partial_data.append( partial_data.append(
np.split(array, np.split(array, self.dp_world_sizes[i])[
self.dp_world_sizes[i])[self.dp_ranks[i]]) self.dp_ranks[i]
]
)
yield partial_data yield partial_data
...@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable=False, iterable=False,
return_list=self.return_list, return_list=self.return_list,
use_multiprocess=self.use_multiprocess, use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last) drop_last=self.drop_last,
)
dataloader.set_batch_generator(data_generator, self.places) dataloader.set_batch_generator(data_generator, self.places)
return dataloader return dataloader
class DistributedDataLoader(DistributedDataLoaderBase): class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
places=None, places=None,
return_list=True, return_list=True,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=None, collate_fn=None,
num_workers=0, num_workers=0,
use_buffer_reader=True, use_buffer_reader=True,
use_shared_memory=True, use_shared_memory=True,
timeout=0, timeout=0,
worker_init_fn=None, worker_init_fn=None,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.return_list = return_list self.return_list = return_list
...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self.split_data = split_data self.split_data = split_data
# TODO: rank info # TODO: rank info
self.batch_sampler = DistributedBatchSampler( self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0], self.dataset,
self.dp_ranks[0], self.shuffle, self.drop_last) self.batch_size,
self.dp_world_sizes[0],
self.dp_ranks[0],
self.shuffle,
self.drop_last,
)
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self): def __iter__(self):
...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader=self.use_buffer_reader, use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory, use_shared_memory=self.use_shared_memory,
timeout=self.timeout, timeout=self.timeout,
worker_init_fn=self.worker_init_fn) worker_init_fn=self.worker_init_fn,
)
self.data = (x for x in dataloader) self.data = (x for x in dataloader)
return dataloader return dataloader
...@@ -18,6 +18,7 @@ import errno ...@@ -18,6 +18,7 @@ import errno
import pickle import pickle
import warnings import warnings
import logging import logging
import collections
import numpy as np import numpy as np
import paddle import paddle
...@@ -53,16 +54,13 @@ def _process_path(path): ...@@ -53,16 +54,13 @@ def _process_path(path):
class DistributedSaver: class DistributedSaver:
def __init__(self): def __init__(self):
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
def save(self, path, serial_program, dist_main_program, dist_context): def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"): def _save_state(program, path, mode="param"):
state = { state = {
k: np.array(v) k: np.array(v) for k, v in program.state_dict(mode).items()
for k, v in program.state_dict(mode).items()
} }
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(state, f) pickle.dump(state, f)
...@@ -108,8 +106,9 @@ class DistributedSaver: ...@@ -108,8 +106,9 @@ class DistributedSaver:
def _load_file(filename, dirname, suffix="pdparams"): def _load_file(filename, dirname, suffix="pdparams"):
file_list = [] file_list = []
for file in os.listdir(dirname): for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix), if check_filename(
file): '{}(.*)_dist(.*).{}'.format(filename, suffix), file
):
file_list.append(os.path.join(dirname, file)) file_list.append(os.path.join(dirname, file))
file_list.sort() file_list.sort()
return file_list return file_list
...@@ -137,14 +136,16 @@ class DistributedSaver: ...@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt # load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname) param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname, opt_state_dict = (
"pdopt") if load_optimizer else {} _load_state(filename, dirname, "pdopt") if load_optimizer else {}
)
state_dict = dict(param_state_dict, **opt_state_dict) state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr # load path.pdattr
dist_attr_file_list = _load_file(filename, dirname, "pdattr") dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info( self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list)
)
dist_attr = {} dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
...@@ -196,12 +197,24 @@ class DistributedSaver: ...@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs += op.input_arg_names used_inputs += op.input_arg_names
used_outputs += op.output_arg_names used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs)) # delete duplicated elements and keep order
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs)) feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars = [ dist_feed_vars_names = [
global_block.vars[name] for name in dist_feed_vars_names var_name for var_name in feed_vars_names if var_name in used_inputs
] ]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [ dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names global_block.vars[name] for name in dist_fetch_vars_names
] ]
...@@ -209,11 +222,13 @@ class DistributedSaver: ...@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock. # NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id) dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename) dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(dist_path, paddle.static.save_inference_model(
dist_feed_vars, dist_path,
dist_fetch_vars, dist_feed_vars,
exe, dist_fetch_vars,
program=dist_main_prog) exe,
program=dist_main_prog,
)
def _save_rank_mapping(self, dirname): def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv') path = os.path.join(dirname, 'rank_mapping.csv')
......
...@@ -225,6 +225,11 @@ class Engine: ...@@ -225,6 +225,11 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None self.history = None
...@@ -388,7 +393,12 @@ class Engine: ...@@ -388,7 +393,12 @@ class Engine:
if self.main_program._pipeline_opt: if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0] fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program() fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block() fwd_block = fwd_prog.global_block()
...@@ -438,8 +448,6 @@ class Engine: ...@@ -438,8 +448,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -466,7 +474,7 @@ class Engine: ...@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches: for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch) var_name = _to_name_str(usr_fetch)
fetch(var_name) fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
...@@ -903,6 +911,7 @@ class Engine: ...@@ -903,6 +911,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size train_data, train_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -931,7 +940,7 @@ class Engine: ...@@ -931,7 +940,7 @@ class Engine:
save_dir=save_dir, save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), metrics=self._metrics_name(),
acc_step=self._k_steps, acc_step=self._acc_steps,
) )
cbks.on_begin('train') cbks.on_begin('train')
...@@ -965,7 +974,7 @@ class Engine: ...@@ -965,7 +974,7 @@ class Engine:
val_logs = self.evaluate( val_logs = self.evaluate(
valid_data, valid_data,
valid_sample_split, valid_sample_split,
batch_size, batch_size * self._acc_steps,
valid_steps, valid_steps,
log_freq, log_freq,
collate_fn, collate_fn,
...@@ -1046,6 +1055,7 @@ class Engine: ...@@ -1046,6 +1055,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size valid_data, valid_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1152,6 +1162,7 @@ class Engine: ...@@ -1152,6 +1162,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size test_data, test_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1214,6 +1225,7 @@ class Engine: ...@@ -1214,6 +1225,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1256,6 +1268,7 @@ class Engine: ...@@ -1256,6 +1268,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1371,14 +1384,6 @@ class Engine: ...@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch=None, steps_per_epoch=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1440,14 +1445,6 @@ class Engine: ...@@ -1440,14 +1445,6 @@ class Engine:
collate_fn=None, collate_fn=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1487,6 +1484,9 @@ class Engine: ...@@ -1487,6 +1484,9 @@ class Engine:
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks, data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
) )
self._prepare_reader(feed_list) self._prepare_reader(feed_list)
return dataloader return dataloader
...@@ -1498,9 +1498,18 @@ class Engine: ...@@ -1498,9 +1498,18 @@ class Engine:
) )
self._optimization_tuning(self._mode, tune_data, batch_size) self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = to_list(specs) specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec): if not isinstance(spec, InputSpec):
...@@ -1513,14 +1522,14 @@ class Engine: ...@@ -1513,14 +1522,14 @@ class Engine:
i, spec i, spec
) )
) )
if self._k_steps > 1: if self._acc_steps > 1:
shape = list(spec.shape) shape = list(spec.shape)
assert ( assert (
shape[0] % self._k_steps == 0 shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps spec.shape[0], self._acc_steps
) )
shape[0] //= self._k_steps shape[0] //= self._acc_steps
spec.shape = shape spec.shape = shape
return specs or [] return specs or []
......
...@@ -297,13 +297,15 @@ class Parallelizer: ...@@ -297,13 +297,15 @@ class Parallelizer:
if self._strategy is None: if self._strategy is None:
return return
# data parallel optimization if self._strategy.dp_optimization.enable:
config = {} config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass(
dp_pass.apply([main_program], [startup_program], self._pass_context) "auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable: if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict()) config = copy.deepcopy(self._strategy.sharding.to_dict())
......
...@@ -422,11 +422,11 @@ class Inserter: ...@@ -422,11 +422,11 @@ class Inserter:
) )
inputs = {'X': [tensor]} inputs = {'X': [tensor]}
outputs = {"Out": [out]} outputs = {"Out": [out]}
attrs = {"in_place": False} attrs = {"in_place": False, "op_role": op_role}
slice_op = block._insert_op( assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
) )
slice_op._set_attr('op_namescope', "/auto_parallel/reshard") assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use split once # use split once
...@@ -1217,6 +1217,8 @@ class Resharder: ...@@ -1217,6 +1217,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1] shape_x[0] <= shape_y[0] < shape_x[1]
): ):
overlapped = True overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped return overlapped
def is_unshard(self, dims_mapping): def is_unshard(self, dims_mapping):
...@@ -1304,6 +1306,14 @@ class Resharder: ...@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh # judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh: if tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else: else:
op_output_dims_mapping = dist_attr[1] op_output_dims_mapping = dist_attr[1]
if all( if all(
...@@ -1585,10 +1595,10 @@ class Resharder: ...@@ -1585,10 +1595,10 @@ class Resharder:
if i == 0: if i == 0:
all_partition_index_list.append(process_index[j][1]) all_partition_index_list.append(process_index[j][1])
for process in group: for process in group:
# append slice op desc min_comm_group = copy.deepcopy(group)
slice_starts = [] all_partition_index_list_copied = copy.deepcopy(
slice_ends = [] all_partition_index_list
slices_axes = [] )
target_partition_index = Resharder.compute_partition_index( target_partition_index = Resharder.compute_partition_index(
process, process,
complete_shape, complete_shape,
...@@ -1596,12 +1606,56 @@ class Resharder: ...@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape, target_process_shape,
target_process_group, target_process_group,
) )
for idx, item in enumerate(target_partition_index): for _process in group:
slice_starts.append(item[0]) source_partition_index = (
slice_ends.append(item[1]) Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx) slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc( slice_op_desc = SliceOpDesc(
starts=slice_starts, starts=slice_starts,
ends=slice_ends, ends=slice_ends,
...@@ -1616,16 +1670,16 @@ class Resharder: ...@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq[process] = ( op_desc_seq[process] = (
[ [
AllGatherOpDesc( AllGatherOpDesc(
group=group, group=min_comm_group,
shape=allgather_shape, shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool), is_bool=(source_tensor.dtype == paddle.bool),
), ),
ConcatOpDesc( ConcatOpDesc(
partition_index_list=all_partition_index_list partition_index_list=all_partition_index_list_copied
), ),
slice_op_desc, slice_op_desc,
] ]
if len(group) > 1 if len(min_comm_group) > 1
else [slice_op_desc] else [slice_op_desc]
) )
......
...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig): ...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super(DatasetConfig, self).__init__(category, config_dict) super(DatasetConfig, self).__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super(DPOptimizationConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
...@@ -194,3 +200,6 @@ class Strategy(BaseConfig): ...@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.DATASET, None) config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict) self.dataset = DatasetConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad", "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "flatten_contiguous_range_grad",
"relu_grad", "relu_grad",
"exp_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle", "fused_softmax_mask_upper_triangle",
"flatten_contiguous_range", "flatten_contiguous_range",
"relu", "relu",
"exp",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
...@@ -1320,6 +1322,11 @@ def is_forward_op(op): ...@@ -1320,6 +1322,11 @@ def is_forward_op(op):
) )
def is_fillconst_op_for_micro_batch(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched))
def is_backward_op(op): def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and int( return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY] op.all_attrs()[OP_ROLE_KEY]
......
...@@ -18,15 +18,31 @@ import numpy as np ...@@ -18,15 +18,31 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import (
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op OpRole,
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
is_backward_op,
ring_id_to_process_group,
find_higher_order_backward_op,
)
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [ __rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum', 'lars_momentum',
'merge_momentum' 'sparse_momentum',
'dgc_momentum',
'momentum',
'merge_momentum',
] ]
# a heuristic number # a heuristic number
...@@ -41,7 +57,7 @@ def numel(var): ...@@ -41,7 +57,7 @@ def numel(var):
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
Apply Optimizations that specialized for data parallelism in Auto Parallel. Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling 1. prune grad scaling
2. overlap comm and calc 2. overlap comm and calc
3. fuse allreduce 3. fuse allreduce
""" """
...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False) self.set_attr("use_sharding", False)
self.set_attr("fuse_all_reduce_ops", False)
self.set_attr("fuse_grad_size_in_MB", 32)
self.set_attr("overlap_comm_cacl", False)
# {grad1: group1, grad2: group1, grad3: group2} # {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory # record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict() self._grad_name_to_group_map = OrderedDict()
...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
return False return False
if (not isinstance(self.get_attr("global_rank"), if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
int)) or self.get_attr("global_rank") < 0: "global_rank"
) < 0:
return False return False
return True return True
...@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase): ...@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase):
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding") self.use_sharding = self.get_attr("use_sharding")
overlap_comm_cacl = self.get_attr("overlap_comm_cacl")
fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops")
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
if self.is_data_parallel_applied(): if self.is_data_parallel_applied():
self._prune_grad_scaling() if overlap_comm_cacl:
self._calc_comm_overlap() self._prune_grad_scaling()
grad_group = self._fuse_allreduce() self._calc_comm_overlap()
if fuse_all_reduce_ops:
grad_group = self._fuse_allreduce()
# self.summary(grad_group) # self.summary(grad_group)
...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase): ...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) ), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id")) group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format( assert (
grad_name, str(op)) group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
self._grad_name_to_group_map[grad_name] = group self._grad_name_to_group_map[grad_name] = group
...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future. # TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network. # here we assume that the optimizer is unique in network.
elif is_optimize_op( elif (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
self._support_rescale_grad = True self._support_rescale_grad = True
not_synchronized_grads = [] not_synchronized_grads = []
for grad_name in scaled_grads: for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map: if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name) not_synchronized_grads.append(grad_name)
assert len( assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( )
not_synchronized_grads)
def is_data_parallel_applied(self): def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0 return len(self._group_to_grad_name_map) > 0
...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context.gradient_scale and ( return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree()) self._support_rescale_grad or self._all_dp_groups_same_degree()
)
def _all_dp_groups_same_degree(self): def _all_dp_groups_same_degree(self):
return len( return (
set([ len(
len(group.ranks) set(
for group in self._group_to_grad_name_map.keys() [
])) == 1 len(group.ranks)
for group in self._group_to_grad_name_map.keys()
]
)
)
== 1
)
def _scale_backward_initial_grad(self): def _scale_backward_initial_grad(self):
...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op): if is_loss_grad_op(op):
assert op.type == 'fill_constant', \ assert op.type == 'fill_constant', (
"loss_grad_op must be fill_constant op, " \ "loss_grad_op must be fill_constant op, "
"but this op is {}".format(op.type) "but this op is {}".format(op.type)
)
assert op.has_attr('value') assert op.has_attr('value')
loss_scale = float(op.attr('value')) loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree loss_scale = loss_scale / dp_degree
...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase): ...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = set() scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op( if (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr( assert op.has_attr(
'rescale_grad' 'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op)) str(op)
assert len( )
op.input("Grad") assert (
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format( len(op.input("Grad")) == 1
str(op)) ), "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op)
)
grad_name = op.input("Grad")[0] grad_name = op.input("Grad")[0]
dp_degree = len( dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks)) list(self._grad_name_to_group_map[grad_name].ranks)
)
scaled_grads.add(grad_name) scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad) op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys( assert scaled_grads == set(
)), "Unexception: gradients [{}] are unscaled.".format( self._grad_name_to_group_map.keys()
set(self._grad_name_to_group_map.keys()) - scaled_grads) ), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
def _could_be_overlap(self): def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream # NOTE current different nccl comm will use different cuda stream
...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False) op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx, block._insert_op_without_sync(
type='c_wait_compute', idx,
inputs={'X': []}, type='c_wait_compute',
outputs={'Out': []}, inputs={'X': []},
attrs={ outputs={'Out': []},
'op_role': OpRole.Backward, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'ring_id': ring_id )
})
block._sync_with_cpp() block._sync_with_cpp()
...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad # other ops that might use communicating grad
else: else:
for input_var_name in op.input_arg_names: for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items( for (
): ring_id,
unsync_grad_names,
) in ring_id_to_un_sync_grad_map.items():
if input_var_name in unsync_grad_names: if input_var_name in unsync_grad_names:
# need to sync before op_i # need to sync before op_i
if i in op_idx_to_sync_ring_id_map: if i in op_idx_to_sync_ring_id_map:
...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for i in sorted(indices, reverse=True): for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]: for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i, block._insert_op_without_sync(
type='c_wait_comm', i,
inputs={'X': []}, type='c_wait_comm',
outputs={'Out': []}, inputs={'X': []},
attrs={ outputs={'Out': []},
'op_role': OpRole.Backward, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'ring_id': ring_id )
})
def _could_be_fuse(self): def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient. # TODO support gradient fuse higher order gradient.
...@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase):
""" """
conditions for gradients to be grouped: conditions for gradients to be grouped:
1. group size < max_fuse_numel 1. group size < max_fuse_numel
2. same dp group 2. same dp group
3. same dtype 3. same dtype
4. dependency: grad would NOT be used by other ops within group segment 4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor gradients inside same group would be fuse into one coalesce tensor
""" """
...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase): ...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]): for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor # create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate( group.coalesce_var = block.create_var(
'coalecse_grad_{}'.format(i)), name=unique_name.generate('coalecse_grad_{}'.format(i)),
dtype=group.dtype, dtype=group.dtype,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
# update allreduce & scale op # update allreduce & scale op
if group.scale_op_idx != -1: if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx] scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format( assert (
str(scale_op)) scale_op.type == 'scale'
scale_op._rename_input(scale_op.input_arg_names[0], ), "should found scale op but found {}".format(str(scale_op))
group.coalesce_var.name) scale_op._rename_input(
scale_op._rename_output(scale_op.output_arg_names[0], scale_op.input_arg_names[0], group.coalesce_var.name
group.coalesce_var.name) )
scale_op._rename_output(
scale_op.output_arg_names[0], group.coalesce_var.name
)
allreduce_op = block.ops[group.allreduce_op_idx] allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format( assert (
str(allreduce_op)) allreduce_op.type == 'c_allreduce_sum'
allreduce_op._rename_input(allreduce_op.input_arg_names[0], ), "should found c_allreduce_sum op but found {}".format(
group.coalesce_var.name) str(allreduce_op)
allreduce_op._rename_output(allreduce_op.output_arg_names[0], )
group.coalesce_var.name) allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
)
# remvoe un-used op # remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices remove_op_indices = (
group.remove_wait_op_indices
+ group.remove_allreduce_op_indices
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True): for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[ assert (
idx].type in remove_op_types, "Unexception: try to remove op {}".format( block.ops[idx].type in remove_op_types
str(op)) ), "Unexception: try to remove op {}".format(
str(block.ops[idx].type())
)
block._remove_op(idx) block._remove_op(idx)
# insert coalecse op # insert coalecse op
...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase): ...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks.append(len(shape)) concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients] grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx, block._insert_op_without_sync(
type="coalesce_tensor", group.coalesce_op_idx,
inputs={"Input": grad_names}, type="coalesce_tensor",
outputs={ inputs={"Input": grad_names},
"Output": grad_names, outputs={
"FusedOutput": group.coalesce_var "Output": grad_names,
}, "FusedOutput": group.coalesce_var,
attrs={ },
"copy_data": False, attrs={
"use_align": True, "copy_data": False,
"dtype": group.dtype, "use_align": True,
"concated_shapes": "dtype": group.dtype,
concated_shapes, "concated_shapes": concated_shapes,
"concated_ranks": concated_ranks, "concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
block._sync_with_cpp() block._sync_with_cpp()
# TODO update dist attr # TODO update dist attr
...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def summary(self, grad_groups=[]): def summary(self, grad_groups=[]):
# TODO: add logger module # TODO: add logger module
import logging import logging
self._logger = logging.getLogger() self._logger = logging.getLogger()
self._logger.propagate = False self._logger.propagate = False
if not self._logger.handlers: if not self._logger.handlers:
...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase): ...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if len(grad_groups) > 0: if len(grad_groups) > 0:
self._logger.info( self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops." "origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
.format(len(self._grad_name_to_group_map.keys()), len(self._grad_name_to_group_map.keys()), len(grad_groups)
len(grad_groups))) )
)
self._logger.info("gradient fusing group are following: ") self._logger.info("gradient fusing group are following: ")
fused_grads = set() fused_grads = set()
for i, group in enumerate(grad_groups): for i, group in enumerate(grad_groups):
self._logger.info( self._logger.info(
"coalecse gradient [{}] is composed by: {}".format( "coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients])) i, [grad.name for grad in group.gradients]
)
)
fused_grads.update([grad.name for grad in group.gradients]) fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set( individual_grads = set(self._grad_name_to_group_map.keys()) - set(
self._grad_name_to_group_map.keys()) - set(fused_grads) fused_grads
)
self._logger.info( self._logger.info(
"the following [{}] gradients are not fused: ".format( "the following [{}] gradients are not fused: ".format(
len(individual_grads))) len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads)) self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object): class GradientsGroup(object):
def __init__(self, ops, max_group_size): def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size self.max_group_size = max_group_size
self.ops = ops self.ops = ops
...@@ -575,8 +643,11 @@ class GradientsGroup(object): ...@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx -= 1 grad_op_idx -= 1
grad_op = self.ops[grad_op_idx] grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( assert (
grad_var.name, str(grad_op)) grad_var.name in grad_op.output_arg_names
), "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op)
)
self.coalesce_op_idx = grad_op_idx self.coalesce_op_idx = grad_op_idx
def finalize(self): def finalize(self):
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable == False
# pp2 1f1b merge training
engine_1f1b = self.get_engine(True)
history = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable == True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history.history["loss"])
losses_1f1b = np.array(history.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS}) ${dist_ENVS})
......
...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
elif strategy == "dp": elif strategy == "dp":
modeling._global_parallel_strategy = "dp" modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7, eos_token_id=7,
bos_token_id=0, bos_token_id=0,
eol_token_id=3, eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
) )
model = GPTForPretraining( model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset from get_gpt_model import generate_model, FakeDataset
paddle.enable_static() paddle.enable_static()
...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False): ...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True strategy.reinit = True
if use_gradient_merge: if use_gradient_merge:
gradient_merge = strategy.gradient_merge gradient_merge = strategy.gradient_merge
gradient_merge.enable = True gradient_merge.enable = True
gradient_merge.k_steps = 4 gradient_merge.k_steps = 4
gradient_merge.avg = True gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy return strategy
...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history = dp_engine.fit( history = dp_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert dp_engine._strategy.gradient_merge.enable == False
dp_losses = np.array(history.history["loss"]) dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training # dp2 gradient merge training
...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history = gm_engine.fit( history = gm_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert gm_engine._strategy.gradient_merge.enable == True
gm_losses = np.array(history.history["loss"]) gm_losses = np.array(history.history["loss"])
# avg_loss = 0 # avg_loss = 0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册