未验证 提交 1e60a0c4 编写于 作者: J JZ-LIANG 提交者: GitHub

[3D-parallelism] Hybrid Model Parallelism (#32074)

上级 363b25aa
......@@ -29,14 +29,18 @@ message RecomputeConfig {
}
message ShardingConfig {
optional float segment_broadcast_MB = 1 [ default = 32.0 ];
optional bool hybrid_dp = 2 [ default = false ];
optional int32 sharding_degree = 3 [ default = 8 ];
optional int32 mp_degree = 4 [ default = 1 ];
optional string sharding_segment_strategy = 5
optional string sharding_segment_strategy = 1
[ default = 'segment_broadcast_MB' ];
repeated string segment_anchors = 6;
optional int32 gradient_merge_acc_step = 7 [ default = 1 ];
optional float segment_broadcast_MB = 2 [ default = 32.0 ];
repeated string segment_anchors = 3;
optional int32 sharding_degree = 4 [ default = 8 ];
optional int32 mp_degree = 5 [ default = 1 ];
optional int32 dp_degree = 6 [ default = 1 ];
optional bool hybrid_dp = 7 [ default = false ];
optional int32 gradient_merge_acc_step = 8 [ default = 1 ];
optional bool optimize_offload = 9 [ default = false ];
optional bool pp_allreduce_in_optimize = 10 [ default = false ];
optional int32 pp_degree = 11 [ default = 1 ];
}
message AMPConfig {
......
......@@ -45,11 +45,16 @@ class PipelineOptimizer(MetaOptimizerBase):
'accumulate_steps']
self.schedule_mode = user_defined_strategy.pipeline_configs[
'schedule_mode']
self.use_sharding = user_defined_strategy.sharding
def _can_apply(self):
if not self.role_maker._is_collective:
return False
# FIXME revise for hybrid parallelism
if self.use_sharding:
return False
if self.user_defined_strategy.pipeline == True:
return True
return False
......
......@@ -81,7 +81,10 @@ class FP16Utils(object):
if not FP16Utils.is_fp32_cast_op(block, op):
continue
output_name = op.desc.output_arg_names()[0]
param_name = output_name.strip("@GRAD")
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
param_name = output_name.strip(
"@GRAD@MERGED"
) if "@MERGED" in output_name else output_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError("Output 'X' of cast_op must be a grad of"
"model param, but {} is not a grad".format(
......@@ -105,6 +108,10 @@ class FP16Utils(object):
reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'):
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params:
raise ValueError(
......@@ -169,3 +176,58 @@ class FP16Utils(object):
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
@staticmethod
def sync_amp_check_nan_inf(block, ring_id):
update_loss_scaling_op_idx = -1
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == "update_loss_scaling":
update_loss_scaling_op_idx = idx
inf_var_name = op.desc.input('FoundInfinite')[0]
op._rename_input(inf_var_name, inf_var_name + "@GLOBAL_WORLD")
# not use amp
if update_loss_scaling_op_idx == -1:
return
inf_var = block.var(inf_var_name)
inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32",
shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32)
inf_var_global = block.create_var(
name=inf_var_name + "@GLOBAL_WORLD",
shape=inf_var.shape,
dtype=inf_var.dtype)
block._insert_op_without_sync(
update_loss_scaling_op_idx,
type='cast',
inputs={'X': inf_var},
outputs={'Out': inf_var_int32},
attrs={
"in_dtype": inf_var.dtype,
"out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 1,
type='c_allreduce_max',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_int32},
attrs={
'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
block._insert_op_without_sync(
update_loss_scaling_op_idx + 2,
type='cast',
inputs={'X': inf_var_int32},
outputs={'Out': inf_var_global},
attrs={
"in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_global.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
block._sync_with_cpp()
......@@ -32,6 +32,7 @@ class GradientClipHelper(object):
deperated_vars = set()
deperate_op_idx = set()
reversed_x_paramname = []
global_norm_sum_op_idx = -1
for idx, op in enumerate(block.ops):
if not self._is_gradient_clip_op(op):
continue
......@@ -41,6 +42,10 @@ class GradientClipHelper(object):
for input_name in op.desc.input_arg_names():
if input_name in deperated_vars:
deperate_op = True
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if "@MERGED" in input_name:
param_name = input_name.strip("@GRAD@MERGED")
else:
param_name = input_name.strip("@GRAD")
if shard.is_param(param_name) and \
not shard.has_param(param_name):
......@@ -51,6 +56,7 @@ class GradientClipHelper(object):
if deperate_op:
deperate_op_idx.add(idx)
for output_name in op.desc.output_arg_names():
if output_name not in op.desc.input_arg_names():
deperated_vars.add(output_name)
if not deperated_vars:
......@@ -65,6 +71,7 @@ class GradientClipHelper(object):
continue
reversed_inputs = []
if op.type == "sum":
global_norm_sum_op_idx = idx
for input_name in op.desc.input_arg_names():
if input_name not in deperated_vars:
reversed_inputs.append(input_name)
......@@ -115,3 +122,45 @@ class GradientClipHelper(object):
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
return
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
def sync_global_norm(self, block, ring_id, pure_dp_degree=1):
"""
prune gradient_clip related ops for params that not belong to cur shard
prune: square, reduce_sum, elementwise_mul
keep: sum, sqrt, elementwise_max, elementwise_div
"""
for idx, op in reversed(list(enumerate(block.ops))):
if not self._is_gradient_clip_op(op):
continue
if op.type == "sum":
sum_res = op.desc.output_arg_names()[0]
block._insert_op_without_sync(
idx + 1,
type='c_allreduce_sum',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'ring_id': ring_id,
'op_namescope': "/gradient_clip_model_parallelism",
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
})
# global norm should only be sum within each model parallelism word size
if pure_dp_degree > 1:
block._insert_op_without_sync(
idx + 2,
type='scale',
inputs={'X': sum_res},
outputs={'Out': sum_res},
attrs={
'scale': 1.0 / float(pure_dp_degree),
'op_namescope': "/gradient_clip_model_parallelism",
'bias': 0.0,
'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize
})
return
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from ..common import is_optimizer_op, OP_ROLE_KEY, OpRole
from paddle.fluid import core, unique_name
class OffloadHelper(object):
cpu_place_type = 0
cuda_place_type = 1
cuda_pinned_place_type = 2
def __init__(self):
pass
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
def _insert_cast_op(self, block, idx, src_name, dst_name):
src_var = block.var(src_name)
if not block.has_var(dst_name):
block.create_var(
name=dst_name,
shape=src_var.shape,
dtype=core.VarDesc.VarType.FP16,
persistable=True)
dst_var = block.var(dst_name)
assert dst_var.dtype == core.VarDesc.VarType.FP16
block._insert_op_without_sync(
idx,
type='cast',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'in_dtype': src_var.dtype,
'out_dtype': dst_var.dtype,
OP_ROLE_KEY: OpRole.Optimize
})
def _insert_memcpy_op(self, block, idx, src_name, dst_name, dst_place_type):
src_var = block.var(src_name)
dst_var = block.var(dst_name)
block._insert_op_without_sync(
idx,
type='memcpy',
inputs={'X': src_var},
outputs={'Out': dst_var},
attrs={
'dst_place_type': dst_place_type,
OP_ROLE_KEY: OpRole.Optimize,
})
def _insert_fetch_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_place_type)
def _insert_offload_op(self, block, idx, src_name, dst_name):
self._insert_memcpy_op(block, idx, src_name, dst_name,
OffloadHelper.cuda_pinned_place_type)
def _get_offload_var_name(self, name):
return unique_name.generate(name + '@offload')
def _create_offload_var(self, var_name, offload_var_name, blocks):
for block in blocks:
var = block.var(var_name)
var.persistable = False
offload_var = block.create_var(
name=offload_var_name,
shape=var.shape,
dtype=var.dtype,
persistable=True)
def offload_fp32param(self, block, startup_block):
"""
(p_fp16) = cast(p)
(p_fp16_recompute) = cast(p)
(pout,) = adam(p)
===========================>
rename(p_fp16_recompute, p_fp16)
(p,) = prefetch(p@offload)
(pout,) = adam(p)
(p_fp16) = cast(p)
(p@offload) = memcpy(p)
"""
param_to_idx = dict()
param_to_fp16 = dict()
# recompute_var which need rename to fp16_param
fp16_param_to_recompute = dict()
recompute_to_fp16 = dict()
def remove_param(input_name):
param_to_idx.pop(input_name)
if input_name in param_to_fp16:
fp16_param = param_to_fp16.pop(input_name)
if fp16_param in fp16_param_to_recompute:
recompute = fp16_param_to_recompute.pop(fp16_param)
recompute_to_fp16.pop(recompute)
# step1: record param
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
param_to_idx[param] = idx
# step2: remove param which can't offload
for idx, op in enumerate(block.ops):
if is_optimizer_op(op):
break
for input_name in op.desc.input_arg_names():
if input_name not in param_to_idx:
continue
# param is real used by fp32 op
if op.type != 'cast':
remove_param(input_name)
continue
# param is only used by cast op,
# which to cast fp32_param to fp16_param
output_name = op.output_arg_names[0]
if 'cast_fp16' not in output_name:
remove_param(input_name)
continue
if 'subprog' not in output_name:
assert output_name == input_name + '.cast_fp16'
assert input_name not in param_to_fp16, \
"There must be only one cast op from fp32 param to fp16 param."
param_to_fp16[input_name] = output_name
else:
# fp16-->recompute_var
assert input_name in param_to_fp16, \
"param must first be cast to fp16"
fp16_param = param_to_fp16[input_name]
fp16_param_to_recompute[fp16_param] = output_name
recompute_to_fp16[output_name] = fp16_param
param_name_to_offload_name = dict()
# step3: main_block add offload, cast op
# change recompute to fp16, remove cast(param) to fp16
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in ('adam', 'momentum', 'lars', 'lamb'):
param = op.desc.input("Param")[0]
if param not in param_to_idx: continue
# step3.1: create offload_var
offload_var_name = self._get_offload_var_name(param)
param_name_to_offload_name[param] = offload_var_name
self._create_offload_var(param, offload_var_name,
[block, startup_block])
# step3.2: insert cast op and offload op
self._insert_offload_op(block, idx + 1, param, offload_var_name)
assert param in param_to_fp16
fp16_param_name = param_to_fp16[param]
fp16_param_var = block.var(fp16_param_name)
fp16_param_var.persistable = True
self._insert_cast_op(block, idx + 1, param,
param_to_fp16[param])
# step3.3: insert fetch op
self._insert_fetch_op(block, idx, offload_var_name, param)
continue
# step3.4: remove cast op
if op.type == 'cast':
input_name = op.desc.input_arg_names()[0]
if input_name in param_to_idx:
block._remove_op(idx, sync=False)
continue
# step3.5: change recompute_param to fp16_param
for input_name in op.desc.input_arg_names():
if input_name in recompute_to_fp16:
op._rename_input(input_name, recompute_to_fp16[input_name])
for output_name in op.desc.output_arg_names():
if output_name in recompute_to_fp16:
op._rename_output(output_name,
recompute_to_fp16[output_name])
# step4: remove recompute_param
for name in recompute_to_fp16.keys():
block._remove_var(name, sync=False)
# step5: startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in param_name_to_offload_name:
var_name = out_name
offload_var_name = param_name_to_offload_name[var_name]
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
self._insert_cast_op(startup_block, idx + 1, var_name,
param_to_fp16[var_name])
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
def offload(self, block, startup_block):
"""
(m1, m2) = prefetch(m1@offload, m2@offload)
(m1out, m2out, pout) = adam(m1, m2, p)
(m1@offload, m2@offload) = memcpy(m1, m2)
"""
vars_name_to_offload_name = dict()
# main_block add offload
for idx, op in reversed(list(enumerate(block.ops))):
if not is_optimizer_op(op):
break
vars_name = []
if op.type == "adam":
# {Moment1Out = [''], Moment2Out = [''], ParamOut = ['']} =
# adam(inputs={Moment1 = [''], Moment2 = [''], Param = ['']})
vars_name.append(op.desc.input("Moment1")[0])
vars_name.append(op.desc.input("Moment2")[0])
elif op.type == 'momentum':
pass
elif op.type == 'lars':
pass
elif op.type == 'lamb':
pass
# step1: create and init offload_var
for var_name in vars_name:
assert var_name not in vars_name_to_offload_name
offload_var_name = self._get_offload_var_name(var_name)
vars_name_to_offload_name[var_name] = offload_var_name
self._create_offload_var(var_name, offload_var_name,
[block, startup_block])
# step2: insert offload op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_offload_op(block, idx + 1, var_name,
offload_var_name)
# step3: insert fetch op
for var_name in vars_name:
offload_var_name = vars_name_to_offload_name[var_name]
self._insert_fetch_op(block, idx, offload_var_name, var_name)
# startup_block add offload
visited_vars = set()
for idx, op in reversed(list(enumerate(startup_block.ops))):
for out_name in op.output_arg_names:
if out_name in visited_vars:
continue
if out_name in vars_name_to_offload_name:
var_name = out_name
offload_var_name = vars_name_to_offload_name[var_name]
# insert offload op after var is generated
self._insert_offload_op(startup_block, idx + 1, var_name,
offload_var_name)
visited_vars.add(out_name)
block._sync_with_cpp()
startup_block._sync_with_cpp()
......@@ -126,6 +126,10 @@ class ProgramDeps(object):
def should_remove_op(self, op_idx):
op = self._block.ops[op_idx]
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
# remove check_finite_and_unscale op if its input 'X' is empty
if op.type == 'check_finite_and_unscale' and len(op.input('X')) == 0:
return True
for output_name in op.desc.output_arg_names():
if output_name not in self._should_removed_var:
return False
......
......@@ -274,6 +274,10 @@ def insert_sync_comm_ops(block, insert_idx, ring_id, comm_dep_vars):
"""
insert sync_comm_op for vars
"""
# NOTE (JZ-LIANG) to be check, may result undefined case
if len(comm_dep_vars) == 0:
return 0
op_role = get_valid_op_role(block, insert_idx)
block._insert_op_without_sync(
insert_idx,
......@@ -324,27 +328,45 @@ def insert_cast_ops(block, insert_idx, cast_ops):
return
def insert_allreduce_ops(block, insert_idx, ring_id, allreduce_vars):
def insert_allreduce_ops(block,
insert_idx,
ring_id,
allreduce_vars,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
if len(allreduce_vars) == 0:
return
for var in allreduce_vars:
block._insert_op_without_sync(
insert_idx,
type='c_allreduce_sum',
inputs={'X': var},
outputs={'Out': var},
attrs={'ring_id': ring_id,
OP_ROLE_KEY: OpRole.Backward})
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
def insert_reduce_ops(block,
insert_idx,
ring_id,
reduce_vars,
shard,
op_role=OpRole.Backward,
use_calc_stream=False):
"""
_add_allreduce_ops
"""
for var in reduce_vars:
root_id = get_grad_device(var, shard)
assert root_id >= 0, "root id should be a positive int".format(var)
block._insert_op_without_sync(
......@@ -355,12 +377,40 @@ def insert_reduce_ops(block, insert_idx, ring_id, reduce_vars, shard):
attrs={
'ring_id': ring_id,
'root_id': root_id,
OP_ROLE_KEY: OpRole.Backward
'use_calc_stream': use_calc_stream,
OP_ROLE_KEY: op_role
})
return
def get_grad_device(grad_name, shard):
assert "@GRAD" in grad_name, "[{}] should be a grad variable.".format(
grad_name)
base_name = None
# mind the traversal order
possible_suffixes = [
'.cast_fp16@GRAD@MERGED', '.cast_fp16@GRAD', '@GRAD@MERGED', '@GRAD'
]
for suffix in possible_suffixes:
if suffix in grad_name:
base_name = re.sub(suffix, '', grad_name)
break
assert base_name in shard.global_param2device, "[{}] should be a param variable.".format(
base_name)
return shard.global_param2device[base_name]
def get_first_check_finite_and_unscale_op_idx(block):
for idx, op in enumerate(block.ops):
if op.type == "check_finite_and_unscale":
return idx
raise ValueError("check_finite_and_unscale does not exist in block")
def insert_broadcast_ops(block, insert_idx, ring_id, broadcast2root):
"""
_add_broadcast_ops
......@@ -420,6 +470,7 @@ def insert_scale_loss_grad_ops(block, scale=1.0):
outputs={'Out': loss_grad_var},
attrs={'scale': scale,
OP_ROLE_KEY: OpRole.Backward})
break
def comm_analyse(main_program):
......@@ -502,6 +553,9 @@ def save_persistables(exe, dirname, main_program, filename=None):
and part of persistable vars are duplicated and exist in all the ranks with different values.
This function handles the model saving for sharding training.
"""
# TODO (JZ-LIANG) revise this for uniform mixed parallelism
if main_program._pipeline_opt:
main_program = main_program._pipeline_opt['section_program']['program']
def is_opt_vars(var):
# NOTE(JZ-LIANG): The checks should be updated when add new compatible optimizer
......
......@@ -16,16 +16,16 @@ import paddle
from paddle.fluid import unique_name, core
import paddle.fluid as fluid
from paddle.distributed.fleet.meta_optimizers.common import OpRole, OP_ROLE_VAR_KEY, CollectiveHelper
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op
from paddle.distributed.fleet.meta_optimizers.common import is_backward_op, is_optimizer_op, is_update_op
from paddle.distributed.fleet.meta_optimizers.meta_optimizer_base import MetaOptimizerBase
from paddle.distributed.fleet.meta_optimizers.sharding.shard import Shard, ProgramSegment
from paddle.distributed.fleet.meta_optimizers.sharding.fp16_helper import FP16Utils
from paddle.distributed.fleet.meta_optimizers.sharding.weight_decay_helper import WeightDecayHelper
from paddle.distributed.fleet.meta_optimizers.sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper
from paddle.distributed.fleet.meta_optimizers.sharding.prune import ProgramDeps
from paddle.distributed.fleet.meta_optimizers.sharding.utils import *
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program, default_startup_program, device_guard
from paddle.fluid import layers
import logging
......@@ -38,6 +38,8 @@ __all__ = ["ShardingOptimizer"]
class ShardingOptimizer(MetaOptimizerBase):
"""Sharding Optimizer."""
def __init__(self, optimizer):
super(ShardingOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer
......@@ -46,7 +48,8 @@ class ShardingOptimizer(MetaOptimizerBase):
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
"ModelParallelOptimizer",
# "ModelParallelOptimizer",
# "PipelineOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self._main_program = None
......@@ -88,26 +91,6 @@ class ShardingOptimizer(MetaOptimizerBase):
self._nrings_sharding = 1
self._nrings_dp = 1
# parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 1, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
self.pp_degree = 1
# dp here is the pure dp as the outest parallelism
self.dp_degree = int(self.role_maker._worker_num() // self.mp_degree //
self.sharding_degree)
assert self.role_maker._worker_num(
) == self.dp_degree * self.mp_degree * self.sharding_degree * self.pp_degree
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
# segment
self._sharding_segment_strategy = str(
self.user_defined_strategy.sharding_configs[
......@@ -128,29 +111,138 @@ class ShardingOptimizer(MetaOptimizerBase):
"the sharding segment strategy [{}] is not implemented".format(
str(self._sharding_segment_strategy)))
# parallelism
self.sharding_degree = int(self.user_defined_strategy.sharding_configs[
"sharding_degree"])
assert self.sharding_degree > 0, "sharding degree must be larger than zero"
self.mp_degree = int(self.user_defined_strategy.sharding_configs[
"mp_degree"])
# pipeline setting
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
self.pp_degree = int(self.user_defined_strategy.sharding_configs[
"pp_degree"])
if self.pp_degree > 1:
assert self.user_defined_strategy.pipeline == True
self.dp_degree = int(self.user_defined_strategy.sharding_configs[
'dp_degree'])
assert self.role_maker._worker_num(
) == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
self.role_maker._worker_num(),
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree, )
self.hybrid_dp = self.user_defined_strategy.sharding_configs[
"hybrid_dp"]
# NOTE (JZ-LIANG)
# there 2 kind of modes for gradient-merge and hybrid-dp in mixed parallism [sharding] and [pipeline].
# we distinguish this two modes since the gm/hybrid-dp related allreduce should be insert in different place according different mode to have best performance:
# sharding: communication within node, and therefore should insert within backward segment to overlap with bw calc, conduct every micro step
# pipeline: communication accross nodes, and therefore should insert in update segemnt, conduct just once per global step
self.hybrid_dp_mode = None
# dp here is the pure dp as the outest parallelism
if self.hybrid_dp:
assert self.dp_degree > 1, "hybrid dp is on, but dp degree is [{}]".format(
self.dp_degree)
if self.pp_degree > 1:
self.hybrid_dp_mode = "pp_hybrid_dp"
else:
assert self.sharding_degree > 1, "by now we only support five kind of hybrid dp: sharding_hybrid_dp, mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
self.hybrid_dp_mode = "sharding_hybrid_dp"
# gradient merge
self._gradient_merge_acc_step = int(
self.user_defined_strategy.sharding_configs[
"gradient_merge_acc_step"])
self.gradient_merge_mode = None
if self.pp_degree <= 1:
self.gradient_merge_mode = "sharding_gm"
self._grad2merged_grad = dict()
else:
self.gradient_merge_mode = "pp_gm"
self._gradient_merge_acc_step = self.user_defined_strategy.pipeline_configs[
'accumulate_steps']
if self._gradient_merge_acc_step > 1:
logging.info("Gradient merge in [{}], acc step = [{}]".format(
self.gradient_merge_mode, self._gradient_merge_acc_step))
# optimize offload
self.optimize_offload = self.user_defined_strategy.sharding_configs[
"optimize_offload"]
# this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = self.user_defined_strategy.sharding_configs[
"pp_allreduce_in_optimize"]
if self.inner_opt is None:
raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.")
if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step)
main_program = loss.block.program
main_program._pipeline_opt = dict()
self.schedule_mode = self.user_defined_strategy.pipeline_configs[
'schedule_mode']
main_program._pipeline_opt['schedule_mode'] = self.schedule_mode
main_program._pipeline_opt[
'micro_batch_size'] = self.user_defined_strategy.pipeline_configs[
'micro_batch_size']
self.pp_rank_ = self.role_maker._worker_index() // (
self.sharding_degree * self.mp_degree) % self.pp_degree
main_program._pipeline_opt['local_rank'] = self.pp_rank_
main_program._pipeline_opt[
'global_rank'] = self.role_maker._worker_index()
main_program._pipeline_opt['use_sharding'] = True
# TODO (JZ-LIANG) should revise here for support mix parallelism with pipeline
main_program._pipeline_opt['ring_id'] = 20
main_program._pipeline_opt['global_ring_id'] = 3
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set)
self.pp_degree = len(program_list)
else:
optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set)
if startup_program is None:
startup_program = default_startup_program()
if self.pp_degree > 1:
startup_program = startup_program._pipeline_opt['startup_program']
#main_program = main_program._pipeline_opt['section_program']['program']
print("pp_rank:", self.pp_rank_)
main_program = program_list[self.pp_rank_]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
main_block = main_program.global_block()
new_params_grads = []
for param, grad in params_grads:
if main_block.has_var(param.name):
new_params_grads.append((param, grad))
params_grads = new_params_grads
else:
main_block = loss.block
startup_block = startup_program.global_block()
self._main_program = main_block.program
self._startup_program = startup_program
if self.pp_degree > 1:
pp_optimizer._rename_gradient_var_name(main_block)
with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
f.writelines(str(main_program))
# step0: _init_comm
self._init_comm()
# step1: _build_shard
if self.sharding_degree > 1:
# step1: build shard
self._build_shard(params_grads)
# step2: split_program
......@@ -161,22 +253,89 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
# step4: scale the loss by the num of dp degree
# sharding is also a senario of dp
scale_ = self.dp_degree * self.sharding_degree
if scale_ > 1:
insert_scale_loss_grad_ops(main_block, scale=1.0 / scale_)
main_block._sync_with_cpp()
# step5: remove unneeded ops and vars from block
# step4: remove unneeded ops and vars from block
self._prune_main_program(main_block)
self._prune_startup_program(startup_block)
if self.hybrid_dp:
self._initialization_broadcast(startup_program)
# step6: optional gradient merge
if self._gradient_merge_acc_step > 1:
if self.pp_degree > 1:
# sharding-pp related logic
# pp_optimizer._rename_gradient_var_name(main_block)
# crop ops
if self.sharding_degree > 1:
for idx, op in reversed(list(enumerate(main_block.ops))):
if is_update_op(op):
op_role_var = op.attr('op_role_var')
param_name = op_role_var[0]
if not self._shard.has_param(param_name):
main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue
in_name = op.input_arg_names[0]
if in_name not in self._params: continue
#if self._shard.has_param(param_name): continue
if in_name not in main_block.vars:
main_block._remove_op(idx)
accumulated_grad_names = pp_optimizer._accumulate_gradients(
main_block)
# accumulated_grad_names = sorted(accumulated_grad_names)
if self.pp_allreduce_in_optimize:
print("persistable FP32 grad: ")
print(accumulated_grad_names)
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_reduce_ops(
main_block,
first_optimize_op_index,
self.sharding_ring_id,
accumulated_grad_names,
self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
if self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
first_optimize_op_index = get_first_check_finite_and_unscale_op_idx(
main_block)
insert_allreduce_ops(
main_block,
first_optimize_op_index,
self.dp_ring_id,
accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True)
# if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt
if self.sharding_degree <= 1:
# amp
FP16Utils.sync_amp_check_nan_inf(main_block, self.global_ring_id)
# clip
gradientclip_helper = GradientClipHelper(self.global_ring_id)
gradientclip_helper.sync_global_norm(
main_block, self.global_ring_id, self.dp_degree)
# step6: loss div dp_degree
global_dp_degree = self.sharding_degree * self.dp_degree
assert int(global_dp_degree) == global_dp_degree
if global_dp_degree > 1:
insert_scale_loss_grad_ops(main_block, scale=1.0 / global_dp_degree)
main_block._sync_with_cpp()
# TODO(wangxi): add optimize offload
# opt offload should be enable while gradient merge is enable && acc_step is quite large (e.g. >> 100)
# sync its memcpy could not be overlap with calc, otherwise it will slower down training severely.
if self.optimize_offload:
logging.info("Sharding with optimize offload !")
offload_helper = OffloadHelper()
offload_helper.offload(main_block, startup_block)
offload_helper.offload_fp32param(main_block, startup_block)
# step6: (optional) sharding gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self._sharding_gradient_merge(main_block)
# # check op dependecy
......@@ -184,14 +343,29 @@ class ShardingOptimizer(MetaOptimizerBase):
# check_broadcast(main_block)
# check_allreduce_sum(main_block, self._shard, self.sharding_ring_id,
# self.dp_ring_id)
if self.hybrid_dp:
# NOTE(JZ-LIANG) ensure in both sharding_hybrid_dp & pp_hybrid_dp
# init param broadcast should be called after startup pruning
self._initialization_broadcast(startup_block)
with open("start_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(),
'w') as f:
f.writelines(str(main_block.program))
self._wait()
return optimize_ops, params_grads
def _init_comm(self):
# config sharding & dp groups
self._build_group()
self._build_groups()
# sync var
startup_block = self._startup_program.global_block()
self.startup_prog_sync_var = startup_block.create_var(
name="startup_prog_sync_var",
......@@ -199,7 +373,7 @@ class ShardingOptimizer(MetaOptimizerBase):
dtype=core.VarDesc.VarType.INT32,
persistable=False)
# global
# global ring
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
......@@ -212,7 +386,7 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# mp
# mp ring
if self.mp_degree > 1:
self._collective_helper._init_communicator(
self._startup_program,
......@@ -226,7 +400,7 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding
# sharding ring
if self.sharding_degree > 1:
self._collective_helper._init_communicator(
self._startup_program,
......@@ -240,7 +414,65 @@ class ShardingOptimizer(MetaOptimizerBase):
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# dp
# pp ring
if self.pp_degree > 1:
if self.schedule_mode == 'F-then-B': # GPipe
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
self.pp_group_endpoints,
self.pp_rank,
self.pp_ring_id + 2,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
else:
assert self.schedule_mode == '1F1B'
for pair in self.pipeline_pair:
pair_key = pair[0] * 1000 + pair[1]
ring_id = self.pp_ring_map[pair_key]
print("pp pair:{}, ring_id: {}".format(pair, ring_id))
if self.pp_rank not in pair: continue
pp_group_endpoints = [
self.pp_group_endpoints[pair[0]],
self.pp_group_endpoints[pair[1]],
]
if pair[0] < pair[1]:
start_ring_id = self.pp_ring_id + pair[1] - pair[0] - 1
else:
start_ring_id = self.pp_ring_id + 2 + pair[0] - pair[
1] - 1
pp_rank = 0 if self.pp_rank == pair[0] else 1
self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint,
pp_group_endpoints,
pp_rank,
ring_id,
False,
global_ring_id=self.global_ring_id,
sync=False)
# append_naive_sync(startup_block, self.startup_prog_sync_var,
# self.global_ring_id)
# TODO (JZ-LIANG) to unify this shit
assert self.pp_rank_ == self.pp_rank, "pp rank for pp opt [{}], pp rank for sharding opt [{}]".format(
self.pp_rank_, self.pp_rank)
# pure dp ring
if self.dp_degree > 1:
self._collective_helper._init_communicator(
self._startup_program,
......@@ -360,16 +592,21 @@ class ShardingOptimizer(MetaOptimizerBase):
self._main_program.global_block().var(input_name))
# find reduce vars
if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize
pass
else:
if is_backward_op(op) and \
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i + 1]
param, reduced_grad = op_role_var[i], op_role_var[
i + 1]
segment._allreduce_vars.append(reduced_grad)
assert (
reduced_grad not in self._reduced_grads_to_param)
assert (reduced_grad not in
self._reduced_grads_to_param)
self._reduced_grads_to_param[reduced_grad] = param
# find cast op
......@@ -462,8 +699,13 @@ class ShardingOptimizer(MetaOptimizerBase):
# Prune
for idx, op in reversed(list(enumerate(block.ops))):
if op.type in [
"c_allreduce_sum", "c_sync_comm_stream",
"c_calc_comm_stream", "c_gen_nccl_id", "c_comm_init"
"c_allreduce_sum",
"c_sync_comm_stream",
"c_calc_comm_stream",
"c_gen_nccl_id",
"c_comm_init",
'send_v2',
'recv_v2',
]:
pass
elif op.type == "conditional_block":
......@@ -500,6 +742,16 @@ class ShardingOptimizer(MetaOptimizerBase):
if program_deps.should_remove_op(idx):
program_deps.remove_op(idx)
# NOTE (JZ-LIANG) revise and unify logic here
# sharding support fp16_allreduce logic
block._sync_with_cpp()
for idx, op in reversed(list(enumerate(block.ops))):
if op.type == 'concat' and is_optimizer_op(op):
# remove inputs that not on this card
reserved_x = []
for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name)
op.desc.set_input('X', reserved_x)
block._sync_with_cpp()
return
......@@ -507,21 +759,41 @@ class ShardingOptimizer(MetaOptimizerBase):
"""
add broadcast allreduce op
if enable gradient_merge, insert related ops
if combined with pipeline(grad accumulate),
the grad allreduce should be done in optimize role
"""
if len(self._segments) < 1:
return
# sharding
if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
for idx in range(len(self._segments)):
assert len(self._segments[idx]._allreduce_vars) == 0
# NOTE (JZ-LIANG) revise and unify logic here
# fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1, -1):
op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1
elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads(self._segments[-1]
._allreduce_vars)
if self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
insert_allreduce_ops(block, self._segments[-1]._end_idx,
self.dp_ring_id, shard_allredue_vars)
# gradient merge
else:
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops(
block,
self._startup_program.global_block(),
......@@ -532,9 +804,14 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_ring_id,
self._segments[-1]._allreduce_vars)
# allreduce --> reduce
insert_reduce_ops(block, self._segments[-1]._end_idx,
insert_reduce_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id,
self._segments[-1]._allreduce_vars, self._shard)
self._segments[-1]._allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[
......@@ -574,8 +851,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_sync_comm_ops(block, segment._end_idx,
self.dp_ring_id, shard_allredue_vars)
......@@ -593,7 +871,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_ring_id,
comm_dep_vars)
# gradient merge
else:
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx,
......@@ -616,7 +894,7 @@ class ShardingOptimizer(MetaOptimizerBase):
# step5: add broadcast ops
# gradient merge
if self._gradient_merge_acc_step > 1:
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
self.create_persistable_gradients_and_insert_merge_ops(
block,
self._startup_program.global_block(), segment._start_idx,
......@@ -627,20 +905,29 @@ class ShardingOptimizer(MetaOptimizerBase):
# step6: add all_reduce ops
# dp
if self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and len(shard_allredue_vars) >= 1:
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1:
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len(
shard_allredue_vars) >= 1:
insert_allreduce_ops(block, segment._start_idx,
self.dp_ring_id, shard_allredue_vars)
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# gradient merge
else:
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1:
insert_sync_comm_ops(block, segment._start_idx,
self.sharding_ring_id, allreduce_vars)
# sharding
# allreduce --> reduce
insert_reduce_ops(block, segment._start_idx, self.sharding_ring_id,
allreduce_vars, self._shard)
# TODO temp change
if len(allreduce_vars) > 0:
insert_reduce_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
self._shard,
op_role=OpRole.Backward,
use_calc_stream=False)
block._sync_with_cpp()
......@@ -691,7 +978,7 @@ class ShardingOptimizer(MetaOptimizerBase):
block._remove_var(var_name, sync=False)
block._sync_with_cpp()
def _build_group(self):
def _build_groups(self):
"""
pre-assign ring ids
mp: 0
......@@ -768,6 +1055,30 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_group_id = -1
self.sharding_group_endpoints = []
# pp
if self.pp_degree > 1:
self.pp_ring_id = 20
self.pp_rank = self.global_rank // (self.sharding_degree *
self.mp_degree) % self.pp_degree
# (NOTE): Already adjust for (outter-pure) dp
self.pp_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree * self.pp_degree)
pp_first_stage_idx = self.global_rank % (
self.sharding_degree * self.mp_degree) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree)
pp_stage_offset = self.sharding_degree * self.mp_degree
self.pp_group_endpoints = []
for i in range(self.pp_degree):
self.pp_group_endpoints.append(self.global_endpoints[
pp_first_stage_idx + pp_stage_offset * i])
assert self.current_endpoint in self.pp_group_endpoints
else:
self.pp_degree = 1
self.pp_ring_id = -1
self.pp_rank = -1
self.pp_group_id = -1
self.pp_group_endpoints = []
# outter-pure-dp group
# NOTE (JZ-LIANG) support outter-pure-dp to scale the throughput in 3D parallelism
# e.g. mp-sharding-pp-dp
......@@ -775,6 +1086,7 @@ class ShardingOptimizer(MetaOptimizerBase):
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree, self.sharding_degree, self.pp_degree,
self.dp_degree, self.global_word_size)
if self.dp_degree > 1:
self.dp_ring_id = 2
self.dp_rank = self.global_rank // (self.sharding_degree *
......@@ -794,6 +1106,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_group_endpoints = []
# global group
# use for gen_nccl_comm_sync, amp check nan inf, clip by global norm
# NOTE (JZ-LIANG) when use global ring for calc global norm and dp_degree > 1, the allreduce result should be devided by dp_degree
self.global_ring_id = 3
logging.info("global word size: {}".format(self.global_word_size))
......@@ -817,25 +1131,31 @@ class ShardingOptimizer(MetaOptimizerBase):
logging.info("sharding ring id: {}".format(self.sharding_ring_id))
logging.info("#####" * 6)
logging.info("outter pure dp group size: {}".format(self.dp_degree))
logging.info("outter pure dp rank: {}".format(self.dp_rank))
logging.info("outter pure dp group endpoints: {}".format(
logging.info("pp group size: {}".format(self.pp_degree))
logging.info("pp rank: {}".format(self.pp_rank))
logging.info("pp group id: {}".format(self.pp_group_id))
logging.info("pp group endpoints: {}".format(self.pp_group_endpoints))
logging.info("pp ring id: {}".format(self.pp_ring_id))
logging.info("#####" * 6)
logging.info("pure dp group size: {}".format(self.dp_degree))
logging.info("pure dp rank: {}".format(self.dp_rank))
logging.info("pure dp group endpoints: {}".format(
self.dp_group_endpoints))
logging.info("outter pure dp ring id: {}".format(self.dp_ring_id))
logging.info("pure dp ring id: {}".format(self.dp_ring_id))
logging.info("#####" * 6)
return
def _initialization_broadcast(self, startup_prog):
def _initialization_broadcast(self, startup_block):
"""
this funtion is to ensure the initialization between dp group to be
identical when hybrid-dp is used.
"""
block = startup_prog.global_block()
params = []
for param in block.iter_parameters():
for param in startup_block.iter_parameters():
params.append(param)
block.append_op(
startup_block.append_op(
type='c_broadcast',
inputs={'X': param},
outputs={'Out': param},
......@@ -844,15 +1164,14 @@ class ShardingOptimizer(MetaOptimizerBase):
'root': 0,
OP_ROLE_KEY: OpRole.Forward
})
block.append_op(
startup_block.append_op(
type='c_sync_comm_stream',
inputs={'X': params},
outputs={'Out': params},
attrs={'ring_id': self.dp_ring_id,
OP_ROLE_KEY: OpRole.Forward})
# sync within global group
append_naive_sync(block, self.startup_prog_sync_var,
append_naive_sync(startup_block, self.startup_prog_sync_var,
self.global_ring_id)
# sharding gradient merge
......
......@@ -233,6 +233,8 @@ def _add_needed_descs_to_block(descs, block, main_block, in_memory_vars):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -252,6 +254,8 @@ def _add_descs_to_block(descs, block):
new_op_desc = block.desc.append_op()
new_op_desc.copy_from(desc)
new_op_desc._set_attr(op_role_attr_name, backward)
if desc.has_attr('op_device'):
new_op_desc._set_attr('op_device', desc.attr('op_device'))
result_descs.append(new_op_desc)
return result_descs
......@@ -843,6 +847,7 @@ def _append_backward_ops_with_checkpoints_(
vars_in_memory = vars_should_be_hold + checkpoints_name
max_calculated_op_position = len(ops)
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
if recompute_segments == []:
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
......@@ -852,6 +857,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......@@ -866,6 +876,11 @@ def _append_backward_ops_with_checkpoints_(
_pretty_op_desc_(op.desc, "with_sub_block"))
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op.desc, cpt.to_text(no_grad_dict[block.idx]), [])
# Set device for grad_op according to forward Op
if op.desc.has_attr(device_attr_name):
op_device = op.desc.attr(device_attr_name)
for op_desc in grad_op_desc:
op_desc._set_attr(device_attr_name, op_device)
added_descs = _add_descs_to_block(grad_op_desc, local_block)
grad_op_descs.extend(added_descs)
grad_to_var.update(op_grad_to_var)
......
......@@ -4033,6 +4033,12 @@ class PipelineOptimizer(object):
"""
Find the post op that has variable named var_name as input.
"""
# bugfix for uniform hybrid parallelism
if '.cast_fp32' in var_name:
var_name = var_name.replace('.cast_fp32', '')
if '.cast_fp16' in var_name:
var_name = var_name.replace('.cast_fp16', '')
post_ops = self.input_var_to_op[var_name]
if post_ops == None: return None
result_op = None
......@@ -4114,7 +4120,23 @@ class PipelineOptimizer(object):
# For LRSched ops, we should put them on all sub-programs to
# make sure each sub-program update the lr correctly
op._set_attr(self._op_device_key, "gpu:all")
elif op.type == "scale" and self._is_backward_op(op):
# bugfix in hybrid parallelism
elif op.type == "sum" and self._is_backward_op(op):
# For sum ops that compute the sum of @RENAMED@ vars
for name in op.desc.input_arg_names():
assert '@RENAME@' in name, \
"The op must be sum used to accumulate renamed vars."
assert len(op.desc.output_arg_names()) == 1
out_name = op.desc.output_arg_names()[0]
post_op = self._find_post_op(idx, out_name)
assert post_op.has_attr(
'op_device'), "{} has no op_device attr for var {}".format(
post_op.type, out_name)
device = post_op.attr(self._op_device_key)
assert device, "The post op must have op_device set."
op._set_attr(self._op_device_key, device)
elif (op.type == "cast" or
op.type == "scale") and self._is_backward_op(op):
prev_op = self._find_prev_op(idx, op.desc.input("X")[0])
op._set_attr(self._op_device_key, prev_op.attr(self._op_device_key))
elif op.type == "memcpy" and not self._is_optimize_op(op):
......@@ -4249,11 +4271,19 @@ class PipelineOptimizer(object):
Insert a pair of send and recv ops for every two
consecutive ops on different devices.
"""
extra_index_info = {'index': 0}
# A map from var to device where op takes it as input,
# avoiding multiple send and recv ops.
input_var_to_device = dict()
# bugfix hybrid parallelism
first_optimize_index = None
for index, op in enumerate(list(block.ops)):
if self._is_optimize_op(op):
first_optimize_index = index
break
extra_index_info = {
'index': 0,
'first_optimize_index': first_optimize_index
}
for index, op in enumerate(list(block.ops)):
cur_device = op.attr(self._op_device_key)
......@@ -4371,16 +4401,25 @@ class PipelineOptimizer(object):
'peer': 1,
})
extra_index_info['index'] += 1
insert_index = None
if int(op_role) == int(self._op_role.Backward):
insert_index = extra_index_info[
'first_optimize_index']
new_op_role = self._op_role.Optimize
else:
insert_index = index
new_op_role = self._op_role.Backward
block._insert_op(
index=index + extra_index_info['index'],
index=insert_index + extra_index_info['index'],
type='c_sync_comm_stream',
inputs={'X': [var]},
outputs={'Out': [var]},
attrs={
self._op_device_key: prev_dev,
self._op_role_key: self._op_role.Backward,
self._op_role_key: new_op_role,
'ring_id': ring_id,
})
if int(op_role) == int(self._op_role.Forward):
extra_index_info['index'] += 1
var_shape = list(var.shape)
var_shape[0] = self.micro_batch_size if var_shape[
......@@ -4768,8 +4807,9 @@ class PipelineOptimizer(object):
# Step4: Special Case: process persistable vars that exist in
# multiple sections
self._process_persistable_vars_in_multi_sections(
main_program, startup_program, program_list)
# FIXME
# self._process_persistable_vars_in_multi_sections(
# main_program, startup_program, program_list)
# Step5: Add sub blocks for section programs
self._add_sub_blocks(main_block, program_list)
......
......@@ -354,6 +354,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 1,
"mp_degree": 1
......@@ -422,6 +423,7 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
"segment_broadcast_MB": 0.2,
"segment_anchors": None,
"sharding_degree": 2,
"dp_degree": 2,
"hybrid_dp": True,
"gradient_merge_acc_step": 4,
"mp_degree": 1
......@@ -458,20 +460,56 @@ class TestFleetMetaOptimizer(TestFleetMetaOptimizer):
fw_bw_ops = [op.type for op in train_prog.blocks[0].ops]
opt_ops = [op.type for op in train_prog.blocks[2].ops]
self.assertEqual(fw_bw_ops, [
'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'c_sync_comm_stream', 'mul', 'elementwise_add', 'tanh', 'mul',
'elementwise_add', 'tanh', 'mul', 'elementwise_add', 'softmax',
'cross_entropy2', 'mean', 'fill_constant', 'scale', 'mean_grad',
'cross_entropy_grad2', 'softmax_grad', 'elementwise_add_grad',
'mul_grad', 'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'tanh_grad', 'elementwise_add_grad', 'mul_grad',
'c_sync_calc_stream', 'c_reduce_sum', 'c_reduce_sum',
'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum', 'c_reduce_sum',
'c_sync_comm_stream', 'elementwise_add', 'elementwise_add',
'elementwise_add', 'increment', 'elementwise_mod', 'equal',
'conditional_block'
'fill_constant',
'fill_constant',
'fill_constant',
'c_sync_calc_stream',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_broadcast',
'c_sync_comm_stream',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'tanh',
'mul',
'elementwise_add',
'softmax',
'cross_entropy2',
'mean',
'fill_constant',
'scale',
'mean_grad',
'cross_entropy_grad2',
'softmax_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'tanh_grad',
'elementwise_add_grad',
'mul_grad',
'c_sync_calc_stream',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_reduce_sum',
'c_sync_comm_stream',
'elementwise_add',
'elementwise_add',
'elementwise_add',
'increment',
'elementwise_mod',
'equal',
'conditional_block',
])
self.assertEqual(opt_ops, [
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'scale',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册