未验证 提交 c47853f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)

上级 6f3c9643
...@@ -110,12 +110,15 @@ void PreventVarsDelete( ...@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile( std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc, const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) { const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze // NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the // the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op. // whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished. // The vars in while block should not be free until the while op is finished.
// The local vars will be free while running op in sub block. // In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor. // The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars; std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops; std::vector<std::unique_ptr<framework::OperatorBase>> ops;
...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile( ...@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) { for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name); while_block_vars.emplace_back(var_name);
} }
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
} }
} }
return while_block_vars; return while_block_vars;
} }
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
}
} // namespace } // namespace
void FleetExecutor::Init( void FleetExecutor::Init(
...@@ -174,13 +162,8 @@ void FleetExecutor::Init( ...@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) { for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") { if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks); GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars = while_block_vars = GetUnusedVarsAfterWhile(
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars); program_desc, task_node, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
VLOG(3) << "Vars will be gced after while op"; VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) { for (auto var : while_block_vars) {
VLOG(3) << var; VLOG(3) << var;
...@@ -210,9 +193,6 @@ void FleetExecutor::Init( ...@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release(); unique_op.release();
} }
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars // NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the // shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result. // inf. If they are GCed, it will cause error during ZeroCopy the result.
...@@ -223,8 +203,6 @@ void FleetExecutor::Init( ...@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) { for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) { if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars); task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
} }
int64_t interceptor_id = task_node->task_id(); int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node); interceptor_id_to_task.emplace(interceptor_id, task_node);
......
...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8) ...@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant']) set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None) set_field_default_config(QAT, "algo", None)
# ######################################### #########################################
# auto tuning configuration # auto tuning configuration
# ######################################### #########################################
TUNING = "tuning" TUNING = "tuning"
set_field_default_config(TUNING, "enable", False) set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1) set_field_default_config(TUNING, "batch_size", 1)
...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True) ...@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET = "dataset" DATASET = "dataset"
set_field_default_config(DATASET, "enable", False) set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1) set_field_default_config(DATASET, "num_shards", 1)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
...@@ -17,12 +17,18 @@ import numpy as np ...@@ -17,12 +17,18 @@ import numpy as np
import paddle import paddle
from paddle.io import BatchSampler, IterableDataset from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler from paddle.fluid.dataloader.batch_sampler import (
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn _InfiniteIterableSampler,
DistributedBatchSampler,
)
from paddle.fluid.dataloader.dataloader_iter import (
_DatasetKind,
default_collate_fn,
default_convert_fn,
)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta): class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod @abc.abstractmethod
def __iter__(self): def __iter__(self):
raise NotImplementedError raise NotImplementedError
...@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta): ...@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
capacity=None, capacity=None,
use_double_buffer=True, use_double_buffer=True,
iterable=True, iterable=True,
return_list=False, return_list=False,
use_multiprocess=False, use_multiprocess=False,
drop_last=True, drop_last=True,
places=None, places=None,
batch_size=1, batch_size=1,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
collate_fn=None, collate_fn=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.capacity = capacity self.capacity = capacity
...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list) assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER self.dataset_kind = _DatasetKind.ITER
...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else: else:
if isinstance(dataset, IterableDataset): if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler( self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size) dataset, batch_size
)
else: else:
self.batch_sampler = BatchSampler(dataset, self.batch_sampler = BatchSampler(
batch_size=batch_size, dataset,
shuffle=False, batch_size=batch_size,
drop_last=drop_last) shuffle=False,
drop_last=drop_last,
)
self.auto_collate_batch = self.batch_sampler is not None self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler) self.sampler_iter = iter(self.index_sampler)
...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self.collate_fn = collate_fn or default_convert_fn self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch, self.dataset_kind,
self.collate_fn, self.drop_last) self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
self._steps = self._infer_steps() self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset): if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None steps_per_epoch = None
elif self.batch_size is None: elif self.batch_size is None:
steps_per_epoch = len(self.dataset) steps_per_epoch = len(self.dataset) // self.acc_steps
else: else:
steps_per_epoch = len(self.dataset) // self.batch_size steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except: except:
raise ValueError( raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class." "Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return _InfiniteIterableSampler(self.dataset, 1) return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self): def _create_inner_dataloader(self):
def data_generator(): def data_generator():
while True: while True:
try: try:
indices = next(self.sampler_iter) indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices) batch = self.dataset_fetcher.fetch(indices)
if batch is None: break if batch is None:
break
except StopIteration: except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher( self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.dataset_kind,
self.auto_collate_batch, self.collate_fn, self.dataset,
self.drop_last) self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
break break
partial_data = [] partial_data = []
...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue continue
batch_size = array.shape[0] batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \ assert (
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i])) batch_size % self.dp_world_sizes[i] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self.dp_world_sizes[i])
)
partial_data.append( partial_data.append(
np.split(array, np.split(array, self.dp_world_sizes[i])[
self.dp_world_sizes[i])[self.dp_ranks[i]]) self.dp_ranks[i]
]
)
yield partial_data yield partial_data
...@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase): ...@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable=False, iterable=False,
return_list=self.return_list, return_list=self.return_list,
use_multiprocess=self.use_multiprocess, use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last) drop_last=self.drop_last,
)
dataloader.set_batch_generator(data_generator, self.places) dataloader.set_batch_generator(data_generator, self.places)
return dataloader return dataloader
class DistributedDataLoader(DistributedDataLoaderBase): class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(
def __init__(self, self,
dataset, dataset,
feed_list=None, feed_list=None,
places=None, places=None,
return_list=True, return_list=True,
batch_size=1, batch_size=1,
shuffle=False, shuffle=False,
drop_last=False, drop_last=False,
collate_fn=None, collate_fn=None,
num_workers=0, num_workers=0,
use_buffer_reader=True, use_buffer_reader=True,
use_shared_memory=True, use_shared_memory=True,
timeout=0, timeout=0,
worker_init_fn=None, worker_init_fn=None,
epochs=1, epochs=1,
steps_per_epoch=None, steps_per_epoch=None,
split_data=True, split_data=True,
data_parallel_world_size=[], data_parallel_world_size=[],
data_parallel_rank=[]): data_parallel_rank=[],
):
self.dataset = dataset self.dataset = dataset
self.feed_list = feed_list self.feed_list = feed_list
self.return_list = return_list self.return_list = return_list
...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self.split_data = split_data self.split_data = split_data
# TODO: rank info # TODO: rank info
self.batch_sampler = DistributedBatchSampler( self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0], self.dataset,
self.dp_ranks[0], self.shuffle, self.drop_last) self.batch_size,
self.dp_world_sizes[0],
self.dp_ranks[0],
self.shuffle,
self.drop_last,
)
self._inner_dataloader = self._create_inner_dataloader() self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self): def __iter__(self):
...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase): ...@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader=self.use_buffer_reader, use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory, use_shared_memory=self.use_shared_memory,
timeout=self.timeout, timeout=self.timeout,
worker_init_fn=self.worker_init_fn) worker_init_fn=self.worker_init_fn,
)
self.data = (x for x in dataloader) self.data = (x for x in dataloader)
return dataloader return dataloader
...@@ -18,6 +18,7 @@ import errno ...@@ -18,6 +18,7 @@ import errno
import pickle import pickle
import warnings import warnings
import logging import logging
import collections
import numpy as np import numpy as np
import paddle import paddle
...@@ -53,16 +54,13 @@ def _process_path(path): ...@@ -53,16 +54,13 @@ def _process_path(path):
class DistributedSaver: class DistributedSaver:
def __init__(self): def __init__(self):
self._logger = get_logger(logging.INFO) self._logger = get_logger(logging.INFO)
def save(self, path, serial_program, dist_main_program, dist_context): def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"): def _save_state(program, path, mode="param"):
state = { state = {
k: np.array(v) k: np.array(v) for k, v in program.state_dict(mode).items()
for k, v in program.state_dict(mode).items()
} }
with open(path, "wb") as f: with open(path, "wb") as f:
pickle.dump(state, f) pickle.dump(state, f)
...@@ -108,8 +106,9 @@ class DistributedSaver: ...@@ -108,8 +106,9 @@ class DistributedSaver:
def _load_file(filename, dirname, suffix="pdparams"): def _load_file(filename, dirname, suffix="pdparams"):
file_list = [] file_list = []
for file in os.listdir(dirname): for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix), if check_filename(
file): '{}(.*)_dist(.*).{}'.format(filename, suffix), file
):
file_list.append(os.path.join(dirname, file)) file_list.append(os.path.join(dirname, file))
file_list.sort() file_list.sort()
return file_list return file_list
...@@ -137,14 +136,16 @@ class DistributedSaver: ...@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt # load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname) param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname, opt_state_dict = (
"pdopt") if load_optimizer else {} _load_state(filename, dirname, "pdopt") if load_optimizer else {}
)
state_dict = dict(param_state_dict, **opt_state_dict) state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr # load path.pdattr
dist_attr_file_list = _load_file(filename, dirname, "pdattr") dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info( self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list)) "Load distributed attribute file: {}".format(dist_attr_file_list)
)
dist_attr = {} dist_attr = {}
for dist_attr_file in dist_attr_file_list: for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f: with open(dist_attr_file, 'rb') as f:
...@@ -196,12 +197,24 @@ class DistributedSaver: ...@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs += op.input_arg_names used_inputs += op.input_arg_names
used_outputs += op.output_arg_names used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs)) # delete duplicated elements and keep order
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs)) feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars = [ dist_feed_vars_names = [
global_block.vars[name] for name in dist_feed_vars_names var_name for var_name in feed_vars_names if var_name in used_inputs
] ]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [ dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names global_block.vars[name] for name in dist_fetch_vars_names
] ]
...@@ -209,11 +222,13 @@ class DistributedSaver: ...@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock. # NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id) dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename) dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(dist_path, paddle.static.save_inference_model(
dist_feed_vars, dist_path,
dist_fetch_vars, dist_feed_vars,
exe, dist_fetch_vars,
program=dist_main_prog) exe,
program=dist_main_prog,
)
def _save_rank_mapping(self, dirname): def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv') path = os.path.join(dirname, 'rank_mapping.csv')
......
...@@ -225,6 +225,11 @@ class Engine: ...@@ -225,6 +225,11 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._dygraph_mode = False self._dygraph_mode = False
self._tuning = self._strategy.tuning self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None self.history = None
...@@ -388,7 +393,12 @@ class Engine: ...@@ -388,7 +393,12 @@ class Engine:
if self.main_program._pipeline_opt: if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"] assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"] fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0] fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program() fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block() fwd_block = fwd_prog.global_block()
...@@ -438,8 +448,6 @@ class Engine: ...@@ -438,8 +448,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format( ), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__ type(user_fetches).__name__
) )
else:
user_fetches = []
fetch_names = [] fetch_names = []
fetch_indices = [] fetch_indices = []
...@@ -466,7 +474,7 @@ class Engine: ...@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list) _process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict": if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"]) _process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches: for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch) var_name = _to_name_str(usr_fetch)
fetch(var_name) fetch(var_name)
user_fetches_collection = [ user_fetches_collection = [
...@@ -903,6 +911,7 @@ class Engine: ...@@ -903,6 +911,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size train_data, train_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -931,7 +940,7 @@ class Engine: ...@@ -931,7 +940,7 @@ class Engine:
save_dir=save_dir, save_dir=save_dir,
verbose=verbose, verbose=verbose,
metrics=self._metrics_name(), metrics=self._metrics_name(),
acc_step=self._k_steps, acc_step=self._acc_steps,
) )
cbks.on_begin('train') cbks.on_begin('train')
...@@ -965,7 +974,7 @@ class Engine: ...@@ -965,7 +974,7 @@ class Engine:
val_logs = self.evaluate( val_logs = self.evaluate(
valid_data, valid_data,
valid_sample_split, valid_sample_split,
batch_size, batch_size * self._acc_steps,
valid_steps, valid_steps,
log_freq, log_freq,
collate_fn, collate_fn,
...@@ -1046,6 +1055,7 @@ class Engine: ...@@ -1046,6 +1055,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size valid_data, valid_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1152,6 +1162,7 @@ class Engine: ...@@ -1152,6 +1162,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size test_data, test_sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1214,6 +1225,7 @@ class Engine: ...@@ -1214,6 +1225,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1256,6 +1268,7 @@ class Engine: ...@@ -1256,6 +1268,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec( self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size dataset, sample_split, batch_size
) )
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]: if not self._has_prepared[self._mode]:
self._prepare_program(self._mode) self._prepare_program(self._mode)
else: else:
...@@ -1371,14 +1384,6 @@ class Engine: ...@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch=None, steps_per_epoch=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1440,14 +1445,6 @@ class Engine: ...@@ -1440,14 +1445,6 @@ class Engine:
collate_fn=None, collate_fn=None,
): ):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode] dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank] dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank] dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
...@@ -1487,6 +1484,9 @@ class Engine: ...@@ -1487,6 +1484,9 @@ class Engine:
split_data=self._strategy.split_data, split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes, data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks, data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
) )
self._prepare_reader(feed_list) self._prepare_reader(feed_list)
return dataloader return dataloader
...@@ -1498,9 +1498,18 @@ class Engine: ...@@ -1498,9 +1498,18 @@ class Engine:
) )
self._optimization_tuning(self._mode, tune_data, batch_size) self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs): def _validate_spec(self, specs):
specs = to_list(specs) specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None: if specs is not None:
for i, spec in enumerate(specs): for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec): if not isinstance(spec, InputSpec):
...@@ -1513,14 +1522,14 @@ class Engine: ...@@ -1513,14 +1522,14 @@ class Engine:
i, spec i, spec
) )
) )
if self._k_steps > 1: if self._acc_steps > 1:
shape = list(spec.shape) shape = list(spec.shape)
assert ( assert (
shape[0] % self._k_steps == 0 shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format( ), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps spec.shape[0], self._acc_steps
) )
shape[0] //= self._k_steps shape[0] //= self._acc_steps
spec.shape = shape spec.shape = shape
return specs or [] return specs or []
......
...@@ -297,13 +297,15 @@ class Parallelizer: ...@@ -297,13 +297,15 @@ class Parallelizer:
if self._strategy is None: if self._strategy is None:
return return
# data parallel optimization if self._strategy.dp_optimization.enable:
config = {} config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["global_rank"] = rank config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config) dp_pass = new_pass(
dp_pass.apply([main_program], [startup_program], self._pass_context) "auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable: if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict()) config = copy.deepcopy(self._strategy.sharding.to_dict())
......
...@@ -13,24 +13,25 @@ ...@@ -13,24 +13,25 @@
# limitations under the License # limitations under the License
import copy import copy
import numpy as np from paddle.fluid.framework import Program, Parameter, core
import paddle from paddle.distributed.auto_parallel.operators.common import (
import paddle.fluid as fluid get_distributed_operator_impl_container,
from paddle.fluid import core )
from paddle.fluid import framework as framework from paddle.distributed.auto_parallel.dist_context import DistributedContext
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from .dist_attribute import OperatorDistributedAttribute from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group from .utils import (
from .utils import set_dist_op_desc_original_id is_forward_op,
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op is_backward_op,
is_loss_op,
is_optimize_op,
is_fillconst_op_for_micro_batch,
)
from .operators.common import BACKWARD_ONLY_DIST_OPS from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue"] __varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [ __not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
] ]
...@@ -39,7 +40,7 @@ class Partitioner(object): ...@@ -39,7 +40,7 @@ class Partitioner(object):
warning:: Partitioner is experimental and subject to change. warning:: Partitioner is experimental and subject to change.
Partitioner convert a program into another program. Partitioner convert a program into another program.
Given a serial program which has been auto completed with shard annotation, the Partitioner Given a serial program which has been auto completed with shard annotation, the Partitioner
convert the serial program into a "distributed" program. The Partitioner will modify the serial convert the serial program into a "distributed" program. The Partitioner will modify the serial
program in following two ways, which is also the major difference between serial and distributed program: program in following two ways, which is also the major difference between serial and distributed program:
1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation 1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
...@@ -56,25 +57,29 @@ class Partitioner(object): ...@@ -56,25 +57,29 @@ class Partitioner(object):
""" """
if not isinstance(dist_context, DistributedContext): if not isinstance(dist_context, DistributedContext):
raise TypeError( raise TypeError(
"dist_context be paddle.fluid.DistributedContext, got %s here" % "dist_context be paddle.fluid.DistributedContext, got %s here"
type(dist_context)) % type(dist_context)
)
self._dist_context = dist_context self._dist_context = dist_context
self._rank_id = rank_id self._rank_id = rank_id
self._serial2dist_varname_mapping = {} self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = "" self._dist_varname_suffix = ""
def partition(self, serial_main_program, serial_startup_program, def partition(
params_grads): self, serial_main_program, serial_startup_program, params_grads
):
if not isinstance(serial_main_program, (Program)): if not isinstance(serial_main_program, (Program)):
raise TypeError( raise TypeError(
"main_program be paddle.fluid.framework.program, got %s here" % "main_program be paddle.fluid.framework.program, got %s here"
type(serial_main_program)) % type(serial_main_program)
)
# check if shard annotated serial program valid # check if shard annotated serial program valid
if not self._is_valid_annotated_program(serial_main_program): if not self._is_valid_annotated_program(serial_main_program):
raise RuntimeError( raise RuntimeError(
"Not all vars or ops are annotated in main program !") "Not all vars or ops are annotated in main program !"
)
# init distop helper # init distop helper
dist_op_context = self._dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
...@@ -86,24 +91,33 @@ class Partitioner(object): ...@@ -86,24 +91,33 @@ class Partitioner(object):
partitioned_startup_prog = None partitioned_startup_prog = None
else: else:
partitioned_startup_prog = self.partition_startup_program( partitioned_startup_prog = self.partition_startup_program(
serial_main_program, serial_startup_program) serial_main_program, serial_startup_program
)
dist_op_context.dst_startup_program = partitioned_startup_prog dist_op_context.dst_startup_program = partitioned_startup_prog
# partition main program # partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program( (
serial_main_program, params_grads) partitioned_main_prog,
partitioned_params_grads,
) = self.partition_main_program(serial_main_program, params_grads)
return partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads return (
partitioned_main_prog,
partitioned_startup_prog,
partitioned_params_grads,
)
def partition_startup_program(self, serial_main_program, def partition_startup_program(
serial_startup_program): self, serial_main_program, serial_startup_program
):
if not isinstance(serial_startup_program, (Program)): if not isinstance(serial_startup_program, (Program)):
raise TypeError( raise TypeError(
"dist_context be paddle.fluid.framework.program, got %s here" % "dist_context be paddle.fluid.framework.program, got %s here"
type(serial_startup_program)) % type(serial_startup_program)
)
partitioned_startup_prog = fluid.Program() partitioned_startup_prog = Program()
ref_block = serial_main_program.global_block() ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block() target_block = partitioned_startup_prog.global_block()
var2shape = {} var2shape = {}
...@@ -114,27 +128,33 @@ class Partitioner(object): ...@@ -114,27 +128,33 @@ class Partitioner(object):
assert var.persistable assert var.persistable
new_name = var.name + self._dist_varname_suffix new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name temp_varname_map[var.name] = new_name
target_shape = _partition_var(self._dist_context, ref_block, target_shape = _partition_var(
target_block, var.name, new_name) self._dist_context, ref_block, target_block, var.name, new_name
)
var2shape[new_name] = target_shape var2shape[new_name] = target_shape
# ops # ops
for op in serial_startup_program.global_block().ops: for op in serial_startup_program.global_block().ops:
# TODO if var not belong to this rank, should be filtered # TODO if var not belong to this rank, should be filtered
output_vars = op.desc.output_arg_names() output_vars = op.desc.output_arg_names()
assert len( assert (
output_vars len(output_vars) == 1
) == 1, "initializer should output only ONE variable, but got [{}]".format( ), "initializer should output only ONE variable, but got [{}]".format(
str(op.desc)) str(op.desc)
assert temp_varname_map[output_vars[ )
0]] in var2shape, "try to initialize [{}] which is not a persistable var".format( assert (
output_vars[0]) temp_varname_map[output_vars[0]] in var2shape
), "try to initialize [{}] which is not a persistable var".format(
output_vars[0]
)
new_op_desc = target_block.desc.append_op() new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0], new_op_desc._rename_output(
temp_varname_map[output_vars[0]]) output_vars[0], temp_varname_map[output_vars[0]]
new_op_desc._set_attr("shape", )
var2shape[temp_varname_map[output_vars[0]]]) new_op_desc._set_attr(
"shape", var2shape[temp_varname_map[output_vars[0]]]
)
target_block._sync_with_cpp() target_block._sync_with_cpp()
# set distribute atrribute # set distribute atrribute
...@@ -142,14 +162,17 @@ class Partitioner(object): ...@@ -142,14 +162,17 @@ class Partitioner(object):
assert new_op.type == new_op_desc.type() assert new_op.type == new_op_desc.type()
assert new_op.desc == new_op_desc assert new_op.desc == new_op_desc
output_var = target_block.var(output_vars[0]) output_var = target_block.var(output_vars[0])
output_var_attr = self._dist_context.get_tensor_dist_attr_for_program( output_var_attr = (
output_var) self._dist_context.get_tensor_dist_attr_for_program(output_var)
)
op_attr = OperatorDistributedAttribute() op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = output_var_attr.process_mesh op_attr.process_mesh = output_var_attr.process_mesh
op_attr.set_output_dims_mapping(output_var.name, op_attr.set_output_dims_mapping(
output_var_attr.dims_mapping) output_var.name, output_var_attr.dims_mapping
op_attr.set_input_dims_mapping(output_var.name, )
output_var_attr.dims_mapping) op_attr.set_input_dims_mapping(
output_var.name, output_var_attr.dims_mapping
)
self._dist_context.set_op_dist_attr_for_program(new_op, op_attr) self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
return partitioned_startup_prog return partitioned_startup_prog
...@@ -160,7 +183,7 @@ class Partitioner(object): ...@@ -160,7 +183,7 @@ class Partitioner(object):
2. replace local op with corresponding dist op 2. replace local op with corresponding dist op
""" """
partitioned_main_prog = fluid.Program() partitioned_main_prog = Program()
dist_op_context = self._dist_context.dist_op_context dist_op_context = self._dist_context.dist_op_context
dist_op_context.dst_main_program = partitioned_main_prog dist_op_context.dst_main_program = partitioned_main_prog
...@@ -171,7 +194,8 @@ class Partitioner(object): ...@@ -171,7 +194,8 @@ class Partitioner(object):
target_block = partitioned_main_prog.blocks[0] target_block = partitioned_main_prog.blocks[0]
else: else:
target_block = partitioned_main_prog._create_block( target_block = partitioned_main_prog._create_block(
parent_idx=ref_block.parent_idx) parent_idx=ref_block.parent_idx
)
assert ref_block.idx == target_block.idx assert ref_block.idx == target_block.idx
target_block._set_forward_block_idx(ref_block.forward_block_idx) target_block._set_forward_block_idx(ref_block.forward_block_idx)
dist_op_context.work_block = target_block dist_op_context.work_block = target_block
...@@ -186,8 +210,9 @@ class Partitioner(object): ...@@ -186,8 +210,9 @@ class Partitioner(object):
for attr_name in op.all_attrs(): for attr_name in op.all_attrs():
if op.attr_type(attr_name) == core.AttrType.BLOCK: if op.attr_type(attr_name) == core.AttrType.BLOCK:
relative_id = op._block_attr_id(attr_name) relative_id = op._block_attr_id(attr_name)
op._set_attr(attr_name, op._set_attr(
partitioned_main_prog.block(relative_id)) attr_name, partitioned_main_prog.block(relative_id)
)
partitioned_params_and_grads = [] partitioned_params_and_grads = []
for p, g in params_and_grads: for p, g in params_and_grads:
...@@ -198,7 +223,8 @@ class Partitioner(object): ...@@ -198,7 +223,8 @@ class Partitioner(object):
else: else:
assert g.name in self._serial2dist_varname_mapping assert g.name in self._serial2dist_varname_mapping
dist_g = self._get_dist_var_by_serial_var( dist_g = self._get_dist_var_by_serial_var(
g, partitioned_main_prog) g, partitioned_main_prog
)
partitioned_params_and_grads.append((dist_p, dist_g)) partitioned_params_and_grads.append((dist_p, dist_g))
return partitioned_main_prog, partitioned_params_and_grads return partitioned_main_prog, partitioned_params_and_grads
...@@ -222,71 +248,116 @@ class Partitioner(object): ...@@ -222,71 +248,116 @@ class Partitioner(object):
for idx in range(len(serial_ops)): for idx in range(len(serial_ops)):
if idx <= last_fwd_op_idx: if idx <= last_fwd_op_idx:
forward_op_id2forward_op[ forward_op_id2forward_op[
serial_ops[idx].desc.original_id()] = serial_ops[idx] serial_ops[idx].desc.original_id()
] = serial_ops[idx]
# partiiton # partiiton
appended_grad_times = 0 appended_grad_times = 0
for idx, op in enumerate(serial_ops): for idx, op in enumerate(serial_ops):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1]) if is_backward_op(op) and (
or is_loss_op(serial_ops[idx - 1])): is_forward_op(serial_ops[idx - 1])
or is_loss_op(serial_ops[idx - 1])
):
if not op_dist_attr.is_recompute: if not op_dist_attr.is_recompute:
appended_grad_times += 1 appended_grad_times += 1
# partititon input variables # partititon input variables
for serial_input_varname in op.desc.input_arg_names(): for serial_input_varname in op.desc.input_arg_names():
if serial_input_varname not in self._serial2dist_varname_mapping: if (
new_varname = serial_input_varname + self._dist_varname_suffix serial_input_varname
not in self._serial2dist_varname_mapping
):
new_varname = (
serial_input_varname + self._dist_varname_suffix
)
if ref_block.has_var(serial_input_varname): if ref_block.has_var(serial_input_varname):
_partition_var(self._dist_context, ref_block, _partition_var(
target_block, serial_input_varname, self._dist_context,
new_varname) ref_block,
target_block,
serial_input_varname,
new_varname,
)
else: else:
for varname_not_in_block in __varname_not_in_block__: for varname_not_in_block in __varname_not_in_block__:
assert varname_not_in_block in serial_input_varname, \ assert (
"{} is not found".format(serial_input_varname) varname_not_in_block in serial_input_varname
), "{} is not found".format(serial_input_varname)
self._serial2dist_varname_mapping[ self._serial2dist_varname_mapping[
serial_input_varname] = new_varname serial_input_varname
] = new_varname
# partition output vars # partition output vars
for serial_output_varname in op.desc.output_arg_names(): for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping: if (
new_varname = serial_output_varname + self._dist_varname_suffix serial_output_varname
_partition_var(self._dist_context, ref_block, target_block, not in self._serial2dist_varname_mapping
serial_output_varname, new_varname) ):
new_varname = (
serial_output_varname + self._dist_varname_suffix
)
_partition_var(
self._dist_context,
ref_block,
target_block,
serial_output_varname,
new_varname,
)
self._serial2dist_varname_mapping[ self._serial2dist_varname_mapping[
serial_output_varname] = new_varname serial_output_varname
] = new_varname
# partition op # partition op
if is_forward_op(op) or op_dist_attr.is_recompute: if (
is_forward_op(op)
or op_dist_attr.is_recompute
or is_fillconst_op_for_micro_batch(op)
):
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement( dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context) op, self._dist_context
dist_op_forward_impl.forward(self._dist_context, **kinputs, )
**koutputs) dist_op_forward_impl.forward(
self._dist_context, **kinputs, **koutputs
)
elif is_backward_op(op): elif is_backward_op(op):
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement( dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op
grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[ )
appended_grad_times] grad_var_to_var = (
self._dist_context.dist_op_context.grad_var_to_var[
appended_grad_times
]
)
dist_op_backward_impl.backward( dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs, self._dist_context,
**{"grad_var_to_var": grad_var_to_var}) **kinputs,
**koutputs,
**{"grad_var_to_var": grad_var_to_var}
)
elif is_optimize_op(op): elif is_optimize_op(op):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS # NOTE: BACKWARD_ONLY_DIST_OPS's op_role must be 2 because of 1F1B PASS
kinputs, koutputs = dist_op_context.prepare_context(op) kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_opt_impl = _get_dist_op_backward_implement( dist_op_opt_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op) op, self._dist_context, forward_op_id2forward_op
dist_op_opt_impl.backward(self._dist_context, **kinputs, )
**koutputs, **{"grad_var_to_var": {}}) dist_op_opt_impl.backward(
self._dist_context,
**kinputs,
**koutputs,
**{"grad_var_to_var": {}}
)
else: else:
raise NotImplementedError( raise NotImplementedError(
"partitioner only support forward and backward, optimize ops, but got {}" "partitioner only support forward and backward, optimize ops, but got {}".format(
.format(str(op))) str(op)
)
)
def _is_valid_annotated_program(self, program): def _is_valid_annotated_program(self, program):
...@@ -298,13 +369,16 @@ class Partitioner(object): ...@@ -298,13 +369,16 @@ class Partitioner(object):
] ]
var_dist_attrs = [ var_dist_attrs = [
self._dist_context.get_tensor_dist_attr_for_program(var) self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_ if (var.type not in __not_shape_var_type__) for var in vars_
if (var.type not in __not_shape_var_type__)
] ]
all_ops_annotated = all(dist_attr is not None all_ops_annotated = all(
for dist_attr in op_dist_attrs) dist_attr is not None for dist_attr in op_dist_attrs
all_vars_annotated = all(dist_attr is not None )
for dist_attr in var_dist_attrs) all_vars_annotated = all(
dist_attr is not None for dist_attr in var_dist_attrs
)
return all_ops_annotated and all_vars_annotated return all_ops_annotated and all_vars_annotated
...@@ -328,22 +402,26 @@ def _get_dist_shape(var, dist_attr): ...@@ -328,22 +402,26 @@ def _get_dist_shape(var, dist_attr):
assert len(var_shape) == len( assert len(var_shape) == len(
mapping mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format( ), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping) var_shape, mapping
)
new_shape = [] new_shape = []
for idx in range(len(var_shape)): for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1: if var_shape[idx] == -1 or mapping[idx] == -1:
new_shape.append(var_shape[idx]) new_shape.append(var_shape[idx])
else: else:
assert var_shape[idx] % mesh[mapping[ assert (
idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format( var_shape[idx] % mesh[mapping[idx]] == 0
var_shape[idx], mesh[mapping[idx]]) ), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
var_shape[idx], mesh[mapping[idx]]
)
new_shape.append(var_shape[idx] // mesh[mapping[idx]]) new_shape.append(var_shape[idx] // mesh[mapping[idx]])
return new_shape return new_shape
def _partition_parameter(dist_context, src_var, dst_block, dst_varname, def _partition_parameter(
dst_shape): dist_context, src_var, dst_block, dst_varname, dst_shape
):
# NOTE hack to copied Parameter # NOTE hack to copied Parameter
# not initialized parameter, need to initialize it # not initialized parameter, need to initialize it
copied_kwargs = {} copied_kwargs = {}
...@@ -353,39 +431,45 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname, ...@@ -353,39 +431,45 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
copied_kwargs['do_model_average'] = src_var.do_model_average copied_kwargs['do_model_average'] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip copied_kwargs['need_clip'] = src_var.need_clip
param = Parameter(block=dst_block, param = Parameter(
type=src_var.type, block=dst_block,
name=dst_varname, type=src_var.type,
shape=dst_shape, name=dst_varname,
dtype=src_var.dtype, shape=dst_shape,
lod_level=src_var.lod_level, dtype=src_var.dtype,
error_clip=src_var.error_clip, lod_level=src_var.lod_level,
stop_gradient=src_var.stop_gradient, error_clip=src_var.error_clip,
is_data=src_var.is_data, stop_gradient=src_var.stop_gradient,
belong_to_optimizer=src_var.belong_to_optimizer, is_data=src_var.is_data,
**copied_kwargs) belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
)
return param return param
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname, def _partition_intermediate_var(
dst_shape): dist_context, src_var, dst_block, dst_varname, dst_shape
var = dst_block.create_var(type=src_var.type, ):
name=dst_varname, var = dst_block.create_var(
shape=dst_shape, type=src_var.type,
dtype=src_var.dtype, name=dst_varname,
lod_level=src_var.lod_level, shape=dst_shape,
persistable=src_var.persistable, dtype=src_var.dtype,
error_clip=src_var.error_clip, lod_level=src_var.lod_level,
stop_gradient=src_var.stop_gradient, persistable=src_var.persistable,
is_data=src_var.is_data, error_clip=src_var.error_clip,
belong_to_optimizer=src_var.belong_to_optimizer) stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
return var return var
def _partition_var(dist_context, src_block, dst_block, src_varname, def _partition_var(
dst_varname): dist_context, src_block, dst_block, src_varname, dst_varname
):
""" """
partition include: split + replicate partition include: split + replicate
""" """
...@@ -393,44 +477,53 @@ def _partition_var(dist_context, src_block, dst_block, src_varname, ...@@ -393,44 +477,53 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
if src_var.type in __not_shape_var_type__: if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False) persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(type=src_var.type, new_var = dst_block.create_var(
name=dst_varname, type=src_var.type,
persistable=persist, name=dst_varname,
stop_gradient=True) persistable=persist,
stop_gradient=True,
)
target_shape = None target_shape = None
else: else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var) dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr) target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter): if isinstance(src_var, Parameter):
new_var = _partition_parameter(dist_context, src_var, dst_block, new_var = _partition_parameter(
dst_varname, target_shape) dist_context, src_var, dst_block, dst_varname, target_shape
)
else: else:
new_var = _partition_intermediate_var(dist_context, src_var, new_var = _partition_intermediate_var(
dst_block, dst_varname, dist_context, src_var, dst_block, dst_varname, target_shape
target_shape) )
dist_attr = copy.deepcopy( dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var)) dist_context.get_tensor_dist_attr_for_program(src_var)
)
assert dist_attr is not None assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr) dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)
return target_shape return target_shape
def _get_dist_op_backward_implement(backward_op, dist_context, def _get_dist_op_backward_implement(
forward_op_id2forward_op): backward_op, dist_context, forward_op_id2forward_op
):
dist_op_context = dist_context.dist_op_context dist_op_context = dist_context.dist_op_context
if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id: if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[ forward_op_id = dist_op_context.grad_op_id_to_op_id[
backward_op.desc.original_id()] backward_op.desc.original_id()
]
forward_op = forward_op_id2forward_op[forward_op_id] forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program( forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op) forward_op
)
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
forward_op_dist_attr.impl_type) forward_op_dist_attr.impl_type
)
dist_op_impl = dist_op_impl_container.get_impl( dist_op_impl = dist_op_impl_container.get_impl(
forward_op_dist_attr.impl_idx) forward_op_dist_attr.impl_idx
)
return dist_op_impl return dist_op_impl
# # NOTE trick for dist ops that only have backward implement # # NOTE trick for dist ops that only have backward implement
...@@ -438,7 +531,8 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -438,7 +531,8 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op) op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0 assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container( dist_op_impl = get_distributed_operator_impl_container(
op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx) op_dist_attr.impl_type
).get_impl(op_dist_attr.impl_idx)
return dist_op_impl return dist_op_impl
dist_op = get_distributed_operator_impl_container("default") dist_op = get_distributed_operator_impl_container("default")
...@@ -448,6 +542,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context, ...@@ -448,6 +542,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def _get_dist_op_forward_implement(forward_op, dist_context): def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op) dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_op_impl_container = get_distributed_operator_impl_container( dist_op_impl_container = get_distributed_operator_impl_container(
dist_attr.impl_type) dist_attr.impl_type
)
dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx) dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
return dist_op_impl return dist_op_impl
...@@ -422,11 +422,11 @@ class Inserter: ...@@ -422,11 +422,11 @@ class Inserter:
) )
inputs = {'X': [tensor]} inputs = {'X': [tensor]}
outputs = {"Out": [out]} outputs = {"Out": [out]}
attrs = {"in_place": False} attrs = {"in_place": False, "op_role": op_role}
slice_op = block._insert_op( assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
) )
slice_op._set_attr('op_namescope', "/auto_parallel/reshard") assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out return out
# use split once # use split once
...@@ -1217,6 +1217,8 @@ class Resharder: ...@@ -1217,6 +1217,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1] shape_x[0] <= shape_y[0] < shape_x[1]
): ):
overlapped = True overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped return overlapped
def is_unshard(self, dims_mapping): def is_unshard(self, dims_mapping):
...@@ -1304,6 +1306,14 @@ class Resharder: ...@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh # judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh: if tensor_process_mesh != op_process_mesh:
is_reshard = True is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else: else:
op_output_dims_mapping = dist_attr[1] op_output_dims_mapping = dist_attr[1]
if all( if all(
...@@ -1585,10 +1595,10 @@ class Resharder: ...@@ -1585,10 +1595,10 @@ class Resharder:
if i == 0: if i == 0:
all_partition_index_list.append(process_index[j][1]) all_partition_index_list.append(process_index[j][1])
for process in group: for process in group:
# append slice op desc min_comm_group = copy.deepcopy(group)
slice_starts = [] all_partition_index_list_copied = copy.deepcopy(
slice_ends = [] all_partition_index_list
slices_axes = [] )
target_partition_index = Resharder.compute_partition_index( target_partition_index = Resharder.compute_partition_index(
process, process,
complete_shape, complete_shape,
...@@ -1596,12 +1606,56 @@ class Resharder: ...@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape, target_process_shape,
target_process_group, target_process_group,
) )
for idx, item in enumerate(target_partition_index): for _process in group:
slice_starts.append(item[0]) source_partition_index = (
slice_ends.append(item[1]) Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx) slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc( slice_op_desc = SliceOpDesc(
starts=slice_starts, starts=slice_starts,
ends=slice_ends, ends=slice_ends,
...@@ -1616,16 +1670,16 @@ class Resharder: ...@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq[process] = ( op_desc_seq[process] = (
[ [
AllGatherOpDesc( AllGatherOpDesc(
group=group, group=min_comm_group,
shape=allgather_shape, shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool), is_bool=(source_tensor.dtype == paddle.bool),
), ),
ConcatOpDesc( ConcatOpDesc(
partition_index_list=all_partition_index_list partition_index_list=all_partition_index_list_copied
), ),
slice_op_desc, slice_op_desc,
] ]
if len(group) > 1 if len(min_comm_group) > 1
else [slice_op_desc] else [slice_op_desc]
) )
......
...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig): ...@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super(DatasetConfig, self).__init__(category, config_dict) super(DatasetConfig, self).__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super(DPOptimizationConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig): class Strategy(BaseConfig):
""" """
The `Strategy` object is used to configure the paralleization and optimization beheviors. The `Strategy` object is used to configure the paralleization and optimization beheviors.
...@@ -194,3 +200,6 @@ class Strategy(BaseConfig): ...@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.DATASET, None) config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict) self.dataset = DatasetConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad", "fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad", "flatten_contiguous_range_grad",
"relu_grad", "relu_grad",
"exp_grad",
] ]
forward_list = [ forward_list = [
"reshape2", "reshape2",
...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context): ...@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle", "fused_softmax_mask_upper_triangle",
"flatten_contiguous_range", "flatten_contiguous_range",
"relu", "relu",
"exp",
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
...@@ -1320,6 +1322,11 @@ def is_forward_op(op): ...@@ -1320,6 +1322,11 @@ def is_forward_op(op):
) )
def is_fillconst_op_for_micro_batch(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched))
def is_backward_op(op): def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and int( return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY] op.all_attrs()[OP_ROLE_KEY]
......
...@@ -18,15 +18,31 @@ import numpy as np ...@@ -18,15 +18,31 @@ import numpy as np
import paddle import paddle
from paddle.fluid import core, unique_name from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY from paddle.distributed.fleet.meta_optimizers.common import (
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op OpRole,
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
is_backward_op,
ring_id_to_process_group,
find_higher_order_backward_op,
)
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [ __rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum', 'lars_momentum',
'merge_momentum' 'sparse_momentum',
'dgc_momentum',
'momentum',
'merge_momentum',
] ]
# a heuristic number # a heuristic number
...@@ -41,7 +57,7 @@ def numel(var): ...@@ -41,7 +57,7 @@ def numel(var):
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
Apply Optimizations that specialized for data parallelism in Auto Parallel. Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling 1. prune grad scaling
2. overlap comm and calc 2. overlap comm and calc
3. fuse allreduce 3. fuse allreduce
""" """
...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self.set_attr("dist_context", None) self.set_attr("dist_context", None)
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False) self.set_attr("use_sharding", False)
self.set_attr("fuse_all_reduce_ops", False)
self.set_attr("fuse_grad_size_in_MB", 32)
self.set_attr("overlap_comm_cacl", False)
# {grad1: group1, grad2: group1, grad3: group2} # {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory # record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict() self._grad_name_to_group_map = OrderedDict()
...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def _check_self(self): def _check_self(self):
if self.get_attr("dist_context") is None: if self.get_attr("dist_context") is None:
return False return False
if (not isinstance(self.get_attr("global_rank"), if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
int)) or self.get_attr("global_rank") < 0: "global_rank"
) < 0:
return False return False
return True return True
...@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase): ...@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase):
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding") self.use_sharding = self.get_attr("use_sharding")
overlap_comm_cacl = self.get_attr("overlap_comm_cacl")
fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops")
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
self._analyze_program() self._analyze_program()
if self.is_data_parallel_applied(): if self.is_data_parallel_applied():
self._prune_grad_scaling() if overlap_comm_cacl:
self._calc_comm_overlap() self._prune_grad_scaling()
grad_group = self._fuse_allreduce() self._calc_comm_overlap()
if fuse_all_reduce_ops:
grad_group = self._fuse_allreduce()
# self.summary(grad_group) # self.summary(grad_group)
...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase): ...@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: comm op [{}] has NOT ring id.".format(str(op)) ), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id")) group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format( assert (
grad_name, str(op)) group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
self._grad_name_to_group_map[grad_name] = group self._grad_name_to_group_map[grad_name] = group
...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future. # TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network. # here we assume that the optimizer is unique in network.
elif is_optimize_op( elif (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
self._support_rescale_grad = True self._support_rescale_grad = True
not_synchronized_grads = [] not_synchronized_grads = []
for grad_name in scaled_grads: for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map: if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name) not_synchronized_grads.append(grad_name)
assert len( assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format( )
not_synchronized_grads)
def is_data_parallel_applied(self): def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0 return len(self._group_to_grad_name_map) > 0
...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase): ...@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self): def _could_be_prune(self):
return self.dist_context.gradient_scale and ( return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree()) self._support_rescale_grad or self._all_dp_groups_same_degree()
)
def _all_dp_groups_same_degree(self): def _all_dp_groups_same_degree(self):
return len( return (
set([ len(
len(group.ranks) set(
for group in self._group_to_grad_name_map.keys() [
])) == 1 len(group.ranks)
for group in self._group_to_grad_name_map.keys()
]
)
)
== 1
)
def _scale_backward_initial_grad(self): def _scale_backward_initial_grad(self):
...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op): if is_loss_grad_op(op):
assert op.type == 'fill_constant', \ assert op.type == 'fill_constant', (
"loss_grad_op must be fill_constant op, " \ "loss_grad_op must be fill_constant op, "
"but this op is {}".format(op.type) "but this op is {}".format(op.type)
)
assert op.has_attr('value') assert op.has_attr('value')
loss_scale = float(op.attr('value')) loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree loss_scale = loss_scale / dp_degree
...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase): ...@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = set() scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op( if (
op) and op.type in __rescale_grad_supported_opts__: is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr( assert op.has_attr(
'rescale_grad' 'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format( ), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op)) str(op)
assert len( )
op.input("Grad") assert (
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format( len(op.input("Grad")) == 1
str(op)) ), "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op)
)
grad_name = op.input("Grad")[0] grad_name = op.input("Grad")[0]
dp_degree = len( dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks)) list(self._grad_name_to_group_map[grad_name].ranks)
)
scaled_grads.add(grad_name) scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad) op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys( assert scaled_grads == set(
)), "Unexception: gradients [{}] are unscaled.".format( self._grad_name_to_group_map.keys()
set(self._grad_name_to_group_map.keys()) - scaled_grads) ), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
def _could_be_overlap(self): def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream # NOTE current different nccl comm will use different cuda stream
...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False) op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx, block._insert_op_without_sync(
type='c_wait_compute', idx,
inputs={'X': []}, type='c_wait_compute',
outputs={'Out': []}, inputs={'X': []},
attrs={ outputs={'Out': []},
'op_role': OpRole.Backward, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'ring_id': ring_id )
})
block._sync_with_cpp() block._sync_with_cpp()
...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase): ...@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad # other ops that might use communicating grad
else: else:
for input_var_name in op.input_arg_names: for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items( for (
): ring_id,
unsync_grad_names,
) in ring_id_to_un_sync_grad_map.items():
if input_var_name in unsync_grad_names: if input_var_name in unsync_grad_names:
# need to sync before op_i # need to sync before op_i
if i in op_idx_to_sync_ring_id_map: if i in op_idx_to_sync_ring_id_map:
...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase): ...@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for i in sorted(indices, reverse=True): for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]: for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i, block._insert_op_without_sync(
type='c_wait_comm', i,
inputs={'X': []}, type='c_wait_comm',
outputs={'Out': []}, inputs={'X': []},
attrs={ outputs={'Out': []},
'op_role': OpRole.Backward, attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
'ring_id': ring_id )
})
def _could_be_fuse(self): def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient. # TODO support gradient fuse higher order gradient.
...@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase): ...@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase):
""" """
conditions for gradients to be grouped: conditions for gradients to be grouped:
1. group size < max_fuse_numel 1. group size < max_fuse_numel
2. same dp group 2. same dp group
3. same dtype 3. same dtype
4. dependency: grad would NOT be used by other ops within group segment 4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor gradients inside same group would be fuse into one coalesce tensor
""" """
...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase): ...@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]): for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor # create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate( group.coalesce_var = block.create_var(
'coalecse_grad_{}'.format(i)), name=unique_name.generate('coalecse_grad_{}'.format(i)),
dtype=group.dtype, dtype=group.dtype,
persistable=False, persistable=False,
stop_gradient=True) stop_gradient=True,
)
# update allreduce & scale op # update allreduce & scale op
if group.scale_op_idx != -1: if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx] scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format( assert (
str(scale_op)) scale_op.type == 'scale'
scale_op._rename_input(scale_op.input_arg_names[0], ), "should found scale op but found {}".format(str(scale_op))
group.coalesce_var.name) scale_op._rename_input(
scale_op._rename_output(scale_op.output_arg_names[0], scale_op.input_arg_names[0], group.coalesce_var.name
group.coalesce_var.name) )
scale_op._rename_output(
scale_op.output_arg_names[0], group.coalesce_var.name
)
allreduce_op = block.ops[group.allreduce_op_idx] allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format( assert (
str(allreduce_op)) allreduce_op.type == 'c_allreduce_sum'
allreduce_op._rename_input(allreduce_op.input_arg_names[0], ), "should found c_allreduce_sum op but found {}".format(
group.coalesce_var.name) str(allreduce_op)
allreduce_op._rename_output(allreduce_op.output_arg_names[0], )
group.coalesce_var.name) allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
)
# remvoe un-used op # remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices remove_op_indices = (
group.remove_wait_op_indices
+ group.remove_allreduce_op_indices
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True): for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[ assert (
idx].type in remove_op_types, "Unexception: try to remove op {}".format( block.ops[idx].type in remove_op_types
str(op)) ), "Unexception: try to remove op {}".format(
str(block.ops[idx].type())
)
block._remove_op(idx) block._remove_op(idx)
# insert coalecse op # insert coalecse op
...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase): ...@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks.append(len(shape)) concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients] grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx, block._insert_op_without_sync(
type="coalesce_tensor", group.coalesce_op_idx,
inputs={"Input": grad_names}, type="coalesce_tensor",
outputs={ inputs={"Input": grad_names},
"Output": grad_names, outputs={
"FusedOutput": group.coalesce_var "Output": grad_names,
}, "FusedOutput": group.coalesce_var,
attrs={ },
"copy_data": False, attrs={
"use_align": True, "copy_data": False,
"dtype": group.dtype, "use_align": True,
"concated_shapes": "dtype": group.dtype,
concated_shapes, "concated_shapes": concated_shapes,
"concated_ranks": concated_ranks, "concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
block._sync_with_cpp() block._sync_with_cpp()
# TODO update dist attr # TODO update dist attr
...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def summary(self, grad_groups=[]): def summary(self, grad_groups=[]):
# TODO: add logger module # TODO: add logger module
import logging import logging
self._logger = logging.getLogger() self._logger = logging.getLogger()
self._logger.propagate = False self._logger.propagate = False
if not self._logger.handlers: if not self._logger.handlers:
...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase): ...@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if len(grad_groups) > 0: if len(grad_groups) > 0:
self._logger.info( self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops." "origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
.format(len(self._grad_name_to_group_map.keys()), len(self._grad_name_to_group_map.keys()), len(grad_groups)
len(grad_groups))) )
)
self._logger.info("gradient fusing group are following: ") self._logger.info("gradient fusing group are following: ")
fused_grads = set() fused_grads = set()
for i, group in enumerate(grad_groups): for i, group in enumerate(grad_groups):
self._logger.info( self._logger.info(
"coalecse gradient [{}] is composed by: {}".format( "coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients])) i, [grad.name for grad in group.gradients]
)
)
fused_grads.update([grad.name for grad in group.gradients]) fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set( individual_grads = set(self._grad_name_to_group_map.keys()) - set(
self._grad_name_to_group_map.keys()) - set(fused_grads) fused_grads
)
self._logger.info( self._logger.info(
"the following [{}] gradients are not fused: ".format( "the following [{}] gradients are not fused: ".format(
len(individual_grads))) len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads)) self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object): class GradientsGroup(object):
def __init__(self, ops, max_group_size): def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size self.max_group_size = max_group_size
self.ops = ops self.ops = ops
...@@ -575,8 +643,11 @@ class GradientsGroup(object): ...@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx -= 1 grad_op_idx -= 1
grad_op = self.ops[grad_op_idx] grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format( assert (
grad_var.name, str(grad_op)) grad_var.name in grad_op.output_arg_names
), "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op)
)
self.coalesce_op_idx = grad_op_idx self.coalesce_op_idx = grad_op_idx
def finalize(self): def finalize(self):
......
...@@ -12,23 +12,40 @@ ...@@ -12,23 +12,40 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import numpy as np
from collections import OrderedDict
from typing import List, Tuple, Dict, Any
import paddle import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid import layers from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.utils import (
set_var_dist_attr,
is_optimize_op,
is_backward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from .pass_base import PassBase, PassType, register_pass from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_process_group = get_world_process_group() world_process_group = get_world_process_group()
def _remove_and_get_optimizer_op(main_program, dist_context): def is_gradient_clip_op(op_desc):
return op_desc.has_attr("op_namescope") and op_desc.attr(
"op_namescope"
).startswith("/gradient_clip")
def _remove_and_get_ops(main_program, dist_context):
# 1 create tmp block # 1 create tmp block
# 2 mv optimizer op from global program to tmp block # 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context # 3 del the op from dist_context
...@@ -36,101 +53,119 @@ def _remove_and_get_optimizer_op(main_program, dist_context): ...@@ -36,101 +53,119 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block = main_program._create_block() temp_block = main_program._create_block()
removed_op_idx = [] removed_op_idx = []
optimize_ops_desc = [] optimize_ops_desc = []
allreduce_sum_desc = []
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
# append optimizer op to tmp block
if is_optimize_op(op): if is_optimize_op(op):
# append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op() new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc) new_op_desc.copy_from(op.desc)
optimize_ops_desc.append(new_op_desc) optimize_ops_desc.append(new_op_desc)
removed_op_idx.append(idx) removed_op_idx.append(idx)
dist_context.del_dist_op_for_program(op)
# del op from dist_context
if dist_context: # append allreduce_op and scale_op to tmp block
if is_backward_op(op):
if is_data_parallel_reduce_op(op) or is_data_parallel_scale_op(op):
assert len(op.desc.output_arg_names()) == 1
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
allreduce_sum_desc.append(new_op_desc)
removed_op_idx.append(idx)
dist_context.del_dist_op_for_program(op) dist_context.del_dist_op_for_program(op)
for idx in removed_op_idx[::-1]: for idx in removed_op_idx[::-1]:
main_block._remove_op(idx, sync=False) main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp() main_block._sync_with_cpp()
return optimize_ops_desc return optimize_ops_desc, allreduce_sum_desc
def _get_gm_cond_var(main_program, k_steps, dist_context): def _create_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block() main_block = main_program.global_block()
# Add const var # Add const var
k_step_var = layers.create_global_var(name="gradient_merge_k", k_step_var = layers.create_global_var(
shape=[1], name="gradient_merge_k",
value=int(k_steps), shape=[1],
dtype='int32', value=int(k_steps),
persistable=True, dtype='int32',
force_cpu=True) persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks)
zero_var = layers.create_global_var(name="gradient_merge_zero", zero_var = layers.create_global_var(
shape=[1], name="gradient_merge_zero",
value=int(0), shape=[1],
dtype='int32', value=int(0),
persistable=True, dtype='int32',
force_cpu=True) persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks)
# Add step var & cond var # Add step var & cond var
step_var = layers.create_global_var(name="gradient_merge_step", step_var = layers.create_global_var(
shape=[1], name="gradient_merge_step",
value=int(0), shape=[1],
dtype='int32', value=int(0),
persistable=True, dtype='int32',
force_cpu=True) persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
cond_var = main_block.create_var(name="gradient_merge_cond", cond_var = main_block.create_var(
shape=[1], name="gradient_merge_cond", shape=[1], dtype='bool'
dtype='bool') )
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks) set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"): with paddle.static.device_guard("cpu"):
# step_var += 1 # step_var += 1
increment_op = main_block.append_op(type='increment', increment_op = main_block.append_op(
inputs={'X': [step_var]}, type='increment',
outputs={'Out': [step_var]}, inputs={'X': [step_var]},
attrs={ outputs={'Out': [step_var]},
'step': float(1.0), attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward},
OP_ROLE_KEY: OpRole.Backward )
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context) increment_op, world_process_group.ranks, [-1], dist_context
)
# step_var %= k_step # step_var %= k_step
elementwise_mod_op = main_block.append_op(type='elementwise_mod', elementwise_mod_op = main_block.append_op(
inputs={ type='elementwise_mod',
'X': step_var, inputs={'X': step_var, 'Y': k_step_var},
'Y': k_step_var outputs={'Out': step_var},
}, attrs={
outputs={'Out': step_var}, 'axis': -1,
attrs={ 'use_mkldnn': False,
'axis': -1, OP_ROLE_KEY: OpRole.Backward,
'use_mkldnn': False, },
OP_ROLE_KEY: )
OpRole.Backward
})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context) elementwise_mod_op, world_process_group.ranks, [-1], dist_context
)
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
equal_op = main_block.append_op(type='equal', equal_op = main_block.append_op(
inputs={ type='equal',
'X': step_var, inputs={'X': step_var, 'Y': zero_var},
'Y': zero_var outputs={'Out': cond_var},
}, attrs={OP_ROLE_KEY: OpRole.Backward},
outputs={'Out': cond_var}, )
attrs={OP_ROLE_KEY: OpRole.Backward})
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context) equal_op, world_process_group.ranks, [-1], dist_context
)
return cond_var return cond_var
def _append_gradient_merge_backward_op( def _append_gradient_merge_backward_op(
main_program, startup_program, params_grads: List[Tuple[Any, Any]], main_program,
dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]: startup_program,
params_grads,
master_grad,
dist_context,
):
main_block = main_program.global_block() main_block = main_program.global_block()
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
...@@ -148,149 +183,260 @@ def _append_gradient_merge_backward_op( ...@@ -148,149 +183,260 @@ def _append_gradient_merge_backward_op(
for param, grad in params_grads: for param, grad in params_grads:
param_name = param.name param_name = param.name
param_var = main_block.var(param_name) param_var = main_block.var(param_name)
assert (param_var is not None) assert param_var is not None
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
assert ref_dist_attr is not None
gradient_merge_var = main_block.create_var(name=param_name +
"@GRAD@GradientMerge",
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping
set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping, dst_dtype = (
ref_process_mesh) core.VarDesc.VarType.FP32 if master_grad else param_var.dtype
)
# 2.1 crate param@GRAD@MERGE var in startup_block
startup_gradient_merge_var = startup_block.create_var( startup_gradient_merge_var = startup_block.create_var(
name=param_name + "@GRAD@GradientMerge", name=param_name + "@GRAD@MERGED",
shape=param_var.shape,
dtype=dst_dtype,
persistable=True,
)
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": param_var.shape,
"dtype": dst_dtype,
"value": float(0),
},
)
# 2.2 crate param@GRAD@MERGE var in main_block
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
assert ref_dist_attr is not None
gradient_merge_var = main_block.create_var(
name=param_name + "@GRAD@MERGED",
shape=param_var.shape, shape=param_var.shape,
dtype=param_var.dtype, dtype=dst_dtype,
persistable=True) persistable=True,
startup_block.append_op(type="fill_constant", )
outputs={"Out": startup_gradient_merge_var}, ref_process_mesh = ref_dist_attr.process_mesh
attrs={ ref_dims_mapping = ref_dist_attr.dims_mapping
"shape": param_var.shape, set_var_dist_attr(
"dtype": param_var.dtype, dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh
"value": float(0), )
})
# 2.3 grad_merge += grad
# grad_merge += grad grad_name = grad.name
new_grad_op = main_block.append_op(type="elementwise_add", if grad.dtype != dst_dtype:
inputs={ cast_grad_name = grad_name + "@TMP"
'X': grad, cast_grad_var = main_block.create_var(
'Y': gradient_merge_var name=cast_grad_name,
}, shape=grad.shape,
outputs={'Out': gradient_merge_var}, dtype=dst_dtype,
attrs={ persistable=False,
'axis': -1, stop_gradient=grad.stop_gradient,
'use_mkldnn': False, )
OP_ROLE_KEY: OpRole.Backward set_var_dist_attr(
}) dist_context, cast_grad_var, ref_dims_mapping, ref_process_mesh
)
cast_op = main_block.append_op(
type="cast",
inputs={"X": grad},
outputs={"Out": cast_grad_var},
attrs={
"in_dtype": grad.dtype,
"out_dtype": dst_dtype,
OP_ROLE_KEY: OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_process_mesh, ref_dims_mapping, dist_context
)
grad = cast_grad_var
new_grad_op = main_block.append_op(
type="elementwise_add",
inputs={'X': grad, 'Y': gradient_merge_var},
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward,
},
)
new_params_to_grads.append([param, gradient_merge_var]) new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name grad_to_gradient_merge[grad_name] = gradient_merge_var.name
naive_set_dist_op_attr_for_program_by_mesh_and_mapping( naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context) new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context
)
return new_params_to_grads, grad_to_gradient_merge return new_params_to_grads, grad_to_gradient_merge
def _create_cond_block_and_update_optimizer( def _rename_arg_names(op_desc, var_name_dict):
main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]], for input_name in op_desc.input_arg_names():
grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any], if input_name in var_name_dict:
k_steps, avg): op_desc._rename_input(input_name, var_name_dict[input_name])
for output_name in op_desc.output_arg_names():
if output_name in var_name_dict:
op_desc._rename_output(output_name, var_name_dict[output_name])
def _create_cond_block_and_update_optimizer(
main_program,
cond_var,
params_grads,
new_params_to_grads,
grad_to_gradient_merge,
optimize_ops_desc,
allreduce_sum_desc,
k_steps,
avg,
master_grad,
):
def true_apply_gradient(): def true_apply_gradient():
cur_block_idx = main_program.current_block_idx cur_block_idx = main_program.current_block_idx
cur_block = main_program.current_block() cur_block = main_program.current_block()
# cur_block's forward_block & backward_block is itself # cur_block's forward_block & backward_block is itself
cur_block._set_forward_block_idx(cur_block_idx) cur_block._set_forward_block_idx(cur_block_idx)
op_maker = core.op_proto_and_checker_maker
# record grads_name to insert c_allreduce_sum op
grads_name = [grad.name for _, grad in params_grads]
# append c_allreduce_sum ops and scale ops
for op_desc in allreduce_sum_desc:
outputs_name = op_desc.output_arg_names()
assert len(outputs_name) == 1
if outputs_name[0] in grads_name:
new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc)
_rename_arg_names(new_op_desc, grad_to_gradient_merge)
new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
cur_block._sync_with_cpp()
if avg: if avg:
for param, new_grad in new_params_to_grads: for _, new_grad in new_params_to_grads:
# grad /= k_steps # grad /= k_steps
cur_block.append_op(type='scale', cur_block.append_op(
inputs={'X': new_grad}, type='scale',
outputs={'Out': new_grad}, inputs={'X': new_grad},
attrs={ outputs={'Out': new_grad},
'scale': 1.0 / k_steps, attrs={
'bias': 0.0, 'scale': 1.0 / k_steps,
'bias_after_scale': False 'bias': 0.0,
}) 'bias_after_scale': False,
},
)
new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize) new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
cast_name_dict = {}
# append optimizer ops # append optimizer ops
for op_desc in optimize_ops_desc: for op_desc in optimize_ops_desc:
if master_grad and is_gradient_clip_op(op_desc):
if op_desc.type() == "cast":
if (
op_desc.attr('out_dtype') in [4, 22]
and op_desc.attr('in_dtype') == 5
):
cast_name_dict[
op_desc.output_arg_names()[0]
] = op_desc.input_arg_names()[0]
elif (
op_desc.attr('in_dtype') in [4, 22]
and op_desc.attr('out_dtype') == 5
):
cast_name_dict[
op_desc.output_arg_names()[0]
] = op_desc.input_arg_names()[0]
continue
for out_name in op_desc.output_arg_names():
out_var = cur_block._var_recursive(out_name)
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
_rename_arg_names(op_desc, cast_name_dict)
new_op_desc = cur_block.desc.append_op() new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc) new_op_desc.copy_from(op_desc)
#update input/output # update input/output
for input_name in new_op_desc.input_arg_names(): _rename_arg_names(new_op_desc, grad_to_gradient_merge)
if input_name in grad_to_gradient_merge:
new_op_desc._rename_input(
input_name, grad_to_gradient_merge[input_name])
for output_name in new_op_desc.output_arg_names():
if output_name in grad_to_gradient_merge:
new_op_desc._rename_output(
output_name, grad_to_gradient_merge[output_name])
# remove op_role_var # remove op_role_var
if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()): if new_op_desc.has_attr(OP_ROLE_VAR_KEY):
new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName()) new_op_desc.remove_attr(OP_ROLE_VAR_KEY)
# op's update Grad # op's update Grad
if core.grad_var_suffix() in new_op_desc.input_arg_names(): if core.grad_var_suffix() in new_op_desc.input_arg_names():
grad_value = new_op_desc.input("Grad")[0] grad_value = new_op_desc.input("Grad")[0]
# TODO FIXME(xym) support fp16 # TODO FIXME(xym) support fp16
grad_merge_value = grad_value + '@GradientMerge' grad_merge_value = grad_value + '@MERGED'
new_op_desc.set_input("Grad", [grad_merge_value]) new_op_desc.set_input("Grad", [grad_merge_value])
main_program.global_block()._sync_with_cpp()
cur_block._sync_with_cpp() cur_block._sync_with_cpp()
# clear gradient_merge_vars # clear gradient_merge_vars
for param, new_grad in new_params_to_grads: for _, new_grad in new_params_to_grads:
layers.fill_constant(shape=new_grad.shape, layers.fill_constant(
dtype=new_grad.dtype, shape=new_grad.shape,
value=0.0, dtype=new_grad.dtype,
out=new_grad) value=0.0,
new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize) out=new_grad,
)
new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None) layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1] cond_op = main_program.global_block().ops[-1]
cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize) cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg, def parse_program(
dist_context): main_program,
# 1 remove optimizer_op from main_program startup_program,
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context) params_grads,
k_steps,
avg,
master_grad,
dist_context,
):
# 1 remove optimizer_op, allreduce_sum_op and scale_op from main_program
optimize_ops_desc, allreduce_sum_desc = _remove_and_get_ops(
main_program, dist_context
)
# back to block 0 # back to block 0
main_program._rollback() main_program._rollback()
# 2 append gradient merge backward op to main_program # 2 append gradient merge backward op to main_program
new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op( (
main_program, startup_program, params_grads, dist_context) new_params_to_grads,
grad_to_gradient_merge,
) = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, master_grad, dist_context
)
# 3 create gradient_merge_cond # 3 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context) cond_var = _create_gm_cond_var(main_program, k_steps, dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops # 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var, _create_cond_block_and_update_optimizer(
new_params_to_grads, main_program,
grad_to_gradient_merge, cond_var,
optimize_ops_desc, k_steps, avg) params_grads,
new_params_to_grads,
grad_to_gradient_merge,
optimize_ops_desc,
allreduce_sum_desc,
k_steps,
avg,
master_grad,
)
@register_pass("auto_parallel_gradient_merge_pass") @register_pass("auto_parallel_gradient_merge_pass")
class GradientMergePass(PassBase): class GradientMergePass(PassBase):
def __init__(self): def __init__(self):
super(GradientMergePass, self).__init__() super(GradientMergePass, self).__init__()
self.set_attr("k_steps", -1) self.set_attr("k_steps", -1)
self.set_attr("avg", True) self.set_attr("avg", True)
self.set_attr("master_grad", False)
def _check_self(self): def _check_self(self):
if self.get_attr("k_steps") < 1: if self.get_attr("k_steps") < 1:
...@@ -306,10 +452,20 @@ class GradientMergePass(PassBase): ...@@ -306,10 +452,20 @@ class GradientMergePass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context): def _apply_single_impl(self, main_program, startup_program, context):
k_steps = self.get_attr("k_steps", -1) k_steps = self.get_attr("k_steps", -1)
avg = self.get_attr("avg", False) avg = self.get_attr("avg", False)
master_grad = self.get_attr("master_grad", False)
dist_context = self.get_attr("dist_context") dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
# TODO(zyl): make master_grad configurable
master_grad = True
with paddle.static.program_guard(main_program, startup_program): with paddle.static.program_guard(main_program, startup_program):
parse_program(main_program, startup_program, params_grads, k_steps, parse_program(
avg, dist_context) main_program,
startup_program,
params_grads,
k_steps,
avg,
master_grad,
dist_context,
)
main_program._sync_with_cpp() main_program._sync_with_cpp()
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from logging import exception
import os import os
from paddle.fluid import core from paddle.fluid import core
...@@ -26,6 +25,7 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -26,6 +25,7 @@ from paddle.distributed.auto_parallel.utils import (
is_backward_op, is_backward_op,
is_optimize_op, is_optimize_op,
is_lr_sched_op, is_lr_sched_op,
is_fillconst_op_for_micro_batch,
) )
...@@ -38,6 +38,12 @@ __not_shape_var_type__ = [ ...@@ -38,6 +38,12 @@ __not_shape_var_type__ = [
] ]
def is_reshard_op(op):
return op.has_attr('op_namescope') and "/auto_parallel/reshard" in op.attr(
'op_namescope'
)
@register_pass("auto_parallel_pipeline") @register_pass("auto_parallel_pipeline")
class PipelinePass(PassBase): class PipelinePass(PassBase):
def __init__(self): def __init__(self):
...@@ -59,8 +65,17 @@ class PipelinePass(PassBase): ...@@ -59,8 +65,17 @@ class PipelinePass(PassBase):
self._gen_bsz = self.get_attr("generation_batch_size") self._gen_bsz = self.get_attr("generation_batch_size")
self._program = main_program self._program = main_program
self._cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
self._nrank = len(trainer_endpoints)
# compute current pp stage
self._pp_stages = len(self._dist_context.process_meshes)
self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B": if self._mode == "1F1B":
raise NotImplementedError("1F1B has not been implemented") self._insert_sync_ops_for_1f1b()
self._task_1f1b()
elif self._mode == "F-Then-B": elif self._mode == "F-Then-B":
raise NotImplementedError("F-Then-B has not been implemented") raise NotImplementedError("F-Then-B has not been implemented")
elif self._mode == "stream": elif self._mode == "stream":
...@@ -103,6 +118,93 @@ class PipelinePass(PassBase): ...@@ -103,6 +118,93 @@ class PipelinePass(PassBase):
block._sync_with_cpp() block._sync_with_cpp()
def _insert_sync_ops_for_1f1b(self):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for block in self._program.blocks:
offset = 0
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if is_optimize_op(op):
first_optimize_index = index
break
# insert sync ops
for index, op in enumerate(list(block.ops)):
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
op_role = op.attr('op_role')
ring_id = op.attr('ring_id')
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._insert_op_without_sync(
index=index + offset,
type="c_sync_calc_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': op_role},
)
offset += 1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if int(op_role) == int(OpRole.Backward):
index = first_optimize_index + offset
new_op_role = OpRole.Optimize
else:
index = index + offset + 1
new_op_role = OpRole.Backward
sync_comm_op = block._insert_op_without_sync(
index=index,
type="c_sync_comm_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
'op_role': new_op_role,
'ring_id': ring_id,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if int(op_role) == int(OpRole.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
offset += 1
block._sync_with_cpp()
offset = 0
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == "recv_v2" and is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index:
break
if op.type == 'c_sync_comm_stream' and op.has_attr(
'pipeline_flag'
):
var_name = op.output_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
block._insert_op_without_sync(
index=backward_recv_index,
type="nop",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': OpRole.Backward},
)
block._sync_with_cpp()
def _create_param(self, dst_block, src_var): def _create_param(self, dst_block, src_var):
copied_kwargs = {} copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable copied_kwargs['trainable'] = src_var.trainable
...@@ -190,16 +292,185 @@ class PipelinePass(PassBase): ...@@ -190,16 +292,185 @@ class PipelinePass(PassBase):
break break
return pp_idx return pp_idx
def _task_stream(self): def _task_1f1b(self):
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0)) # create fwd, bwd, opt program with op_role
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',') num_of_functionality = 4
nrank = len(trainer_endpoints) lr_prog = Program()
num_of_functionality = 5 fwd_prog = Program()
bwd_prog = Program()
opt_prog = Program()
for idx, src_block in enumerate(self._program.blocks):
if idx == 0:
lr_block = lr_prog.block(0)
fwd_block = fwd_prog.block(0)
bwd_block = bwd_prog.block(0)
opt_block = opt_prog.block(0)
else:
lr_block = lr_prog._create_block(
parent_idx=src_block.parent_idx
)
fwd_block = fwd_prog._create_block(
parent_idx=src_block.parent_idx
)
bwd_block = bwd_prog._create_block(
parent_idx=src_block.parent_idx
)
opt_block = opt_prog._create_block(
parent_idx=src_block.parent_idx
)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
for op in src_block.ops:
if is_lr_sched_op(op):
self._create_program(src_block, lr_block, op)
if is_forward_op(op) or is_fillconst_op_for_micro_batch(op):
self._create_program(src_block, fwd_block, op)
elif is_backward_op(op):
self._create_program(src_block, bwd_block, op)
elif is_optimize_op(op):
self._create_program(src_block, opt_block, op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer."
)
# compute current pp stage lr_prog._sync_with_cpp()
pp_stages = len(self._dist_context.process_meshes) fwd_prog._sync_with_cpp()
cur_pp_stage = self._get_pp_stage(cur_rank) bwd_prog._sync_with_cpp()
opt_prog._sync_with_cpp()
lr_prog._rollback()
fwd_prog._rollback()
bwd_prog._rollback()
opt_prog._rollback()
# Create task nodes.
lr_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=lr_prog,
task_id=int(self._cur_rank * num_of_functionality + 0),
node_type="Amplifier",
lazy_initialize=True,
)
lr_task_node.set_run_pre_steps(self._acc_steps)
fwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=fwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 1),
node_type="Compute",
lazy_initialize=True,
)
bwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=bwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 2),
node_type="Compute",
lazy_initialize=True,
)
opt_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=opt_prog,
task_id=int(self._cur_rank * num_of_functionality + 3),
node_type="Amplifier",
lazy_initialize=True,
)
opt_task_node.set_run_pre_steps(self._acc_steps)
opt_task_node.set_run_at_offset(self._acc_steps - 1)
task_nodes = {
"lr": lr_task_node,
"fwd": fwd_task_node,
"bwd": bwd_task_node,
"opt": opt_task_node,
}
# get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams
pp_upstream = up_down_streams.ups(self._cur_rank)
pp_downstream = up_down_streams.downs(self._cur_rank)
# set upstream/downstream for task_nodes of cur_rank
for i, (task_role, task_node) in enumerate(task_nodes.items()):
cur_id = int(self._cur_rank * num_of_functionality + i)
ups = []
downs = []
# set upstream/downstream and buffersize in pipeline stage
pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
prev_id = cur_id - 1
next_id = cur_id + 1
if task_role != "lr":
buf_size = pp_buff_size if task_role == "bwd" else 2
ups.append((prev_id, buf_size))
if task_role != "opt":
buf_size = pp_buff_size if task_role == "fwd" else 2
downs.append((next_id, buf_size))
# set upstream/downstream and buffersize cross pipeline stage
for upstream in pp_upstream:
upstream_id = int(upstream * num_of_functionality + i)
if task_role == "fwd":
if upstream != -1:
ups.append((upstream_id, 2))
elif task_role == "bwd":
if upstream != -1:
downs.append((upstream_id, 2))
for downstream in pp_downstream:
downstream_id = int(downstream * num_of_functionality + i)
if task_role == "fwd":
if downstream != -1:
downs.append((downstream_id, 2))
elif task_role == "bwd":
if downstream != -1:
ups.append((downstream_id, 2))
for up in ups:
print(
"Task:",
cur_id,
"'s upstream includes:",
up[0],
", buffer size is:",
up[1],
)
task_node.add_upstream_task(up[0], up[1])
for down in downs:
print(
"Task:",
cur_id,
"'s downstream includes:",
down[0],
", buffer size is:",
down[1],
)
task_node.add_downstream_task(down[0], down[1])
# record global message: task_id_to_rank
task_id_to_rank = {}
for i in range(self._nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = {}
self._program._pipeline_opt['fleet_opt'] = {
"tasks": list(task_nodes.values()),
"task_id_to_rank": task_id_to_rank,
"num_micro_batches": self._acc_steps,
}
def _task_stream(self):
num_of_functionality = 5
start_prog = Program() start_prog = Program()
cond_prog = Program() cond_prog = Program()
end_prog = Program() end_prog = Program()
...@@ -207,6 +478,7 @@ class PipelinePass(PassBase): ...@@ -207,6 +478,7 @@ class PipelinePass(PassBase):
recv_prog = Program() recv_prog = Program()
cond_var_name = None cond_var_name = None
# record the varnames related to the while cond vars and communicate by nccl
send_vars_name = set() send_vars_name = set()
recv_vars_name = dict() recv_vars_name = dict()
for ib, src_block in enumerate(self._program.blocks): for ib, src_block in enumerate(self._program.blocks):
...@@ -231,38 +503,23 @@ class PipelinePass(PassBase): ...@@ -231,38 +503,23 @@ class PipelinePass(PassBase):
src_block, end_block, op, force_create=True src_block, end_block, op, force_create=True
) )
elif ib == 1: elif ib == 1:
# NOTE: The while block will be split to two separate blocks.
# The send_block:
# include all ops about tansformer generation
# execlude the nccl op about the while cond var
# The recv_block:
# include all ops about the while cond var
# execlude the nccl op about the while cond var
# the nccl op about cond var:
# put these varnames in the task node and do communication by brpc
send_block = send_prog.block(0) send_block = send_prog.block(0)
recv_block = recv_prog.block(0) recv_block = recv_prog.block(0)
is_after_send_op = False is_after_send_op = False
is_after_recv_op = False is_after_recv_op = False
for op in src_block.ops: for i, op in enumerate(src_block.ops):
if op.type == "send_v2" and not is_after_send_op: if op.type == "send_v2" and not is_after_send_op:
is_after_send_op = True is_after_send_op = True
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
continue
if (
op.type not in ["recv_2", "assign"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
):
if (
len(op.desc.input_arg_names()) > 0
and "@RESHARD"
not in op.desc.input_arg_names()[0]
):
send_vars_name.add(
op.desc.input_arg_names()[0]
)
continue
if op.type == "send_v2":
continue
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if ( if (
is_after_send_op is_after_send_op
...@@ -270,45 +527,21 @@ class PipelinePass(PassBase): ...@@ -270,45 +527,21 @@ class PipelinePass(PassBase):
and op.type == "recv_v2" and op.type == "recv_v2"
): ):
is_after_recv_op = True is_after_recv_op = True
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
old_var_name = var_name[:index]
else:
old_var_name = var_name
recv_vars_name[var_name] = old_var_name
if not src_block._find_var_recursive(old_var_name):
src_var = src_block._var_recursive(var_name)
recv_block.create_var(
type=src_var.type,
name=old_var_name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
continue
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
if not is_after_send_op or not is_after_recv_op: if not is_after_send_op or not is_after_recv_op:
if cur_pp_stage == pp_stages - 1: if self._cur_pp_stage == self._pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]: # the c_sync_calc_stream about c_allgather cannot be removed
if (
op.type == "c_sync_calc_stream"
and src_block.ops[i + 1].type == "send_v2"
):
continue
if op.type == "nop":
continue continue
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if ( if (
op.type not in ["recv_2", "assign"] op.type
not in ["recv_2", "assign", "c_allgather"]
and op.has_attr('op_namescope') and op.has_attr('op_namescope')
and "/auto_parallel/reshard" and "/auto_parallel/reshard"
in op.attr('op_namescope') in op.attr('op_namescope')
...@@ -327,13 +560,16 @@ class PipelinePass(PassBase): ...@@ -327,13 +560,16 @@ class PipelinePass(PassBase):
self._create_program( self._create_program(
src_block, send_block, op, force_create=True src_block, send_block, op, force_create=True
) )
continue
if is_after_send_op and is_after_recv_op: if is_after_send_op and is_after_recv_op:
# HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm
if op.has_attr( if op.has_attr(
'op_namescope' 'op_namescope'
) and "/auto_parallel/reshard" in op.attr( ) and "/auto_parallel/reshard" in op.attr(
'op_namescope' 'op_namescope'
): ):
# remove the suffix of "@RESHARD"
var_name = op.desc.output_arg_names()[0] var_name = op.desc.output_arg_names()[0]
index = var_name.find("@") index = var_name.find("@")
if index > 0: if index > 0:
...@@ -365,6 +601,7 @@ class PipelinePass(PassBase): ...@@ -365,6 +601,7 @@ class PipelinePass(PassBase):
self._create_program( self._create_program(
src_block, recv_block, op, force_create=True src_block, recv_block, op, force_create=True
) )
continue
else: else:
raise Exception("Only support generation condition.") raise Exception("Only support generation condition.")
...@@ -406,52 +643,52 @@ class PipelinePass(PassBase): ...@@ -406,52 +643,52 @@ class PipelinePass(PassBase):
vars_to_shape = recv_task_node_var_shape vars_to_shape = recv_task_node_var_shape
start_task_node = TaskNode( start_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Start", node_type="Start",
task_id=int(cur_rank * num_of_functionality + 0), task_id=int(self._cur_rank * num_of_functionality + 0),
program=start_prog, program=start_prog,
lazy_initialize=True, lazy_initialize=True,
) )
cond_task_node = TaskNode( cond_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Cond", node_type="Cond",
task_id=int(cur_rank * num_of_functionality + 1), task_id=int(self._cur_rank * num_of_functionality + 1),
program=cond_prog, program=cond_prog,
cond_var_name=cond_var_name, cond_var_name=cond_var_name,
lazy_initialize=True, lazy_initialize=True,
) )
send_task_node = TaskNode( send_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 2), task_id=int(self._cur_rank * num_of_functionality + 2),
program=send_prog, program=send_prog,
lazy_initialize=True, lazy_initialize=True,
) )
recv_task_node = TaskNode( recv_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 3), task_id=int(self._cur_rank * num_of_functionality + 3),
program=recv_prog, program=recv_prog,
lazy_initialize=True, lazy_initialize=True,
vars_to_dtype=vars_to_dtype, vars_to_dtype=vars_to_dtype,
vars_to_shape=vars_to_shape, vars_to_shape=vars_to_shape,
) )
end_task_node = TaskNode( end_task_node = TaskNode(
rank=cur_rank, rank=self._cur_rank,
max_run_times=self._acc_steps, max_run_times=self._acc_steps,
node_type="Compute", node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 4), task_id=int(self._cur_rank * num_of_functionality + 4),
program=end_prog, program=end_prog,
lazy_initialize=True, lazy_initialize=True,
) )
# add dependencies for task nodes intra stage # add dependencies for task nodes intra stage
inf = -1 inf = -1
pp_buff_size = int(pp_stages - cur_pp_stage) pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
start_task_node.add_downstream_task( start_task_node.add_downstream_task(
cond_task_node.task_id(), self._gen_bsz cond_task_node.task_id(), self._gen_bsz
) )
...@@ -560,12 +797,12 @@ class PipelinePass(PassBase): ...@@ -560,12 +797,12 @@ class PipelinePass(PassBase):
# add dependencies for task nodes inter stage # add dependencies for task nodes inter stage
# get upstream ranks and downstream ranks of cur_rank # get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams up_down_streams = self._dist_context.up_down_streams
pp_upstream_ranks = up_down_streams.ups(cur_rank) pp_upstream = up_down_streams.ups(self._cur_rank)
pp_downstream_ranks = up_down_streams.downs(cur_rank) pp_downstream = up_down_streams.downs(self._cur_rank)
for upstream_rank in pp_upstream_ranks: for upstream_rank in pp_upstream:
upstream_pp_stage = self._get_pp_stage(upstream_rank) upstream_pp_stage = self._get_pp_stage(upstream_rank)
if upstream_pp_stage < pp_stages - 1: if upstream_pp_stage < self._pp_stages - 1:
upstream_task_id = int(upstream_rank * num_of_functionality + 2) upstream_task_id = int(upstream_rank * num_of_functionality + 2)
send_task_node.add_upstream_task(upstream_task_id) send_task_node.add_upstream_task(upstream_task_id)
print( print(
...@@ -587,8 +824,8 @@ class PipelinePass(PassBase): ...@@ -587,8 +824,8 @@ class PipelinePass(PassBase):
", buffer size is:", ", buffer size is:",
2, 2,
) )
for downstream_rank in pp_downstream_ranks: for downstream_rank in pp_downstream:
if cur_pp_stage < pp_stages - 1: if self._cur_pp_stage < self._pp_stages - 1:
downstream_task_id = int( downstream_task_id = int(
downstream_rank * num_of_functionality + 2 downstream_rank * num_of_functionality + 2
) )
...@@ -616,7 +853,7 @@ class PipelinePass(PassBase): ...@@ -616,7 +853,7 @@ class PipelinePass(PassBase):
) )
task_id_to_rank = {} task_id_to_rank = {}
for i in range(nrank): for i in range(self._nrank):
for j in range(num_of_functionality): for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = { self._program._pipeline_opt = {
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable == False
# pp2 1f1b merge training
engine_1f1b = self.get_engine(True)
history = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable == True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history.history["loss"])
losses_1f1b = np.array(history.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU) ...@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks) py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50) PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS}) ${dist_ENVS})
......
...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp" modeling._global_parallel_strategy = "mp"
elif strategy == "dp": elif strategy == "dp":
modeling._global_parallel_strategy = "dp" modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else: else:
raise ValueError("Only support serial, mp2 and dp2.") raise ValueError("Only support serial, mp2 and dp2.")
...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0): ...@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7, eos_token_id=7,
bos_token_id=0, bos_token_id=0,
eol_token_id=3, eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
) )
model = GPTForPretraining( model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02 gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
...@@ -19,7 +19,7 @@ import paddle ...@@ -19,7 +19,7 @@ import paddle
from paddle.distributed.fleet import auto from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset from get_gpt_model import generate_model, FakeDataset
paddle.enable_static() paddle.enable_static()
...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False): ...@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy() strategy = auto.Strategy()
strategy.auto_mode = "semi" strategy.auto_mode = "semi"
strategy.reinit = True strategy.reinit = True
if use_gradient_merge: if use_gradient_merge:
gradient_merge = strategy.gradient_merge gradient_merge = strategy.gradient_merge
gradient_merge.enable = True gradient_merge.enable = True
gradient_merge.k_steps = 4 gradient_merge.k_steps = 4
gradient_merge.avg = True gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy return strategy
...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history = dp_engine.fit( history = dp_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert dp_engine._strategy.gradient_merge.enable == False
dp_losses = np.array(history.history["loss"]) dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training # dp2 gradient merge training
...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase): ...@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history = gm_engine.fit( history = gm_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1 self.dataset, 3, batch_size=self.batch_size, log_freq=1
) )
assert gm_engine._strategy.gradient_merge.enable == True
gm_losses = np.array(history.history["loss"]) gm_losses = np.array(history.history["loss"])
# avg_loss = 0 # avg_loss = 0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册