未验证 提交 1e60a0c4 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-parallelism] Hybrid Model Parallelism (#32074)

上级 363b25aa
......@@ -29,14 +29,18 @@ message RecomputeConfig {
}
message ShardingConfig {
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
optional string sharding_segment_strategy = 1
[ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
optional float segment_broadcast_MB = 2 [ default = 32.0 ];
repeated string segment_anchors = 3;
optional int32 sharding_degree = 4 [ default = 8 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 dp_degree = 6 [ default = 1 ];
optional bool hybrid_dp = 7 [ default = false ];
optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
}
message AMPConfig {
......
......@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode']
self.use_sharding = user_defined_strategy.sharding
def _can_apply(self):
if not self.role_maker._is_collective:
return False
# FIXME revise for hybrid parallelism
if self.use_sharding:
return False
if self.user_defined_strategy.pipeline == True:
return True
return False
......
......@@ -81,7 +81,10 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
......@@ -105,7 +108,11 @@ class FP16Utils(object):
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
"Input 'X' of check_finite_and_unscale must"
......@@ -169,3 +176,58 @@ class FP16Utils(object):
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
# not use amp
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
......@@ -32,6 +32,7 @@ class GradientClipHelper(object):
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
global_norm_sum_op_idx = -1
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
......@@ -41,7 +42,11 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
param_name = input_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
deperate_op = True
......@@ -51,7 +56,8 @@ class GradientClipHelper(object):
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name)
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
# got no gradient_clip op
......@@ -65,6 +71,7 @@ class GradientClipHelper(object):
continue
reversed_inputs = []
if op.type == "sum":
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
......@@ -86,20 +93,20 @@ class GradientClipHelper(object):
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
# the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
......@@ -115,3 +122,45 @@ class GradientClipHelper(object):
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -126,6 +126,10 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
......
......@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
# NOTE (JZ-LIANG) to be check, may result undefined case
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
......@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
def insert_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
......@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
attrs={
'ring_id': ring_id,
'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})
break
def comm_analyse(main_program):
......@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training.
"""
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
......
......@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
......@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......
......@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
"""
Find the post op that has variable named var_name as input.
"""
# bugfix for uniform hybrid parallelism
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_ops = self.input_var_to_op[var_name]
if post_ops == None: return None
result_op = None
......@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "scale" and self._is_backward_op(op):
# bugfix in hybrid parallelism
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(idx, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
elif op.type == "memcpy" and not self._is_optimize_op(op):
......@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index_info = {'index': 0}
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
input_var_to_device = dict()
# bugfix hybrid parallelism
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if self._is_optimize_op(op):
first_optimize_index = index
break
extra_index_info = {
'index': 0,
'first_optimize_index': first_optimize_index
}
for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key)
......@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
'peer': 1,
})
extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
new_op_role = self._op_role.Optimize
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op(
index=index + extra_index_info['index'],
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: self._op_role.Backward,
self._op_role_key: new_op_role,
'ring_id': ring_id,
})
extra_index_info['index'] += 1
if int(op_role) == int(self._op_role.Forward):
extra_index_info['index'] += 1
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0]
......@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list)
# FIXME
# self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
......
......@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
......@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 4,
"mp_degree": 1
......@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add',
'elementwise_add', 'increment', 'elementwise_mod', 'equal',
'conditional_block'
'fill_constant',
'fill_constant',
'fill_constant',
'c_sync_calc_stream',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_sync_comm_stream',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'softmax',
'cross_entropy2',
'mean',
'fill_constant',
'scale',
'mean_grad',
'cross_entropy_grad2',
'softmax_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'c_sync_calc_stream',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'elementwise_add',
'elementwise_add',
'elementwise_add',
'increment',
'elementwise_mod',
'equal',
'conditional_block',
])
self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册