未验证 提交 1e60a0c4 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-parallelism] Hybrid Model Parallelism (#32074)

上级 363b25aa
...@@ -29,14 +29,18 @@ message RecomputeConfig { ...@@ -29,14 +29,18 @@ message RecomputeConfig {
} }
message ShardingConfig { message ShardingConfig {
optional float segment_broadcast_MB = 1 [ default = 32.0 ]; optional string sharding_segment_strategy = 1
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
[ default = 'segment_broadcast_MB' ]; [ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6; optional float segment_broadcast_MB = 2 [ default = 32.0 ];
optional int32 gradient_merge_acc_step = 7 [ default = 1 ]; repeated string segment_anchors = 3;
optional int32 sharding_degree = 4 [ default = 8 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 dp_degree = 6 [ default = 1 ];
optional bool hybrid_dp = 7 [ default = false ];
optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
} }
message AMPConfig { message AMPConfig {
......
...@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase): ...@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
'accumulate_steps'] 'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[ self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode'] 'schedule_mode']
self.use_sharding = user_defined_strategy.sharding
def _can_apply(self): def _can_apply(self):
if not self.role_maker._is_collective: if not self.role_maker._is_collective:
return False return False
# FIXME revise for hybrid parallelism
if self.use_sharding:
return False
if self.user_defined_strategy.pipeline == True: if self.user_defined_strategy.pipeline == True:
return True return True
return False return False
......
...@@ -81,7 +81,10 @@ class FP16Utils(object): ...@@ -81,7 +81,10 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op): if not FP16Utils.is_fp32_cast_op(block, op):
continue continue
output_name = op.desc.output_arg_names()[0] output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of" raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format( "model param, but {} is not a grad".format(
...@@ -105,7 +108,11 @@ class FP16Utils(object): ...@@ -105,7 +108,11 @@ class FP16Utils(object):
reversed_x = [] reversed_x = []
reversed_x_paramname = [] reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
raise ValueError( raise ValueError(
"Input 'X' of check_finite_and_unscale must" "Input 'X' of check_finite_and_unscale must"
...@@ -169,3 +176,58 @@ class FP16Utils(object): ...@@ -169,3 +176,58 @@ class FP16Utils(object):
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
block._sync_with_cpp() block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
# not use amp
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
...@@ -32,6 +32,7 @@ class GradientClipHelper(object): ...@@ -32,6 +32,7 @@ class GradientClipHelper(object):
deperated_vars = set() deperated_vars = set()
deperate_op_idx = set() deperate_op_idx = set()
reversed_x_paramname = [] reversed_x_paramname = []
global_norm_sum_op_idx = -1
for idx, op in enumerate(block.ops): for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op): if not self._is_gradient_clip_op(op):
continue continue
...@@ -41,7 +42,11 @@ class GradientClipHelper(object): ...@@ -41,7 +42,11 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in deperated_vars: if input_name in deperated_vars:
deperate_op = True deperate_op = True
param_name = input_name.strip("@GRAD") # TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \ if shard.is_param(param_name) and \
not shard.has_param(param_name): not shard.has_param(param_name):
deperate_op = True deperate_op = True
...@@ -51,7 +56,8 @@ class GradientClipHelper(object): ...@@ -51,7 +56,8 @@ class GradientClipHelper(object):
if deperate_op: if deperate_op:
deperate_op_idx.add(idx) deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
deperated_vars.add(output_name) if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars: if not deperated_vars:
# got no gradient_clip op # got no gradient_clip op
...@@ -65,6 +71,7 @@ class GradientClipHelper(object): ...@@ -65,6 +71,7 @@ class GradientClipHelper(object):
continue continue
reversed_inputs = [] reversed_inputs = []
if op.type == "sum": if op.type == "sum":
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars: if input_name not in deperated_vars:
reversed_inputs.append(input_name) reversed_inputs.append(input_name)
...@@ -86,20 +93,20 @@ class GradientClipHelper(object): ...@@ -86,20 +93,20 @@ class GradientClipHelper(object):
OP_ROLE_KEY: OpRole.Optimize, OP_ROLE_KEY: OpRole.Optimize,
}) })
# global norm should only be sum within each model parallelism word size when use global group # global norm should only be sum within each model parallelism word size when use global group
if pure_dp_degree > 1: if pure_dp_degree > 1:
block._insert_op_without_sync( block._insert_op_without_sync(
idx + 2, idx + 2,
type='scale', type='scale',
inputs={'X': sum_res}, inputs={'X': sum_res},
outputs={'Out': sum_res}, outputs={'Out': sum_res},
attrs={ attrs={
'scale': 1.0 / float(pure_dp_degree), 'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism", 'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
# the grad sum here should take the all and only param in the current shard # the grad sum here should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname) to_check_param = set(reversed_x_paramname)
...@@ -115,3 +122,45 @@ class GradientClipHelper(object): ...@@ -115,3 +122,45 @@ class GradientClipHelper(object):
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
return return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
...@@ -126,6 +126,10 @@ class ProgramDeps(object): ...@@ -126,6 +126,10 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx): def should_remove_op(self, op_idx):
op = self._block.ops[op_idx] op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var: if output_name not in self._should_removed_var:
return False return False
......
...@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars): ...@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
""" """
insert sync_comm_op for vars insert sync_comm_op for vars
""" """
# NOTE (JZ-LIANG) to be check, may result undefined case
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx) op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
...@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops): ...@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return return
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars): def insert_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
if len(allreduce_vars) == 0:
return
for var in allreduce_vars: for var in allreduce_vars:
block._insert_op_without_sync( block._insert_op_without_sync(
insert_idx, insert_idx,
type='c_allreduce_sum', type='c_allreduce_sum',
inputs={'X': var}, inputs={'X': var},
outputs={'Out': var}, outputs={'Out': var},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Backward}) 'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
""" """
_add_allreduce_ops _add_allreduce_ops
""" """
for var in reduce_vars: for var in reduce_vars:
root_id = get_grad_device(var, shard) root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var) assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync( block._insert_op_without_sync(
...@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard): ...@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
attrs={ attrs={
'ring_id': ring_id, 'ring_id': ring_id,
'root_id': root_id, 'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward 'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
}) })
return return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root): def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
""" """
_add_broadcast_ops _add_broadcast_ops
...@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0): ...@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs={'Out': loss_grad_var}, outputs={'Out': loss_grad_var},
attrs={'scale': scale, attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward}) OP_ROLE_KEY: OpRole.Backward})
break
def comm_analyse(main_program): def comm_analyse(main_program):
...@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None): ...@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
and part of persistable vars are duplicated and exist in all the ranks with different values. and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training. This function handles the model saving for sharding training.
""" """
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var): def is_opt_vars(var):
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer # NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
......
...@@ -16,16 +16,16 @@ import paddle ...@@ -16,16 +16,16 @@ import paddle
from paddle.fluid import unique_name, core from paddle.fluid import unique_name, core
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import * from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid import layers from paddle.fluid import layers
import logging import logging
...@@ -38,6 +38,8 @@ __all__ = ["ShardingOptimizer"] ...@@ -38,6 +38,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase): class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer): def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer) super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
...@@ -46,7 +48,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -46,7 +48,8 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer", "AMPOptimizer",
"LarsOptimizer", "LarsOptimizer",
"LambOptimizer", "LambOptimizer",
"ModelParallelOptimizer", # "ModelParallelOptimizer",
# "PipelineOptimizer",
] ]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None self._main_program = None
...@@ -88,26 +91,6 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -88,26 +91,6 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding = 1 self._nrings_sharding = 1
self._nrings_dp = 1 self._nrings_dp = 1
# parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 1, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
self.pp_degree = 1
# dp here is the pure dp as the outest parallelism
self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree //
self.sharding_degree)
assert self.role_maker._worker_num(
) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
# segment # segment
self._sharding_segment_strategy = str( self._sharding_segment_strategy = str(
self.user_defined_strategy.sharding_configs[ self.user_defined_strategy.sharding_configs[
...@@ -128,55 +111,231 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -128,55 +111,231 @@ class ShardingOptimizer(MetaOptimizerBase):
"the sharding segment strategy [{}] is not implemented".format( "the sharding segment strategy [{}] is not implemented".format(
str(self._sharding_segment_strategy))) str(self._sharding_segment_strategy)))
# parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 0, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
# pipeline setting
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
self.pp_degree = int(self.user_defined_strategy.sharding_configs[
"pp_degree"])
if self.pp_degree > 1:
assert self.user_defined_strategy.pipeline == True
self.dp_degree = int(self.user_defined_strategy.sharding_configs[
'dp_degree'])
assert self.role_maker._worker_num(
) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
self.role_maker._worker_num(),
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree, )
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
# NOTE (JZ-LIANG)
# there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline].
# we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance:
# sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step
# pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step
self.hybrid_dp_mode = None
# dp here is the pure dp as the outest parallelism
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
if self.pp_degree > 1:
self.hybrid_dp_mode = "pp_hybrid_dp"
else:
assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
self.hybrid_dp_mode = "sharding_hybrid_dp"
# gradient merge # gradient merge
self._gradient_merge_acc_step = int( self._gradient_merge_acc_step = int(
self.user_defined_strategy.sharding_configs[ self.user_defined_strategy.sharding_configs[
"gradient_merge_acc_step"]) "gradient_merge_acc_step"])
self._grad2merged_grad = dict() self.gradient_merge_mode = None
if self.pp_degree <= 1:
self.gradient_merge_mode = "sharding_gm"
self._grad2merged_grad = dict()
else:
self.gradient_merge_mode = "pp_gm"
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
'accumulate_steps']
if self._gradient_merge_acc_step > 1:
logging.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step))
# optimize offload
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None.")
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step)
main_program = loss.block.program
main_program._pipeline_opt = dict()
self.schedule_mode = self.user_defined_strategy.pipeline_configs[
'schedule_mode']
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
main_program._pipeline_opt[
'micro_batch_size'] = self.user_defined_strategy.pipeline_configs[
'micro_batch_size']
self.pp_rank_ = self.role_maker._worker_index() // (
self.sharding_degree * self.mp_degree) % self.pp_degree
main_program._pipeline_opt['local_rank'] = self.pp_rank_
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
main_program._pipeline_opt['ring_id'] = 20
main_program._pipeline_opt['global_ring_id'] = 3
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pp_degree = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
main_block = loss.block
if self.pp_degree > 1:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", self.pp_rank_)
main_program = program_list[self.pp_rank_]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
startup_block = startup_program.global_block() startup_block = startup_program.global_block()
self._main_program = main_block.program self._main_program = main_block.program
self._startup_program = startup_program self._startup_program = startup_program
if self.pp_degree > 1:
pp_optimizer._rename_gradient_var_name(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
# step0: _init_comm # step0: _init_comm
self._init_comm() self._init_comm()
# step1: _build_shard if self.sharding_degree > 1:
self._build_shard(params_grads)
# step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops # step1: build shard
self._add_broadcast_allreduce(main_block) self._build_shard(params_grads)
main_block._sync_with_cpp()
startup_block._sync_with_cpp() # step2: split_program
self._split_program(main_block)
# step3: add broadcast and reduce ops
self._add_broadcast_allreduce(main_block)
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
main_block._sync_with_cpp()
# step4: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
if self.pp_degree > 1:
# sharding-pp related logic
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
if self.sharding_degree > 1:
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree <= 1:
# amp
FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id)
# clip
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper.sync_global_norm(
main_block, self.global_ring_id, self.dp_degree)
# step4: scale the loss by the num of dp degree # step6: loss div dp_degree
# sharding is also a senario of dp global_dp_degree = self.sharding_degree * self.dp_degree
scale_ = self.dp_degree * self.sharding_degree assert int(global_dp_degree) == global_dp_degree
if scale_ > 1: if global_dp_degree > 1:
insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_) insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree)
main_block._sync_with_cpp() main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block # TODO(wangxi): add optimize offload
self._prune_main_program(main_block) # opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
self._prune_startup_program(startup_block) # sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
if self.hybrid_dp: if self.optimize_offload:
self._initialization_broadcast(startup_program) logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
# step6: optional gradient merge offload_helper.offload(main_block, startup_block)
if self._gradient_merge_acc_step > 1: offload_helper.offload_fp32param(main_block, startup_block)
# step6: (optional) sharding gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self._sharding_gradient_merge(main_block) self._sharding_gradient_merge(main_block)
# # check op dependecy # # check op dependecy
...@@ -184,14 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -184,14 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase):
# check_broadcast(main_block) # check_broadcast(main_block)
# check_allreduce_sum(main_block, self._shard, self.sharding_ring_id, # check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id) # self.dp_ring_id)
if self.hybrid_dp:
# NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp
# init param broadcast should be called after startup pruning
self._initialization_broadcast(startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
self._wait() self._wait()
return optimize_ops, params_grads return optimize_ops, params_grads
def _init_comm(self): def _init_comm(self):
# config sharding & dp groups # config sharding & dp groups
self._build_group() self._build_groups()
# sync var
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
self.startup_prog_sync_var = startup_block.create_var( self.startup_prog_sync_var = startup_block.create_var(
name="startup_prog_sync_var", name="startup_prog_sync_var",
...@@ -199,7 +373,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -199,7 +373,7 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype=core.VarDesc.VarType.INT32, dtype=core.VarDesc.VarType.INT32,
persistable=False) persistable=False)
# global # global ring
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
self.current_endpoint, self.current_endpoint,
...@@ -212,7 +386,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -212,7 +386,7 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var, append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id) self.global_ring_id)
# mp # mp ring
if self.mp_degree > 1: if self.mp_degree > 1:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
...@@ -226,7 +400,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -226,7 +400,7 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var, append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id) self.global_ring_id)
# sharding # sharding ring
if self.sharding_degree > 1: if self.sharding_degree > 1:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
...@@ -240,7 +414,65 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -240,7 +414,65 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var, append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id) self.global_ring_id)
# dp # pp ring
if self.pp_degree > 1:
if self.schedule_mode == 'F-then-B': # GPipe
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id + 2,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
else:
assert self.schedule_mode == '1F1B'
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
# pure dp ring
if self.dp_degree > 1: if self.dp_degree > 1:
self._collective_helper._init_communicator( self._collective_helper._init_communicator(
self._startup_program, self._startup_program,
...@@ -360,17 +592,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -360,17 +592,22 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name))
# find reduce vars # find reduce vars
if is_backward_op(op) and \ if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
OP_ROLE_VAR_KEY in op.attr_names: # place pipeline gradient allreduce in optimize
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] pass
if len(op_role_var) != 0: else:
assert len(op_role_var) % 2 == 0 if is_backward_op(op) and \
for i in range(0, len(op_role_var), 2): OP_ROLE_VAR_KEY in op.attr_names:
param, reduced_grad = op_role_var[i], op_role_var[i + 1] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
segment._allreduce_vars.append(reduced_grad) if len(op_role_var) != 0:
assert ( assert len(op_role_var) % 2 == 0
reduced_grad not in self._reduced_grads_to_param) for i in range(0, len(op_role_var), 2):
self._reduced_grads_to_param[reduced_grad] = param param, reduced_grad = op_role_var[i], op_role_var[
i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (reduced_grad not in
self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
if FP16Utils.is_fp16_cast_op(block, op, self._params): if FP16Utils.is_fp16_cast_op(block, op, self._params):
...@@ -462,8 +699,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -462,8 +699,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune # Prune
for idx, op in reversed(list(enumerate(block.ops))): for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [ if op.type in [
"c_allreduce_sum", "c_sync_comm_stream", "c_allreduce_sum",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init" "c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]: ]:
pass pass
elif op.type == "conditional_block": elif op.type == "conditional_block":
...@@ -500,6 +742,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -500,6 +742,16 @@ class ShardingOptimizer(MetaOptimizerBase):
if program_deps.should_remove_op(idx): if program_deps.should_remove_op(idx):
program_deps.remove_op(idx) program_deps.remove_op(idx)
# NOTE (JZ-LIANG) revise and unify logic here
# sharding support fp16_allreduce logic
block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp() block._sync_with_cpp()
return return
...@@ -507,21 +759,41 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -507,21 +759,41 @@ class ShardingOptimizer(MetaOptimizerBase):
""" """
add broadcast allreduce op add broadcast allreduce op
if enable gradient_merge, insert related ops if enable gradient_merge, insert related ops
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
""" """
if len(self._segments) < 1: if len(self._segments) < 1:
return return
# sharding # sharding
if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
# NOTE (JZ-LIANG) revise and unify logic here
# fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1, -1):
op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1] shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars) ._allreduce_vars)
if self._gradient_merge_acc_step <= 1: if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars) self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx, insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars) self.dp_ring_id, shard_allredue_vars)
# gradient merge # gradient merge
else: elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, block,
self._startup_program.global_block(), self._startup_program.global_block(),
...@@ -532,9 +804,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -532,9 +804,14 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars)
# allreduce --> reduce # allreduce --> reduce
insert_reduce_ops(block, self._segments[-1]._end_idx, insert_reduce_ops(
self.sharding_ring_id, block,
self._segments[-1]._allreduce_vars, self._shard) self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = self._segments[
...@@ -574,8 +851,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -574,8 +851,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# step2: add Sync ops # step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars) shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self._gradient_merge_acc_step <= 1: if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(block, segment._end_idx,
self.dp_ring_id, shard_allredue_vars) self.dp_ring_id, shard_allredue_vars)
...@@ -593,7 +871,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -593,7 +871,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_ring_id, self.sharding_ring_id,
comm_dep_vars) comm_dep_vars)
# gradient merge # gradient merge
else: elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(block, segment._end_idx,
...@@ -616,7 +894,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -616,7 +894,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# step5: add broadcast ops # step5: add broadcast ops
# gradient merge # gradient merge
if self._gradient_merge_acc_step > 1: if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, block,
self._startup_program.global_block(), segment._start_idx, self._startup_program.global_block(), segment._start_idx,
...@@ -627,20 +905,29 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -627,20 +905,29 @@ class ShardingOptimizer(MetaOptimizerBase):
# step6: add all_reduce ops # step6: add all_reduce ops
# dp # dp
if self._gradient_merge_acc_step <= 1: if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1: if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx, insert_allreduce_ops(block, segment._start_idx,
self.dp_ring_id, shard_allredue_vars) self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# gradient merge # gradient merge
else: elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
insert_sync_comm_ops(block, segment._start_idx, insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars) self.sharding_ring_id, allreduce_vars)
# sharding # sharding
# allreduce --> reduce # allreduce --> reduce
insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id, # TODO temp change
allreduce_vars, self._shard) if len(allreduce_vars) > 0:
insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp() block._sync_with_cpp()
...@@ -691,14 +978,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -691,14 +978,14 @@ class ShardingOptimizer(MetaOptimizerBase):
block._remove_var(var_name, sync=False) block._remove_var(var_name, sync=False)
block._sync_with_cpp() block._sync_with_cpp()
def _build_group(self): def _build_groups(self):
""" """
pre-assign ring ids pre-assign ring ids
mp: 0 mp: 0
sharding: 1 sharding: 1
pure-dp: 2 pure-dp: 2
global: 3 global: 3
pp: >= 20 pp: >= 20
if one parallelism is not enable: -1 if one parallelism is not enable: -1
and only support parallelism hierarchy: mp --> sharding --> pp --> dp and only support parallelism hierarchy: mp --> sharding --> pp --> dp
""" """
...@@ -768,6 +1055,30 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -768,6 +1055,30 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_group_id = -1 self.sharding_group_id = -1
self.sharding_group_endpoints = [] self.sharding_group_endpoints = []
# pp
if self.pp_degree > 1:
self.pp_ring_id = 20
self.pp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree) % self.pp_degree
# (NOTE): Already adjust for (outter-pure) dp
self.pp_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree * self.pp_degree)
pp_first_stage_idx = self.global_rank % (
self.sharding_degree * self.mp_degree) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree)
pp_stage_offset = self.sharding_degree * self.mp_degree
self.pp_group_endpoints = []
for i in range(self.pp_degree):
self.pp_group_endpoints.append(self.global_endpoints[
pp_first_stage_idx + pp_stage_offset * i])
assert self.current_endpoint in self.pp_group_endpoints
else:
self.pp_degree = 1
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_id = -1
self.pp_group_endpoints = []
# outter-pure-dp group # outter-pure-dp group
# NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism # NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism
# e.g. mp-sharding-pp-dp # e.g. mp-sharding-pp-dp
...@@ -775,6 +1086,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -775,6 +1086,7 @@ class ShardingOptimizer(MetaOptimizerBase):
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree, self.sharding_degree, self.pp_degree, self.mp_degree, self.sharding_degree, self.pp_degree,
self.dp_degree, self.global_word_size) self.dp_degree, self.global_word_size)
if self.dp_degree > 1: if self.dp_degree > 1:
self.dp_ring_id = 2 self.dp_ring_id = 2
self.dp_rank = self.global_rank // (self.sharding_degree * self.dp_rank = self.global_rank // (self.sharding_degree *
...@@ -794,6 +1106,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -794,6 +1106,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_group_endpoints = [] self.dp_group_endpoints = []
# global group # global group
# use for gen_nccl_comm_sync, amp check nan inf, clip by global norm
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
self.global_ring_id = 3 self.global_ring_id = 3
logging.info("global word size: {}".format(self.global_word_size)) logging.info("global word size: {}".format(self.global_word_size))
...@@ -817,25 +1131,31 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -817,25 +1131,31 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("sharding ring id: {}".format(self.sharding_ring_id)) logging.info("sharding ring id: {}".format(self.sharding_ring_id))
logging.info("#####" * 6) logging.info("#####" * 6)
logging.info("outter pure dp group size: {}".format(self.dp_degree)) logging.info("pp group size: {}".format(self.pp_degree))
logging.info("outter pure dp rank: {}".format(self.dp_rank)) logging.info("pp rank: {}".format(self.pp_rank))
logging.info("outter pure dp group endpoints: {}".format( logging.info("pp group id: {}".format(self.pp_group_id))
logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logging.info("pp ring id: {}".format(self.pp_ring_id))
logging.info("#####" * 6)
logging.info("pure dp group size: {}".format(self.dp_degree))
logging.info("pure dp rank: {}".format(self.dp_rank))
logging.info("pure dp group endpoints: {}".format(
self.dp_group_endpoints)) self.dp_group_endpoints))
logging.info("outter pure dp ring id: {}".format(self.dp_ring_id)) logging.info("pure dp ring id: {}".format(self.dp_ring_id))
logging.info("#####" * 6) logging.info("#####" * 6)
return return
def _initialization_broadcast(self, startup_prog): def _initialization_broadcast(self, startup_block):
""" """
this funtion is to ensure the initialization between dp group to be this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used. identical when hybrid-dp is used.
""" """
block = startup_prog.global_block()
params = [] params = []
for param in block.iter_parameters(): for param in startup_block.iter_parameters():
params.append(param) params.append(param)
block.append_op( startup_block.append_op(
type='c_broadcast', type='c_broadcast',
inputs={'X': param}, inputs={'X': param},
outputs={'Out': param}, outputs={'Out': param},
...@@ -844,15 +1164,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -844,15 +1164,14 @@ class ShardingOptimizer(MetaOptimizerBase):
'root': 0, 'root': 0,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward
}) })
block.append_op( startup_block.append_op(
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': params}, inputs={'X': params},
outputs={'Out': params}, outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id, attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward}) OP_ROLE_KEY: OpRole.Forward})
# sync within global group # sync within global group
append_naive_sync(block, self.startup_prog_sync_var, append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id) self.global_ring_id)
# sharding gradient merge # sharding gradient merge
......
...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars): ...@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block): ...@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op() new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc) new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward) new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc) result_descs.append(new_op_desc)
return result_descs return result_descs
...@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_( ...@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops) max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []: if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position] gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops): for op in reversed(gap_ops):
...@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
...@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_( ...@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block")) _pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc( grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), []) op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block) added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs) grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var) grad_to_var.update(op_grad_to_var)
......
...@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object): ...@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
""" """
Find the post op that has variable named var_name as input. Find the post op that has variable named var_name as input.
""" """
# bugfix for uniform hybrid parallelism
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_ops = self.input_var_to_op[var_name] post_ops = self.input_var_to_op[var_name]
if post_ops == None: return None if post_ops == None: return None
result_op = None result_op = None
...@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object): ...@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to # For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly # make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all") op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "scale" and self._is_backward_op(op): # bugfix in hybrid parallelism
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(idx, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0]) prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key)) op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
elif op.type == "memcpy" and not self._is_optimize_op(op): elif op.type == "memcpy" and not self._is_optimize_op(op):
...@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object): ...@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two Insert a pair of send and recv ops for every two
consecutive ops on different devices. consecutive ops on different devices.
""" """
extra_index_info = {'index': 0}
# A map from var to device where op takes it as input, # A map from var to device where op takes it as input,
# avoiding multiple send and recv ops. # avoiding multiple send and recv ops.
input_var_to_device = dict() input_var_to_device = dict()
# bugfix hybrid parallelism
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if self._is_optimize_op(op):
first_optimize_index = index
break
extra_index_info = {
'index': 0,
'first_optimize_index': first_optimize_index
}
for index, op in enumerate(list(block.ops)): for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key) cur_device = op.attr(self._op_device_key)
...@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object): ...@@ -4371,17 +4401,26 @@ class PipelineOptimizer(object):
'peer': 1, 'peer': 1,
}) })
extra_index_info['index'] += 1 extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
new_op_role = self._op_role.Optimize
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op( block._insert_op(
index=index + extra_index_info['index'], index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream', type='c_sync_comm_stream',
inputs={'X': [var]}, inputs={'X': [var]},
outputs={'Out': [var]}, outputs={'Out': [var]},
attrs={ attrs={
self._op_device_key: prev_dev, self._op_device_key: prev_dev,
self._op_role_key: self._op_role.Backward, self._op_role_key: new_op_role,
'ring_id': ring_id, 'ring_id': ring_id,
}) })
extra_index_info['index'] += 1 if int(op_role) == int(self._op_role.Forward):
extra_index_info['index'] += 1
var_shape = list(var.shape) var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[ var_shape[0] = self.micro_batch_size if var_shape[
0] < 0 else var_shape[0] 0] < 0 else var_shape[0]
...@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object): ...@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
# Step4: Special Case: process persistable vars that exist in # Step4: Special Case: process persistable vars that exist in
# multiple sections # multiple sections
self._process_persistable_vars_in_multi_sections( # FIXME
main_program, startup_program, program_list) # self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs # Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list) self._add_sub_blocks(main_block, program_list)
......
...@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2, "segment_broadcast_MB": 0.2,
"segment_anchors": None, "segment_anchors": None,
"sharding_degree": 2, "sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True, "hybrid_dp": True,
"gradient_merge_acc_step": 1, "gradient_merge_acc_step": 1,
"mp_degree": 1 "mp_degree": 1
...@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2, "segment_broadcast_MB": 0.2,
"segment_anchors": None, "segment_anchors": None,
"sharding_degree": 2, "sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True, "hybrid_dp": True,
"gradient_merge_acc_step": 4, "gradient_merge_acc_step": 4,
"mp_degree": 1 "mp_degree": 1
...@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer): ...@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops] fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops] opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [ self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'fill_constant',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'fill_constant',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul', 'c_sync_calc_stream',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax', 'c_broadcast',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad', 'c_broadcast',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad', 'c_broadcast',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_broadcast',
'tanh_grad', 'elementwise_add_grad', 'mul_grad', 'c_broadcast',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum', 'c_broadcast',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_sync_comm_stream',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add', 'mul',
'elementwise_add', 'increment', 'elementwise_mod', 'equal', 'elementwise_add',
'conditional_block' 'tanh',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'softmax',
'cross_entropy2',
'mean',
'fill_constant',
'scale',
'mean_grad',
'cross_entropy_grad2',
'softmax_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'c_sync_calc_stream',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'elementwise_add',
'elementwise_add',
'elementwise_add',
'increment',
'elementwise_mod',
'equal',
'conditional_block',
]) ])
self.assertEqual(opt_ops, [ self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册