提交 7aa0cc3c 编写于 作者: S sandyhouse

update

上级 fa71ee87
...@@ -22,9 +22,10 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_u ...@@ -22,9 +22,10 @@ from .common import OpRole, OP_ROLE_KEY, OP_ROLE_VAR_KEY, CollectiveHelper, is_u
class ModelParallelHelper(object): class ModelParallelHelper(object):
def __init__(self, role_maker, wait_port=True): def __init__(self, role_maker, wait_port=True, megatron_dp=False):
self.wait_port = wait_port self.wait_port = wait_port
self.role_maker = role_maker self.role_maker = role_maker
self.megatron_dp = megatron_dp
def update_startup_program(self, def update_startup_program(self,
startup_program=None, startup_program=None,
...@@ -48,6 +49,11 @@ class ModelParallelHelper(object): ...@@ -48,6 +49,11 @@ class ModelParallelHelper(object):
mp_endpoints, mp_rank, 0, self.wait_port) mp_endpoints, mp_rank, 0, self.wait_port)
self._broadcast_params(0, broadcast_distributed_weight=False) self._broadcast_params(0, broadcast_distributed_weight=False)
print("megatron group size: {}".format(inner_parallelism))
print("megatron rank: {}".format(mp_rank))
print("megatron endpoints: {}".format(mp_endpoints))
if self.megatron_dp:
mp_num = len(endpoints) // inner_parallelism mp_num = len(endpoints) // inner_parallelism
if mp_num == 1: return if mp_num == 1: return
# Create rings for gpus as the same model parallel part # Create rings for gpus as the same model parallel part
...@@ -129,9 +135,14 @@ class ModelParallelOptimizer(MetaOptimizerBase): ...@@ -129,9 +135,14 @@ class ModelParallelOptimizer(MetaOptimizerBase):
def __init__(self, optimizer): def __init__(self, optimizer):
super(ModelParallelOptimizer, self).__init__(optimizer) super(ModelParallelOptimizer, self).__init__(optimizer)
self.inner_opt = optimizer self.inner_opt = optimizer
# we do not allow meta optimizer to be inner optimizer currently self.meta_optimizers_white_list = [
self.meta_optimizers_white_list = [] "RecomputeOptimizer",
"AMPOptimizer",
"LarsOptimizer",
"LambOptimizer",
]
self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ] self.meta_optimizers_black_list = ["GraphExecutionOptimizer", ]
self.megatron_dp = False
def _set_basic_info(self, loss, role_maker, user_defined_optimizer, def _set_basic_info(self, loss, role_maker, user_defined_optimizer,
user_defined_strategy): user_defined_strategy):
...@@ -156,6 +167,10 @@ class ModelParallelOptimizer(MetaOptimizerBase): ...@@ -156,6 +167,10 @@ class ModelParallelOptimizer(MetaOptimizerBase):
dist_strategy.model_parallel = True dist_strategy.model_parallel = True
dist_strategy.model_parallel_configs = {"parallelism": 1, } dist_strategy.model_parallel_configs = {"parallelism": 1, }
# the following function will be used by AMP if both Megatron and AMP are turn on together.
def apply_gradients(self, params_grads):
return self.minimize_impl(params_grads=params_grads)
def minimize_impl(self, def minimize_impl(self,
loss, loss,
startup_program=None, startup_program=None,
...@@ -167,6 +182,8 @@ class ModelParallelOptimizer(MetaOptimizerBase): ...@@ -167,6 +182,8 @@ class ModelParallelOptimizer(MetaOptimizerBase):
if startup_program is None: if startup_program is None:
self.startup_program = fluid.default_startup_program() self.startup_program = fluid.default_startup_program()
# (TODO) check the order of metaoptimizer
# (TODO) check the params_grads
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, self.startup_program, parameter_list, no_grad_set) loss, self.startup_program, parameter_list, no_grad_set)
...@@ -179,6 +196,8 @@ class ModelParallelOptimizer(MetaOptimizerBase): ...@@ -179,6 +196,8 @@ class ModelParallelOptimizer(MetaOptimizerBase):
self.inner_parallelism) self.inner_parallelism)
assert self.nranks % self.inner_parallelism == 0 assert self.nranks % self.inner_parallelism == 0
if self.megatron_dp:
# data parallelism # data parallelism
dp_parallelism = self.nranks // self.inner_parallelism dp_parallelism = self.nranks // self.inner_parallelism
......
...@@ -73,7 +73,7 @@ class FP16Utils(object): ...@@ -73,7 +73,7 @@ class FP16Utils(object):
@staticmethod @staticmethod
def prune_fp16(block, shard, reduced_grads_to_param, ring_id): def prune_fp16(block, shard, reduced_grads_to_param, ring_id):
""" """
1. prune all cast_fp32_to_fp16 ops if the param not belongs to this shard 1. prune all cast_fp16_to_fp32 ops if the param not belongs to this shard
2. revise amp inifine grad checking for sharding 2. revise amp inifine grad checking for sharding
""" """
# remove cast # remove cast
...@@ -103,6 +103,7 @@ class FP16Utils(object): ...@@ -103,6 +103,7 @@ class FP16Utils(object):
op._rename_input(inf_var_name, inf_var_name + "@sharding") op._rename_input(inf_var_name, inf_var_name + "@sharding")
if op.type in ["check_finite_and_unscale", "update_loss_scaling"]: if op.type in ["check_finite_and_unscale", "update_loss_scaling"]:
reversed_x = [] reversed_x = []
reversed_x_paramname = []
for input_name in op.desc.input('X'): for input_name in op.desc.input('X'):
param_name = input_name.strip("@GRAD") param_name = input_name.strip("@GRAD")
if param_name not in shard.global_params: if param_name not in shard.global_params:
...@@ -111,12 +112,26 @@ class FP16Utils(object): ...@@ -111,12 +112,26 @@ class FP16Utils(object):
"be grads, but {} is not a grad".format(input_name)) "be grads, but {} is not a grad".format(input_name))
if shard.has_param(param_name): if shard.has_param(param_name):
reversed_x.append(input_name) reversed_x.append(input_name)
reversed_x_paramname.append(param_name)
op.desc.set_input('X', reversed_x) op.desc.set_input('X', reversed_x)
op.desc.set_output('Out', reversed_x) op.desc.set_output('Out', reversed_x)
# the grad checking should take the all and only param in the current shard
to_check_param = set(reversed_x_paramname)
should_check_param = set(shard.global_params).intersection(
set([
param
for param, worker_idx in shard.global_param2device.
items() if worker_idx == shard.worker_idx
]))
assert to_check_param == should_check_param, "amp check_finite_and_unscale checking miss [{}] and got unexpected [{}]".format(
should_check_param - to_check_param,
to_check_param - should_check_param)
if update_loss_scaling_op_idx == -1: if update_loss_scaling_op_idx == -1:
return return
inf_var = block.var(inf_var_name) inf_var = block.var(inf_var_name)
inf_var_fp32 = block.create_var( inf_var_int32 = block.create_var(
name=inf_var_name + "@cast_int32", name=inf_var_name + "@cast_int32",
shape=inf_var.shape, shape=inf_var.shape,
dtype=core.VarDesc.VarType.INT32) dtype=core.VarDesc.VarType.INT32)
...@@ -128,32 +143,36 @@ class FP16Utils(object): ...@@ -128,32 +143,36 @@ class FP16Utils(object):
update_loss_scaling_op_idx, update_loss_scaling_op_idx,
type='cast', type='cast',
inputs={'X': inf_var}, inputs={'X': inf_var},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={ attrs={
"in_dtype": inf_var.dtype, "in_dtype": inf_var.dtype,
"out_dtype": inf_var_fp32.dtype, "out_dtype": inf_var_int32.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
insert_sync_calc_op(block, update_loss_scaling_op_idx + 1, # this allreduce communication should not overlap with calc
[inf_var_fp32]) # insert_sync_calc_op(block, update_loss_scaling_op_idx + 1,
# [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 2, update_loss_scaling_op_idx + 1,
type='c_allreduce_max', type='c_allreduce_max',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_fp32}, outputs={'Out': inf_var_int32},
attrs={'ring_id': ring_id, attrs={
OP_ROLE_KEY: OpRole.Optimize}) 'ring_id': ring_id,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize
})
comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3, # comm_op_num = insert_sync_comm_op(block, update_loss_scaling_op_idx + 3,
ring_id, [inf_var_fp32]) # ring_id, [inf_var_int32])
block._insert_op_without_sync( block._insert_op_without_sync(
update_loss_scaling_op_idx + 3 + comm_op_num, update_loss_scaling_op_idx + 2,
type='cast', type='cast',
inputs={'X': inf_var_fp32}, inputs={'X': inf_var_int32},
outputs={'Out': inf_var_sharding}, outputs={'Out': inf_var_sharding},
attrs={ attrs={
"in_dtype": inf_var_fp32.dtype, "in_dtype": inf_var_int32.dtype,
"out_dtype": inf_var_sharding.dtype, "out_dtype": inf_var_sharding.dtype,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize
}) })
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册