未验证 提交 c3974d0e 编写于 作者: L lilong12 提交者: GitHub

[3D-parallel] Reformat pipeline parallel (#31786)

* update, test=develop
上级 01aa2526
......@@ -39,13 +39,13 @@ void SectionWorker::RunForward(
int op_role = op->Attr<int>(std::string("op_role"));
// We run op with op_role = kLRSched only for the first microbatch
// to avoid increasing the @LR_DECAY_STEP@ multiple times.
bool run_first_mbatch = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)) ||
op_role == static_cast<int>(OpRole::kLRSched);
bool run_others = op_role == static_cast<int>(OpRole::kForward) ||
op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss));
bool run_first_mbatch = (op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss))) ||
(op_role == static_cast<int>(OpRole::kLRSched));
bool run_others = (op_role == static_cast<int>(OpRole::kForward)) ||
(op_role == (static_cast<int>(OpRole::kForward) |
static_cast<int>(OpRole::kLoss)));
if ((micro_id == 0 && run_first_mbatch) || (micro_id != 0 && run_others)) {
VLOG(3) << "Forward: running op " << op->Type() << " for micro-batch "
<< micro_id;
......@@ -64,9 +64,9 @@ void SectionWorker::RunBackward(
&unused_vars_) {
for (auto &op : ops_) {
int op_role = op->Attr<int>(std::string("op_role"));
if (op_role == static_cast<int>(OpRole::kBackward) ||
op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss))) {
if ((op_role == static_cast<int>(OpRole::kBackward)) ||
(op_role == (static_cast<int>(OpRole::kBackward) |
static_cast<int>(OpRole::kLoss)))) {
VLOG(3) << "Backward: running op " << op->Type() << " for micro-batch "
<< micro_id;
op->Run(*microbatch_scopes_[micro_id], place_);
......
......@@ -47,7 +47,7 @@ def is_optimizer_op(op):
class CollectiveHelper(object):
def __init__(self, role_maker, nrings=1, wait_port='6174'):
def __init__(self, role_maker, nrings=1, wait_port=True):
self.nrings = nrings
self.wait_port = wait_port
self.role_maker = role_maker
......@@ -65,14 +65,48 @@ class CollectiveHelper(object):
self.role_maker._worker_index(), ring_id, self.wait_port)
self._broadcast_params()
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
def _init_communicator(self,
program,
current_endpoint,
endpoints,
rank,
ring_id,
wait_port,
global_ring_id=None,
sync=True):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
def _add_sync_by_allreduce(block):
sync_var = block.create_var(
name=unique_name.generate('sync_var'),
dtype=core.VarDesc.VarType.INT32,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
inputs={},
outputs={'Out': [sync_var]},
attrs={
'shape': [1],
'dtype': sync_var.dtype,
'value': 1,
'force_cpu': False,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_allreduce_sum',
inputs={'X': [sync_var]},
outputs={'Out': [sync_var]},
attrs={
'ring_id': global_ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward
})
block = program.global_block()
if core.is_compiled_with_cuda():
comm_id_var = block.create_var(
......@@ -128,6 +162,7 @@ class CollectiveHelper(object):
raise ValueError(
"comm_id must be generated in paddlepaddle-xpu or paddlepaddle-xpu."
)
if sync: _add_sync_by_allreduce(block)
def _wait(self, current_endpoint, endpoints):
assert (self.wait_port)
......
......@@ -19,130 +19,21 @@ from paddle.fluid import core, unique_name
from ..base.private_helper_function import wait_server_ready
from paddle.fluid.optimizer import PipelineOptimizer as PO
from .meta_optimizer_base import MetaOptimizerBase
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_update_op, is_loss_grad_op, is_backward_op, is_optimizer_op
def _get_node_num(endpoints):
ss = set()
for ep in endpoints:
ip = ep.split(":")[0].strip()
if ip not in ss:
ss.add(ip)
return len(ss)
class PipelineHelper(object):
def __init__(self, role_maker, wait_port='6174'):
self.wait_port = wait_port
self.role_maker = role_maker
def update_startup_program(self,
startup_program=None,
inner_parallelism=None):
self.startup_program = startup_program
nranks = self.role_maker._worker_num()
rank = self.role_maker._worker_index()
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[rank]
node_num = _get_node_num(endpoints)
assert nranks % node_num == 0
# Create ring 0 for all gpus in the same pipeline
if inner_parallelism > 1:
pipeline_rank = rank % inner_parallelism
pipeline_id = rank // inner_parallelism
start_index = pipeline_id * inner_parallelism
pipeline_endpoints = endpoints[start_index:start_index +
inner_parallelism]
self._init_communicator(self.startup_program, current_endpoint,
pipeline_endpoints, pipeline_rank, 0,
self.wait_port)
pipeline_num = len(endpoints) // inner_parallelism
if pipeline_num == 1: return
# Create rings for gpus with the same pipeline id for data parallel
eps = []
pipeline_rank = rank % inner_parallelism
ring_id = pipeline_rank + 1
for i in range(pipeline_num):
eps.append(endpoints[i * inner_parallelism + pipeline_rank])
# rank in a ring of gpus with the same pipeline id for data parallel
dp_rank = rank // inner_parallelism
self._init_communicator(self.startup_program, current_endpoint, eps,
dp_rank, ring_id, self.wait_port)
self._broadcast_params(ring_id)
def _init_communicator(self, program, current_endpoint, endpoints, rank,
ring_id, wait_port):
nranks = len(endpoints)
other_endpoints = endpoints[:]
other_endpoints.remove(current_endpoint)
if rank == 0 and wait_port:
wait_server_ready(other_endpoints)
block = program.global_block()
nccl_id_var = block.create_var(
name=unique_name.generate('nccl_id'),
persistable=True,
type=core.VarDesc.VarType.RAW)
block.append_op(
type='c_gen_nccl_id',
inputs={},
outputs={'Out': nccl_id_var},
attrs={
'rank': rank,
'endpoint': current_endpoint,
'other_endpoints': other_endpoints,
OP_ROLE_KEY: OpRole.Forward,
})
block.append_op(
type='c_comm_init',
inputs={'X': nccl_id_var},
outputs={},
attrs={
'nranks': nranks,
'rank': rank,
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward,
})
def _broadcast_params(self, ring_id):
block = self.startup_program.global_block()
for var_name in block.vars:
if "nccl_id" in var_name: continue
param = block.var(var_name)
if not param.persistable:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_loss_grad_op, is_backward_op, is_optimizer_op
class PipelineOptimizer(MetaOptimizerBase):
def __init__(self, optimizer):
super(PipelineOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently
self.meta_optimizers_white_list = [
"RecomputeOptimizer",
"AMPOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self.global_ring_id = 1
self.dp_ring_id = 2
self.start_pipeline_ring_id = 20 # Just a magic number
def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy):
......@@ -165,7 +56,11 @@ class PipelineOptimizer(MetaOptimizerBase):
def _disable_strategy(self, dist_strategy):
dist_strategy.pipeline = False
dist_strategy.pipeline_configs = {}
dist_strategy.pipeline_configs = {
"micro_batch_size": 1,
"accumulate_steps": 1,
"schedule_mode": "1F1B",
}
def _enable_strategy(self, dist_strategy, context):
dist_strategy.pipeline = True
......@@ -175,61 +70,134 @@ class PipelineOptimizer(MetaOptimizerBase):
"schedule_mode": "1F1B",
}
def _broadcast_params(self, ring_id):
block = self.startup_program.global_block()
param = None
for param in block.iter_parameters():
if param.is_distributed:
continue
block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
attrs={
'ring_id': ring_id,
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
if not param: return # no parameter on this device
block.append_op(
type='c_sync_comm_stream',
inputs={'X': param},
outputs={'Out': param},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Forward})
def _get_process_group_info(self):
# global ring info
self.global_endpoints = self.endpoints
self.global_rank = self.rank
self.global_nranks = self.nranks
# data parallel ring info
if self.pipeline_num > 1:
self.dp_rank = self.rank // self.inner_parallelism
self.dp_nranks = self.nranks // self.inner_parallelism
start_index = self.rank % self.inner_parallelism
self.dp_endpoints = [
self.endpoints[start_index + i * self.inner_parallelism]
for i in range(self.pipeline_num)
]
def _init_process_group(self, pipeline_pair, pipeline_ring_map):
self._get_process_group_info()
collective_helper = CollectiveHelper(self.role_maker, wait_port=False)
# Create global ring for all gpus (ring_id = 0)
collective_helper._init_communicator(
self.startup_program, self.current_endpoint, self.global_endpoints,
self.global_rank, self.global_ring_id, True, self.global_ring_id,
True)
# Create pipeline rings
if self.inner_parallelism > 1:
pipeline_id = self.rank // self.inner_parallelism
start_index = pipeline_id * self.inner_parallelism
for pair in pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = pipeline_ring_map[pair_key]
assert ring_id >= self.start_pipeline_ring_id
first_node = pair[0] + start_index
second_node = pair[1] + start_index
if self.rank != first_node and self.rank != second_node:
continue
pipeline_endpoints = [
self.endpoints[first_node], self.endpoints[second_node]
]
pipeline_rank = 0 if self.rank == first_node else 1
pipeline_nranks = 2
collective_helper._init_communicator(
self.startup_program, self.current_endpoint,
pipeline_endpoints, pipeline_rank, ring_id, False,
self.global_ring_id, True)
# Create dp rings
if self.pipeline_num > 1:
collective_helper._init_communicator(
self.startup_program, self.current_endpoint, self.dp_endpoints,
self.dp_rank, self.dp_ring_id, True, self.global_ring_id, True)
self._broadcast_params(self.dp_ring_id)
def minimize_impl(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
endpoints = self.role_maker._get_trainer_endpoints()
current_endpoint = endpoints[self.role_maker._worker_index()]
self.wrapped_opt = PO(self.inner_opt,
num_microbatches=self.num_microbatches)
node_num = _get_node_num(endpoints)
gpus_per_node = len(endpoints) // node_num
self.startup_program = startup_program
if startup_program is None:
self.startup_program = fluid.default_startup_program()
self.endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.endpoints[self.role_maker._worker_index()]
self.rank = self.role_maker._worker_index()
self.nranks = self.role_maker._worker_num()
assert self.nranks % node_num == 0
loss.block.program._pipeline_opt = dict()
loss.block.program._pipeline_opt['local_rank'] = self.rank
loss.block.program._pipeline_opt[
'micro_batch_size'] = self.micro_batch_size
loss.block.program._pipeline_opt['schedule_mode'] = self.schedule_mode
optimize_ops, params_grads, prog_list = self.wrapped_opt.minimize(
self.wrapped_opt = PO(self.inner_opt,
num_microbatches=self.num_microbatches)
orig_startup_program = startup_program if startup_program else fluid.default_startup_program(
)
block = loss.block
program = block.program
program._pipeline_opt = dict()
program._pipeline_opt['local_rank'] = self.rank
program._pipeline_opt['global_ring_id'] = self.global_ring_id
program._pipeline_opt['ring_id'] = self.start_pipeline_ring_id
program._pipeline_opt['micro_batch_size'] = self.micro_batch_size
program._pipeline_opt['schedule_mode'] = self.schedule_mode
optimize_ops, params_grads, prog_list, pp_pair, ring_map = self.wrapped_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
assert prog_list
self.main_program_list = prog_list
self.main_program = loss.block.program
self.inner_parallelism = loss.block.program._pipeline_opt[
'inner_parallelism']
self.startup_program = orig_startup_program._pipeline_opt[
'startup_program']
self.inner_parallelism = program._pipeline_opt['inner_parallelism']
assert self.nranks % self.inner_parallelism == 0
assert prog_list
self.pipeline_num = len(self.endpoints) // self.inner_parallelism
pipeline_helper = PipelineHelper(self.role_maker)
pipeline_helper.update_startup_program(
self.startup_program._pipeline_opt["startup_program"],
self.inner_parallelism)
self._init_process_group(pp_pair, ring_map)
pipeline_num = self.nranks // self.inner_parallelism
self._transpile_main_program(loss, pipeline_num, self.inner_parallelism)
self.main_program_list = prog_list
self.main_program = program
if self.pipeline_num > 1:
self._transpile_main_program(loss)
return optimize_ops, params_grads
def _transpile_main_program(self, loss, pipeline_num, inner_parallelism):
if pipeline_num <= 1: return
self._insert_loss_grad_ops(loss, pipeline_num)
for ring_id in range(1, inner_parallelism + 1):
self._insert_allreduce_ops(ring_id)
def _transpile_main_program(self, loss):
self._insert_loss_grad_ops(loss, self.pipeline_num)
self._insert_allreduce_ops(self.dp_ring_id)
def _insert_loss_grad_ops(self, loss, pipeline_num):
"""
In order to keep the learning rate consistent in different numbers of
training workers, we scale the loss grad by the number of workers
"""
block = self.main_program_list[-1]['program'].global_block()
block = self.main_program_list[-1].global_block()
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
......@@ -244,57 +212,53 @@ class PipelineOptimizer(MetaOptimizerBase):
})
def _insert_allreduce_ops(self, ring_id):
block = self.main_program_list[ring_id - 1]['program'].global_block()
block = self.main_program._pipeline_opt['section_program'].global_block(
)
origin_block = self.main_program.global_block()
grad = None
processed_param_name = set()
first_optimize_op_idx = None
add_sync_calc_stream = False
for idx, op in reversed(list(enumerate(block.ops))):
if is_backward_op(op) and not first_optimize_op_idx:
first_optimize_op_idx = idx + 1
# no optimize phase
if first_optimize_op_idx == len(block.ops): return
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) == 0:
continue
assert len(op_role_var) % 2 == 0
offset = idx
offset = 0
for i in range(0, len(op_role_var), 2):
param_name = op_role_var[i]
param = block.vars[op_role_var[i]]
if param_name in processed_param_name: continue
processed_param_name.add(param_name)
grad = block.vars[op_role_var[i + 1]]
grad_name = op_role_var[i + 1]
if not 'MERGED' in grad_name: grad_name += '@MERGED'
grad = block.vars[grad_name]
origin_param = origin_block.vars[op_role_var[i]]
if origin_param.is_distributed:
continue
if offset == idx:
offset += 1
if not add_sync_calc_stream:
add_sync_calc_stream = True
block._insert_op(
offset,
first_optimize_op_idx + offset,
type='c_sync_calc_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={OP_ROLE_KEY: OpRole.Backward})
attrs={OP_ROLE_KEY: OpRole.Optimize})
offset += 1
block._insert_op(
offset,
first_optimize_op_idx + offset,
type='c_allreduce_sum',
inputs={'X': grad},
outputs={'Out': grad},
attrs={
'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
if grad is None:
return
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
block._insert_op(
idx,
type='c_sync_comm_stream',
inputs={'X': grad},
outputs={'Out': grad},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
break
......@@ -123,7 +123,8 @@ def _insert_cast_op(block, op, idx, src_dtype, dest_dtype):
outputs={"Out": out_var},
attrs={
"in_dtype": in_var.dtype,
"out_dtype": out_var.dtype
"out_dtype": out_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1
_rename_arg(op, in_var.name, out_var.name)
......@@ -171,8 +172,11 @@ def _insert_cast_post_op(block, op, idx, src_dtype, dest_dtype, target_name,
type="cast",
inputs={"X": target_var},
outputs={"Out": cast_var},
attrs={"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype})
attrs={
"in_dtype": target_var.dtype,
"out_dtype": cast_var.dtype,
"op_device": op.attr("op_device")
})
num_cast_ops += 1
op_var_rename_map[block.idx][target_var.name] = cast_var.name
......
......@@ -427,7 +427,7 @@ class Section(DeviceWorker):
section_param.schedule_mode = schedule_mode
cfg = section_param.section_config
program = pipeline_opt["section_program"]
cfg.program_desc.ParseFromString(program["program"]._get_desc()
cfg.program_desc.ParseFromString(program._get_desc()
.serialize_to_string())
# TODO: why does not work
# cfg.program_desc.CopyFrom(program.program._get_desc())
......
......@@ -1458,7 +1458,7 @@ class Executor(object):
dataset._prepare_to_run()
real_fetch_list = []
if program._pipeline_opt:
real_program = program._pipeline_opt["section_program"]['program']
real_program = program._pipeline_opt["section_program"]
for fetch_var in fetch_list:
if isinstance(fetch_var, Variable):
fetch_var_name = fetch_var.name
......@@ -1467,13 +1467,20 @@ class Executor(object):
if fetch_var_name in real_program.global_block().vars:
real_fetch_list.append(fetch_var)
program._pipeline_opt["section_program"][
'program'] = self._add_feed_fetch_ops(
program=program._pipeline_opt["section_program"]['program'],
program._pipeline_opt["section_program"] = self._add_feed_fetch_ops(
program=program._pipeline_opt["section_program"],
feed=[],
fetch_list=real_fetch_list,
feed_var_name='feed',
fetch_var_name='fetch')
main_block = program._pipeline_opt["section_program"].block(0)
for op in main_block.ops:
# set the op_role of fetch op to Optimize to avoid
# erase the fetched vars by gc for pipeline
if op.type == 'fetch':
op._set_attr(
'op_role',
core.op_proto_and_checker_maker.OpRole.Optimize)
fetch_list = None
scope, trainer = self._prepare_trainer(
......
......@@ -3784,6 +3784,12 @@ class PipelineOptimizer(object):
"Optimizer, but the given type is {}.".format(
type(optimizer)))
self._optimizer = optimizer
# Get the original optimizer defined by users, such as SGD
self._origin_optimizer = self._optimizer
while hasattr(self._origin_optimizer, "inner_opt"):
self._origin_optimizer = self._origin_optimizer.inner_opt
assert num_microbatches >= 1, (
"num_microbatches must be a positive value.")
self._num_microbatches = num_microbatches
......@@ -3797,13 +3803,98 @@ class PipelineOptimizer(object):
self._op_role_var_key = op_maker.kOpRoleVarAttrName()
self._op_device_key = op_maker.kOpDeviceAttrName()
self._param_device_map = None
self._pipeline_pair = []
self._pp_ring_map = dict()
self._global_ring_id = None
# insert allreduce op to sync global information for global
# gradient clip and amp
def _insert_allreduce_op(self, op_idx, block):
"""
Insert allreduce op to sync global information for global
gradient clip and amp.
"""
op = block.ops[op_idx]
out_name = op.desc.output_arg_names()[0]
out_var = block.var(out_name)
offset = 0
if op.type == "reduce_any":
# cast the bool var to int32 to use allreduce_max op
temp_var_name = unique_name.generate(out_name + "_cast_int32")
temp_var = block.create_var(
name=temp_var_name, shape=[1], dtype="int32")
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': out_var},
outputs={'Out': temp_var},
attrs={
'in_dtype': out_var.dtype,
'out_dtype': temp_var.dtype,
self._op_role_key: self._op_role.Optimize
})
offset += 1
block._insert_op(
op_idx + 1 + offset,
type='c_allreduce_max'
if op.type == "reduce_any" else 'c_allreduce_sum',
inputs={'X': temp_var if op.type == "reduce_any" else out_var},
outputs={'Out': temp_var if op.type == "reduce_any" else out_var},
attrs={
'ring_id': self._global_ring_id,
self._op_role_key: self._op_role.Optimize,
'use_calc_stream': True
})
offset += 1
if op.type == "reduce_any":
block._insert_op(
op_idx + 1 + offset,
type='cast',
inputs={'X': temp_var},
outputs={'Out': out_var},
attrs={
'in_dtype': temp_var.dtype,
'out_dtype': out_var.dtype,
self._op_role_key: self._op_role.Optimize
})
return offset
def _create_vars(self, block, ori_block):
# Create vars for block, copied from main_program's global block
# Create vars for block, copied from ori_block
used_var_set = set()
for op_idx in range(block.desc.op_size()):
op_desc = block.desc.op(op_idx)
vars = op_desc.input_arg_names() + op_desc.output_arg_names()
added_op_num = 0
op_idx = 0
op_size = block.desc.op_size()
while op_idx < op_size + added_op_num:
# Whether to insert allreduce_sum or allreduce_max op.
# For amp and global gradient clip strategies, we should
# get the global information, so allreduce op is needed.
should_insert = False
op = block.ops[op_idx]
# For op process vars on all devices, remove its input
# vars not in this block
reserved_x = []
if op.type == 'reduce_any' and self._is_optimize_op(op):
should_insert = True
elif op.type == 'concat' and self._is_optimize_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
elif op.type == 'update_loss_scaling':
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
op.desc.set_output('Out', reserved_x)
elif op.type == 'sum' and self._is_gradient_clip_op(op):
for input_name in op.desc.input("X"):
if block._find_var_recursive(input_name):
reserved_x.append(input_name)
op.desc.set_input('X', reserved_x)
should_insert = True
vars = op.desc.input_arg_names() + op.desc.output_arg_names()
for var in vars:
# a var whose name contains "blocking_queue"
# only exists in startup program
......@@ -3813,27 +3904,39 @@ class PipelineOptimizer(object):
if block._find_var_recursive(str(var)): continue
source_var = ori_block._var_recursive(str(var))
if source_var.type == core.VarDesc.VarType.READER:
block.create_var(
dest_var = block.create_var(
name=var,
type=core.VarDesc.VarType.READER,
persistable=source_var.persistable)
else:
block._clone_variable(source_var, False)
dest_var = block._clone_variable(source_var, False)
dest_var.stop_gradient = source_var.stop_gradient
# When use with sharding, allreduce_sum and allreduce_max
# used for global gradient clip and amp will be added by sharding.
op_idx += 1
if self.use_sharding or not should_insert: continue
inserted_ops = self._insert_allreduce_op(op_idx - 1, block)
added_op_num += inserted_ops
op_idx += inserted_ops
block._sync_with_cpp()
def _is_loss_grad_op(self, op):
if self._op_role_key not in op.attr_names:
return False
op_role = int(op.all_attrs()[self._op_role_key])
assert self._op_role_key in op.attr_names
op_role = int(op.attr(self._op_role_key))
return op_role & int(self._op_role.Backward) and op_role & int(
self._op_role.Loss)
def _is_backward_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Backward)
return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) & int(self._op_role.Backward))
def _is_loss_op(self, op):
assert self._op_role_key in op.attr_names
return int(op.attr(self._op_role_key)) == int(self._op_role.Loss)
def _is_optimize_op(self, op):
return self._op_role_key in op.attr_names and int(op.all_attrs()[
self._op_role_key]) & int(self._op_role.Optimize)
return self._op_role_key in op.attr_names and (
int(op.attr(self._op_role_key)) & int(self._op_role.Optimize))
def _is_update_op(self, op):
return 'Param' in op.input_names and 'Grad' in op.input_names and (
......@@ -3842,50 +3945,40 @@ class PipelineOptimizer(object):
def _split_program(self, main_program, devices):
"""
Split a program into sections according to devices that ops run on.
The ops of the role LRSched are copied to all sections.
The op whose op_device attr is "gpu:all" is copied to all sections.
Args:
main_program (Program): the main program
devices: all used devices
"""
programs = []
# Map from device to its corresponding section program info
device_program_map = dict()
for device in devices:
p = {'program': Program()}
device_program_map[device] = p
device_program_map = defaultdict(Program)
block = main_program.block(0)
for op in block.ops:
device = op.attr(self._op_device_key)
op_role = op.attr(self._op_role_key)
if int(op_role) & int(self._op_role.LRSched):
# Copy ops of the role LRSched to all sections.
for device in device_program_map.keys():
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op.copy_from(op_desc)
# ap_op._set_attr(self._op_device_key, "")
elif op.type == "create_py_reader" or op.type == "read" or op.type == "create_double_buffer_reader":
# Copy read related ops to all section to make them exit after each epoch.
for device in device_program_map.keys():
# Copy ops whose op_device set to "gpu:all" to all sections.
if device == "gpu:all":
for device in devices:
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op = program.global_block().desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
else:
program = device_program_map[device]
op_desc = op.desc
ap_op = program["program"].block(0).desc.append_op()
ap_op = program.global_block().desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
program_list = []
for key in devices:
program = device_program_map[key]
program['program']._sync_with_cpp()
programs.append(program)
program._sync_with_cpp()
program_list.append(program)
return programs
return program_list
def _get_op_device_for_startup_program(self, var_name):
"""
......@@ -3894,21 +3987,22 @@ class PipelineOptimizer(object):
get the real op_device attribute of the fill_constant as the device
where the corresponding parameters on.
"""
assert "beta1_pow_acc" in var_name or "beta2_pow_acc" in var_name
assert "beta1_pow_acc" in var_name or "beta2_pow_acc" in var_name, \
'For accumulators for Adam, the name must contain beta1_pow_acc ' \
'or beta2_pow_acc.'
param_name = var_name[0:var_name.index('_beta')]
device = self._param_device_map[param_name]
return device
def _split_startup_program(self, startup_program, local_rank):
block = startup_program.block(0)
def _split_startup_program(self, startup_program, device_id):
block = startup_program.global_block()
new_startup_program = Program()
for op in block.ops:
device = op.attr(self._op_device_key)
if device == "cpu":
assert op.type == "fill_constant", (
"For ops in startup "
"program that with the op_device attribute of cpu, "
"they must be fill_constant.")
"For ops in startup program with the op_device attribute "
"of cpu, they must be of type fill_constant.")
output_var = op.output_arg_names[0]
device = self._get_op_device_for_startup_program(output_var)
......@@ -3917,14 +4011,13 @@ class PipelineOptimizer(object):
else:
# LR related ops
device = None
if device and device_index != local_rank: continue
if device and device_index != device_id: continue
op_desc = op.desc
ap_op = new_startup_program.block(0).desc.append_op()
ap_op = new_startup_program.global_block().desc.append_op()
ap_op.copy_from(op_desc)
ap_op._set_attr(self._op_device_key, "")
new_startup_program._sync_with_cpp()
self._create_vars(
new_startup_program.block(0), startup_program.global_block())
self._create_vars(new_startup_program.global_block(), block)
return new_startup_program
def _find_post_op(self, ops, cur_op, var_name):
......@@ -3937,6 +4030,11 @@ class PipelineOptimizer(object):
var_name as output.
var_name (string): Variable name.
"""
# To skip the cast op added by amp which has no op_device set
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
elif '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_op = []
before = True
for op in ops:
......@@ -3965,7 +4063,8 @@ class PipelineOptimizer(object):
"""
prev_op = []
for op in ops:
if op.type == 'send_v2' or op.type == 'recv_v2':
if op.type == 'send_v2' or op.type == 'recv_v2' \
or op.type == 'c_broadcast':
continue
if op == cur_op:
break
......@@ -3980,11 +4079,8 @@ class PipelineOptimizer(object):
return None
def _rename_arg(self, op, old_name, new_name):
op_desc = op.desc
if isinstance(op_desc, tuple):
op_desc = op_desc[0]
op_desc._rename_input(old_name, new_name)
op_desc._rename_output(old_name, new_name)
op._rename_input(old_name, new_name)
op._rename_output(old_name, new_name)
def _create_var(self, block, ref_var, name):
"""
......@@ -3998,99 +4094,12 @@ class PipelineOptimizer(object):
dtype=ref_var.dtype,
type=ref_var.type,
lod_level=ref_var.lod_level,
persistable=False,
is_data=False,
persistable=ref_var.persistable,
is_data=ref_var.is_data,
need_check_feed=ref_var.desc.need_check_feed())
new_var.stop_gradient = ref_var.stop_gradient
return new_var
def _get_data_var_info(self, block):
"""
Get info of all vars whose is_data attribute are true.
"""
# map of data vars to devices that that data on
data_devices_map = dict()
for op in block.ops:
dev_spec = op.attr(self._op_device_key)
for var_name in op.input_arg_names:
if "blocking_queue" in var_name: continue
var = block.var(var_name)
if not var.is_data:
continue
if not var_name in data_devices_map:
data_devices_map[var_name] = []
if not dev_spec in data_devices_map[var_name]:
data_devices_map[var_name].append(dev_spec)
return data_devices_map
def _insert_sendrecv_for_data_var(self, main_block, programs, startup,
devices):
"""
Insert send and recv ops for data var that on other devices.
Args:
main_block (Block): Global block for main program
programs (dict): Dictionary for section params
startup (Program): Startup program
devices (list): List of devices in the format (dev:dev_index)
"""
main_program = main_block.program
data_devices_map = self._get_data_var_info(main_block)
first_prog = programs[0]['program']
first_block = first_prog.block(0)
insert_index = 0
for op in first_block.ops:
insert_index += 1
if op.type == "read":
break
first_dev_spec = devices[0]
first_dev_index = int(first_dev_spec.split(':')[1])
for var_name in data_devices_map.keys():
for device in data_devices_map[var_name]:
if device == first_dev_spec: continue
main_var = main_block.var(var_name)
assert main_var.is_data
if not var_name in first_block.vars:
self._create_var(first_block, main_var, var_name)
dev_index = int(device.split(':')[1])
first_block._insert_op(
index=insert_index,
type='send_v2',
inputs={'X': first_block.var(var_name)},
attrs={
self._op_device_key: first_dev_spec,
self._op_role_key: self._op_role.Forward,
'use_calc_stream': True,
'peer': dev_index,
})
# Get the device that that data on
assert device in devices
prog_index = devices.index(device)
prog = programs[prog_index]['program']
block = prog.block(0)
index = 0
for op in block.ops:
index += 1
if op.type == "read":
break
source_var = main_program.block(0).var(var_name)
new_var = self._create_var(block, source_var, var_name)
new_var_shape = list(new_var.shape)
new_var_shape[0] = self.micro_batch_size if new_var_shape[
0] < 0 else new_var_shape[0]
block._insert_op(
index=index,
type='recv_v2',
outputs={'Out': [new_var]},
attrs={
'out_shape': new_var_shape,
'dtype': new_var.dtype,
self._op_device_key: device,
self._op_role_key: self._op_role.Forward,
'peer': first_dev_index,
'use_calc_stream': True,
})
def _strip_grad_suffix(self, name):
"""
Strip the grad suffix from the given variable name
......@@ -4104,95 +4113,161 @@ class PipelineOptimizer(object):
"""
return name + core.grad_var_suffix()
def _add_opdevice_attr_for_regularization_clip(self, block):
def _get_op_device_attr(self, op):
"""
Add op_device attribute for regulization and clip ops.
Get the op_device attribute of a op.
"""
for op in block.ops:
# role for regularization and clip ops is optimize
if int(op.attr(self._op_role_key)) != int(self._op_role.Optimize):
continue
if op.has_attr(self._op_device_key) and (
op.attr(self._op_device_key) != ""):
continue
assert self._op_role_var_key in op.attr_names
op_role_var = op.all_attrs()[self._op_role_var_key]
assert len(op_role_var) == 2
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
device = op.attr(self._op_device_key) \
if op.has_attr(self._op_device_key) else None
if device:
assert device[0:3] == 'gpu', "Now, only gpu devices are " \
"supported in pipeline parallemism."
return device
def _add_default_opdevice_attr(self, block):
def _add_op_device_attr_for_op(self, op, idx, block):
"""
1. Add default op_device attribute for lr-related ops.
The default value is the one that of the first place.
2. Add default op_device attribute for sum ops added during
backward. For these ops, we set the op_device attribute
as the one of its post op, i.e, which op has the output of the
sum op as an input.
Add op_device attrribute for ops that have not that attribute set.
We use "gpu:all" to represent the op should be put on all
sub-programs, such as lr-related ops. Note that: "gpu:all"
is only used by pipeline as an indicator.
"""
first_devcie = ""
# Get the device spec of the first place.
# device_spec: 'cpu' for cpu device and 'gpu:id' for gpu device,
# e.g. 'gpu:0', 'gpu:1', etc.
for op in block.ops:
if op.has_attr(self._op_device_key) and (
op.attr(self._op_device_key) != ""):
first_device = op.attr(self._op_device_key)
break
assert first_device
first_device_type = first_device.split(":")[0]
assert first_device_type == "gpu"
# set op_device attr for lr-related ops
lrsched_role = int(self._op_role.LRSched)
for op in block.ops:
if not op.has_attr(self._op_device_key) or (
op.attr(self._op_device_key) == ""):
if op.type == "sum":
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(block.ops, op, out_name)
device = post_op.attr(self._op_device_key)
assert device
if op.attr(self._op_role_key) == lrsched_role:
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all")
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
elif op.type == "memcpy" and not self._is_optimize_op(op):
assert len(op.input_arg_names) == 1 and len(
op.output_arg_names) == 1
input_name = op.input_arg_names[0]
output_name = op.output_arg_names[0]
if '@Fetch' in output_name:
post_op = self._find_post_op(block.ops, op, output_name)
op._set_attr(self._op_device_key,
post_op.attr(self._op_device_key))
else:
prev_op = self._find_real_prev_op(block.ops, op,
op.desc.input("X")[0])
op._set_attr(self._op_device_key,
prev_op.attr(self._op_device_key))
elif self._is_loss_op(op):
# For loss * loss_scaling op added by AMP
offset = 1
while (not block.ops[idx + offset].has_attr(self._op_device_key) or
not block.ops[idx + offset].attr(self._op_device_key)):
offset += 1
device = block.ops[idx + offset].attr(self._op_device_key)
assert device, "Please put you program within device_guard scope."
for i in range(offset):
block.ops[idx + i]._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "check_finite_and_unscale":
op_role_var = op.attr(self._op_role_var_key)
param_name = op_role_var[0]
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_optimize_op(op) and op.type == "cast":
# For fp16-->fp32 cast added by AMP
grad_name = op.output('Out')
assert len(grad_name) == 1
param_name = grad_name[0].strip(core.grad_var_suffix())
device = self._param_device_map[param_name]
op._set_attr(self._op_device_key, device)
elif self._is_gradient_clip_op(op) or self._is_regularization_op(op):
# For gradient clip and regularization ops, we set their op_device
# attribute to the device where their corresponding parameters on.
assert self._op_role_var_key in op.attr_names, "gradient_clip " \
"and regularization ops must have op_role_var attribute."
op_role_var = op.attr(self._op_role_var_key)
assert len(op_role_var) == 2, "op_role_var for gradient_clip " \
"regularization ops must have two elements."
param_name = op_role_var[0]
device = self._param_device_map[param_name]
# For sum op added by global gradient clip, it must be
# put on all devices
if (op.type == 'sum' or op.type == 'sqrt' or
op.type == 'fill_constant' or
op.type == 'elementwise_max' or
op.type == 'elementwise_div'):
device = "gpu:all"
op._set_attr(self._op_device_key, device)
else:
other_known_ops = [
'update_loss_scaling', 'reduce_any', 'concat', 'sum'
]
assert op.type in other_known_ops, "For other ops without " \
"op_device set, they must be one of {}, but it " \
"is {}".format(other_known_ops, op.type)
assert self._is_optimize_op(op)
op._set_attr(self._op_device_key, "gpu:all")
def _add_op_device_attr(self, block):
"""
Add op_device attrribute for ops in block that have
not that attribute set.
"""
for idx, op in enumerate(list(block.ops)):
if (op.type == "create_py_reader" or op.type == "read" or
op.type == "create_double_buffer_reader"):
# Copy read related ops to all section to make them exit
# after each epoch.
# We use "gpu:all" to represent the op should be put on all
# sub-programs, such as lr-related ops. Note that: "gpu:all"
# is only used by pipeline as an indicator.
op._set_attr(self._op_device_key, "gpu:all")
continue
assert op.attr(self._op_role_key) == lrsched_role, (
"Op whose op_device attr has not been set for pipeline"
" must be of the role LRSched.")
op._set_attr(self._op_device_key, first_device)
# op_device attribute has been set
if self._get_op_device_attr(op): continue
self._add_op_device_attr_for_op(op, idx, block)
def _check_validation(self, block):
"""
Check whether ops in a block are all validate (i.e., the
op_device attribute has been set).
Then, return all device specifications in order.
Check whether ops in a block have both the op_device and the
op_role attributes set.
Then, return all devices in order.
"""
device_specs = []
device_list = []
# Section worker only supports the following op_role
valid_op_role_value = [
int(self._op_role.LRSched),
int(self._op_role.Forward),
int(self._op_role.Backward),
int(self._op_role.Loss),
int(self._op_role.Optimize),
int(self._op_role.Backward) | int(self._op_role.Loss),
]
for op in block.ops:
type = op.type
if not op._has_kernel(type):
if not op._has_kernel(op.type):
assert op.type == "conditional_block" and (
op.attr(self._op_role_key) == int(self._op_role.LRSched)), (
"Now, the only supported op without kernel is "
"conditional_block, and its op role must be LRSched.")
assert op.has_attr(self._op_role_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_role_key))
assert int(op.attr(self._op_role_key)) in valid_op_role_value, \
"op_role {} for op {} must be one of {}".format(
op.attr(self._op_role_key),
op.type,
valid_op_role_value)
assert op.has_attr(self._op_device_key), (
"op ({}) has no {} attribute.".format(op.type,
self._op_device_key))
dev_spec = op.attr(self._op_device_key)
assert dev_spec, ("op_device attribute for op "
device = op.attr(self._op_device_key)
assert device, ("op_device attribute for op "
"{} has not been set.".format(op.type))
dev_type = dev_spec.split(':')[0]
if device == "gpu:all": continue
dev_type = device.split(':')[0]
assert dev_type == "gpu", ("Now only gpu devices are supported "
"for pipeline parallelism.")
if not dev_spec in device_specs:
device_specs.append(dev_spec)
return device_specs
if not device in device_list:
device_list.append(device)
return device_list
def _insert_sendrecv_ops_for_boundaries(self, block):
"""
......@@ -4201,49 +4276,105 @@ class PipelineOptimizer(object):
"""
extra_index = 0
# A map from var to device spec where op takes it as input,
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
var_devspec = dict()
var_dev_map = dict()
for index, op in enumerate(list(block.ops)):
# skips lr-related ops and vars, as we will process them later.
if int(op.attr(self._op_role_key)) & int(self._op_role.LRSched):
continue
# skips update ops and vars, as we will process them later.
if self._is_update_op(op): continue
cur_device_spec = op.attr(self._op_device_key)
cur_device = op.attr(self._op_device_key)
if cur_device == "gpu:all": continue
for var_name in op.input_arg_names:
# i.e., lod_tensor_blocking_queue created by DataLoader,
# which only exists in startup program.
if not var_name in block.vars: continue
var = block.var(var_name)
# skip data, because we will process it later
if var.is_data: continue
prev_device = None
if var_name in self._param_device_map:
prev_device = self._param_device_map[var_name]
prev_op = self._find_real_prev_op(block.ops, op, var_name)
if prev_op is None:
continue
prev_device_spec = prev_op.attr(self._op_device_key)
if not prev_device:
prev_device = prev_op.attr(self._op_device_key) \
if prev_op else None
if not prev_device or prev_device == 'gpu:all': continue
if prev_device_spec != cur_device_spec:
if var_name not in var_devspec:
var_devspec[var_name] = []
if cur_device_spec in var_devspec[var_name]: continue
var_devspec[var_name].append(cur_device_spec)
if prev_device != cur_device:
if var_name not in var_dev_map: var_dev_map[var_name] = []
if cur_device in var_dev_map[var_name]: continue
var_dev_map[var_name].append(cur_device)
op_role = op.all_attrs()[self._op_role_key]
var = block.vars[var_name]
prev_device_index = int(prev_device_spec.split(':')[1])
cur_device_index = int(cur_device_spec.split(':')[1])
prev_device_index = int(prev_device.split(':')[1])
cur_device_index = int(cur_device.split(':')[1])
pair = (prev_device_index, cur_device_index)
pair_key = prev_device_index * 1000 + cur_device_index
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
if self.schedule_mode == 'F-then-B': # F-then-B
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device_spec,
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': 1,
'ring_id': ring_id
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='recv_v2',
outputs={'Out': [var]},
attrs={
'out_shape': var.shape,
'dtype': var.dtype,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': cur_device_index,
'peer': 0,
'ring_id': ring_id
})
extra_index += 1
elif self.schedule_mode == '1F1B': # 1F1B
block._insert_op(
index=index + extra_index,
type='c_sync_calc_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='send_v2',
inputs={'X': var},
attrs={
self._op_device_key: prev_device,
self._op_role_key: op_role,
'use_calc_stream': False,
'ring_id': ring_id,
'peer': 1,
})
extra_index += 1
block._insert_op(
index=index + extra_index,
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_device,
self._op_role_key: self._op_role.Backward,
'ring_id': ring_id,
})
extra_index += 1
var_shape = list(var.shape)
......@@ -4256,93 +4387,156 @@ class PipelineOptimizer(object):
attrs={
'out_shape': var_shape,
'dtype': var.dtype,
self._op_device_key: cur_device_spec,
self._op_device_key: cur_device,
self._op_role_key: op_role,
'use_calc_stream': True,
'peer': prev_device_index,
'peer': 0,
'ring_id': ring_id
})
extra_index += 1
else:
raise ValueError(
"Now only 'F-then-B' and '1F1B' are supported."
"The given value is {}.".format(self.schedule_mode))
def _clear_gradients(self, main_block, dev_spec):
"""
Clear gradients at the begining of each run of a minibatch.
"""
for param_name in self._param_device_map:
device = self._param_device_map[param_name]
if device != dev_spec: continue
grad_name = self._append_grad_suffix(param_name)
if not main_block.has_var(grad_name): continue
grad_var = main_block.vars[grad_name]
grad_var.persistable = True
main_block._insert_op(
index=0,
type='fill_constant',
inputs={},
outputs={'Out': [grad_var]},
attrs={
'shape': grad_var.shape,
'dtype': grad_var.dtype,
'value': float(0),
self._op_device_key: device,
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
})
def _accumulate_gradients(self, block):
def _insert_loss_scale(self, block):
"""
Accumulate the gradients generated in microbatch to the one in mini-batch.
We also scale the loss corresponding to number of micro-batches as well.
Scale the loss corresponding to number of micro-batches.
"""
if self._num_microbatches == 1: return
for index, op in reversed(tuple(enumerate(list(block.ops)))):
offset = index
device = op.attr(self._op_device_key)
# Backward pass
if self._is_loss_grad_op(op):
loss_grad_var = block.vars[op.output_arg_names[0]]
scale_factor = self._num_microbatches
block._insert_op(
index=index + 1,
type='scale',
inputs={'X': loss_grad_var},
outputs={'Out': loss_grad_var},
attrs={
'scale': 1.0 / scale_factor,
self._op_device_key: device,
'scale': 1.0 / self._num_microbatches,
self._op_role_key: self._op_role.Backward
})
break
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.all_attrs()[self._op_role_var_key]
if len(op_role_var) == 0:
def _rename_gradient_var_name(self, block):
for index, op in enumerate(block.ops):
if not self._is_optimize_op(op): continue
input_names = op.input_arg_names
output_names = op.output_arg_names
in_out_names = input_names + output_names
if op.type == 'cast': continue
# append "MERGED" to the names of parameter gradients,
# and mofify the op_role_var attribute (by rename_arg func).
for name in in_out_names:
if not core.grad_var_suffix() in name: continue
param_name = name.strip(core.grad_var_suffix())
new_grad_name = name + "@MERGED"
self._rename_arg(op, name, new_grad_name)
def _accumulate_gradients(self, block, pp_allreduce_in_optimize=False):
"""
Create a new merged gradient for each parameter and accumulate the
corresponding gradient to it.
"""
merged_gradient_names = []
first_opt_op_idx = None
for index, op in reversed(tuple(enumerate(list(block.ops)))):
# remove the cast op of fp16 grad to fp32 grad
if self._is_optimize_op(op) and op.type == 'cast':
in_name = op.input_arg_names[0]
out_name = op.output_arg_names[0]
if out_name.strip('@GRAD') in self._param_device_map:
assert in_name.replace('.cast_fp16', '') == out_name
block._remove_op(index)
continue
if self._is_backward_op(op) and not first_opt_op_idx:
first_opt_op_idx = index + 1
# no optimize phase
if first_opt_op_idx == len(block.ops): return
if block.ops[first_opt_op_idx].type == "c_sync_comm_stream":
first_opt_op_idx += 1
if self._is_backward_op(op) and (
self._op_role_var_key in op.attr_names):
op_role_var = op.attr(self._op_role_var_key)
if len(op_role_var) == 0: continue
assert len(op_role_var) % 2 == 0
offset = index
for i in range(0, len(op_role_var), 2):
offset = 0
param_name = op_role_var[i]
if not block.has_var(param_name): continue
if '@BroadCast' in param_name: continue
param_grad_name = param_name + core.grad_var_suffix()
merged_param_grad_name = param_grad_name + '@MERGED'
if not block.has_var(merged_param_grad_name):
self._create_var(block, block.vars[param_name],
merged_param_grad_name)
assert block.has_var(merged_param_grad_name)
param_grad_var = block.var(param_grad_name)
merged_param_grad_var = block.var(merged_param_grad_name)
merged_param_grad_var.persistable = True
block._insert_op(
index=first_opt_op_idx + offset,
type='fill_constant',
inputs={},
outputs={'Out': [merged_param_grad_var]},
attrs={
'shape': merged_param_grad_var.shape,
'dtype': merged_param_grad_var.dtype,
'value': float(0),
# a trick to run this op once per mini-batch
self._op_role_key: self._op_role.Optimize.LRSched,
})
offset += 1
grad_name = op_role_var[i + 1]
grad_var = block.vars[grad_name]
new_grad_var_name = unique_name.generate(grad_name)
new_var = self._create_var(block, grad_var,
new_grad_var_name)
self._rename_arg(op, grad_name, new_grad_var_name)
if not 'cast_fp16' in grad_name:
block._insert_op(
index=offset + 1,
index=first_opt_op_idx + offset,
type='sum',
inputs={'X': [grad_var, new_var]},
outputs={'Out': grad_var},
inputs={'X': [grad_var, merged_param_grad_var]},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
else:
# cast gradient to fp32 to accumulate to merged gradient
cast_grad_var_name = param_grad_name + '@TMP'
cast_grad_var = self._create_var(block, param_grad_var,
cast_grad_var_name)
cast_grad_var.persistable = False
block._insert_op(
index=first_opt_op_idx + offset,
type='cast',
inputs={'X': grad_var},
outputs={'Out': cast_grad_var},
attrs={
self._op_device_key: device,
'in_dtype': grad_var.dtype,
'out_dtype': cast_grad_var.dtype,
self._op_role_key: self._op_role.Backward,
self._op_role_var_key: op_role_var
})
offset += 1
block._insert_op(
index=first_opt_op_idx + offset,
type='sum',
inputs={
'X': [merged_param_grad_var, cast_grad_var]
},
outputs={'Out': merged_param_grad_var},
attrs={
self._op_role_key: self._op_role.Backward,
})
offset += 1
merged_gradient_names.append(merged_param_grad_name)
return merged_gradient_names
def _add_sub_blocks(self, main_block, program_list):
main_program = main_block.program
for prog_info in program_list:
prog = prog_info['program']
for prog in program_list:
for op in prog.block(0).ops:
if not op.has_attr('sub_block'):
continue
......@@ -4372,8 +4566,7 @@ class PipelineOptimizer(object):
# var_info = {var_name: [program1, program2...]},
# persistable var only
var_info = dict()
for prog_info in program_list:
prog = prog_info['program']
for prog in program_list:
block = prog.block(0)
for var_name in block.vars:
if var_name == "double_buffer_0": continue
......@@ -4395,7 +4588,7 @@ class PipelineOptimizer(object):
block = prog.block(0)
for op in block.ops:
if op.type == "recv_v2" or op.type == "create_py_reader" or \
op.type == "read":
op.type == "read" or op.type == "update_loss_scaling":
continue
# We have processed lr related vars
if op.attr(self._op_role_key) == int(
......@@ -4423,6 +4616,15 @@ class PipelineOptimizer(object):
read_block = prog.block(0)
read_device = self._get_device_info(read_block)
read_dev_index = int(read_device.split(':')[1])
pair = (write_dev_index, read_dev_index)
pair_key = write_dev_index * 1000 + read_dev_index
if pair not in self._pipeline_pair:
self._pipeline_pair.append(pair)
self._pp_ring_map[pair_key] = self.ring_id
ring_id = self.ring_id
self.ring_id += 1
else:
ring_id = self._pp_ring_map[pair_key]
write_block._insert_op(
index=0,
......@@ -4430,11 +4632,12 @@ class PipelineOptimizer(object):
inputs={'X': write_block.var(var_name), },
attrs={
self._op_device_key: write_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': read_dev_index,
'ring_id': ring_id
})
read_block._insert_op(
index=0,
......@@ -4444,36 +4647,68 @@ class PipelineOptimizer(object):
'out_shape': read_block.var(var_name).shape,
'dtype': read_block.var(var_name).dtype,
self._op_device_key: read_device,
'use_calc_stream': True,
'use_calc_stream': False,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'peer': write_dev_index
'peer': write_dev_index,
'ring_id': ring_id
})
read_block._insert_op(
index=1,
type='c_sync_comm_stream',
inputs={'X': [read_block.var(var_name)]},
outputs={'Out': [read_block.var(var_name)]},
attrs={
self._op_device_key: read_device,
# A trick to make the role LRSched to avoid copy every
# microbatch
self._op_role_key: self._op_role.LRSched,
'ring_id': ring_id
})
def _is_gradient_clip_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/gradient_clip")
def _is_regularization_op(self, op):
return op.desc.has_attr("op_namescope") \
and op.desc.attr("op_namescope").startswith("/regularization")
def minimize(self,
loss,
startup_program=None,
parameter_list=None,
no_grad_set=None):
main_block = loss.block
self.origin_main_block = main_block
if startup_program is None:
startup_program = default_startup_program()
optimize_ops, params_grads = self._optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self._param_device_map = self._optimizer._param_device_map
self._param_device_map = self._origin_optimizer._param_device_map
assert main_block.program._pipeline_opt \
and 'local_rank' in main_block.program._pipeline_opt, \
'Please use pipeline with fleet.'
local_rank = main_block.program._pipeline_opt['local_rank']
self._global_ring_id = main_block.program._pipeline_opt[
'global_ring_id']
schedule_mode = 0
if 'schedule_mode' in main_block.program._pipeline_opt:
schedule_mode = main_block.program._pipeline_opt['schedule_mode']
self.schedule_mode = schedule_mode
# micro batch size
self.micro_batch_size = main_block.program._pipeline_opt[
'micro_batch_size']
# Step1: add default op_device attribute for regulization and clip ops
self._add_opdevice_attr_for_regularization_clip(main_block)
# Step2: add default op_device attribute for ops whose op_device
# attribute have not been set yet. Then check all ops have the
# op_device attribute.
self._add_default_opdevice_attr(main_block)
self.use_sharding = False
if 'use_sharding' in main_block.program._pipeline_opt:
self.use_sharding = main_block.program._pipeline_opt['use_sharding']
self.ring_id = main_block.program._pipeline_opt['ring_id']
device_specs = self._check_validation(main_block)
# Step1: add default op_device attribute for ops.
self._add_op_device_attr(main_block)
device_list = self._check_validation(main_block)
def device_cmp(device1, device2):
dev1_id = int(device1.split(':')[1])
......@@ -4485,70 +4720,59 @@ class PipelineOptimizer(object):
else:
return 0
sorted_device_spec = sorted(device_specs, key=cmp_to_key(device_cmp))
assert sorted_device_spec == device_specs, (
"With pipeline "
"parallelism, you must use gpu devices one after another "
"in the order of their ids.")
# Step3: add send and recv ops between section boundaries
sorted_device_list = sorted(device_list, key=cmp_to_key(device_cmp))
assert sorted_device_list == device_list, (
"With pipeline parallelism, you must use gpu devices one after "
"another in the order of their ids.")
# Step2: add send and recv ops between section boundaries
self._insert_sendrecv_ops_for_boundaries(main_block)
# Step4: split program into sections and add pairs of
# Step3: split program into sections and add pairs of
# send and recv ops for data var.
main_program = main_block.program
program_list = self._split_program(main_program, device_specs)
program_list = self._split_program(main_program, device_list)
for p in program_list:
self._create_vars(p["program"].block(0),
main_program.global_block())
self._insert_sendrecv_for_data_var(main_block, program_list,
startup_program, device_specs)
self._create_vars(p.global_block(), main_block)
# Step5: Special Case: process persistable vars that exist in
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list)
# Step6: Add sub blocks for section programs
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
assert (main_program._pipeline_opt and
isinstance(main_program._pipeline_opt, dict) and
'local_rank' in main_program._pipeline_opt), \
"You must use pipeline with fleet"
local_rank = main_program._pipeline_opt['local_rank'] % len(
device_specs)
self.schedule_mode = main_program._pipeline_opt['schedule_mode']
local_rank = main_program._pipeline_opt['local_rank'] % len(device_list)
place_list = []
for dev_spec in device_specs:
dev_index = dev_spec.split(":")[1]
place_list.append(core.CUDAPlace(local_rank))
for dev in device_list:
dev_index = int(dev.split(":")[1])
place_list.append(core.CUDAPlace(dev_index % 8))
# Step7: Split startup program
# Step6: Split startup program
new_startup_program = self._split_startup_program(startup_program,
local_rank)
# Step8: clear gradients before each mini-batch and
# accumulate gradients during backward
self._clear_gradients(
program_list[local_rank]['program'].global_block(),
dev_spec=device_specs[local_rank])
self._accumulate_gradients(program_list[local_rank]['program']
.global_block())
startup_program._pipeline_opt = {
"startup_program": new_startup_program,
}
real_block = program_list[local_rank].global_block()
self._insert_loss_scale(real_block)
if not self.use_sharding:
# Step7: clear gradients before each mini-batch and
# accumulate gradients during backward
self._rename_gradient_var_name(real_block)
real_block._sync_with_cpp()
self._accumulate_gradients(real_block)
real_block._sync_with_cpp()
place_id = int(os.getenv("FLAGS_selected_gpus", "0"))
main_program._pipeline_opt = {
"trainer": "PipelineTrainer",
"device_worker": "Section",
"pipeline_stage": local_rank,
"num_pipeline_stages": len(device_specs),
"num_pipeline_stages": len(device_list),
"schedule_mode": self.schedule_mode,
"inner_parallelism": len(device_specs),
"inner_parallelism": len(device_list),
"section_program": program_list[local_rank],
"place": place_list[local_rank],
"place_id": place_id,
......@@ -4556,7 +4780,7 @@ class PipelineOptimizer(object):
"num_microbatches": self._num_microbatches,
"start_cpu_core_id": self._start_cpu_core_id,
}
return optimize_ops, params_grads, program_list
return optimize_ops, params_grads, program_list, self._pipeline_pair, self._pp_ring_map
class RecomputeOptimizer(Optimizer):
......
......@@ -66,12 +66,21 @@ def cnn_model(data):
param_shape = [reduce(lambda a, b: a * b, input_shape[1:], 1)] + [SIZE]
scale = (2.0 / (param_shape[0]**2 * SIZE))**0.5
with fluid.device_guard("gpu:1"):
predict = fluid.layers.fc(
input=conv_pool_2,
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
# To cover @RENAMED@GRADIENT
predict2 = fluid.layers.fc(
input=conv_pool_1,
size=SIZE,
act="softmax",
param_attr=fluid.param_attr.ParamAttr(
initializer=fluid.initializer.Constant(value=0.01)))
predict += predict2
return predict
......@@ -108,7 +117,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
bd = [steps_per_pass * p for p in passes]
lr = [base_lr * (0.1**i) for i in range(len(bd) + 1)]
lr_val = fluid.layers.piecewise_decay(boundaries=bd, values=lr)
opt = fluid.optimizer.Momentum(learning_rate=lr_val, momentum=0.9)
opt = fluid.optimizer.Momentum(
learning_rate=lr_val,
momentum=0.9,
grad_clip=fluid.clip.GradientClipByGlobalNorm(clip_norm=1.0))
acc_steps = 2 # accumulated steps for pipeline
if dist_strategy:
......@@ -120,6 +132,7 @@ class TestDistMnist2x2(TestDistRunnerBase):
fleet.init(is_collective=True)
strategy = fleet.DistributedStrategy()
strategy.pipeline = True
strategy.amp = True
strategy.pipeline_configs = {
'micro_batch_size': batch_size,
'schedule_mode': '1F1B',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册