未验证 提交 1e60a0c4 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-parallelism] Hybrid Model Parallelism (#32074)

上级 363b25aa
...@@ -29,14 +29,18 @@ message RecomputeConfig { ...@@ -29,14 +29,18 @@ message RecomputeConfig {
} }
message ShardingConfig { message ShardingConfig {
optional float segment_broadcast_MB = 1 [ default = 32.0 ]; optional string sharding_segment_strategy = 1
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ]; [ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6; optional float segment_broadcast_MB = 2 [ default = 32.0 ];
optional int32 gradient_merge_acc_step = 7 [ default = 1 ]; repeated string segment_anchors = 3;
optional int32 sharding_degree = 4 [ default = 8 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 dp_degree = 6 [ default = 1 ];
optional bool hybrid_dp = 7 [ default = false ];
optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
'accumulate_steps'] 'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[ self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode'] 'schedule_mode']
self.use_sharding = user_defined_strategy.sharding
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
return False return False
# FIXME revise for hybrid parallelism
if self.use_sharding:
return False
if self.user_defined_strategy.pipeline == True: if self.user_defined_strategy.pipeline == True:
return True return True
return False return False
......
...@@ -81,7 +81,10 @@ class FP16Utils(object): ...@@ -81,7 +81,10 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
continue continue
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of" raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format( "model param, but {} is not a grad".format(
...@@ -105,7 +108,11 @@ class FP16Utils(object): ...@@ -105,7 +108,11 @@ class FP16Utils(object):
reversed_x = [] reversed_x = []
reversed_x_paramname = [] reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError( raise ValueError(
"Input 'X' of check_finite_and_unscale must" "Input 'X' of check_finite_and_unscale must"
...@@ -169,3 +176,58 @@ class FP16Utils(object): ...@@ -169,3 +176,58 @@ class FP16Utils(object):
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
block._sync_with_cpp() block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
# not use amp
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
...@@ -32,6 +32,7 @@ class GradientClipHelper(object): ...@@ -32,6 +32,7 @@ class GradientClipHelper(object):
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = [] reversed_x_paramname = []
global_norm_sum_op_idx = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -41,7 +42,11 @@ class GradientClipHelper(object): ...@@ -41,7 +42,11 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in deperated_vars: if input_name in deperated_vars:
deperate_op = True deperate_op = True
param_name = input_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
...@@ -51,7 +56,8 @@ class GradientClipHelper(object): ...@@ -51,7 +56,8 @@ class GradientClipHelper(object):
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name) if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars: if not deperated_vars:
# got no gradient_clip op # got no gradient_clip op
...@@ -65,6 +71,7 @@ class GradientClipHelper(object): ...@@ -65,6 +71,7 @@ class GradientClipHelper(object):
continue continue
reversed_inputs = [] reversed_inputs = []
if op.type == "sum": if op.type == "sum":
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
...@@ -86,20 +93,20 @@ class GradientClipHelper(object): ...@@ -86,20 +93,20 @@ class GradientClipHelper(object):
OP_ROLE_KEY: OpRole.Optimize, OP_ROLE_KEY: OpRole.Optimize,
}) })
# global norm should only be sum within each model parallelism word size when use global group # global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1: if pure_dp_degree > 1:
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 2, idx + 2,
type='scale', type='scale',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'scale': 1.0 / float(pure_dp_degree), 'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism", 'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
# the grad sum here should take the all and only param in the current shard # the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname) to_check_param = set(reversed_x_paramname)
...@@ -115,3 +122,45 @@ class GradientClipHelper(object): ...@@ -115,3 +122,45 @@ class GradientClipHelper(object):
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
return return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
...@@ -126,6 +126,10 @@ class ProgramDeps(object): ...@@ -126,6 +126,10 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): ...@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
""" """
insert sync_comm_op for vars insert sync_comm_op for vars
""" """
# NOTE (JZ-LIANG) to be check, may result undefined case
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops): ...@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return return
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): def insert_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
if len(allreduce_vars) == 0:
return
for var in allreduce_vars: for var in allreduce_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var}, outputs={'Out': var},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Backward}) 'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
for var in reduce_vars: for var in reduce_vars:
root_id = get_grad_device(var, shard) root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var) assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync( block._insert_op_without_sync(
...@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): ...@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
attrs={ attrs={
'ring_id': ring_id, 'ring_id': ring_id,
'root_id': root_id, 'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward 'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
}) })
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
...@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0): ...@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={'scale': scale, attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
break
def comm_analyse(main_program): def comm_analyse(main_program):
...@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
and part of persistable vars are duplicated and exist in all the ranks with different values. and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
......
...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): ...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops) max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []: if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position] gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops): for op in reversed(gap_ops):
...@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
......
...@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object): ...@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
""" """
Find the post op that has variable named var_name as input. Find the post op that has variable named var_name as input.
""" """
# bugfix for uniform hybrid parallelism
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_ops = self.input_var_to_op[var_name] post_ops = self.input_var_to_op[var_name]
if post_ops == None: return None if post_ops == None: return None
result_op = None result_op = None
...@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object): ...@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to # For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly # make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all") op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "scale" and self._is_backward_op(op): # bugfix in hybrid parallelism
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(idx, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
elif op.type == "memcpy" and not self._is_optimize_op(op): elif op.type == "memcpy" and not self._is_optimize_op(op):
...@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object): ...@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two Insert a pair of send and recv ops for every two
consecutive ops on different devices. consecutive ops on different devices.
""" """
extra_index_info = {'index': 0}
# A map from var to device where op takes it as input, # A map from var to device where op takes it as input,
# avoiding multiple send and recv ops. # avoiding multiple send and recv ops.
input_var_to_device = dict() input_var_to_device = dict()
# bugfix hybrid parallelism
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if self._is_optimize_op(op):
first_optimize_index = index
break
extra_index_info = {
'index': 0,
'first_optimize_index': first_optimize_index
}
for index, op in enumerate(list(block.ops)): for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key) cur_device = op.attr(self._op_device_key)
...@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object): ...@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
'peer': 1, 'peer': 1,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
new_op_role = self._op_role.Optimize
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op( block._insert_op(
index=index + extra_index_info['index'], index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': [var]}, inputs={'X': [var]},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
self._op_device_key: prev_dev, self._op_device_key: prev_dev,
self._op_role_key: self._op_role.Backward, self._op_role_key: new_op_role,
'ring_id': ring_id, 'ring_id': ring_id,
}) })
extra_index_info['index'] += 1 if int(op_role) == int(self._op_role.Forward):
extra_index_info['index'] += 1
var_shape = list(var.shape) var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[ var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0] 0] < 0 else var_shape[0]
...@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object): ...@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
# Step4: Special Case: process persistable vars that exist in # Step4: Special Case: process persistable vars that exist in
# multiple sections # multiple sections
self._process_persistable_vars_in_multi_sections( # FIXME
main_program, startup_program, program_list) # self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs # Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list) self._add_sub_blocks(main_block, program_list)
......
...@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2, "segment_broadcast_MB": 0.2,
"segment_anchors": None, "segment_anchors": None,
"sharding_degree": 2, "sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True, "hybrid_dp": True,
"gradient_merge_acc_step": 1, "gradient_merge_acc_step": 1,
"mp_degree": 1 "mp_degree": 1
...@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2, "segment_broadcast_MB": 0.2,
"segment_anchors": None, "segment_anchors": None,
"sharding_degree": 2, "sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True, "hybrid_dp": True,
"gradient_merge_acc_step": 4, "gradient_merge_acc_step": 4,
"mp_degree": 1 "mp_degree": 1
...@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops] fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops] opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [ self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'fill_constant',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'fill_constant',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'c_sync_calc_stream',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'c_broadcast',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad', 'c_broadcast',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', 'c_broadcast',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_broadcast',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_broadcast',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_broadcast',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add', 'mul',
'elementwise_add', 'increment', 'elementwise_mod', 'equal', 'elementwise_add',
'conditional_block' 'tanh',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'softmax',
'cross_entropy2',
'mean',
'fill_constant',
'scale',
'mean_grad',
'cross_entropy_grad2',
'softmax_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'c_sync_calc_stream',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'elementwise_add',
'elementwise_add',
'elementwise_add',
'increment',
'elementwise_mod',
'equal',
'conditional_block',
]) ])
self.assertEqual(opt_ops, [ self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册