未验证 提交 c47853f6 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] add gradient_merge master_grad & 1F1B pass (#52647)

上级 6f3c9643
......@@ -110,12 +110,15 @@ void PreventVarsDelete(
std::vector<std::string> GetUnusedVarsAfterWhile(
const framework::ProgramDesc& program_desc,
TaskNode* cond_task,
const std::vector<std::string>& vars_not_gc) {
// NOTE: Since while op won't appear in task node, in order to analyze
// the vars which should be free after calling while op, we rebuild the
// whole program and get the unused vars after calling while op.
// vars in parent block should not be free until the while op is finished.
// The local vars will be free while running op in sub block.
// The vars in while block should not be free until the while op is finished.
// In a word, the vars need to be free after while op is:
// 1. Vars in parent block and being used in while block.
// 2. Local vars only defined in while block.
// The unused vars above will be free in cond interceptor.
std::vector<std::string> while_block_vars;
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
......@@ -129,29 +132,14 @@ std::vector<std::string> GetUnusedVarsAfterWhile(
for (const auto& var_name : pair.second) {
while_block_vars.emplace_back(var_name);
}
for (auto& var : program_desc.Block(1).AllVars()) {
while_block_vars.emplace_back(var->Name());
}
}
}
return while_block_vars;
}
std::unordered_map<const framework::OperatorBase*, std::vector<std::string>>
GetSubUnusedVars(const framework::ProgramDesc& program_desc,
const std::set<TaskNode*>& sub_block_tasks,
const std::vector<std::string>& vars_not_gc) {
std::vector<std::unique_ptr<framework::OperatorBase>> ops;
for (auto* task_node : sub_block_tasks) {
for (const auto& op : task_node->ops()) {
ops.emplace_back(std::unique_ptr<framework::OperatorBase>(op));
}
}
auto unused_vars = framework::GetUnusedVars(program_desc.Block(1), ops, {});
for (auto& unique_op : ops) {
unique_op.release();
}
PreventVarsDelete(&unused_vars, vars_not_gc);
return unused_vars;
}
} // namespace
void FleetExecutor::Init(
......@@ -174,13 +162,8 @@ void FleetExecutor::Init(
for (const auto& task_node : task_nodes) {
if (task_node->type() == "Cond") {
GetSubBlockTask(task_nodes, task_node, &sub_block_tasks);
while_block_vars =
GetUnusedVarsAfterWhile(program_desc, inference_root_scope_vars);
for (auto* task_node : sub_block_tasks) {
for (auto iter : task_node->vars_to_dtype()) {
while_block_vars.emplace_back(iter.first);
}
}
while_block_vars = GetUnusedVarsAfterWhile(
program_desc, task_node, inference_root_scope_vars);
VLOG(3) << "Vars will be gced after while op";
for (auto var : while_block_vars) {
VLOG(3) << var;
......@@ -210,9 +193,6 @@ void FleetExecutor::Init(
unique_op.release();
}
auto sub_unused_vars =
GetSubUnusedVars(program_desc, sub_block_tasks, while_block_vars);
// NOTE: For inference, the vars in inference_root_scope_vars
// shouldn't be deleted during inf, for that they may be the result of the
// inf. If they are GCed, it will cause error during ZeroCopy the result.
......@@ -223,8 +203,6 @@ void FleetExecutor::Init(
for (auto task_node : task_nodes) {
if (sub_block_tasks.find(task_node) == sub_block_tasks.end()) {
task_node->SetUnusedVars(global_unused_vars);
} else {
task_node->SetUnusedVars(sub_unused_vars);
}
int64_t interceptor_id = task_node->task_id();
interceptor_id_to_task.emplace(interceptor_id, task_node);
......
......@@ -117,9 +117,9 @@ set_field_default_config(QAT, "activation_bits", 8)
set_field_default_config(QAT, "not_quant_pattern", ['skip_quant'])
set_field_default_config(QAT, "algo", None)
# #########################################
#########################################
# auto tuning configuration
# #########################################
#########################################
TUNING = "tuning"
set_field_default_config(TUNING, "enable", False)
set_field_default_config(TUNING, "batch_size", 1)
......@@ -135,3 +135,12 @@ set_field_default_config(TUNING, "verbose", True)
DATASET = "dataset"
set_field_default_config(DATASET, "enable", False)
set_field_default_config(DATASET, "num_shards", 1)
#########################################
# data parallel configuration
#########################################
DP_OPTIMIZATION = "dp_optimization"
set_field_default_config(DP_OPTIMIZATION, "enable", False)
set_field_default_config(DP_OPTIMIZATION, "fuse_all_reduce_ops", True)
set_field_default_config(DP_OPTIMIZATION, "fuse_grad_size_in_MB", 32)
set_field_default_config(DP_OPTIMIZATION, "overlap_comm_cacl", True)
......@@ -17,12 +17,18 @@ import numpy as np
import paddle
from paddle.io import BatchSampler, IterableDataset
from paddle.fluid.dataloader.batch_sampler import _InfiniteIterableSampler, DistributedBatchSampler
from paddle.fluid.dataloader.dataloader_iter import _DatasetKind, default_collate_fn, default_convert_fn
from paddle.fluid.dataloader.batch_sampler import (
_InfiniteIterableSampler,
DistributedBatchSampler,
)
from paddle.fluid.dataloader.dataloader_iter import (
_DatasetKind,
default_collate_fn,
default_convert_fn,
)
class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
@abc.abstractmethod
def __iter__(self):
raise NotImplementedError
......@@ -43,24 +49,26 @@ class DistributedDataLoaderBase(metaclass=abc.ABCMeta):
class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
def __init__(
self,
dataset,
feed_list=None,
capacity=None,
use_double_buffer=True,
iterable=True,
return_list=False,
use_multiprocess=False,
drop_last=True,
places=None,
batch_size=1,
epochs=1,
steps_per_epoch=None,
collate_fn=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
acc_steps=1,
):
self.dataset = dataset
self.feed_list = feed_list
self.capacity = capacity
......@@ -79,6 +87,7 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
assert len(data_parallel_rank) == len(feed_list)
self.dp_world_sizes = data_parallel_world_size
self.dp_ranks = data_parallel_rank
self.acc_steps = acc_steps
if isinstance(dataset, IterableDataset):
self.dataset_kind = _DatasetKind.ITER
......@@ -90,12 +99,15 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
else:
if isinstance(dataset, IterableDataset):
self.batch_sampler = _InfiniteIterableSampler(
dataset, batch_size)
dataset, batch_size
)
else:
self.batch_sampler = BatchSampler(dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last)
self.batch_sampler = BatchSampler(
dataset,
batch_size=batch_size,
shuffle=False,
drop_last=drop_last,
)
self.auto_collate_batch = self.batch_sampler is not None
self.sampler_iter = iter(self.index_sampler)
......@@ -106,8 +118,12 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
self.collate_fn = collate_fn or default_convert_fn
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset, self.auto_collate_batch,
self.collate_fn, self.drop_last)
self.dataset_kind,
self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
self._steps = self._infer_steps()
self._inner_dataloader = self._create_inner_dataloader()
......@@ -136,9 +152,11 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
if isinstance(self.dataset, IterableDataset):
steps_per_epoch = None
elif self.batch_size is None:
steps_per_epoch = len(self.dataset)
steps_per_epoch = len(self.dataset) // self.acc_steps
else:
steps_per_epoch = len(self.dataset) // self.batch_size
steps_per_epoch = (
len(self.dataset) // self.batch_size // self.acc_steps
)
except:
raise ValueError(
"Pleace set `steps_per_epoch` or implement `__len__` methond in dataset class."
......@@ -156,18 +174,21 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
return _InfiniteIterableSampler(self.dataset, 1)
def _create_inner_dataloader(self):
def data_generator():
while True:
try:
indices = next(self.sampler_iter)
batch = self.dataset_fetcher.fetch(indices)
if batch is None: break
if batch is None:
break
except StopIteration:
self.dataset_fetcher = _DatasetKind.create_fetcher(
self.dataset_kind, self.dataset,
self.auto_collate_batch, self.collate_fn,
self.drop_last)
self.dataset_kind,
self.dataset,
self.auto_collate_batch,
self.collate_fn,
self.drop_last,
)
break
partial_data = []
......@@ -178,11 +199,16 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
continue
batch_size = array.shape[0]
assert batch_size % self.dp_world_sizes[i] == 0, \
"batch_size [{}] is not divisible by dp_world_size [{}]".format(str(batch_size), str(self.dp_world_sizes[i]))
assert (
batch_size % self.dp_world_sizes[i] == 0
), "batch_size [{}] is not divisible by dp_world_size [{}]".format(
str(batch_size), str(self.dp_world_sizes[i])
)
partial_data.append(
np.split(array,
self.dp_world_sizes[i])[self.dp_ranks[i]])
np.split(array, self.dp_world_sizes[i])[
self.dp_ranks[i]
]
)
yield partial_data
......@@ -194,33 +220,35 @@ class DistributedDataLoaderFromGenerator(DistributedDataLoaderBase):
iterable=False,
return_list=self.return_list,
use_multiprocess=self.use_multiprocess,
drop_last=self.drop_last)
drop_last=self.drop_last,
)
dataloader.set_batch_generator(data_generator, self.places)
return dataloader
class DistributedDataLoader(DistributedDataLoaderBase):
def __init__(self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[]):
def __init__(
self,
dataset,
feed_list=None,
places=None,
return_list=True,
batch_size=1,
shuffle=False,
drop_last=False,
collate_fn=None,
num_workers=0,
use_buffer_reader=True,
use_shared_memory=True,
timeout=0,
worker_init_fn=None,
epochs=1,
steps_per_epoch=None,
split_data=True,
data_parallel_world_size=[],
data_parallel_rank=[],
):
self.dataset = dataset
self.feed_list = feed_list
self.return_list = return_list
......@@ -241,8 +269,13 @@ class DistributedDataLoader(DistributedDataLoaderBase):
self.split_data = split_data
# TODO: rank info
self.batch_sampler = DistributedBatchSampler(
self.dataset, self.batch_size, self.dp_world_sizes[0],
self.dp_ranks[0], self.shuffle, self.drop_last)
self.dataset,
self.batch_size,
self.dp_world_sizes[0],
self.dp_ranks[0],
self.shuffle,
self.drop_last,
)
self._inner_dataloader = self._create_inner_dataloader()
def __iter__(self):
......@@ -263,7 +296,8 @@ class DistributedDataLoader(DistributedDataLoaderBase):
use_buffer_reader=self.use_buffer_reader,
use_shared_memory=self.use_shared_memory,
timeout=self.timeout,
worker_init_fn=self.worker_init_fn)
worker_init_fn=self.worker_init_fn,
)
self.data = (x for x in dataloader)
return dataloader
......@@ -18,6 +18,7 @@ import errno
import pickle
import warnings
import logging
import collections
import numpy as np
import paddle
......@@ -53,16 +54,13 @@ def _process_path(path):
class DistributedSaver:
def __init__(self):
self._logger = get_logger(logging.INFO)
def save(self, path, serial_program, dist_main_program, dist_context):
def _save_state(program, path, mode="param"):
state = {
k: np.array(v)
for k, v in program.state_dict(mode).items()
k: np.array(v) for k, v in program.state_dict(mode).items()
}
with open(path, "wb") as f:
pickle.dump(state, f)
......@@ -108,8 +106,9 @@ class DistributedSaver:
def _load_file(filename, dirname, suffix="pdparams"):
file_list = []
for file in os.listdir(dirname):
if check_filename('{}(.*)_dist(.*).{}'.format(filename, suffix),
file):
if check_filename(
'{}(.*)_dist(.*).{}'.format(filename, suffix), file
):
file_list.append(os.path.join(dirname, file))
file_list.sort()
return file_list
......@@ -137,14 +136,16 @@ class DistributedSaver:
# load path.pdparam and path.pdopt
param_state_dict = _load_state(filename, dirname)
opt_state_dict = _load_state(filename, dirname,
"pdopt") if load_optimizer else {}
opt_state_dict = (
_load_state(filename, dirname, "pdopt") if load_optimizer else {}
)
state_dict = dict(param_state_dict, **opt_state_dict)
# load path.pdattr
dist_attr_file_list = _load_file(filename, dirname, "pdattr")
self._logger.info(
"Load distributed attribute file: {}".format(dist_attr_file_list))
"Load distributed attribute file: {}".format(dist_attr_file_list)
)
dist_attr = {}
for dist_attr_file in dist_attr_file_list:
with open(dist_attr_file, 'rb') as f:
......@@ -196,12 +197,24 @@ class DistributedSaver:
used_inputs += op.input_arg_names
used_outputs += op.output_arg_names
dist_feed_vars_names = list(set(feed_vars_names) & set(used_inputs))
dist_fetch_vars_names = list(set(fetch_vars_names) & set(used_outputs))
# delete duplicated elements and keep order
feed_vars_names = list({}.fromkeys(feed_vars_names).keys())
used_inputs = list({}.fromkeys(used_inputs).keys())
fetch_vars_names = list({}.fromkeys(fetch_vars_names).keys())
used_outputs = list({}.fromkeys(used_outputs).keys())
dist_feed_vars = [
global_block.vars[name] for name in dist_feed_vars_names
dist_feed_vars_names = [
var_name for var_name in feed_vars_names if var_name in used_inputs
]
dist_fetch_vars_names = [
var_name
for var_name in fetch_vars_names
if var_name in used_outputs
]
dist_feed_vars = list(
reversed([global_block.vars[name] for name in dist_feed_vars_names])
)
dist_fetch_vars = [
global_block.vars[name] for name in dist_fetch_vars_names
]
......@@ -209,11 +222,13 @@ class DistributedSaver:
# NOTE: `paddle.static.save_inference_model` does not support subblock.
dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename)
paddle.static.save_inference_model(dist_path,
dist_feed_vars,
dist_fetch_vars,
exe,
program=dist_main_prog)
paddle.static.save_inference_model(
dist_path,
dist_feed_vars,
dist_fetch_vars,
exe,
program=dist_main_prog,
)
def _save_rank_mapping(self, dirname):
path = os.path.join(dirname, 'rank_mapping.csv')
......
......@@ -225,6 +225,11 @@ class Engine:
self._planned_mode = None
self._dygraph_mode = False
self._tuning = self._strategy.tuning
self._acc_steps = 1
if self._strategy.gradient_merge.enable:
self._acc_steps = self._strategy.gradient_merge.k_steps
elif self._strategy.pipeline.enable:
self._acc_steps = self._strategy.pipeline.accumulate_steps
self.history = None
......@@ -388,7 +393,12 @@ class Engine:
if self.main_program._pipeline_opt:
assert "tasks" in self.main_program._pipeline_opt["fleet_opt"]
fleet_opt = self.main_program._pipeline_opt["fleet_opt"]
fwd_task = fleet_opt["tasks"][0]
fwd_task = None
if self._strategy.pipeline.schedule_mode == "1F1B":
fwd_task = fleet_opt["tasks"][1]
elif self._strategy.pipeline.schedule_mode == "stream":
fwd_task = fleet_opt["tasks"][0]
assert fwd_task is not None
fwd_prog = fwd_task.get_program()
fwd_block = fwd_prog.global_block()
......@@ -438,8 +448,6 @@ class Engine:
), "user_fetches must be a list, but receive {}".format(
type(user_fetches).__name__
)
else:
user_fetches = []
fetch_names = []
fetch_indices = []
......@@ -466,7 +474,7 @@ class Engine:
_process_fetch_group("metrics_" + str(i), var_list)
if mode == "predict":
_process_fetch_group("outputs", fetch_vars["outputs"])
for usr_fetch in user_fetches:
for usr_fetch in user_fetches or []:
var_name = _to_name_str(usr_fetch)
fetch(var_name)
user_fetches_collection = [
......@@ -903,6 +911,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
train_data, train_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -931,7 +940,7 @@ class Engine:
save_dir=save_dir,
verbose=verbose,
metrics=self._metrics_name(),
acc_step=self._k_steps,
acc_step=self._acc_steps,
)
cbks.on_begin('train')
......@@ -965,7 +974,7 @@ class Engine:
val_logs = self.evaluate(
valid_data,
valid_sample_split,
batch_size,
batch_size * self._acc_steps,
valid_steps,
log_freq,
collate_fn,
......@@ -1046,6 +1055,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
valid_data, valid_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1152,6 +1162,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
test_data, test_sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1214,6 +1225,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1256,6 +1268,7 @@ class Engine:
self._inputs_spec, self._labels_spec = self._prepare_data_spec(
dataset, sample_split, batch_size
)
batch_size = self._validate_batch_size(batch_size)
if not self._has_prepared[self._mode]:
self._prepare_program(self._mode)
else:
......@@ -1371,14 +1384,6 @@ class Engine:
steps_per_epoch=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1440,14 +1445,6 @@ class Engine:
collate_fn=None,
):
if self._strategy.gradient_merge and batch_size is not None:
assert (
batch_size % self._k_steps == 0
), "Requires batch_size:[{}] to be divisible by k_steps:[{}].".format(
batch_size, self._k_steps
)
batch_size //= self._k_steps
dist_context = self._dist_contexts[self._mode]
dist_main_prog = dist_context.dist_main_programs[self._cur_rank]
dist_startup_prog = dist_context.dist_startup_programs[self._cur_rank]
......@@ -1487,6 +1484,9 @@ class Engine:
split_data=self._strategy.split_data,
data_parallel_world_size=self._dp_world_sizes,
data_parallel_rank=self._dp_ranks,
acc_steps=1
if not self._strategy.pipeline.enable
else self._acc_steps,
)
self._prepare_reader(feed_list)
return dataloader
......@@ -1498,9 +1498,18 @@ class Engine:
)
self._optimization_tuning(self._mode, tune_data, batch_size)
def _validate_batch_size(self, batch_size):
if batch_size is None:
return None
assert (
batch_size % self._acc_steps == 0
), "Requires batch_size:[{}] to be divisible by acc_steps:[{}].".format(
batch_size, self._acc_steps
)
return batch_size // self._acc_steps
def _validate_spec(self, specs):
specs = to_list(specs)
self._k_steps = self._strategy.gradient_merge.k_steps
if specs is not None:
for i, spec in enumerate(specs):
if not isinstance(spec, InputSpec):
......@@ -1513,14 +1522,14 @@ class Engine:
i, spec
)
)
if self._k_steps > 1:
if self._acc_steps > 1:
shape = list(spec.shape)
assert (
shape[0] % self._k_steps == 0
shape[0] % self._acc_steps == 0
), "Requires batch_size[{}] to be divisible by k_steps[{}].".format(
spec.shape[0], self._k_steps
spec.shape[0], self._acc_steps
)
shape[0] //= self._k_steps
shape[0] //= self._acc_steps
spec.shape = shape
return specs or []
......
......@@ -297,13 +297,15 @@ class Parallelizer:
if self._strategy is None:
return
# data parallel optimization
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.dp_optimization.enable:
config = copy.deepcopy(self._strategy.dp_optimization.to_dict())
config["dist_context"] = self._dist_context
config["global_rank"] = rank
config["use_sharding"] = self._strategy.sharding.enable
dp_pass = new_pass(
"auto_parallel_data_parallel_optimization", config
)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding.enable:
config = copy.deepcopy(self._strategy.sharding.to_dict())
......
......@@ -13,24 +13,25 @@
# limitations under the License
import copy
import numpy as np
import paddle
import paddle.fluid as fluid
from paddle.fluid import core
from paddle.fluid import framework as framework
from paddle.fluid import core, unique_name
from paddle.fluid.framework import Program, Parameter, Variable, program_guard
from paddle.distributed.auto_parallel.operators.common import get_distributed_operator_impl_container
from paddle.distributed.auto_parallel.dist_context import DistributedContext, DistributedOperatorContext
from paddle.fluid.framework import Program, Parameter, core
from paddle.distributed.auto_parallel.operators.common import (
get_distributed_operator_impl_container,
)
from paddle.distributed.auto_parallel.dist_context import DistributedContext
from .dist_attribute import OperatorDistributedAttribute
from .process_group import new_process_group
from .utils import set_dist_op_desc_original_id
from .utils import print_program_with_dist_attr, is_forward_op, is_backward_op, is_loss_op, is_optimize_op
from .utils import (
is_forward_op,
is_backward_op,
is_loss_op,
is_optimize_op,
is_fillconst_op_for_micro_batch,
)
from .operators.common import BACKWARD_ONLY_DIST_OPS
__varname_not_in_block__ = ["lod_tensor_blocking_queue"]
__not_shape_var_type__ = [
core.VarDesc.VarType.READER, core.VarDesc.VarType.STEP_SCOPES
core.VarDesc.VarType.READER,
core.VarDesc.VarType.STEP_SCOPES,
]
......@@ -39,7 +40,7 @@ class Partitioner(object):
warning:: Partitioner is experimental and subject to change.
Partitioner convert a program into another program.
Given a serial program which has been auto completed with shard annotation, the Partitioner
Given a serial program which has been auto completed with shard annotation, the Partitioner
convert the serial program into a "distributed" program. The Partitioner will modify the serial
program in following two ways, which is also the major difference between serial and distributed program:
1. partition op: replace a serial op into its corresponding dist op infered from the shard annotation
......@@ -56,25 +57,29 @@ class Partitioner(object):
"""
if not isinstance(dist_context, DistributedContext):
raise TypeError(
"dist_context be paddle.fluid.DistributedContext, got %s here" %
type(dist_context))
"dist_context be paddle.fluid.DistributedContext, got %s here"
% type(dist_context)
)
self._dist_context = dist_context
self._rank_id = rank_id
self._serial2dist_varname_mapping = {}
self._dist_varname_suffix = ""
def partition(self, serial_main_program, serial_startup_program,
params_grads):
def partition(
self, serial_main_program, serial_startup_program, params_grads
):
if not isinstance(serial_main_program, (Program)):
raise TypeError(
"main_program be paddle.fluid.framework.program, got %s here" %
type(serial_main_program))
"main_program be paddle.fluid.framework.program, got %s here"
% type(serial_main_program)
)
# check if shard annotated serial program valid
if not self._is_valid_annotated_program(serial_main_program):
raise RuntimeError(
"Not all vars or ops are annotated in main program !")
"Not all vars or ops are annotated in main program !"
)
# init distop helper
dist_op_context = self._dist_context.dist_op_context
......@@ -86,24 +91,33 @@ class Partitioner(object):
partitioned_startup_prog = None
else:
partitioned_startup_prog = self.partition_startup_program(
serial_main_program, serial_startup_program)
serial_main_program, serial_startup_program
)
dist_op_context.dst_startup_program = partitioned_startup_prog
# partition main program
partitioned_main_prog, partitioned_params_grads = self.partition_main_program(
serial_main_program, params_grads)
(
partitioned_main_prog,
partitioned_params_grads,
) = self.partition_main_program(serial_main_program, params_grads)
return partitioned_main_prog, partitioned_startup_prog, partitioned_params_grads
return (
partitioned_main_prog,
partitioned_startup_prog,
partitioned_params_grads,
)
def partition_startup_program(self, serial_main_program,
serial_startup_program):
def partition_startup_program(
self, serial_main_program, serial_startup_program
):
if not isinstance(serial_startup_program, (Program)):
raise TypeError(
"dist_context be paddle.fluid.framework.program, got %s here" %
type(serial_startup_program))
"dist_context be paddle.fluid.framework.program, got %s here"
% type(serial_startup_program)
)
partitioned_startup_prog = fluid.Program()
partitioned_startup_prog = Program()
ref_block = serial_main_program.global_block()
target_block = partitioned_startup_prog.global_block()
var2shape = {}
......@@ -114,27 +128,33 @@ class Partitioner(object):
assert var.persistable
new_name = var.name + self._dist_varname_suffix
temp_varname_map[var.name] = new_name
target_shape = _partition_var(self._dist_context, ref_block,
target_block, var.name, new_name)
target_shape = _partition_var(
self._dist_context, ref_block, target_block, var.name, new_name
)
var2shape[new_name] = target_shape
# ops
for op in serial_startup_program.global_block().ops:
# TODO if var not belong to this rank, should be filtered
output_vars = op.desc.output_arg_names()
assert len(
output_vars
) == 1, "initializer should output only ONE variable, but got [{}]".format(
str(op.desc))
assert temp_varname_map[output_vars[
0]] in var2shape, "try to initialize [{}] which is not a persistable var".format(
output_vars[0])
assert (
len(output_vars) == 1
), "initializer should output only ONE variable, but got [{}]".format(
str(op.desc)
)
assert (
temp_varname_map[output_vars[0]] in var2shape
), "try to initialize [{}] which is not a persistable var".format(
output_vars[0]
)
new_op_desc = target_block.desc.append_op()
new_op_desc.copy_from(op.desc)
new_op_desc._rename_output(output_vars[0],
temp_varname_map[output_vars[0]])
new_op_desc._set_attr("shape",
var2shape[temp_varname_map[output_vars[0]]])
new_op_desc._rename_output(
output_vars[0], temp_varname_map[output_vars[0]]
)
new_op_desc._set_attr(
"shape", var2shape[temp_varname_map[output_vars[0]]]
)
target_block._sync_with_cpp()
# set distribute atrribute
......@@ -142,14 +162,17 @@ class Partitioner(object):
assert new_op.type == new_op_desc.type()
assert new_op.desc == new_op_desc
output_var = target_block.var(output_vars[0])
output_var_attr = self._dist_context.get_tensor_dist_attr_for_program(
output_var)
output_var_attr = (
self._dist_context.get_tensor_dist_attr_for_program(output_var)
)
op_attr = OperatorDistributedAttribute()
op_attr.process_mesh = output_var_attr.process_mesh
op_attr.set_output_dims_mapping(output_var.name,
output_var_attr.dims_mapping)
op_attr.set_input_dims_mapping(output_var.name,
output_var_attr.dims_mapping)
op_attr.set_output_dims_mapping(
output_var.name, output_var_attr.dims_mapping
)
op_attr.set_input_dims_mapping(
output_var.name, output_var_attr.dims_mapping
)
self._dist_context.set_op_dist_attr_for_program(new_op, op_attr)
return partitioned_startup_prog
......@@ -160,7 +183,7 @@ class Partitioner(object):
2. replace local op with corresponding dist op
"""
partitioned_main_prog = fluid.Program()
partitioned_main_prog = Program()
dist_op_context = self._dist_context.dist_op_context
dist_op_context.dst_main_program = partitioned_main_prog
......@@ -171,7 +194,8 @@ class Partitioner(object):
target_block = partitioned_main_prog.blocks[0]
else:
target_block = partitioned_main_prog._create_block(
parent_idx=ref_block.parent_idx)
parent_idx=ref_block.parent_idx
)
assert ref_block.idx == target_block.idx
target_block._set_forward_block_idx(ref_block.forward_block_idx)
dist_op_context.work_block = target_block
......@@ -186,8 +210,9 @@ class Partitioner(object):
for attr_name in op.all_attrs():
if op.attr_type(attr_name) == core.AttrType.BLOCK:
relative_id = op._block_attr_id(attr_name)
op._set_attr(attr_name,
partitioned_main_prog.block(relative_id))
op._set_attr(
attr_name, partitioned_main_prog.block(relative_id)
)
partitioned_params_and_grads = []
for p, g in params_and_grads:
......@@ -198,7 +223,8 @@ class Partitioner(object):
else:
assert g.name in self._serial2dist_varname_mapping
dist_g = self._get_dist_var_by_serial_var(
g, partitioned_main_prog)
g, partitioned_main_prog
)
partitioned_params_and_grads.append((dist_p, dist_g))
return partitioned_main_prog, partitioned_params_and_grads
......@@ -222,71 +248,116 @@ class Partitioner(object):
for idx in range(len(serial_ops)):
if idx <= last_fwd_op_idx:
forward_op_id2forward_op[
serial_ops[idx].desc.original_id()] = serial_ops[idx]
serial_ops[idx].desc.original_id()
] = serial_ops[idx]
# partiiton
appended_grad_times = 0
for idx, op in enumerate(serial_ops):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
if is_backward_op(op) and (is_forward_op(serial_ops[idx - 1])
or is_loss_op(serial_ops[idx - 1])):
if is_backward_op(op) and (
is_forward_op(serial_ops[idx - 1])
or is_loss_op(serial_ops[idx - 1])
):
if not op_dist_attr.is_recompute:
appended_grad_times += 1
# partititon input variables
for serial_input_varname in op.desc.input_arg_names():
if serial_input_varname not in self._serial2dist_varname_mapping:
new_varname = serial_input_varname + self._dist_varname_suffix
if (
serial_input_varname
not in self._serial2dist_varname_mapping
):
new_varname = (
serial_input_varname + self._dist_varname_suffix
)
if ref_block.has_var(serial_input_varname):
_partition_var(self._dist_context, ref_block,
target_block, serial_input_varname,
new_varname)
_partition_var(
self._dist_context,
ref_block,
target_block,
serial_input_varname,
new_varname,
)
else:
for varname_not_in_block in __varname_not_in_block__:
assert varname_not_in_block in serial_input_varname, \
"{} is not found".format(serial_input_varname)
assert (
varname_not_in_block in serial_input_varname
), "{} is not found".format(serial_input_varname)
self._serial2dist_varname_mapping[
serial_input_varname] = new_varname
serial_input_varname
] = new_varname
# partition output vars
for serial_output_varname in op.desc.output_arg_names():
if serial_output_varname not in self._serial2dist_varname_mapping:
new_varname = serial_output_varname + self._dist_varname_suffix
_partition_var(self._dist_context, ref_block, target_block,
serial_output_varname, new_varname)
if (
serial_output_varname
not in self._serial2dist_varname_mapping
):
new_varname = (
serial_output_varname + self._dist_varname_suffix
)
_partition_var(
self._dist_context,
ref_block,
target_block,
serial_output_varname,
new_varname,
)
self._serial2dist_varname_mapping[
serial_output_varname] = new_varname
serial_output_varname
] = new_varname
# partition op
if is_forward_op(op) or op_dist_attr.is_recompute:
if (
is_forward_op(op)
or op_dist_attr.is_recompute
or is_fillconst_op_for_micro_batch(op)
):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_forward_impl = _get_dist_op_forward_implement(
op, self._dist_context)
dist_op_forward_impl.forward(self._dist_context, **kinputs,
**koutputs)
op, self._dist_context
)
dist_op_forward_impl.forward(
self._dist_context, **kinputs, **koutputs
)
elif is_backward_op(op):
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_backward_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
grad_var_to_var = self._dist_context.dist_op_context.grad_var_to_var[
appended_grad_times]
op, self._dist_context, forward_op_id2forward_op
)
grad_var_to_var = (
self._dist_context.dist_op_context.grad_var_to_var[
appended_grad_times
]
)
dist_op_backward_impl.backward(
self._dist_context, **kinputs, **koutputs,
**{"grad_var_to_var": grad_var_to_var})
self._dist_context,
**kinputs,
**koutputs,
**{"grad_var_to_var": grad_var_to_var}
)
elif is_optimize_op(op):
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must 2 because of 1F1B PASS
# NOTE: BACKWARD_ONLY_DIST_OPS's op_role must be 2 because of 1F1B PASS
kinputs, koutputs = dist_op_context.prepare_context(op)
dist_op_opt_impl = _get_dist_op_backward_implement(
op, self._dist_context, forward_op_id2forward_op)
dist_op_opt_impl.backward(self._dist_context, **kinputs,
**koutputs, **{"grad_var_to_var": {}})
op, self._dist_context, forward_op_id2forward_op
)
dist_op_opt_impl.backward(
self._dist_context,
**kinputs,
**koutputs,
**{"grad_var_to_var": {}}
)
else:
raise NotImplementedError(
"partitioner only support forward and backward, optimize ops, but got {}"
.format(str(op)))
"partitioner only support forward and backward, optimize ops, but got {}".format(
str(op)
)
)
def _is_valid_annotated_program(self, program):
......@@ -298,13 +369,16 @@ class Partitioner(object):
]
var_dist_attrs = [
self._dist_context.get_tensor_dist_attr_for_program(var)
for var in vars_ if (var.type not in __not_shape_var_type__)
for var in vars_
if (var.type not in __not_shape_var_type__)
]
all_ops_annotated = all(dist_attr is not None
for dist_attr in op_dist_attrs)
all_vars_annotated = all(dist_attr is not None
for dist_attr in var_dist_attrs)
all_ops_annotated = all(
dist_attr is not None for dist_attr in op_dist_attrs
)
all_vars_annotated = all(
dist_attr is not None for dist_attr in var_dist_attrs
)
return all_ops_annotated and all_vars_annotated
......@@ -328,22 +402,26 @@ def _get_dist_shape(var, dist_attr):
assert len(var_shape) == len(
mapping
), "variable shape [{}] and dim_mapping [{}] is NOT match !".format(
var_shape, mapping)
var_shape, mapping
)
new_shape = []
for idx in range(len(var_shape)):
if var_shape[idx] == -1 or mapping[idx] == -1:
new_shape.append(var_shape[idx])
else:
assert var_shape[idx] % mesh[mapping[
idx]] == 0, "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
var_shape[idx], mesh[mapping[idx]])
assert (
var_shape[idx] % mesh[mapping[idx]] == 0
), "un-event partition: var_shape[idx]=[{}], mesh[{}]".format(
var_shape[idx], mesh[mapping[idx]]
)
new_shape.append(var_shape[idx] // mesh[mapping[idx]])
return new_shape
def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
dst_shape):
def _partition_parameter(
dist_context, src_var, dst_block, dst_varname, dst_shape
):
# NOTE hack to copied Parameter
# not initialized parameter, need to initialize it
copied_kwargs = {}
......@@ -353,39 +431,45 @@ def _partition_parameter(dist_context, src_var, dst_block, dst_varname,
copied_kwargs['do_model_average'] = src_var.do_model_average
copied_kwargs['need_clip'] = src_var.need_clip
param = Parameter(block=dst_block,
type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs)
param = Parameter(
block=dst_block,
type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
**copied_kwargs
)
return param
def _partition_intermediate_var(dist_context, src_var, dst_block, dst_varname,
dst_shape):
var = dst_block.create_var(type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer)
def _partition_intermediate_var(
dist_context, src_var, dst_block, dst_varname, dst_shape
):
var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
shape=dst_shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
return var
def _partition_var(dist_context, src_block, dst_block, src_varname,
dst_varname):
def _partition_var(
dist_context, src_block, dst_block, src_varname, dst_varname
):
"""
partition include: split + replicate
"""
......@@ -393,44 +477,53 @@ def _partition_var(dist_context, src_block, dst_block, src_varname,
if src_var.type in __not_shape_var_type__:
persist = getattr(src_var, 'persistable', False)
new_var = dst_block.create_var(type=src_var.type,
name=dst_varname,
persistable=persist,
stop_gradient=True)
new_var = dst_block.create_var(
type=src_var.type,
name=dst_varname,
persistable=persist,
stop_gradient=True,
)
target_shape = None
else:
dist_attr = dist_context.get_tensor_dist_attr_for_program(src_var)
target_shape = _get_dist_shape(src_var, dist_attr)
if isinstance(src_var, Parameter):
new_var = _partition_parameter(dist_context, src_var, dst_block,
dst_varname, target_shape)
new_var = _partition_parameter(
dist_context, src_var, dst_block, dst_varname, target_shape
)
else:
new_var = _partition_intermediate_var(dist_context, src_var,
dst_block, dst_varname,
target_shape)
new_var = _partition_intermediate_var(
dist_context, src_var, dst_block, dst_varname, target_shape
)
dist_attr = copy.deepcopy(
dist_context.get_tensor_dist_attr_for_program(src_var))
dist_context.get_tensor_dist_attr_for_program(src_var)
)
assert dist_attr is not None
dist_context.set_tensor_dist_attr_for_program(new_var, dist_attr)
return target_shape
def _get_dist_op_backward_implement(backward_op, dist_context,
forward_op_id2forward_op):
def _get_dist_op_backward_implement(
backward_op, dist_context, forward_op_id2forward_op
):
dist_op_context = dist_context.dist_op_context
if backward_op.desc.original_id() in dist_op_context.grad_op_id_to_op_id:
forward_op_id = dist_op_context.grad_op_id_to_op_id[
backward_op.desc.original_id()]
backward_op.desc.original_id()
]
forward_op = forward_op_id2forward_op[forward_op_id]
forward_op_dist_attr = dist_context.get_op_dist_attr_for_program(
forward_op)
forward_op
)
dist_op_impl_container = get_distributed_operator_impl_container(
forward_op_dist_attr.impl_type)
forward_op_dist_attr.impl_type
)
dist_op_impl = dist_op_impl_container.get_impl(
forward_op_dist_attr.impl_idx)
forward_op_dist_attr.impl_idx
)
return dist_op_impl
# # NOTE trick for dist ops that only have backward implement
......@@ -438,7 +531,8 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
op_dist_attr = dist_context.get_op_dist_attr_for_program(backward_op)
assert op_dist_attr.impl_idx >= 0
dist_op_impl = get_distributed_operator_impl_container(
op_dist_attr.impl_type).get_impl(op_dist_attr.impl_idx)
op_dist_attr.impl_type
).get_impl(op_dist_attr.impl_idx)
return dist_op_impl
dist_op = get_distributed_operator_impl_container("default")
......@@ -448,6 +542,7 @@ def _get_dist_op_backward_implement(backward_op, dist_context,
def _get_dist_op_forward_implement(forward_op, dist_context):
dist_attr = dist_context.get_op_dist_attr_for_program(forward_op)
dist_op_impl_container = get_distributed_operator_impl_container(
dist_attr.impl_type)
dist_attr.impl_type
)
dist_op_impl = dist_op_impl_container.get_impl(dist_attr.impl_idx)
return dist_op_impl
......@@ -422,11 +422,11 @@ class Inserter:
)
inputs = {'X': [tensor]}
outputs = {"Out": [out]}
attrs = {"in_place": False}
slice_op = block._insert_op(
attrs = {"in_place": False, "op_role": op_role}
assign_op = block._insert_op(
idx, type="assign", inputs=inputs, outputs=outputs, attrs=attrs
)
slice_op._set_attr('op_namescope', "/auto_parallel/reshard")
assign_op._set_attr('op_namescope', "/auto_parallel/reshard")
return out
# use split once
......@@ -1217,6 +1217,8 @@ class Resharder:
shape_x[0] <= shape_y[0] < shape_x[1]
):
overlapped = True
if shape_x == [0, 0] and shape_y == [0, 0]:
overlapped = True
return overlapped
def is_unshard(self, dims_mapping):
......@@ -1304,6 +1306,14 @@ class Resharder:
# judge whether need reshard by process_mesh
if tensor_process_mesh != op_process_mesh:
is_reshard = True
# not reshard data in send/recv scene
if (
tensor_process_mesh != op_process_mesh
and len(tensor_process_mesh.process_ids)
== len(op_process_mesh.process_ids)
and dist_tensor.serial_tensor.is_data
):
is_reshard = False
else:
op_output_dims_mapping = dist_attr[1]
if all(
......@@ -1585,10 +1595,10 @@ class Resharder:
if i == 0:
all_partition_index_list.append(process_index[j][1])
for process in group:
# append slice op desc
slice_starts = []
slice_ends = []
slices_axes = []
min_comm_group = copy.deepcopy(group)
all_partition_index_list_copied = copy.deepcopy(
all_partition_index_list
)
target_partition_index = Resharder.compute_partition_index(
process,
complete_shape,
......@@ -1596,12 +1606,56 @@ class Resharder:
target_process_shape,
target_process_group,
)
for idx, item in enumerate(target_partition_index):
slice_starts.append(item[0])
slice_ends.append(item[1])
for _process in group:
source_partition_index = (
Resharder.compute_partition_index(
_process,
complete_shape,
source_dims_mapping,
source_process_shape,
source_process_group,
)
)
if not all(
_
for _ in list(
map(
self.is_overlapped,
source_partition_index,
target_partition_index,
)
)
):
min_comm_group.remove(_process)
all_partition_index_list_copied.remove(
source_partition_index
)
concatenated_partition_index_list = []
for partition_index in all_partition_index_list_copied:
Resharder.concat_partitions(
concatenated_partition_index_list, partition_index
)
concatenated_partition_index = (
concatenated_partition_index_list[0]
)
slice_starts = []
slice_ends = []
slices_axes = []
to_slice_tensor_shape = []
for idx, item in enumerate(concatenated_partition_index):
slice_starts.append(
target_partition_index[idx][0] - item[0]
)
slice_ends.append(
target_partition_index[idx][1] - item[0]
)
slices_axes.append(idx)
to_slice_tensor_shape.append(item[1] - item[0])
to_slice_tensor_shape = dist_tensor.global_sizes()
slice_op_desc = SliceOpDesc(
starts=slice_starts,
ends=slice_ends,
......@@ -1616,16 +1670,16 @@ class Resharder:
op_desc_seq[process] = (
[
AllGatherOpDesc(
group=group,
group=min_comm_group,
shape=allgather_shape,
is_bool=(source_tensor.dtype == paddle.bool),
),
ConcatOpDesc(
partition_index_list=all_partition_index_list
partition_index_list=all_partition_index_list_copied
),
slice_op_desc,
]
if len(group) > 1
if len(min_comm_group) > 1
else [slice_op_desc]
)
......
......@@ -123,6 +123,12 @@ class DatasetConfig(BaseConfig):
super(DatasetConfig, self).__init__(category, config_dict)
class DPOptimizationConfig(BaseConfig):
def __init__(self, config_dict=None):
category = constants.DP_OPTIMIZATION
super(DPOptimizationConfig, self).__init__(category, config_dict)
class Strategy(BaseConfig):
"""
The `Strategy` object is used to configure the paralleization and optimization beheviors.
......@@ -194,3 +200,6 @@ class Strategy(BaseConfig):
config_dict = self._config_dict.get(constants.DATASET, None)
self.dataset = DatasetConfig(config_dict)
config_dict = self._config_dict.get(constants.DP_OPTIMIZATION, None)
self.dp_optimization = DPOptimizationConfig(config_dict)
......@@ -1252,6 +1252,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle_grad",
"flatten_contiguous_range_grad",
"relu_grad",
"exp_grad",
]
forward_list = [
"reshape2",
......@@ -1270,6 +1271,7 @@ def set_grad_var_shape(program, dist_context):
"fused_softmax_mask_upper_triangle",
"flatten_contiguous_range",
"relu",
"exp",
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......@@ -1320,6 +1322,11 @@ def is_forward_op(op):
)
def is_fillconst_op_for_micro_batch(op):
op_role = int(op.attr('op_role'))
return OP_ROLE_KEY in op.attr_names and (op_role == int(OpRole.LRSched))
def is_backward_op(op):
return OP_ROLE_KEY in op.attr_names and int(
op.all_attrs()[OP_ROLE_KEY]
......
......@@ -18,15 +18,31 @@ import numpy as np
import paddle
from paddle.fluid import core, unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, is_backward_op, ring_id_to_process_group, find_higher_order_backward_op
from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
)
from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
is_backward_op,
ring_id_to_process_group,
find_higher_order_backward_op,
)
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum',
'merge_momentum'
'lars_momentum',
'sparse_momentum',
'dgc_momentum',
'momentum',
'merge_momentum',
]
# a heuristic number
......@@ -41,7 +57,7 @@ def numel(var):
class DataParallelOptimizationPass(PassBase):
"""
Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling
1. prune grad scaling
2. overlap comm and calc
3. fuse allreduce
"""
......@@ -52,6 +68,9 @@ class DataParallelOptimizationPass(PassBase):
self.set_attr("dist_context", None)
self.set_attr("global_rank", -1)
self.set_attr("use_sharding", False)
self.set_attr("fuse_all_reduce_ops", False)
self.set_attr("fuse_grad_size_in_MB", 32)
self.set_attr("overlap_comm_cacl", False)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict()
......@@ -62,8 +81,9 @@ class DataParallelOptimizationPass(PassBase):
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if (not isinstance(self.get_attr("global_rank"),
int)) or self.get_attr("global_rank") < 0:
if (not isinstance(self.get_attr("global_rank"), int)) or self.get_attr(
"global_rank"
) < 0:
return False
return True
......@@ -80,13 +100,18 @@ class DataParallelOptimizationPass(PassBase):
self.global_rank = int(self.get_attr("global_rank"))
self.use_sharding = self.get_attr("use_sharding")
overlap_comm_cacl = self.get_attr("overlap_comm_cacl")
fuse_all_reduce_ops = self.get_attr("fuse_all_reduce_ops")
with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()
if self.is_data_parallel_applied():
self._prune_grad_scaling()
self._calc_comm_overlap()
grad_group = self._fuse_allreduce()
if overlap_comm_cacl:
self._prune_grad_scaling()
self._calc_comm_overlap()
if fuse_all_reduce_ops:
grad_group = self._fuse_allreduce()
# self.summary(grad_group)
......@@ -140,8 +165,11 @@ class DataParallelOptimizationPass(PassBase):
), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op))
assert (
group is not None
), "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op)
)
self._grad_name_to_group_map[grad_name] = group
......@@ -156,18 +184,21 @@ class DataParallelOptimizationPass(PassBase):
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
elif (
is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
self._support_rescale_grad = True
not_synchronized_grads = []
for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name)
assert len(
assert (
len(not_synchronized_grads) == 0
), "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads)
)
def is_data_parallel_applied(self):
return len(self._group_to_grad_name_map) > 0
......@@ -175,14 +206,21 @@ class DataParallelOptimizationPass(PassBase):
def _could_be_prune(self):
return self.dist_context.gradient_scale and (
self._support_rescale_grad or self._all_dp_groups_same_degree())
self._support_rescale_grad or self._all_dp_groups_same_degree()
)
def _all_dp_groups_same_degree(self):
return len(
set([
len(group.ranks)
for group in self._group_to_grad_name_map.keys()
])) == 1
return (
len(
set(
[
len(group.ranks)
for group in self._group_to_grad_name_map.keys()
]
)
)
== 1
)
def _scale_backward_initial_grad(self):
......@@ -191,9 +229,10 @@ class DataParallelOptimizationPass(PassBase):
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
assert op.type == 'fill_constant', \
"loss_grad_op must be fill_constant op, " \
assert op.type == 'fill_constant', (
"loss_grad_op must be fill_constant op, "
"but this op is {}".format(op.type)
)
assert op.has_attr('value')
loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree
......@@ -215,28 +254,35 @@ class DataParallelOptimizationPass(PassBase):
scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
if (
is_optimize_op(op)
and op.type in __rescale_grad_supported_opts__
):
assert op.has_attr(
'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op))
assert len(
op.input("Grad")
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op))
str(op)
)
assert (
len(op.input("Grad")) == 1
), "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op)
)
grad_name = op.input("Grad")[0]
dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks))
list(self._grad_name_to_group_map[grad_name].ranks)
)
scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys(
)), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads)
assert scaled_grads == set(
self._grad_name_to_group_map.keys()
), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads
)
def _could_be_overlap(self):
# NOTE current different nccl comm will use different cuda stream
......@@ -266,14 +312,13 @@ class DataParallelOptimizationPass(PassBase):
op._set_attr('use_calc_stream', False)
ring_id = op.attr("ring_id")
block._insert_op_without_sync(idx,
type='c_wait_compute',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
block._insert_op_without_sync(
idx,
type='c_wait_compute',
inputs={'X': []},
outputs={'Out': []},
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
)
block._sync_with_cpp()
......@@ -307,8 +352,10 @@ class DataParallelOptimizationPass(PassBase):
# other ops that might use communicating grad
else:
for input_var_name in op.input_arg_names:
for ring_id, unsync_grad_names in ring_id_to_un_sync_grad_map.items(
):
for (
ring_id,
unsync_grad_names,
) in ring_id_to_un_sync_grad_map.items():
if input_var_name in unsync_grad_names:
# need to sync before op_i
if i in op_idx_to_sync_ring_id_map:
......@@ -328,14 +375,13 @@ class DataParallelOptimizationPass(PassBase):
for i in sorted(indices, reverse=True):
for ring_id in op_idx_to_sync_ring_id_map[i]:
block._insert_op_without_sync(i,
type='c_wait_comm',
inputs={'X': []},
outputs={'Out': []},
attrs={
'op_role': OpRole.Backward,
'ring_id': ring_id
})
block._insert_op_without_sync(
i,
type='c_wait_comm',
inputs={'X': []},
outputs={'Out': []},
attrs={'op_role': OpRole.Backward, 'ring_id': ring_id},
)
def _could_be_fuse(self):
# TODO support gradient fuse higher order gradient.
......@@ -350,9 +396,9 @@ class DataParallelOptimizationPass(PassBase):
"""
conditions for gradients to be grouped:
1. group size < max_fuse_numel
2. same dp group
2. same dp group
3. same dtype
4. dependency: grad would NOT be used by other ops within group segment
4. dependency: grad would NOT be used by other ops within group segment
gradients inside same group would be fuse into one coalesce tensor
"""
......@@ -423,36 +469,51 @@ class DataParallelOptimizationPass(PassBase):
for i, group in enumerate(grad_groups[::-1]):
# create coalecse tensor
group.coalesce_var = block.create_var(name=unique_name.generate(
'coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True)
group.coalesce_var = block.create_var(
name=unique_name.generate('coalecse_grad_{}'.format(i)),
dtype=group.dtype,
persistable=False,
stop_gradient=True,
)
# update allreduce & scale op
if group.scale_op_idx != -1:
scale_op = block.ops[group.scale_op_idx]
assert scale_op.type == 'scale', "should found scale op but found {}".format(
str(scale_op))
scale_op._rename_input(scale_op.input_arg_names[0],
group.coalesce_var.name)
scale_op._rename_output(scale_op.output_arg_names[0],
group.coalesce_var.name)
assert (
scale_op.type == 'scale'
), "should found scale op but found {}".format(str(scale_op))
scale_op._rename_input(
scale_op.input_arg_names[0], group.coalesce_var.name
)
scale_op._rename_output(
scale_op.output_arg_names[0], group.coalesce_var.name
)
allreduce_op = block.ops[group.allreduce_op_idx]
assert allreduce_op.type == 'c_allreduce_sum', "should found c_allreduce_sum op but found {}".format(
str(allreduce_op))
allreduce_op._rename_input(allreduce_op.input_arg_names[0],
group.coalesce_var.name)
allreduce_op._rename_output(allreduce_op.output_arg_names[0],
group.coalesce_var.name)
assert (
allreduce_op.type == 'c_allreduce_sum'
), "should found c_allreduce_sum op but found {}".format(
str(allreduce_op)
)
allreduce_op._rename_input(
allreduce_op.input_arg_names[0], group.coalesce_var.name
)
allreduce_op._rename_output(
allreduce_op.output_arg_names[0], group.coalesce_var.name
)
# remvoe un-used op
remove_op_indices = group.remove_wait_op_indices + group.remove_allreduce_op_indices + group.remove_scale_op_indices
remove_op_indices = (
group.remove_wait_op_indices
+ group.remove_allreduce_op_indices
+ group.remove_scale_op_indices
)
for idx in sorted(remove_op_indices, reverse=True):
assert block.ops[
idx].type in remove_op_types, "Unexception: try to remove op {}".format(
str(op))
assert (
block.ops[idx].type in remove_op_types
), "Unexception: try to remove op {}".format(
str(block.ops[idx].type())
)
block._remove_op(idx)
# insert coalecse op
......@@ -464,22 +525,23 @@ class DataParallelOptimizationPass(PassBase):
concated_ranks.append(len(shape))
grad_names = [grad.name for grad in group.gradients]
block._insert_op_without_sync(group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes":
concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward
})
block._insert_op_without_sync(
group.coalesce_op_idx,
type="coalesce_tensor",
inputs={"Input": grad_names},
outputs={
"Output": grad_names,
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": False,
"use_align": True,
"dtype": group.dtype,
"concated_shapes": concated_shapes,
"concated_ranks": concated_ranks,
OP_ROLE_KEY: OpRole.Backward,
},
)
block._sync_with_cpp()
# TODO update dist attr
......@@ -487,6 +549,7 @@ class DataParallelOptimizationPass(PassBase):
def summary(self, grad_groups=[]):
# TODO: add logger module
import logging
self._logger = logging.getLogger()
self._logger.propagate = False
if not self._logger.handlers:
......@@ -500,26 +563,31 @@ class DataParallelOptimizationPass(PassBase):
if len(grad_groups) > 0:
self._logger.info(
"origin {} allreduce ops are fused into {} coalecse allreduce ops."
.format(len(self._grad_name_to_group_map.keys()),
len(grad_groups)))
"origin {} allreduce ops are fused into {} coalecse allreduce ops.".format(
len(self._grad_name_to_group_map.keys()), len(grad_groups)
)
)
self._logger.info("gradient fusing group are following: ")
fused_grads = set()
for i, group in enumerate(grad_groups):
self._logger.info(
"coalecse gradient [{}] is composed by: {}".format(
i, [grad.name for grad in group.gradients]))
i, [grad.name for grad in group.gradients]
)
)
fused_grads.update([grad.name for grad in group.gradients])
individual_grads = set(
self._grad_name_to_group_map.keys()) - set(fused_grads)
individual_grads = set(self._grad_name_to_group_map.keys()) - set(
fused_grads
)
self._logger.info(
"the following [{}] gradients are not fused: ".format(
len(individual_grads)))
len(individual_grads)
)
)
self._logger.info("individual gradient {}".format(individual_grads))
class GradientsGroup(object):
def __init__(self, ops, max_group_size):
self.max_group_size = max_group_size
self.ops = ops
......@@ -575,8 +643,11 @@ class GradientsGroup(object):
grad_op_idx -= 1
grad_op = self.ops[grad_op_idx]
assert grad_var.name in grad_op.output_arg_names, "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op))
assert (
grad_var.name in grad_op.output_arg_names
), "grad [{}] should be output of {}".format(
grad_var.name, str(grad_op)
)
self.coalesce_op_idx = grad_op_idx
def finalize(self):
......
......@@ -12,23 +12,40 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
from collections import OrderedDict
from typing import List, Tuple, Dict, Any
import paddle
from paddle.framework import core
from paddle.fluid import layers
from paddle.fluid.framework import program_guard, device_guard
from paddle.distributed.fleet.meta_optimizers.common import (
OpRole,
OP_ROLE_KEY,
OP_ROLE_VAR_KEY,
)
from paddle.distributed.auto_parallel.utils import (
set_var_dist_attr,
is_optimize_op,
is_backward_op,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
)
from paddle.distributed.auto_parallel.process_group import (
get_world_process_group,
)
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_reduce_op,
is_data_parallel_scale_op,
)
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.utils import set_var_dist_attr, is_optimize_op, OpRole, OP_ROLE_KEY
from paddle.distributed.auto_parallel.utils import naive_set_dist_op_attr_for_program_by_mesh_and_mapping
from paddle.distributed.auto_parallel.process_group import get_world_process_group
world_process_group = get_world_process_group()
def _remove_and_get_optimizer_op(main_program, dist_context):
def is_gradient_clip_op(op_desc):
return op_desc.has_attr("op_namescope") and op_desc.attr(
"op_namescope"
).startswith("/gradient_clip")
def _remove_and_get_ops(main_program, dist_context):
# 1 create tmp block
# 2 mv optimizer op from global program to tmp block
# 3 del the op from dist_context
......@@ -36,101 +53,119 @@ def _remove_and_get_optimizer_op(main_program, dist_context):
temp_block = main_program._create_block()
removed_op_idx = []
optimize_ops_desc = []
allreduce_sum_desc = []
for idx, op in enumerate(main_block.ops):
# append optimizer op to tmp block
if is_optimize_op(op):
# append optimizer op to tmp block
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
optimize_ops_desc.append(new_op_desc)
removed_op_idx.append(idx)
# del op from dist_context
if dist_context:
dist_context.del_dist_op_for_program(op)
# append allreduce_op and scale_op to tmp block
if is_backward_op(op):
if is_data_parallel_reduce_op(op) or is_data_parallel_scale_op(op):
assert len(op.desc.output_arg_names()) == 1
new_op_desc = temp_block.desc.append_op()
new_op_desc.copy_from(op.desc)
allreduce_sum_desc.append(new_op_desc)
removed_op_idx.append(idx)
dist_context.del_dist_op_for_program(op)
for idx in removed_op_idx[::-1]:
main_block._remove_op(idx, sync=False)
main_block._sync_with_cpp()
return optimize_ops_desc
return optimize_ops_desc, allreduce_sum_desc
def _get_gm_cond_var(main_program, k_steps, dist_context):
def _create_gm_cond_var(main_program, k_steps, dist_context):
main_block = main_program.global_block()
# Add const var
k_step_var = layers.create_global_var(name="gradient_merge_k",
shape=[1],
value=int(k_steps),
dtype='int32',
persistable=True,
force_cpu=True)
k_step_var = layers.create_global_var(
name="gradient_merge_k",
shape=[1],
value=int(k_steps),
dtype='int32',
persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, k_step_var, [-1], world_process_group.ranks)
zero_var = layers.create_global_var(name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
zero_var = layers.create_global_var(
name="gradient_merge_zero",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, zero_var, [-1], world_process_group.ranks)
# Add step var & cond var
step_var = layers.create_global_var(name="gradient_merge_step",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True)
step_var = layers.create_global_var(
name="gradient_merge_step",
shape=[1],
value=int(0),
dtype='int32',
persistable=True,
force_cpu=True,
)
set_var_dist_attr(dist_context, step_var, [-1], world_process_group.ranks)
cond_var = main_block.create_var(name="gradient_merge_cond",
shape=[1],
dtype='bool')
cond_var = main_block.create_var(
name="gradient_merge_cond", shape=[1], dtype='bool'
)
set_var_dist_attr(dist_context, cond_var, [-1], world_process_group.ranks)
with device_guard("cpu"):
with paddle.static.device_guard("cpu"):
# step_var += 1
increment_op = main_block.append_op(type='increment',
inputs={'X': [step_var]},
outputs={'Out': [step_var]},
attrs={
'step': float(1.0),
OP_ROLE_KEY: OpRole.Backward
})
increment_op = main_block.append_op(
type='increment',
inputs={'X': [step_var]},
outputs={'Out': [step_var]},
attrs={'step': float(1.0), OP_ROLE_KEY: OpRole.Backward},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
increment_op, world_process_group.ranks, [-1], dist_context)
increment_op, world_process_group.ranks, [-1], dist_context
)
# step_var %= k_step
elementwise_mod_op = main_block.append_op(type='elementwise_mod',
inputs={
'X': step_var,
'Y': k_step_var
},
outputs={'Out': step_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY:
OpRole.Backward
})
elementwise_mod_op = main_block.append_op(
type='elementwise_mod',
inputs={'X': step_var, 'Y': k_step_var},
outputs={'Out': step_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
elementwise_mod_op, world_process_group.ranks, [-1], dist_context)
elementwise_mod_op, world_process_group.ranks, [-1], dist_context
)
# cond_var = (step_var == 0)
equal_op = main_block.append_op(type='equal',
inputs={
'X': step_var,
'Y': zero_var
},
outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Backward})
equal_op = main_block.append_op(
type='equal',
inputs={'X': step_var, 'Y': zero_var},
outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Backward},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
equal_op, world_process_group.ranks, [-1], dist_context)
equal_op, world_process_group.ranks, [-1], dist_context
)
return cond_var
def _append_gradient_merge_backward_op(
main_program, startup_program, params_grads: List[Tuple[Any, Any]],
dist_context) -> Tuple[List[Tuple[Any, Any]], Dict[str, Any]]:
main_program,
startup_program,
params_grads,
master_grad,
dist_context,
):
main_block = main_program.global_block()
startup_block = startup_program.global_block()
......@@ -148,149 +183,260 @@ def _append_gradient_merge_backward_op(
for param, grad in params_grads:
param_name = param.name
param_var = main_block.var(param_name)
assert (param_var is not None)
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
assert ref_dist_attr is not None
gradient_merge_var = main_block.create_var(name=param_name +
"@GRAD@GradientMerge",
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping
assert param_var is not None
set_var_dist_attr(dist_context, gradient_merge_var, ref_dims_mapping,
ref_process_mesh)
dst_dtype = (
core.VarDesc.VarType.FP32 if master_grad else param_var.dtype
)
# 2.1 crate param@GRAD@MERGE var in startup_block
startup_gradient_merge_var = startup_block.create_var(
name=param_name + "@GRAD@GradientMerge",
name=param_name + "@GRAD@MERGED",
shape=param_var.shape,
dtype=dst_dtype,
persistable=True,
)
startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": param_var.shape,
"dtype": dst_dtype,
"value": float(0),
},
)
# 2.2 crate param@GRAD@MERGE var in main_block
ref_dist_attr = dist_context.get_tensor_dist_attr_for_program(param_var)
assert ref_dist_attr is not None
gradient_merge_var = main_block.create_var(
name=param_name + "@GRAD@MERGED",
shape=param_var.shape,
dtype=param_var.dtype,
persistable=True)
startup_block.append_op(type="fill_constant",
outputs={"Out": startup_gradient_merge_var},
attrs={
"shape": param_var.shape,
"dtype": param_var.dtype,
"value": float(0),
})
# grad_merge += grad
new_grad_op = main_block.append_op(type="elementwise_add",
inputs={
'X': grad,
'Y': gradient_merge_var
},
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward
})
dtype=dst_dtype,
persistable=True,
)
ref_process_mesh = ref_dist_attr.process_mesh
ref_dims_mapping = ref_dist_attr.dims_mapping
set_var_dist_attr(
dist_context, gradient_merge_var, ref_dims_mapping, ref_process_mesh
)
# 2.3 grad_merge += grad
grad_name = grad.name
if grad.dtype != dst_dtype:
cast_grad_name = grad_name + "@TMP"
cast_grad_var = main_block.create_var(
name=cast_grad_name,
shape=grad.shape,
dtype=dst_dtype,
persistable=False,
stop_gradient=grad.stop_gradient,
)
set_var_dist_attr(
dist_context, cast_grad_var, ref_dims_mapping, ref_process_mesh
)
cast_op = main_block.append_op(
type="cast",
inputs={"X": grad},
outputs={"Out": cast_grad_var},
attrs={
"in_dtype": grad.dtype,
"out_dtype": dst_dtype,
OP_ROLE_KEY: OpRole.Backward,
},
)
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
cast_op, ref_process_mesh, ref_dims_mapping, dist_context
)
grad = cast_grad_var
new_grad_op = main_block.append_op(
type="elementwise_add",
inputs={'X': grad, 'Y': gradient_merge_var},
outputs={'Out': gradient_merge_var},
attrs={
'axis': -1,
'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward,
},
)
new_params_to_grads.append([param, gradient_merge_var])
grad_to_gradient_merge[grad.name] = gradient_merge_var.name
grad_to_gradient_merge[grad_name] = gradient_merge_var.name
naive_set_dist_op_attr_for_program_by_mesh_and_mapping(
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context)
new_grad_op, ref_process_mesh, ref_dims_mapping, dist_context
)
return new_params_to_grads, grad_to_gradient_merge
def _create_cond_block_and_update_optimizer(
main_program, cond_var, new_params_to_grads: List[Tuple[Any, Any]],
grad_to_gradient_merge: Dict[str, str], optimize_ops_desc: List[Any],
k_steps, avg):
def _rename_arg_names(op_desc, var_name_dict):
for input_name in op_desc.input_arg_names():
if input_name in var_name_dict:
op_desc._rename_input(input_name, var_name_dict[input_name])
for output_name in op_desc.output_arg_names():
if output_name in var_name_dict:
op_desc._rename_output(output_name, var_name_dict[output_name])
def _create_cond_block_and_update_optimizer(
main_program,
cond_var,
params_grads,
new_params_to_grads,
grad_to_gradient_merge,
optimize_ops_desc,
allreduce_sum_desc,
k_steps,
avg,
master_grad,
):
def true_apply_gradient():
cur_block_idx = main_program.current_block_idx
cur_block = main_program.current_block()
# cur_block's forward_block & backward_block is itself
cur_block._set_forward_block_idx(cur_block_idx)
op_maker = core.op_proto_and_checker_maker
# record grads_name to insert c_allreduce_sum op
grads_name = [grad.name for _, grad in params_grads]
# append c_allreduce_sum ops and scale ops
for op_desc in allreduce_sum_desc:
outputs_name = op_desc.output_arg_names()
assert len(outputs_name) == 1
if outputs_name[0] in grads_name:
new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc)
_rename_arg_names(new_op_desc, grad_to_gradient_merge)
new_op_desc._set_attr(OP_ROLE_KEY, OpRole.Optimize)
cur_block._sync_with_cpp()
if avg:
for param, new_grad in new_params_to_grads:
for _, new_grad in new_params_to_grads:
# grad /= k_steps
cur_block.append_op(type='scale',
inputs={'X': new_grad},
outputs={'Out': new_grad},
attrs={
'scale': 1.0 / k_steps,
'bias': 0.0,
'bias_after_scale': False
})
cur_block.append_op(
type='scale',
inputs={'X': new_grad},
outputs={'Out': new_grad},
attrs={
'scale': 1.0 / k_steps,
'bias': 0.0,
'bias_after_scale': False,
},
)
new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
cast_name_dict = {}
# append optimizer ops
for op_desc in optimize_ops_desc:
if master_grad and is_gradient_clip_op(op_desc):
if op_desc.type() == "cast":
if (
op_desc.attr('out_dtype') in [4, 22]
and op_desc.attr('in_dtype') == 5
):
cast_name_dict[
op_desc.output_arg_names()[0]
] = op_desc.input_arg_names()[0]
elif (
op_desc.attr('in_dtype') in [4, 22]
and op_desc.attr('out_dtype') == 5
):
cast_name_dict[
op_desc.output_arg_names()[0]
] = op_desc.input_arg_names()[0]
continue
for out_name in op_desc.output_arg_names():
out_var = cur_block._var_recursive(out_name)
out_var.desc.set_dtype(core.VarDesc.VarType.FP32)
_rename_arg_names(op_desc, cast_name_dict)
new_op_desc = cur_block.desc.append_op()
new_op_desc.copy_from(op_desc)
#update input/output
for input_name in new_op_desc.input_arg_names():
if input_name in grad_to_gradient_merge:
new_op_desc._rename_input(
input_name, grad_to_gradient_merge[input_name])
for output_name in new_op_desc.output_arg_names():
if output_name in grad_to_gradient_merge:
new_op_desc._rename_output(
output_name, grad_to_gradient_merge[output_name])
# update input/output
_rename_arg_names(new_op_desc, grad_to_gradient_merge)
# remove op_role_var
if new_op_desc.has_attr(op_maker.kOpRoleVarAttrName()):
new_op_desc.remove_attr(op_maker.kOpRoleVarAttrName())
if new_op_desc.has_attr(OP_ROLE_VAR_KEY):
new_op_desc.remove_attr(OP_ROLE_VAR_KEY)
# op's update Grad
if core.grad_var_suffix() in new_op_desc.input_arg_names():
grad_value = new_op_desc.input("Grad")[0]
# TODO FIXME(xym) support fp16
grad_merge_value = grad_value + '@GradientMerge'
grad_merge_value = grad_value + '@MERGED'
new_op_desc.set_input("Grad", [grad_merge_value])
main_program.global_block()._sync_with_cpp()
cur_block._sync_with_cpp()
# clear gradient_merge_vars
for param, new_grad in new_params_to_grads:
layers.fill_constant(shape=new_grad.shape,
dtype=new_grad.dtype,
value=0.0,
out=new_grad)
new_grad.op._set_attr(OP_ROLE_KEY, op_maker.OpRole.Optimize)
for _, new_grad in new_params_to_grads:
layers.fill_constant(
shape=new_grad.shape,
dtype=new_grad.dtype,
value=0.0,
out=new_grad,
)
new_grad.op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
layers.cond(cond_var, true_fn=true_apply_gradient, false_fn=None)
cond_op = main_program.global_block().ops[-1]
cond_op._set_attr(OP_ROLE_KEY, OpRole.Optimize)
def parse_program(main_program, startup_program, params_grads, k_steps, avg,
dist_context):
# 1 remove optimizer_op from main_program
optimize_ops_desc = _remove_and_get_optimizer_op(main_program, dist_context)
def parse_program(
main_program,
startup_program,
params_grads,
k_steps,
avg,
master_grad,
dist_context,
):
# 1 remove optimizer_op, allreduce_sum_op and scale_op from main_program
optimize_ops_desc, allreduce_sum_desc = _remove_and_get_ops(
main_program, dist_context
)
# back to block 0
main_program._rollback()
# 2 append gradient merge backward op to main_program
new_params_to_grads, grad_to_gradient_merge = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, dist_context)
(
new_params_to_grads,
grad_to_gradient_merge,
) = _append_gradient_merge_backward_op(
main_program, startup_program, params_grads, master_grad, dist_context
)
# 3 create gradient_merge_cond
cond_var = _get_gm_cond_var(main_program, k_steps, dist_context)
cond_var = _create_gm_cond_var(main_program, k_steps, dist_context)
# 4 create ConditionalBlock and append gradient merge optimizer ops
_create_cond_block_and_update_optimizer(main_program, cond_var,
new_params_to_grads,
grad_to_gradient_merge,
optimize_ops_desc, k_steps, avg)
_create_cond_block_and_update_optimizer(
main_program,
cond_var,
params_grads,
new_params_to_grads,
grad_to_gradient_merge,
optimize_ops_desc,
allreduce_sum_desc,
k_steps,
avg,
master_grad,
)
@register_pass("auto_parallel_gradient_merge_pass")
class GradientMergePass(PassBase):
def __init__(self):
super(GradientMergePass, self).__init__()
self.set_attr("k_steps", -1)
self.set_attr("avg", True)
self.set_attr("master_grad", False)
def _check_self(self):
if self.get_attr("k_steps") < 1:
......@@ -306,10 +452,20 @@ class GradientMergePass(PassBase):
def _apply_single_impl(self, main_program, startup_program, context):
k_steps = self.get_attr("k_steps", -1)
avg = self.get_attr("avg", False)
master_grad = self.get_attr("master_grad", False)
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
# TODO(zyl): make master_grad configurable
master_grad = True
with paddle.static.program_guard(main_program, startup_program):
parse_program(main_program, startup_program, params_grads, k_steps,
avg, dist_context)
parse_program(
main_program,
startup_program,
params_grads,
k_steps,
avg,
master_grad,
dist_context,
)
main_program._sync_with_cpp()
......@@ -12,7 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from logging import exception
import os
from paddle.fluid import core
......@@ -26,6 +25,7 @@ from paddle.distributed.auto_parallel.utils import (
is_backward_op,
is_optimize_op,
is_lr_sched_op,
is_fillconst_op_for_micro_batch,
)
......@@ -38,6 +38,12 @@ __not_shape_var_type__ = [
]
def is_reshard_op(op):
return op.has_attr('op_namescope') and "/auto_parallel/reshard" in op.attr(
'op_namescope'
)
@register_pass("auto_parallel_pipeline")
class PipelinePass(PassBase):
def __init__(self):
......@@ -59,8 +65,17 @@ class PipelinePass(PassBase):
self._gen_bsz = self.get_attr("generation_batch_size")
self._program = main_program
self._cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
self._nrank = len(trainer_endpoints)
# compute current pp stage
self._pp_stages = len(self._dist_context.process_meshes)
self._cur_pp_stage = self._get_pp_stage(self._cur_rank)
if self._mode == "1F1B":
raise NotImplementedError("1F1B has not been implemented")
self._insert_sync_ops_for_1f1b()
self._task_1f1b()
elif self._mode == "F-Then-B":
raise NotImplementedError("F-Then-B has not been implemented")
elif self._mode == "stream":
......@@ -103,6 +118,93 @@ class PipelinePass(PassBase):
block._sync_with_cpp()
def _insert_sync_ops_for_1f1b(self):
"""
This implementation refers to lots of Paddle/python/paddle/fluid/optimizer.py.
The difference between this function with 'PipelineOptimizer' is that
'send_v2' op and 'recv_v2' op have been inserted in program by 'reshard'.
"""
for block in self._program.blocks:
offset = 0
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if is_optimize_op(op):
first_optimize_index = index
break
# insert sync ops
for index, op in enumerate(list(block.ops)):
if op.type == 'send_v2':
# step1: set 'use_calc_stream' False
op._set_attr("use_calc_stream", False)
op_role = op.attr('op_role')
ring_id = op.attr('ring_id')
# step2: insert 'c_sync_calc_stream' op before 'send_v2' op
var_name = op.input_arg_names[0]
var = block.var(var_name)
block._insert_op_without_sync(
index=index + offset,
type="c_sync_calc_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': op_role},
)
offset += 1
# step3: insert 'c_sync_comm_stream' op after 'send_v2' op or
# before the first optimize op
if int(op_role) == int(OpRole.Backward):
index = first_optimize_index + offset
new_op_role = OpRole.Optimize
else:
index = index + offset + 1
new_op_role = OpRole.Backward
sync_comm_op = block._insert_op_without_sync(
index=index,
type="c_sync_comm_stream",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
'op_role': new_op_role,
'ring_id': ring_id,
},
)
# step4: If 'send_v2' op in forward parse, set 'pipeline_flag' to distinguish
# whether the 'c_sync_comm_stream' op is inserted for pipeline.
if int(op_role) == int(OpRole.Forward):
sync_comm_op._set_attr('pipeline_flag', '')
offset += 1
block._sync_with_cpp()
offset = 0
backward_recv_index = None
for index, op in enumerate(block.ops):
if op.type == "recv_v2" and is_backward_op(op):
backward_recv_index = index
break
if backward_recv_index is None:
continue
# replace 'c_sync_comm_stream' op with 'nop' op
for index, op in enumerate(list(block.ops)):
if index >= backward_recv_index:
break
if op.type == 'c_sync_comm_stream' and op.has_attr(
'pipeline_flag'
):
var_name = op.output_arg_names[0]
var = block.var(var_name)
block._remove_op(index + offset, sync=False)
offset -= 1
block._insert_op_without_sync(
index=backward_recv_index,
type="nop",
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={'op_role': OpRole.Backward},
)
block._sync_with_cpp()
def _create_param(self, dst_block, src_var):
copied_kwargs = {}
copied_kwargs['trainable'] = src_var.trainable
......@@ -190,16 +292,185 @@ class PipelinePass(PassBase):
break
return pp_idx
def _task_stream(self):
cur_rank = int(os.getenv("PADDLE_TRAINER_ID", 0))
trainer_endpoints = os.getenv("PADDLE_TRAINER_ENDPOINTS", "").split(',')
nrank = len(trainer_endpoints)
num_of_functionality = 5
def _task_1f1b(self):
# create fwd, bwd, opt program with op_role
num_of_functionality = 4
lr_prog = Program()
fwd_prog = Program()
bwd_prog = Program()
opt_prog = Program()
for idx, src_block in enumerate(self._program.blocks):
if idx == 0:
lr_block = lr_prog.block(0)
fwd_block = fwd_prog.block(0)
bwd_block = bwd_prog.block(0)
opt_block = opt_prog.block(0)
else:
lr_block = lr_prog._create_block(
parent_idx=src_block.parent_idx
)
fwd_block = fwd_prog._create_block(
parent_idx=src_block.parent_idx
)
bwd_block = bwd_prog._create_block(
parent_idx=src_block.parent_idx
)
opt_block = opt_prog._create_block(
parent_idx=src_block.parent_idx
)
lr_block._set_forward_block_idx(src_block.forward_block_idx)
fwd_block._set_forward_block_idx(src_block.forward_block_idx)
bwd_block._set_forward_block_idx(src_block.forward_block_idx)
opt_block._set_forward_block_idx(src_block.forward_block_idx)
# split the program based on the op_role
for op in src_block.ops:
if is_lr_sched_op(op):
self._create_program(src_block, lr_block, op)
if is_forward_op(op) or is_fillconst_op_for_micro_batch(op):
self._create_program(src_block, fwd_block, op)
elif is_backward_op(op):
self._create_program(src_block, bwd_block, op)
elif is_optimize_op(op):
self._create_program(src_block, opt_block, op)
else:
raise ValueError(
"The op role: "
+ str(op.attr('op_role'))
+ " isn't one of LRSched, Forward, Backward or Optimizer."
)
# compute current pp stage
pp_stages = len(self._dist_context.process_meshes)
cur_pp_stage = self._get_pp_stage(cur_rank)
lr_prog._sync_with_cpp()
fwd_prog._sync_with_cpp()
bwd_prog._sync_with_cpp()
opt_prog._sync_with_cpp()
lr_prog._rollback()
fwd_prog._rollback()
bwd_prog._rollback()
opt_prog._rollback()
# Create task nodes.
lr_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=lr_prog,
task_id=int(self._cur_rank * num_of_functionality + 0),
node_type="Amplifier",
lazy_initialize=True,
)
lr_task_node.set_run_pre_steps(self._acc_steps)
fwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=fwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 1),
node_type="Compute",
lazy_initialize=True,
)
bwd_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=bwd_prog,
task_id=int(self._cur_rank * num_of_functionality + 2),
node_type="Compute",
lazy_initialize=True,
)
opt_task_node = TaskNode(
rank=self._cur_rank,
max_run_times=self._acc_steps,
program=opt_prog,
task_id=int(self._cur_rank * num_of_functionality + 3),
node_type="Amplifier",
lazy_initialize=True,
)
opt_task_node.set_run_pre_steps(self._acc_steps)
opt_task_node.set_run_at_offset(self._acc_steps - 1)
task_nodes = {
"lr": lr_task_node,
"fwd": fwd_task_node,
"bwd": bwd_task_node,
"opt": opt_task_node,
}
# get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams
pp_upstream = up_down_streams.ups(self._cur_rank)
pp_downstream = up_down_streams.downs(self._cur_rank)
# set upstream/downstream for task_nodes of cur_rank
for i, (task_role, task_node) in enumerate(task_nodes.items()):
cur_id = int(self._cur_rank * num_of_functionality + i)
ups = []
downs = []
# set upstream/downstream and buffersize in pipeline stage
pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
prev_id = cur_id - 1
next_id = cur_id + 1
if task_role != "lr":
buf_size = pp_buff_size if task_role == "bwd" else 2
ups.append((prev_id, buf_size))
if task_role != "opt":
buf_size = pp_buff_size if task_role == "fwd" else 2
downs.append((next_id, buf_size))
# set upstream/downstream and buffersize cross pipeline stage
for upstream in pp_upstream:
upstream_id = int(upstream * num_of_functionality + i)
if task_role == "fwd":
if upstream != -1:
ups.append((upstream_id, 2))
elif task_role == "bwd":
if upstream != -1:
downs.append((upstream_id, 2))
for downstream in pp_downstream:
downstream_id = int(downstream * num_of_functionality + i)
if task_role == "fwd":
if downstream != -1:
downs.append((downstream_id, 2))
elif task_role == "bwd":
if downstream != -1:
ups.append((downstream_id, 2))
for up in ups:
print(
"Task:",
cur_id,
"'s upstream includes:",
up[0],
", buffer size is:",
up[1],
)
task_node.add_upstream_task(up[0], up[1])
for down in downs:
print(
"Task:",
cur_id,
"'s downstream includes:",
down[0],
", buffer size is:",
down[1],
)
task_node.add_downstream_task(down[0], down[1])
# record global message: task_id_to_rank
task_id_to_rank = {}
for i in range(self._nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = {}
self._program._pipeline_opt['fleet_opt'] = {
"tasks": list(task_nodes.values()),
"task_id_to_rank": task_id_to_rank,
"num_micro_batches": self._acc_steps,
}
def _task_stream(self):
num_of_functionality = 5
start_prog = Program()
cond_prog = Program()
end_prog = Program()
......@@ -207,6 +478,7 @@ class PipelinePass(PassBase):
recv_prog = Program()
cond_var_name = None
# record the varnames related to the while cond vars and communicate by nccl
send_vars_name = set()
recv_vars_name = dict()
for ib, src_block in enumerate(self._program.blocks):
......@@ -231,38 +503,23 @@ class PipelinePass(PassBase):
src_block, end_block, op, force_create=True
)
elif ib == 1:
# NOTE: The while block will be split to two separate blocks.
# The send_block:
# include all ops about tansformer generation
# execlude the nccl op about the while cond var
# The recv_block:
# include all ops about the while cond var
# execlude the nccl op about the while cond var
# the nccl op about cond var:
# put these varnames in the task node and do communication by brpc
send_block = send_prog.block(0)
recv_block = recv_prog.block(0)
is_after_send_op = False
is_after_recv_op = False
for op in src_block.ops:
for i, op in enumerate(src_block.ops):
if op.type == "send_v2" and not is_after_send_op:
is_after_send_op = True
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
continue
if (
op.type not in ["recv_2", "assign"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
):
if (
len(op.desc.input_arg_names()) > 0
and "@RESHARD"
not in op.desc.input_arg_names()[0]
):
send_vars_name.add(
op.desc.input_arg_names()[0]
)
continue
if op.type == "send_v2":
continue
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if (
is_after_send_op
......@@ -270,45 +527,21 @@ class PipelinePass(PassBase):
and op.type == "recv_v2"
):
is_after_recv_op = True
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
old_var_name = var_name[:index]
else:
old_var_name = var_name
recv_vars_name[var_name] = old_var_name
if not src_block._find_var_recursive(old_var_name):
src_var = src_block._var_recursive(var_name)
recv_block.create_var(
type=src_var.type,
name=old_var_name,
shape=src_var.shape,
dtype=src_var.dtype,
lod_level=src_var.lod_level,
persistable=src_var.persistable,
error_clip=src_var.error_clip,
stop_gradient=src_var.stop_gradient,
is_data=src_var.is_data,
belong_to_optimizer=src_var.belong_to_optimizer,
)
continue
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
if not is_after_send_op or not is_after_recv_op:
if cur_pp_stage == pp_stages - 1:
if op.type in ["c_sync_calc_stream", "nop"]:
if self._cur_pp_stage == self._pp_stages - 1:
# the c_sync_calc_stream about c_allgather cannot be removed
if (
op.type == "c_sync_calc_stream"
and src_block.ops[i + 1].type == "send_v2"
):
continue
if op.type == "nop":
continue
# HACKCODE: the varname of send_v2 op, cast op should be recorded for brpc comm
if (
op.type not in ["recv_2", "assign"]
op.type
not in ["recv_2", "assign", "c_allgather"]
and op.has_attr('op_namescope')
and "/auto_parallel/reshard"
in op.attr('op_namescope')
......@@ -327,13 +560,16 @@ class PipelinePass(PassBase):
self._create_program(
src_block, send_block, op, force_create=True
)
continue
if is_after_send_op and is_after_recv_op:
# HACKCODE: the varname of recv_v2 op, assign op should be recorded for brpc comm
if op.has_attr(
'op_namescope'
) and "/auto_parallel/reshard" in op.attr(
'op_namescope'
):
# remove the suffix of "@RESHARD"
var_name = op.desc.output_arg_names()[0]
index = var_name.find("@")
if index > 0:
......@@ -365,6 +601,7 @@ class PipelinePass(PassBase):
self._create_program(
src_block, recv_block, op, force_create=True
)
continue
else:
raise Exception("Only support generation condition.")
......@@ -406,52 +643,52 @@ class PipelinePass(PassBase):
vars_to_shape = recv_task_node_var_shape
start_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Start",
task_id=int(cur_rank * num_of_functionality + 0),
task_id=int(self._cur_rank * num_of_functionality + 0),
program=start_prog,
lazy_initialize=True,
)
cond_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Cond",
task_id=int(cur_rank * num_of_functionality + 1),
task_id=int(self._cur_rank * num_of_functionality + 1),
program=cond_prog,
cond_var_name=cond_var_name,
lazy_initialize=True,
)
send_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 2),
task_id=int(self._cur_rank * num_of_functionality + 2),
program=send_prog,
lazy_initialize=True,
)
recv_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 3),
task_id=int(self._cur_rank * num_of_functionality + 3),
program=recv_prog,
lazy_initialize=True,
vars_to_dtype=vars_to_dtype,
vars_to_shape=vars_to_shape,
)
end_task_node = TaskNode(
rank=cur_rank,
rank=self._cur_rank,
max_run_times=self._acc_steps,
node_type="Compute",
task_id=int(cur_rank * num_of_functionality + 4),
task_id=int(self._cur_rank * num_of_functionality + 4),
program=end_prog,
lazy_initialize=True,
)
# add dependencies for task nodes intra stage
inf = -1
pp_buff_size = int(pp_stages - cur_pp_stage)
pp_buff_size = int(self._pp_stages - self._cur_pp_stage)
start_task_node.add_downstream_task(
cond_task_node.task_id(), self._gen_bsz
)
......@@ -560,12 +797,12 @@ class PipelinePass(PassBase):
# add dependencies for task nodes inter stage
# get upstream ranks and downstream ranks of cur_rank
up_down_streams = self._dist_context.up_down_streams
pp_upstream_ranks = up_down_streams.ups(cur_rank)
pp_downstream_ranks = up_down_streams.downs(cur_rank)
pp_upstream = up_down_streams.ups(self._cur_rank)
pp_downstream = up_down_streams.downs(self._cur_rank)
for upstream_rank in pp_upstream_ranks:
for upstream_rank in pp_upstream:
upstream_pp_stage = self._get_pp_stage(upstream_rank)
if upstream_pp_stage < pp_stages - 1:
if upstream_pp_stage < self._pp_stages - 1:
upstream_task_id = int(upstream_rank * num_of_functionality + 2)
send_task_node.add_upstream_task(upstream_task_id)
print(
......@@ -587,8 +824,8 @@ class PipelinePass(PassBase):
", buffer size is:",
2,
)
for downstream_rank in pp_downstream_ranks:
if cur_pp_stage < pp_stages - 1:
for downstream_rank in pp_downstream:
if self._cur_pp_stage < self._pp_stages - 1:
downstream_task_id = int(
downstream_rank * num_of_functionality + 2
)
......@@ -616,7 +853,7 @@ class PipelinePass(PassBase):
)
task_id_to_rank = {}
for i in range(nrank):
for i in range(self._nrank):
for j in range(num_of_functionality):
task_id_to_rank[int(i * num_of_functionality + j)] = i
self._program._pipeline_opt = {
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import random
import numpy as np
import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
def apply_pass(use_1f1b=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_1f1b:
pipeline = strategy.pipeline
pipeline.enable = True
pipeline.schedule_mode = "1F1B"
pipeline.accumulate_steps = 2
else:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 2
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
def reset_prog():
paddle.fluid.framework.switch_main_program(paddle.static.Program())
paddle.fluid.framework.switch_startup_program(paddle.static.Program())
class Test1F1BPass(unittest.TestCase):
def setUp(self):
self.rtol = 1e-5
self.atol = 1e-8
self.batch_size = 2
self.batch_num = 10
self.clip_norm = 0.2
self.dataset = FakeDataset(self.batch_size * self.batch_num)
def init(self, engine):
paddle.seed(2021)
np.random.seed(2021)
random.seed(2021)
paddle.distributed.fleet.init(is_collective=True)
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
engine._executor = paddle.static.Executor(place)
def get_engine(self, use_1f1b=False):
reset_prog()
strategy = apply_pass(use_1f1b)
clip = paddle.nn.ClipGradByGlobalNorm(self.clip_norm)
opt = paddle.optimizer.AdamW(learning_rate=0.00001, grad_clip=clip)
model, loss = generate_model("pp")
engine = auto.Engine(model, loss, opt, strategy=strategy)
self.init(engine)
return engine
def check_results(self, ref_losses, check_losses):
np.testing.assert_allclose(
ref_losses,
check_losses,
rtol=self.rtol,
atol=self.atol,
err_msg='pass {} has wrong results!, \nu={}\nv={}\ndiff={}'.format(
__class__, ref_losses, check_losses, ref_losses - check_losses
),
)
def test_1f1b_pass(self):
# navie_pp+gradient_merge training
engine_pp = self.get_engine()
history = engine_pp.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_pp._strategy.pipeline.enable == False
# pp2 1f1b merge training
engine_1f1b = self.get_engine(True)
history = engine_1f1b.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert engine_1f1b._strategy.pipeline.enable == True
# NOTE: every sample data from dataset is all the same
if paddle.distributed.get_rank() == 1:
losses_pp = np.array(history.history["loss"])
losses_1f1b = np.array(history.history["loss"])
self.check_results(losses_pp, losses_1f1b)
if __name__ == "__main__":
unittest.main()
......@@ -69,6 +69,9 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_engine_callbacks MODULES test_engine_callbacks)
set_tests_properties(test_engine_callbacks
PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE" TIMEOUT 50)
py_test_modules(test_pass_1F1B MODULES test_pass_1F1B)
set_tests_properties(test_pass_1F1B PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE"
TIMEOUT 50)
py_test_modules(test_parallel_tuner MODULES test_parallel_tuner ENVS
${dist_ENVS})
......
......@@ -89,6 +89,12 @@ def generate_model(strategy, dropout_prob=0.0):
modeling._global_parallel_strategy = "mp"
elif strategy == "dp":
modeling._global_parallel_strategy = "dp"
elif strategy == "pp":
modeling._global_parallel_strategy = "pp"
modeling.PP_MESH_LIST = [
auto.ProcessMesh(mesh=[0]),
auto.ProcessMesh(mesh=[1]),
]
else:
raise ValueError("Only support serial, mp2 and dp2.")
......@@ -108,6 +114,7 @@ def generate_model(strategy, dropout_prob=0.0):
eos_token_id=7,
bos_token_id=0,
eol_token_id=3,
pp_degree=2 if strategy == "pp" else None,
)
model = GPTForPretraining(
gpt, vocab_size=1000, hidden_size=64, initializer_range=0.02
......
......@@ -19,7 +19,7 @@ import paddle
from paddle.distributed.fleet import auto
from paddle.fluid.dygraph.parallel import ParallelEnv
from get_gpt_model import generate_model, create_data_holder, FakeDataset
from get_gpt_model import generate_model, FakeDataset
paddle.enable_static()
......@@ -28,12 +28,25 @@ def apply_pass(use_gradient_merge=False):
strategy = auto.Strategy()
strategy.auto_mode = "semi"
strategy.reinit = True
if use_gradient_merge:
gradient_merge = strategy.gradient_merge
gradient_merge.enable = True
gradient_merge.k_steps = 4
gradient_merge.avg = True
amp = strategy.amp
amp.enable = True
amp.custom_white_list = ['softmax', 'layer_norm', 'gelu']
amp.custom_black_list = [
'c_softmax_with_cross_entropy',
'elementwise_div',
'reduce_sum',
]
amp.init_loss_scaling = 32768
amp.use_fp16_guard = False
amp.use_pure_fp16 = True
return strategy
......@@ -88,6 +101,7 @@ class TestGradientMergePass(unittest.TestCase):
history = dp_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert dp_engine._strategy.gradient_merge.enable == False
dp_losses = np.array(history.history["loss"])
# dp2 gradient merge training
......@@ -95,6 +109,7 @@ class TestGradientMergePass(unittest.TestCase):
history = gm_engine.fit(
self.dataset, 3, batch_size=self.batch_size, log_freq=1
)
assert gm_engine._strategy.gradient_merge.enable == True
gm_losses = np.array(history.history["loss"])
# avg_loss = 0
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import tempfile
import unittest
import os
import sys
import shutil
import subprocess
from paddle.distributed.fleet.launch_utils import run_with_coverage
class Test1F1BPass(unittest.TestCase):
def test_pp2(self):
file_dir = os.path.dirname(os.path.abspath(__file__))
launch_model_path = os.path.join(file_dir, "1F1B_pass_unittest.py")
if os.environ.get("WITH_COVERAGE", "OFF") == "ON":
coverage_args = ["-m", "coverage", "run", "--branch", "-p"]
else:
coverage_args = []
tmp_dir = tempfile.TemporaryDirectory()
cmd = (
[sys.executable, "-u"]
+ coverage_args
+ [
"-m",
"paddle.distributed.launch",
"--devices",
"0,1",
"--log_dir",
tmp_dir.name,
launch_model_path,
]
)
process = subprocess.Popen(cmd)
process.wait()
self.assertEqual(process.returncode, 0)
tmp_dir.cleanup()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册