未验证 提交 d12588d2 编写于 作者: Y Yiqun Liu 提交者: GitHub

Broadcast the master weight along with param for distributed training. (#52638)

* Broadcast the master weight along with param for distributed training.

* Fix codestyle.
上级 ba9a22db
...@@ -28,6 +28,7 @@ from .sharding.gradient_clip_helper import GradientClipHelper ...@@ -28,6 +28,7 @@ from .sharding.gradient_clip_helper import GradientClipHelper
from .sharding.offload_helper import OffloadHelper from .sharding.offload_helper import OffloadHelper
from .sharding.prune import ProgramDeps from .sharding.prune import ProgramDeps
from .sharding import utils from .sharding import utils
# FIXME: import * # FIXME: import *
from .sharding.utils import * from .sharding.utils import *
import logging import logging
...@@ -84,7 +85,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -84,7 +85,7 @@ class ShardingOptimizer(MetaOptimizerBase):
dist_strategy.sharding_configs = {"segment_broadcast_MB": 32} dist_strategy.sharding_configs = {"segment_broadcast_MB": 32}
def _get_sharding_segment_strategy(self): def _get_sharding_segment_strategy(self):
""" get """get
self._sharding_segment_strategy self._sharding_segment_strategy
1. if by_size: self._broadcast_MB 1. if by_size: self._broadcast_MB
2. if by_anchors: self._sharding_segment_anchors 2. if by_anchors: self._sharding_segment_anchors
...@@ -97,21 +98,26 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -97,21 +98,26 @@ class ShardingOptimizer(MetaOptimizerBase):
if segment_strategy == "segment_broadcast_MB": if segment_strategy == "segment_broadcast_MB":
self._broadcast_MB = sharding_configs["segment_broadcast_MB"] self._broadcast_MB = sharding_configs["segment_broadcast_MB"]
assert self._broadcast_MB > 0, "segment size should larger than zero !" assert (
self._broadcast_MB > 0
), "segment size should larger than zero !"
elif segment_strategy == "segment_anchors": elif segment_strategy == "segment_anchors":
self._sharding_segment_anchors = sharding_configs["segment_anchors"] self._sharding_segment_anchors = sharding_configs["segment_anchors"]
assert len(self._sharding_segment_anchors assert (
) > 0, "you should set the sharding segment anchors !" len(self._sharding_segment_anchors) > 0
), "you should set the sharding segment anchors !"
self._backward_remain_anchors = self._sharding_segment_anchors[:] self._backward_remain_anchors = self._sharding_segment_anchors[:]
self._forward_remain_anchors = [] self._forward_remain_anchors = []
else: else:
raise NotImplementedError( raise NotImplementedError(
"the sharding segment strategy [{}] is not implemented".format( "the sharding segment strategy [{}] is not implemented".format(
str(segment_strategy))) str(segment_strategy)
)
)
self._sharding_segment_strategy = segment_strategy self._sharding_segment_strategy = segment_strategy
def _get_hybrid_degree(self): def _get_hybrid_degree(self):
""" get """get
self.hybrid_dp self.hybrid_dp
self.sharding_degree self.sharding_degree
self.mp_degree self.mp_degree
...@@ -135,21 +141,32 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -135,21 +141,32 @@ class ShardingOptimizer(MetaOptimizerBase):
assert strategy.pipeline is True assert strategy.pipeline is True
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert pp_degree == 2, ("For manually set pipeline, only " assert pp_degree == 2, (
"pp_degree = 2 is supported.") "For manually set pipeline, only " "pp_degree = 2 is supported."
assert global_world_size == mp_degree * sharding_degree * dp_degree, \ )
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( assert (
global_world_size, mp_degree, sharding_degree, dp_degree) global_world_size == mp_degree * sharding_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
global_world_size, mp_degree, sharding_degree, dp_degree
)
else: else:
assert global_world_size == mp_degree * sharding_degree * pp_degree * dp_degree, \ assert (
"global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format( global_world_size
global_world_size, mp_degree, sharding_degree, pp_degree, dp_degree) == mp_degree * sharding_degree * pp_degree * dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], pp_degree [{}], dp_degree [{}].".format(
global_world_size,
mp_degree,
sharding_degree,
pp_degree,
dp_degree,
)
# FIXME (JZ-LIANG) deprecated hybrid_dp # FIXME (JZ-LIANG) deprecated hybrid_dp
if sharding_configs["hybrid_dp"]: if sharding_configs["hybrid_dp"]:
logger.warning( logger.warning(
"[hybrid_dp] API setting is deprecated. Now when " "[hybrid_dp] API setting is deprecated. Now when "
"dp_degree >= 2, its will be in hybrid dp mode automatically") "dp_degree >= 2, its will be in hybrid dp mode automatically"
)
assert dp_degree >= 1 assert dp_degree >= 1
self.hybrid_dp = True if dp_degree > 1 else False self.hybrid_dp = True if dp_degree > 1 else False
...@@ -159,7 +176,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -159,7 +176,7 @@ class ShardingOptimizer(MetaOptimizerBase):
self.dp_degree = dp_degree self.dp_degree = dp_degree
def _get_hybrid_dp_mode(self): def _get_hybrid_dp_mode(self):
""" get """get
self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp' self.hybrid_dp_mode = 'pp_hybrid_dp' or 'sharding_hybrid_dp'
self.gradient_merge_mode = 'pp_gm' or 'sharding_gm' self.gradient_merge_mode = 'pp_gm' or 'sharding_gm'
self._gradient_merge_acc_step self._gradient_merge_acc_step
...@@ -183,9 +200,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -183,9 +200,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if self.pp_degree > 1: if self.pp_degree > 1:
dp_mode = "pp_hybrid_dp" dp_mode = "pp_hybrid_dp"
else: else:
assert self.sharding_degree > 1, \ assert self.sharding_degree > 1, (
"by now we only support five kind of hybrid dp: sharding_hybrid_dp, " \ "by now we only support five kind of hybrid dp: sharding_hybrid_dp, "
"mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp." "mp_sharding_hybrid_dp, pp_hybrid_dp, mp_sharding_pp_hybrid_dp, sharding_pp_hybrid_dp."
)
dp_mode = "sharding_hybrid_dp" dp_mode = "sharding_hybrid_dp"
# gradient merge # gradient merge
...@@ -198,23 +216,33 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -198,23 +216,33 @@ class ShardingOptimizer(MetaOptimizerBase):
gm_mode = "pp_gm" gm_mode = "pp_gm"
gm_acc_step = strategy.pipeline_configs['accumulate_steps'] gm_acc_step = strategy.pipeline_configs['accumulate_steps']
gradient_scale_configs = strategy.gradient_scale_configs gradient_scale_configs = strategy.gradient_scale_configs
assert gradient_scale_configs['scale_strategy'] == 'avg', \ assert gradient_scale_configs['scale_strategy'] == 'avg', (
'For pipeline mode, the ' 'gradient scale mode should ' \ 'For pipeline mode, the '
'be "avg", but got {}'.format(gradient_scale_configs['scale_strategy']) 'gradient scale mode should '
'be "avg", but got {}'.format(
gradient_scale_configs['scale_strategy']
)
)
# Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge. # Note (Yuang Liu): this avg_loss flag determines where to do the average op for grad merge.
# If True, will do sum firstly for gradient merge, then do scale by gm_acc_step. # If True, will do sum firstly for gradient merge, then do scale by gm_acc_step.
# If False, will scale loss by gm_acc_step first, then do sum for gradient merge. # If False, will scale loss by gm_acc_step first, then do sum for gradient merge.
self.scale_gradient = gradient_scale_configs['scale_gradient'] self.scale_gradient = gradient_scale_configs['scale_gradient']
if gm_acc_step > 1: if gm_acc_step > 1:
logger.info("Gradient merge in [{}], acc step = [{}]".format( logger.info(
gm_mode, gm_acc_step)) "Gradient merge in [{}], acc step = [{}]".format(
gm_mode, gm_acc_step
)
)
optimizer_sharding = False optimizer_sharding = False
# TODO(wangxi): need support dp_as_opt_sharding with sharding # TODO(wangxi): need support dp_as_opt_sharding with sharding
# need support without pp in future # need support without pp in future
if self.sharding_degree == 1 and self.dp_degree > 1 \ if (
and sharding_configs['_dp_as_optimizer_sharding'] \ self.sharding_degree == 1
and self.pp_degree > 1: and self.dp_degree > 1
and sharding_configs['_dp_as_optimizer_sharding']
and self.pp_degree > 1
):
optimizer_sharding = True optimizer_sharding = True
self.hybrid_dp_mode = dp_mode self.hybrid_dp_mode = dp_mode
...@@ -224,19 +252,23 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -224,19 +252,23 @@ class ShardingOptimizer(MetaOptimizerBase):
# this feature is design for ascend, and should NOT be used in GPU training # this feature is design for ascend, and should NOT be used in GPU training
self.pp_allreduce_in_optimize = sharding_configs[ self.pp_allreduce_in_optimize = sharding_configs[
"pp_allreduce_in_optimize"] "pp_allreduce_in_optimize"
]
def _inner_opt_minimize(self, loss, startup_program, parameter_list, def _inner_opt_minimize(
no_grad_set): self, loss, startup_program, parameter_list, no_grad_set
):
pipeline_configs = self.user_defined_strategy.pipeline_configs pipeline_configs = self.user_defined_strategy.pipeline_configs
if self.inner_opt is None: if self.inner_opt is None:
raise ValueError( raise ValueError(
"self.inner_opt of ShardingOptimizer should not be None.") "self.inner_opt of ShardingOptimizer should not be None."
)
if self.pp_degree > 1: if self.pp_degree > 1:
pp_optimizer = fluid.optimizer.PipelineOptimizer( pp_optimizer = fluid.optimizer.PipelineOptimizer(
self.inner_opt, self._gradient_merge_acc_step) self.inner_opt, self._gradient_merge_acc_step
)
self._pp_optimizer = pp_optimizer self._pp_optimizer = pp_optimizer
global_rank = self.role_maker._worker_index() global_rank = self.role_maker._worker_index()
...@@ -253,17 +285,25 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -253,17 +285,25 @@ class ShardingOptimizer(MetaOptimizerBase):
'global_ring_id': 3, 'global_ring_id': 3,
'mp_degree': self.mp_degree, 'mp_degree': self.mp_degree,
'mp_rank': global_rank % self.mp_degree, 'mp_rank': global_rank % self.mp_degree,
'scale_gradient': self.scale_gradient 'scale_gradient': self.scale_gradient,
} }
main_program = loss.block.program main_program = loss.block.program
main_program._pipeline_opt = pipeline_opt main_program._pipeline_opt = pipeline_opt
optimize_ops, params_grads, program_list, self.pipeline_pair, self.pp_ring_map = pp_optimizer.minimize( (
loss, startup_program, parameter_list, no_grad_set) optimize_ops,
params_grads,
program_list,
self.pipeline_pair,
self.pp_ring_map,
) = pp_optimizer.minimize(
loss, startup_program, parameter_list, no_grad_set
)
assert self.pp_degree == len(program_list) assert self.pp_degree == len(program_list)
else: else:
optimize_ops, params_grads = self.inner_opt.minimize( optimize_ops, params_grads = self.inner_opt.minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set
)
if startup_program is None: if startup_program is None:
startup_program = default_startup_program() startup_program = default_startup_program()
...@@ -272,8 +312,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -272,8 +312,9 @@ class ShardingOptimizer(MetaOptimizerBase):
startup_program = startup_program._pipeline_opt['startup_program'] startup_program = startup_program._pipeline_opt['startup_program']
print("pp_rank:", self.pp_rank) print("pp_rank:", self.pp_rank)
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
main_program = program_list[int( main_program = program_list[
os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))] int(os.getenv("PADDLE_MANUAL_PIPELINE_STAGE"))
]
else: else:
main_program = program_list[self.pp_rank] main_program = program_list[self.pp_rank]
with open("main_%d" % self.role_maker._worker_index(), 'w') as f: with open("main_%d" % self.role_maker._worker_index(), 'w') as f:
...@@ -299,14 +340,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -299,14 +340,16 @@ class ShardingOptimizer(MetaOptimizerBase):
return optimize_ops, params_grads return optimize_ops, params_grads
def _apply_sharding_pass(self, params_grads): def _apply_sharding_pass(self, params_grads):
if self.sharding_degree == 1: return if self.sharding_degree == 1:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
# step1: build shard # step1: build shard
self._build_shard(params_grads, self.sharding_rank, self._build_shard(
self.sharding_degree) params_grads, self.sharding_rank, self.sharding_degree
)
# step2: split_program # step2: split_program
self._split_program(main_block) self._split_program(main_block)
...@@ -318,13 +361,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -318,13 +361,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program( self._prune_main_program(
main_block, self._shard, main_block,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id]) self._shard,
[self.mp_ring_id, self.sharding_ring_id, self.pp_ring_id],
)
self._prune_startup_program(startup_block, self._shard) self._prune_startup_program(startup_block, self._shard)
def _apply_opt_sharding_pass(self, params_grads): def _apply_opt_sharding_pass(self, params_grads):
""" outer dp as optimizer sharding """ """outer dp as optimizer sharding"""
if self._optimizer_sharding is False: return if self._optimizer_sharding is False:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -338,12 +384,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -338,12 +384,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# step4: remove unneeded ops and vars from block # step4: remove unneeded ops and vars from block
self._prune_main_program( self._prune_main_program(
main_block, self._shard, main_block,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id]) self._shard,
[self.mp_ring_id, self.pp_ring_id, self.dp_ring_id],
)
self._prune_startup_program(startup_block, self._shard) self._prune_startup_program(startup_block, self._shard)
def _insert_allreduce_for_pp(self, params_grads): def _insert_allreduce_for_pp(self, params_grads):
if self.pp_degree == 1: return if self.pp_degree == 1:
return
strategy = self.user_defined_strategy strategy = self.user_defined_strategy
sharding_configs = strategy.sharding_configs sharding_configs = strategy.sharding_configs
...@@ -363,10 +412,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -363,10 +412,12 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(idx) main_block._remove_op(idx)
for idx, op in reversed(list(enumerate(main_block.ops))): for idx, op in reversed(list(enumerate(main_block.ops))):
if op.type != 'cast': continue if op.type != 'cast':
continue
in_name = op.input_arg_names[0] in_name = op.input_arg_names[0]
if in_name not in self._params: continue if in_name not in self._params:
#if self._shard.has_param(param_name): continue continue
# if self._shard.has_param(param_name): continue
if in_name not in main_block.vars: if in_name not in main_block.vars:
main_block._remove_op(idx) main_block._remove_op(idx)
...@@ -376,7 +427,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -376,7 +427,8 @@ class ShardingOptimizer(MetaOptimizerBase):
shard = self._shard if self._optimizer_sharding else None shard = self._shard if self._optimizer_sharding else None
accumulated_grad_names = self._pp_optimizer._accumulate_gradients( accumulated_grad_names = self._pp_optimizer._accumulate_gradients(
main_block, strategy=strategy, shard=shard) main_block, strategy=strategy, shard=shard
)
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
if self.scale_gradient: if self.scale_gradient:
...@@ -384,8 +436,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -384,8 +436,9 @@ class ShardingOptimizer(MetaOptimizerBase):
first_optimize_op_index = get_first_optimize_op_idx(main_block) first_optimize_op_index = get_first_optimize_op_idx(main_block)
if self.pp_allreduce_in_optimize: if self.pp_allreduce_in_optimize:
logger.info("Pipeline Persistable grad is {}".format( logger.info(
accumulated_grad_names)) "Pipeline Persistable grad is {}".format(accumulated_grad_names)
)
# FIXME(wangxi): accumulated_grad get from pipeline is not # FIXME(wangxi): accumulated_grad get from pipeline is not
# include sharding's param@BroadCast grad when # include sharding's param@BroadCast grad when
# pp_allreduce_in_optimize # pp_allreduce_in_optimize
...@@ -397,10 +450,11 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -397,10 +450,11 @@ class ShardingOptimizer(MetaOptimizerBase):
self._shard, self._shard,
core.op_proto_and_checker_maker.OpRole.Optimize, core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.sharding_rank) rank=self.sharding_rank,
)
logger.info("PP-Sharding grad is {}".format(accumulated_grad_names)) logger.info("PP-Sharding grad is {}".format(accumulated_grad_names))
first_optimize_op_index += (len(main_block.ops) - len_of_ops) first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
if self._optimizer_sharding: if self._optimizer_sharding:
...@@ -413,10 +467,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -413,10 +467,12 @@ class ShardingOptimizer(MetaOptimizerBase):
OpRole.Optimize, OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.dp_rank, rank=self.dp_rank,
strategy=strategy) strategy=strategy,
)
logger.info( logger.info(
"Optimizer grad in this rank {}".format(accumulated_grad_names)) "Optimizer grad in this rank {}".format(accumulated_grad_names)
first_optimize_op_index += (len(main_block.ops) - len_of_ops) )
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
# NOTE(wangxi): we fused after optimize_cast # NOTE(wangxi): we fused after optimize_cast
...@@ -424,14 +480,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -424,14 +480,17 @@ class ShardingOptimizer(MetaOptimizerBase):
optimizer_param = utils.insert_broadcast_param_ops( optimizer_param = utils.insert_broadcast_param_ops(
main_block, main_block,
len_of_ops, len_of_ops,
self.dp_ring_id, [x[0].name for x in params_grads], self.dp_ring_id,
[x[0].name for x in params_grads],
self._shard, self._shard,
OpRole.Optimize, OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
rank=self.dp_rank, rank=self.dp_rank,
strategy=None if optimize_cast else strategy) strategy=None if optimize_cast else strategy,
)
logger.info( logger.info(
"Optimizer param in this rank {}".format(optimizer_param)) "Optimizer param in this rank {}".format(optimizer_param)
)
if not strategy.fuse_grad_merge and not optimize_cast: if not strategy.fuse_grad_merge and not optimize_cast:
assert len(accumulated_grad_names) == len(optimizer_param) assert len(accumulated_grad_names) == len(optimizer_param)
elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp": elif self.hybrid_dp and self.hybrid_dp_mode == "pp_hybrid_dp":
...@@ -442,15 +501,20 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -442,15 +501,20 @@ class ShardingOptimizer(MetaOptimizerBase):
accumulated_grad_names, accumulated_grad_names,
core.op_proto_and_checker_maker.OpRole.Optimize, core.op_proto_and_checker_maker.OpRole.Optimize,
use_calc_stream=True, use_calc_stream=True,
user_defined_strategy=strategy) user_defined_strategy=strategy,
first_optimize_op_index += (len(main_block.ops) - len_of_ops) )
first_optimize_op_index += len(main_block.ops) - len_of_ops
len_of_ops = len(main_block.ops) len_of_ops = len(main_block.ops)
# FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there? # FIXME(wangxi): if fp16_allreduce, put cast fp16->fp32 to there?
def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names): def _avg_grad_merge_after_sum(self, main_block, accumulated_grad_names):
if self.user_defined_strategy.amp and \ if (
self.user_defined_strategy.amp_configs['use_dynamic_loss_scaling']: self.user_defined_strategy.amp
and self.user_defined_strategy.amp_configs[
'use_dynamic_loss_scaling'
]
):
# For AMP, if using dynamic loss scaling the avg # For AMP, if using dynamic loss scaling the avg
# operation can be simple done by modify the LossScaling op. # operation can be simple done by modify the LossScaling op.
for idx, op in enumerate(main_block.ops): for idx, op in enumerate(main_block.ops):
...@@ -461,7 +525,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -461,7 +525,8 @@ class ShardingOptimizer(MetaOptimizerBase):
loss_scale_tmp_var = main_block.create_var( loss_scale_tmp_var = main_block.create_var(
name=loss_scale_tmp_var_name, name=loss_scale_tmp_var_name,
shape=loss_scaling_var.shape, shape=loss_scaling_var.shape,
dtype=loss_scaling_var.dtype) dtype=loss_scaling_var.dtype,
)
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
idx, idx,
type='scale', type='scale',
...@@ -471,8 +536,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -471,8 +536,9 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': self._gradient_merge_acc_step, 'scale': self._gradient_merge_acc_step,
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
op._rename_input(loss_scale_name, loss_scale_tmp_var_name) op._rename_input(loss_scale_name, loss_scale_tmp_var_name)
break break
else: else:
...@@ -483,7 +549,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -483,7 +549,9 @@ class ShardingOptimizer(MetaOptimizerBase):
if is_optimizer_op(op) and op.type != 'c_sync_comm_stream': if is_optimizer_op(op) and op.type != 'c_sync_comm_stream':
tmp_first_opt_idx = idx tmp_first_opt_idx = idx
break break
assert tmp_first_opt_idx is not None, 'Occurs some errors, no optimize ops' assert (
tmp_first_opt_idx is not None
), 'Occurs some errors, no optimize ops'
for grad in accumulated_grad_names: for grad in accumulated_grad_names:
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
tmp_first_opt_idx, tmp_first_opt_idx,
...@@ -494,14 +562,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -494,14 +562,17 @@ class ShardingOptimizer(MetaOptimizerBase):
'scale': 1.0 / self._gradient_merge_acc_step, 'scale': 1.0 / self._gradient_merge_acc_step,
'bias': 0.0, 'bias': 0.0,
'bias_after_scale': False, 'bias_after_scale': False,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
def _adapt_amp_clip_without_sharding(self): def _adapt_amp_clip_without_sharding(self):
# if not use sharding, adapt amp/clip, for remain parallelism. # if not use sharding, adapt amp/clip, for remain parallelism.
# cast --> amp --> clip --> opt # cast --> amp --> clip --> opt
if self.sharding_degree > 1: return if self.sharding_degree > 1:
if self._optimizer_sharding: return return
if self._optimizer_sharding:
return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
...@@ -515,9 +586,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -515,9 +586,9 @@ class ShardingOptimizer(MetaOptimizerBase):
FP16Utils.sync_amp_check_nan_inf(main_block, rings) FP16Utils.sync_amp_check_nan_inf(main_block, rings)
gradientclip_helper = GradientClipHelper(None) gradientclip_helper = GradientClipHelper(None)
gradientclip_helper.sync_global_norm(main_block, gradientclip_helper.sync_global_norm(
[self.mp_ring_id, self.pp_ring_id], main_block, [self.mp_ring_id, self.pp_ring_id], self.mp_rank
self.mp_rank) )
def _insert_loss_grad_scale_op(self): def _insert_loss_grad_scale_op(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
...@@ -538,8 +609,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -538,8 +609,9 @@ class ShardingOptimizer(MetaOptimizerBase):
mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None mp_ring_id = self.mp_ring_id if self.mp_degree > 1 else None
dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None dp_ring_id = self.dp_ring_id if self.dp_degree > 1 else None
offload_helper = OffloadHelper(mp_ring_id=mp_ring_id, offload_helper = OffloadHelper(
dp_ring_id=dp_ring_id) mp_ring_id=mp_ring_id, dp_ring_id=dp_ring_id
)
# optimize offload should be enable while gradient merge is enable and # optimize offload should be enable while gradient merge is enable and
# acc_step is quite large (e.g. >> 100). Since its memcpy could not be # acc_step is quite large (e.g. >> 100). Since its memcpy could not be
...@@ -555,32 +627,32 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -555,32 +627,32 @@ class ShardingOptimizer(MetaOptimizerBase):
# will take more memory, but will be faster. Trade space for time. # will take more memory, but will be faster. Trade space for time.
if self._optimizer_sharding: if self._optimizer_sharding:
offload_helper.opt_sharding_cast_fp32param( offload_helper.opt_sharding_cast_fp32param(
main_block, startup_block, main_block, startup_block, [x[0].name for x in params_grads]
[x[0].name for x in params_grads]) )
# NOTE(wangxi): fused after optimize_cast # NOTE(wangxi): fused after optimize_cast
utils.fuse_opt_broadcast_param_ops(main_block, utils.fuse_opt_broadcast_param_ops(
dp_ring_id, main_block, dp_ring_id, self._shard, strategy=strategy
self._shard, )
strategy=strategy)
else: else:
offload_helper.cast_fp32param_in_optimize( offload_helper.cast_fp32param_in_optimize(
main_block, startup_block) main_block, startup_block
)
def _dump_program_for_debug(self): def _dump_program_for_debug(self):
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
startup_block = self._startup_program.global_block() startup_block = self._startup_program.global_block()
with open("start_sharding_%d" % self.role_maker._worker_index(), with open(
'w') as f: "start_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(startup_block.program)) f.writelines(str(startup_block.program))
with open("main_sharding_%d" % self.role_maker._worker_index(), with open(
'w') as f: "main_sharding_%d" % self.role_maker._worker_index(), 'w'
) as f:
f.writelines(str(main_block.program)) f.writelines(str(main_block.program))
def minimize_impl(self, def minimize_impl(
loss, self, loss, startup_program=None, parameter_list=None, no_grad_set=None
startup_program=None, ):
parameter_list=None,
no_grad_set=None):
# TODO: (JZ-LIANG) support multiple comm in future # TODO: (JZ-LIANG) support multiple comm in future
# self._nrings = self.user_defined_strategy.nccl_comm_num # self._nrings = self.user_defined_strategy.nccl_comm_num
self._nrings_sharding = 1 self._nrings_sharding = 1
...@@ -595,7 +667,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -595,7 +667,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# inner optimize minimize # inner optimize minimize
optimize_ops, params_grads = self._inner_opt_minimize( optimize_ops, params_grads = self._inner_opt_minimize(
loss, startup_program, parameter_list, no_grad_set) loss, startup_program, parameter_list, no_grad_set
)
self._init_comm() self._init_comm()
...@@ -644,13 +717,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -644,13 +717,15 @@ class ShardingOptimizer(MetaOptimizerBase):
] ]
pp_rank = 0 if self.pp_rank == pair[0] else 1 pp_rank = 0 if self.pp_rank == pair[0] else 1
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint, self.current_endpoint,
pp_group_endpoints, pp_group_endpoints,
pp_rank, pp_rank,
ring_id, ring_id,
False, False,
sync=False) sync=False,
)
def _init_npu_pipeline_comm(self, startup_block): def _init_npu_pipeline_comm(self, startup_block):
# NOTE(wangxi): some bug with hccl, must set pp_degree be even number # NOTE(wangxi): some bug with hccl, must set pp_degree be even number
...@@ -668,15 +743,22 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -668,15 +743,22 @@ class ShardingOptimizer(MetaOptimizerBase):
my_pair.append(pair) my_pair.append(pair)
# for example: self.pp_rank=2, self.pp_degree=4 # for example: self.pp_rank=2, self.pp_degree=4
send_to_next_pair = (self.pp_rank, (self.pp_rank + 1) % self.pp_degree send_to_next_pair = (
self.pp_rank,
(self.pp_rank + 1) % self.pp_degree,
) # 2->3 ) # 2->3
recv_from_next_pair = ( recv_from_next_pair = (
(self.pp_rank + 1) % self.pp_degree, self.pp_rank) # 3->2 (self.pp_rank + 1) % self.pp_degree,
self.pp_rank,
) # 3->2
recv_from_prev_pair = ( recv_from_prev_pair = (
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree, self.pp_rank (self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
self.pp_rank,
) # 1->2 ) # 1->2
send_to_prev_pair = (self.pp_rank, (self.pp_rank - 1 + self.pp_degree) % send_to_prev_pair = (
self.pp_degree) # 2->1 self.pp_rank,
(self.pp_rank - 1 + self.pp_degree) % self.pp_degree,
) # 2->1
even = (self.pp_rank % 2) == 0 even = (self.pp_rank % 2) == 0
...@@ -685,54 +767,66 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -685,54 +767,66 @@ class ShardingOptimizer(MetaOptimizerBase):
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair0(even->odd): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair0(even->odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# 2. even recv from next, odd send to prev, 1->0, 3->2 # 2. even recv from next, odd send to prev, 1->0, 3->2
pair = recv_from_next_pair if even else send_to_prev_pair pair = recv_from_next_pair if even else send_to_prev_pair
ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]] ring_id = self.pp_ring_map[pair[0] * 1000 + pair[1]]
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair1(even<-odd): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair1(even<-odd): pp pair:{}, ring_id: {}".format(pair, ring_id)
)
# if pp_degree is 2, only need pair(0->1, 1->0) # if pp_degree is 2, only need pair(0->1, 1->0)
if self.pp_degree > 2: if self.pp_degree > 2:
# 3. odd send to next, even recv from prev, 1->2, 3->0 # 3. odd send to next, even recv from prev, 1->2, 3->0
pair = send_to_next_pair if not even else recv_from_prev_pair pair = send_to_next_pair if not even else recv_from_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], ring_id = self.pp_ring_map.get(
max_ring_id + pair[0] * 1000 + pair[1], max_ring_id + 1
1) # 3->0 not in pp_ring_map ) # 3->0 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair2(odd->even): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair2(odd->even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
# 4. odd recv from next, even send to prev, 2->1, 0->3 # 4. odd recv from next, even send to prev, 2->1, 0->3
pair = recv_from_next_pair if not even else send_to_prev_pair pair = recv_from_next_pair if not even else send_to_prev_pair
ring_id = self.pp_ring_map.get(pair[0] * 1000 + pair[1], ring_id = self.pp_ring_map.get(
max_ring_id + pair[0] * 1000 + pair[1], max_ring_id + 2
2) # 0->3 not in pp_ring_map ) # 0->3 not in pp_ring_map
self._init_pair_comm(pair, ring_id) self._init_pair_comm(pair, ring_id)
if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1: if self.pp_rank != 0 and self.pp_rank != self.pp_degree - 1:
my_pair.remove(pair) my_pair.remove(pair)
logger.info("pair3(odd<-even): pp pair:{}, ring_id: {}".format( logger.info(
pair, ring_id)) "pair3(odd<-even): pp pair:{}, ring_id: {}".format(
pair, ring_id
)
)
assert len(my_pair) == 0, "Current pipeline does not support cross stage communication, " \ assert len(my_pair) == 0, (
"Current pipeline does not support cross stage communication, "
"please check unexpected pair {}".format(my_pair) "please check unexpected pair {}".format(my_pair)
)
def _init_pipeline_comm(self, startup_block): def _init_pipeline_comm(self, startup_block):
# TODO (JZ-LIANG) to unify pp_rank_ and pp_rank # TODO (JZ-LIANG) to unify pp_rank_ and pp_rank
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None: if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None) is None:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint, self.current_endpoint,
self.pp_group_endpoints, self.pp_group_endpoints,
self.pp_rank, self.pp_rank,
self.pp_ring_id, self.pp_ring_id,
False, False,
sync=False) sync=False,
)
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self._init_npu_pipeline_comm(startup_block) self._init_npu_pipeline_comm(startup_block)
...@@ -752,13 +846,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -752,13 +846,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# mp ring # mp ring
if self.mp_degree > 1: if self.mp_degree > 1:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint, self.current_endpoint,
self.mp_group_endpoints, self.mp_group_endpoints,
self.mp_rank, self.mp_rank,
self.mp_ring_id, self.mp_ring_id,
False, False,
sync=False) sync=False,
)
# sharding ring # sharding ring
if self.sharding_degree > 1: if self.sharding_degree > 1:
...@@ -769,7 +865,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -769,7 +865,8 @@ class ShardingOptimizer(MetaOptimizerBase):
self.sharding_rank, self.sharding_rank,
self.sharding_ring_id, self.sharding_ring_id,
False, False,
sync=False) sync=False,
)
# pp ring # pp ring
if self.pp_degree > 1: if self.pp_degree > 1:
...@@ -777,13 +874,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -777,13 +874,15 @@ class ShardingOptimizer(MetaOptimizerBase):
# pure dp ring # pure dp ring
if self.dp_degree > 1: if self.dp_degree > 1:
self._collective_helper._init_communicator(self._startup_program, self._collective_helper._init_communicator(
self._startup_program,
self.current_endpoint, self.current_endpoint,
self.dp_group_endpoints, self.dp_group_endpoints,
self.dp_rank, self.dp_rank,
self.dp_ring_id, self.dp_ring_id,
False, False,
sync=False) sync=False,
)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -794,9 +893,12 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -794,9 +893,12 @@ class ShardingOptimizer(MetaOptimizerBase):
# step 3: get broadcast vars # step 3: get broadcast vars
self._broadcast_vars = self._shard.find_broadcast_params( self._broadcast_vars = self._shard.find_broadcast_params(
self._main_program.global_block()) self._main_program.global_block()
)
def _wait(self, ): def _wait(
self,
):
endpoints = self.global_endpoints[:] endpoints = self.global_endpoints[:]
current_endpoint = endpoints[self.global_rank] current_endpoint = endpoints[self.global_rank]
if self.global_rank == 0: if self.global_rank == 0:
...@@ -821,7 +923,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -821,7 +923,7 @@ class ShardingOptimizer(MetaOptimizerBase):
segment._end_idx = last_backward_op_idx segment._end_idx = last_backward_op_idx
for op_idx in reversed(range(last_backward_op_idx)): for op_idx in reversed(range(last_backward_op_idx)):
op = block.ops[op_idx] op = block.ops[op_idx]
assert (int(op.attr('op_role')) != int(OpRole.Optimize)) assert int(op.attr('op_role')) != int(OpRole.Optimize)
if self._sharding_segment_strategy == "segment_broadcast_MB": if self._sharding_segment_strategy == "segment_broadcast_MB":
if segment._param_mem >= self._broadcast_MB: if segment._param_mem >= self._broadcast_MB:
segment = self.collect_segment(segment, op_idx, block) segment = self.collect_segment(segment, op_idx, block)
...@@ -835,21 +937,27 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -835,21 +937,27 @@ class ShardingOptimizer(MetaOptimizerBase):
if ".cast_fp16@GRAD" not in input_name: if ".cast_fp16@GRAD" not in input_name:
continue continue
else: else:
input_name = input_name[:input_name. input_name = input_name[
find(".cast_fp16@GRAD")] : input_name.find(".cast_fp16@GRAD")
]
if input_name in self._backward_remain_anchors: if input_name in self._backward_remain_anchors:
segment = self.collect_segment( segment = self.collect_segment(
segment, op_idx, block) segment, op_idx, block
assert input_name not in self._forward_remain_anchors, "segment anchor [{}] met twice !".format( )
input_name) assert (
input_name not in self._forward_remain_anchors
), "segment anchor [{}] met twice !".format(
input_name
)
self._backward_remain_anchors.remove(input_name) self._backward_remain_anchors.remove(input_name)
self._forward_remain_anchors.append(input_name) self._forward_remain_anchors.append(input_name)
elif int(op.attr('op_role')) == int(OpRole.Forward): elif int(op.attr('op_role')) == int(OpRole.Forward):
for output_name in op.desc.output_arg_names(): for output_name in op.desc.output_arg_names():
if output_name in self._forward_remain_anchors: if output_name in self._forward_remain_anchors:
segment = self.collect_segment( segment = self.collect_segment(
segment, op_idx, block) segment, op_idx, block
)
self._forward_remain_anchors.remove(output_name) self._forward_remain_anchors.remove(output_name)
# find broadcast vars # find broadcast vars
...@@ -865,47 +973,49 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -865,47 +973,49 @@ class ShardingOptimizer(MetaOptimizerBase):
if self._shard.has_param(input_name): if self._shard.has_param(input_name):
broadcast_var_name = input_name broadcast_var_name = input_name
else: else:
broadcast_var_name = unique_name.generate(input_name + broadcast_var_name = unique_name.generate(
"@BroadCast") input_name + "@BroadCast"
)
segment._fill_constant_vars.append(broadcast_var_name) segment._fill_constant_vars.append(broadcast_var_name)
# (JZ-LIANG) should use Param base name ? # (JZ-LIANG) should use Param base name ?
broadcast_var_base_name = input_name broadcast_var_base_name = input_name
if "subprog" in broadcast_var_base_name: if "subprog" in broadcast_var_base_name:
# remove suffix # remove suffix
broadcast_var_base_name = broadcast_var_base_name[: broadcast_var_base_name = broadcast_var_base_name[
broadcast_var_base_name : broadcast_var_base_name.find(".subprog")
.find( ]
".subprog"
)]
var2broadcast_time[ var2broadcast_time[broadcast_var_base_name] = (
broadcast_var_base_name] = var2broadcast_time.get( var2broadcast_time.get(broadcast_var_base_name, 0) + 1
broadcast_var_base_name, 0) + 1 )
segment._param2broadcast[input_name] = broadcast_var_name segment._param2broadcast[input_name] = broadcast_var_name
segment._broadcast_vars.append( segment._broadcast_vars.append(
(broadcast_var_name, self._shard.device(input_name))) (broadcast_var_name, self._shard.device(input_name))
)
segment._param_mem += get_var_size( segment._param_mem += get_var_size(
self._main_program.global_block().var(input_name)) self._main_program.global_block().var(input_name)
)
# find reduce vars # find reduce vars
if self.pp_degree > 1 and self.pp_allreduce_in_optimize: if self.pp_degree > 1 and self.pp_allreduce_in_optimize:
# place pipeline gradient allreduce in optimize # place pipeline gradient allreduce in optimize
pass pass
else: else:
if is_backward_op(op) and \ if is_backward_op(op) and OP_ROLE_VAR_KEY in op.attr_names:
OP_ROLE_VAR_KEY in op.attr_names:
op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY] op_role_var = op.all_attrs()[OP_ROLE_VAR_KEY]
if len(op_role_var) != 0: if len(op_role_var) != 0:
assert len(op_role_var) % 2 == 0 assert len(op_role_var) % 2 == 0
for i in range(0, len(op_role_var), 2): for i in range(0, len(op_role_var), 2):
param, reduced_grad = op_role_var[i], op_role_var[i param, reduced_grad = (
+ op_role_var[i],
1] op_role_var[i + 1],
)
segment._allreduce_vars.append(reduced_grad) segment._allreduce_vars.append(reduced_grad)
assert (reduced_grad assert (
not in self._reduced_grads_to_param) reduced_grad not in self._reduced_grads_to_param
)
self._reduced_grads_to_param[reduced_grad] = param self._reduced_grads_to_param[reduced_grad] = param
# find cast op # find cast op
...@@ -920,29 +1030,40 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -920,29 +1030,40 @@ class ShardingOptimizer(MetaOptimizerBase):
self._segments.insert(0, segment) self._segments.insert(0, segment)
if self._sharding_segment_strategy == "segment_anchors": if self._sharding_segment_strategy == "segment_anchors":
assert len( assert (
self._forward_remain_anchors) == 0, "remain anchors {}".format( len(self._forward_remain_anchors) == 0
self._forward_remain_anchors) ), "remain anchors {}".format(self._forward_remain_anchors)
assert len( assert (
self._backward_remain_anchors) == 0, "remain anchors {}".format( len(self._backward_remain_anchors) == 0
self._backward_remain_anchors) ), "remain anchors {}".format(self._backward_remain_anchors)
if self._verbose: if self._verbose:
for varname in sorted(var2broadcast_time, for varname in sorted(
key=var2broadcast_time.get, var2broadcast_time, key=var2broadcast_time.get, reverse=True
reverse=True): ):
logger.info("Sharding broadcast: [{}] times [{}]".format( logger.info(
var2broadcast_time[varname], varname)) "Sharding broadcast: [{}] times [{}]".format(
var2broadcast_time[varname], varname
)
)
for idx_ in range(len(self._segments)): for idx_ in range(len(self._segments)):
logger.info("segment [{}] :".format(idx_)) logger.info("segment [{}] :".format(idx_))
logger.info("start op: [{}] [{}]".format( logger.info(
"start op: [{}] [{}]".format(
block.ops[self._segments[idx_]._start_idx].desc.type(), block.ops[self._segments[idx_]._start_idx].desc.type(),
block.ops[self._segments[idx_]. block.ops[
_start_idx].desc.input_arg_names())) self._segments[idx_]._start_idx
logger.info("end op: [{}] [{}]".format( ].desc.input_arg_names(),
)
)
logger.info(
"end op: [{}] [{}]".format(
block.ops[self._segments[idx_]._end_idx].desc.type(), block.ops[self._segments[idx_]._end_idx].desc.type(),
block.ops[ block.ops[
self._segments[idx_]._end_idx].desc.input_arg_names())) self._segments[idx_]._end_idx
].desc.input_arg_names(),
)
)
return return
def _prune_main_program(self, block, shard, rings): def _prune_main_program(self, block, shard, rings):
...@@ -975,17 +1096,18 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -975,17 +1096,18 @@ class ShardingOptimizer(MetaOptimizerBase):
input_names = op.desc.input_arg_names() input_names = op.desc.input_arg_names()
output_names = op.desc.output_arg_names() output_names = op.desc.output_arg_names()
# FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE # FIXME(wangxi): need use grads, pipeline grad is @GRAD@MERGE
if op.type == "c_allreduce_sum" and \ if (
op.attr('use_model_parallel') is False: op.type == "c_allreduce_sum"
assert (len(output_names) == 1) and op.attr('use_model_parallel') is False
):
assert len(output_names) == 1
output_name = output_names[0] output_name = output_names[0]
reduced_grads.append(output_name) reduced_grads.append(output_name)
# prune optimizer state and param # prune optimizer state and param
pruned_opti_vars = [] pruned_opti_vars = []
for var_name in list(block.vars.keys()): for var_name in list(block.vars.keys()):
if shard.is_opti_var(var_name) and \ if shard.is_opti_var(var_name) and not shard.has_opt_var(var_name):
not shard.has_opt_var(var_name):
pruned_opti_vars.append(var_name) pruned_opti_vars.append(var_name)
program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars) program_deps = ProgramDeps(block, reduced_grads, pruned_opti_vars)
...@@ -1006,7 +1128,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1006,7 +1128,7 @@ class ShardingOptimizer(MetaOptimizerBase):
]: ]:
pass pass
elif op.type == "conditional_block": elif op.type == "conditional_block":
assert (op.desc.has_attr("sub_block")) assert op.desc.has_attr("sub_block")
subblock_idx = op.desc.attr("sub_block").id subblock_idx = op.desc.attr("sub_block").id
subblock_deps = program_deps.get_sub_block_deps(subblock_idx) subblock_deps = program_deps.get_sub_block_deps(subblock_idx)
# only prune amp subblock # only prune amp subblock
...@@ -1022,7 +1144,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1022,7 +1144,8 @@ class ShardingOptimizer(MetaOptimizerBase):
reversed_output_vars.append(output_name) reversed_output_vars.append(output_name)
# prune # prune
for sub_op_idx, _ in reversed( for sub_op_idx, _ in reversed(
list(enumerate(subblock_deps._block.ops))): list(enumerate(subblock_deps._block.ops))
):
if subblock_deps.should_remove_op(sub_op_idx): if subblock_deps.should_remove_op(sub_op_idx):
subblock_deps.remove_op(sub_op_idx) subblock_deps.remove_op(sub_op_idx)
reversed_input_vars = [] reversed_input_vars = []
...@@ -1038,7 +1161,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1038,7 +1161,9 @@ class ShardingOptimizer(MetaOptimizerBase):
# _should_removed_var: opt state not cur shard # _should_removed_var: opt state not cur shard
if program_deps.should_remove_op(idx): if program_deps.should_remove_op(idx):
# NOTE(wangxi): need reserve all param in optimizer_sharding # NOTE(wangxi): need reserve all param in optimizer_sharding
reserved_vars = self._params if self._optimizer_sharding else None reserved_vars = (
self._params if self._optimizer_sharding else None
)
program_deps.remove_op(idx, reserved_vars) program_deps.remove_op(idx, reserved_vars)
# NOTE (JZ-LIANG) revise and unify logic here # NOTE (JZ-LIANG) revise and unify logic here
...@@ -1049,7 +1174,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1049,7 +1174,8 @@ class ShardingOptimizer(MetaOptimizerBase):
# remove inputs that not on this card # remove inputs that not on this card
reserved_x = [] reserved_x = []
for var_name in op.desc.input("X"): for var_name in op.desc.input("X"):
if block.has_var(var_name): reserved_x.append(var_name) if block.has_var(var_name):
reserved_x.append(var_name)
op.desc.set_input('X', reserved_x) op.desc.set_input('X', reserved_x)
block._sync_with_cpp() block._sync_with_cpp()
return return
...@@ -1072,175 +1198,280 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1072,175 +1198,280 @@ class ShardingOptimizer(MetaOptimizerBase):
# NOTE (JZ-LIANG) revise and unify logic here # NOTE (JZ-LIANG) revise and unify logic here
# fix the _end_idx for segments[-1] if pp is used. # fix the _end_idx for segments[-1] if pp is used.
new_end_idx = self._segments[-1]._end_idx new_end_idx = self._segments[-1]._end_idx
for idx in range(self._segments[-1]._end_idx - 1, for idx in range(
self._segments[-1]._start_idx - 1, -1): self._segments[-1]._end_idx - 1,
self._segments[-1]._start_idx - 1,
-1,
):
op = block.ops[idx] op = block.ops[idx]
if op.type == "fill_constant" or op.type == "sum": if op.type == "fill_constant" or op.type == "sum":
if "MERGED" in op.output_arg_names[0]: new_end_idx = idx + 1 if "MERGED" in op.output_arg_names[0]:
new_end_idx = idx + 1
elif op.type == "cast": elif op.type == "cast":
if "@TMP" in op.output_arg_names[0]: new_end_idx = idx + 1 if "@TMP" in op.output_arg_names[0]:
new_end_idx = idx + 1
self._segments[-1]._end_idx = new_end_idx self._segments[-1]._end_idx = new_end_idx
if self._segments[-1]._allreduce_vars: if self._segments[-1]._allreduce_vars:
shard_allredue_vars = self._shard.filter_grads( shard_allredue_vars = self._shard.filter_grads(
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: )
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( if (
shard_allredue_vars) >= 1: self.gradient_merge_mode != "sharding_gm"
insert_sync_comm_ops(block, self._segments[-1]._end_idx, or self._gradient_merge_acc_step <= 1
self.dp_ring_id, shard_allredue_vars) ):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
insert_allreduce_ops( insert_allreduce_ops(
block, block,
self._segments[-1]._end_idx, self._segments[-1]._end_idx,
self.dp_ring_id, self.dp_ring_id,
shard_allredue_vars, shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy) user_defined_strategy=self.user_defined_strategy,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(), block,
self._segments[-1]._end_idx, shard_allredue_vars, self._startup_program.global_block(),
self._shard) self._segments[-1]._end_idx,
shard_allredue_vars,
self._shard,
)
insert_sync_comm_ops(block, self._segments[-1]._end_idx, insert_sync_comm_ops(
block,
self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars) self._segments[-1]._allreduce_vars,
)
# allreduce --> reduce # allreduce --> reduce
insert_reduce_ops(block, insert_reduce_ops(
block,
self._segments[-1]._end_idx, self._segments[-1]._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[-1]._allreduce_vars, self._segments[-1]._allreduce_vars,
self._shard, self._shard,
op_role=OpRole.Backward, op_role=OpRole.Backward,
use_calc_stream=False) use_calc_stream=False,
)
for idx, segment in reversed(list(enumerate(self._segments))): for idx, segment in reversed(list(enumerate(self._segments))):
allreduce_vars = self._segments[ allreduce_vars = (
idx - 1]._allreduce_vars if idx > 0 else [] self._segments[idx - 1]._allreduce_vars if idx > 0 else []
broadcast_vars = self._segments[ )
idx + broadcast_vars = (
1]._broadcast_vars if idx < len(self._segments) - 1 else [] self._segments[idx + 1]._broadcast_vars
fill_constant_vars = self._segments[ if idx < len(self._segments) - 1
idx + else []
2]._fill_constant_vars if idx < len(self._segments) - 2 else [] )
cast_ops = self._segments[ fill_constant_vars = (
idx + 2]._cast_ops if idx < len(self._segments) - 2 else {} self._segments[idx + 2]._fill_constant_vars
if idx < len(self._segments) - 2
else []
)
cast_ops = (
self._segments[idx + 2]._cast_ops
if idx < len(self._segments) - 2
else {}
)
for op_idx in reversed(range(segment._start_idx, segment._end_idx)): for op_idx in reversed(range(segment._start_idx, segment._end_idx)):
op = block.ops[op_idx] op = block.ops[op_idx]
for input_name in op.desc.input_arg_names(): for input_name in op.desc.input_arg_names():
if input_name in segment._param2broadcast and \ if (
input_name != segment._param2broadcast[input_name]: input_name in segment._param2broadcast
op._rename_input(input_name, and input_name != segment._param2broadcast[input_name]
segment._param2broadcast[input_name]) ):
op._rename_input(
input_name, segment._param2broadcast[input_name]
)
for param_name, broadcast_name in segment._param2broadcast.items(): for param_name, broadcast_name in segment._param2broadcast.items():
if param_name != broadcast_name: if param_name != broadcast_name:
block.create_var( block.create_var(
name=broadcast_name, name=broadcast_name,
shape=self._main_program.global_block().var( shape=self._main_program.global_block()
param_name).shape, .var(param_name)
dtype=self._main_program.global_block().var( .shape,
param_name).dtype, dtype=self._main_program.global_block()
persistable=False) .var(param_name)
.dtype,
persistable=False,
)
# step1: remove cast ops # step1: remove cast ops
block._sync_with_cpp() block._sync_with_cpp()
segment._end_idx += FP16Utils.remove_cast_op( segment._end_idx += FP16Utils.remove_cast_op(
block, self._params, segment, 0) block, self._params, segment, 0
)
# step2: add Sync ops # step2: add Sync ops
shard_allredue_vars = self._shard.filter_grads(allreduce_vars) shard_allredue_vars = self._shard.filter_grads(allreduce_vars)
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( self.gradient_merge_mode != "sharding_gm"
shard_allredue_vars) >= 1: or self._gradient_merge_acc_step <= 1
insert_sync_comm_ops(block, segment._end_idx, ):
self.dp_ring_id, shard_allredue_vars) if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_sync_comm_ops(
block,
segment._end_idx,
self.dp_ring_id,
shard_allredue_vars,
)
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
block,
segment._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
broad_cast_vars) broad_cast_vars,
)
else: else:
comm_dep_vars = allreduce_vars + [ comm_dep_vars = allreduce_vars + [
x[0] for x in broadcast_vars x[0] for x in broadcast_vars
] ]
if len(comm_dep_vars) > 0: if len(comm_dep_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
block,
segment._end_idx,
self.sharding_ring_id, self.sharding_ring_id,
comm_dep_vars) comm_dep_vars,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
broad_cast_vars = [x[0] for x in broadcast_vars] broad_cast_vars = [x[0] for x in broadcast_vars]
if len(broad_cast_vars) > 0: if len(broad_cast_vars) > 0:
insert_sync_comm_ops(block, segment._end_idx, insert_sync_comm_ops(
self.sharding_ring_id, broad_cast_vars) block,
segment._end_idx,
self.sharding_ring_id,
broad_cast_vars,
)
calc_dep_vars = fill_constant_vars + [ calc_dep_vars = (
k for k, v in cast_ops.items() fill_constant_vars
] + self._segments[idx]._allreduce_vars + [k for k, v in cast_ops.items()]
+ self._segments[idx]._allreduce_vars
)
if len(calc_dep_vars) > 0: if len(calc_dep_vars) > 0:
insert_sync_calc_op(block, segment._end_idx, insert_sync_calc_op(
[calc_dep_vars[-1]]) block, segment._end_idx, [calc_dep_vars[-1]]
)
# step3: insert `fill_constant` ops # step3: insert `fill_constant` ops
insert_fill_constant_ops(block, segment._end_idx, insert_fill_constant_ops(
fill_constant_vars) block, segment._end_idx, fill_constant_vars
)
# step4: add `cast` ops # step4: add `cast` ops
insert_cast_ops(block, segment._end_idx, cast_ops) insert_cast_ops(block, segment._end_idx, cast_ops)
# step5: add broadcast ops # step5: add broadcast ops
# gradient merge # gradient merge
if self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: if (
self.gradient_merge_mode == "sharding_gm"
and self._gradient_merge_acc_step > 1
):
self.create_persistable_gradients_and_insert_merge_ops( self.create_persistable_gradients_and_insert_merge_ops(
block, self._startup_program.global_block(), block,
segment._start_idx, shard_allredue_vars, self._shard) self._startup_program.global_block(),
segment._start_idx,
shard_allredue_vars,
self._shard,
)
insert_broadcast_ops(block, segment._start_idx, insert_broadcast_ops(
self.sharding_ring_id, broadcast_vars) block, segment._start_idx, self.sharding_ring_id, broadcast_vars
)
# step6: add all_reduce ops # step6: add all_reduce ops
# dp # dp
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
if self.hybrid_dp and self.hybrid_dp_mode == "sharding_hybrid_dp" and len( self.gradient_merge_mode != "sharding_gm"
shard_allredue_vars) >= 1: or self._gradient_merge_acc_step <= 1
):
if (
self.hybrid_dp
and self.hybrid_dp_mode == "sharding_hybrid_dp"
and len(shard_allredue_vars) >= 1
):
insert_allreduce_ops( insert_allreduce_ops(
block, block,
segment._start_idx, segment._start_idx,
self.dp_ring_id, self.dp_ring_id,
shard_allredue_vars, shard_allredue_vars,
user_defined_strategy=self.user_defined_strategy) user_defined_strategy=self.user_defined_strategy,
insert_sync_comm_ops(block, segment._start_idx, )
self.sharding_ring_id, allreduce_vars) insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# gradient merge # gradient merge
elif self.gradient_merge_mode == "sharding_gm" and self._gradient_merge_acc_step > 1: elif (
insert_sync_comm_ops(block, segment._start_idx, self.gradient_merge_mode == "sharding_gm"
self.sharding_ring_id, allreduce_vars) and self._gradient_merge_acc_step > 1
):
insert_sync_comm_ops(
block,
segment._start_idx,
self.sharding_ring_id,
allreduce_vars,
)
# sharding # sharding
# allreduce --> reduce # allreduce --> reduce
# TODO temp change # TODO temp change
if len(allreduce_vars) > 0: if len(allreduce_vars) > 0:
insert_reduce_ops(block, insert_reduce_ops(
block,
segment._start_idx, segment._start_idx,
self.sharding_ring_id, self.sharding_ring_id,
allreduce_vars, allreduce_vars,
self._shard, self._shard,
op_role=OpRole.Backward, op_role=OpRole.Backward,
use_calc_stream=False) use_calc_stream=False,
)
block._sync_with_cpp() block._sync_with_cpp()
if self._segments[0]._broadcast_vars: if self._segments[0]._broadcast_vars:
broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars] broadcast_vars = [x[0] for x in self._segments[0]._broadcast_vars]
insert_sync_comm_ops(block, self._segments[0]._start_idx, insert_sync_comm_ops(
self.sharding_ring_id, broadcast_vars) block,
insert_broadcast_ops(block, self._segments[0]._start_idx, self._segments[0]._start_idx,
self.sharding_ring_id, self.sharding_ring_id,
self._segments[0]._broadcast_vars) broadcast_vars,
)
insert_broadcast_ops(
block,
self._segments[0]._start_idx,
self.sharding_ring_id,
self._segments[0]._broadcast_vars,
)
fill_constant_vars = [] fill_constant_vars = []
for x in self._segments[:2]: for x in self._segments[:2]:
...@@ -1254,12 +1485,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1254,12 +1485,14 @@ class ShardingOptimizer(MetaOptimizerBase):
calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()] calc_deps_vars = fill_constant_vars + [k for k, v in cast_ops.items()]
if fill_constant_vars or cast_ops: if fill_constant_vars or cast_ops:
insert_sync_calc_op(block, self._segments[0]._start_idx, insert_sync_calc_op(
[calc_deps_vars[-1]]) block, self._segments[0]._start_idx, [calc_deps_vars[-1]]
)
if fill_constant_vars: if fill_constant_vars:
insert_fill_constant_ops(block, self._segments[0]._start_idx, insert_fill_constant_ops(
fill_constant_vars) block, self._segments[0]._start_idx, fill_constant_vars
)
if cast_ops: if cast_ops:
insert_cast_ops(block, self._segments[0]._start_idx, cast_ops) insert_cast_ops(block, self._segments[0]._start_idx, cast_ops)
...@@ -1273,7 +1506,7 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1273,7 +1506,7 @@ class ShardingOptimizer(MetaOptimizerBase):
continue continue
if self._optimizer_sharding and shard.is_param(output_name): if self._optimizer_sharding and shard.is_param(output_name):
continue continue
#TODO why do we remove op, when only one var is removed # TODO why do we remove op, when only one var is removed
block._remove_op(idx, sync=False) block._remove_op(idx, sync=False)
break break
...@@ -1302,16 +1535,29 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1302,16 +1535,29 @@ class ShardingOptimizer(MetaOptimizerBase):
self.global_rank = self.role_maker._worker_index() self.global_rank = self.role_maker._worker_index()
self.global_endpoints = self.role_maker._get_trainer_endpoints() self.global_endpoints = self.role_maker._get_trainer_endpoints()
self.current_endpoint = self.global_endpoints[self.global_rank] self.current_endpoint = self.global_endpoints[self.global_rank]
self._collective_helper = CollectiveHelper(self.role_maker, self._collective_helper = CollectiveHelper(
nrings=self._nrings_sharding) self.role_maker, nrings=self._nrings_sharding
assert self.global_word_size % self.mp_degree == 0, \ )
"global_word_size: {} should be divisible to the mp_degree: {}".format(self.global_word_size, self.mp_degree) assert (
assert self.global_word_size % self.sharding_degree == 0, \ self.global_word_size % self.mp_degree == 0
"global_word_size: {} should be divisible to the sharding_degree: {}".format(self.global_word_size, self.sharding_degree) ), "global_word_size: {} should be divisible to the mp_degree: {}".format(
assert self.global_word_size % self.pp_degree == 0, \ self.global_word_size, self.mp_degree
"global_word_size: {} should be divisible to the pp_degree: {}".format(self.global_word_size, self.pp_degree) )
assert self.global_word_size % self.dp_degree == 0, \ assert (
"global_word_size: {} should be divisible to the dp_degree: {}".format(self.global_word_size, self.dp_degree) self.global_word_size % self.sharding_degree == 0
), "global_word_size: {} should be divisible to the sharding_degree: {}".format(
self.global_word_size, self.sharding_degree
)
assert (
self.global_word_size % self.pp_degree == 0
), "global_word_size: {} should be divisible to the pp_degree: {}".format(
self.global_word_size, self.pp_degree
)
assert (
self.global_word_size % self.dp_degree == 0
), "global_word_size: {} should be divisible to the dp_degree: {}".format(
self.global_word_size, self.dp_degree
)
# mp group # mp group
if self.mp_degree > 1: if self.mp_degree > 1:
...@@ -1319,14 +1565,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1319,14 +1565,16 @@ class ShardingOptimizer(MetaOptimizerBase):
self.mp_rank = self.global_rank % self.mp_degree self.mp_rank = self.global_rank % self.mp_degree
self.mp_group_id = self.global_rank // self.mp_degree self.mp_group_id = self.global_rank // self.mp_degree
self.mp_group_endpoints = [ self.mp_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
for idx, ep in enumerate(self.global_endpoints)
if idx // self.mp_degree == self.mp_group_id if idx // self.mp_degree == self.mp_group_id
] ]
assert self.current_endpoint in self.mp_group_endpoints assert self.current_endpoint in self.mp_group_endpoints
assert len( assert (
self.mp_group_endpoints len(self.mp_group_endpoints) == self.mp_degree
) == self.mp_degree, "num of mp worker in group is [{}], but mp group size is [{}]".format( ), "num of mp worker in group is [{}], but mp group size is [{}]".format(
len(self.mp_group_endpoints), self.mp_degree) len(self.mp_group_endpoints), self.mp_degree
)
else: else:
self.mp_degree = 1 self.mp_degree = 1
self.mp_ring_id = -1 self.mp_ring_id = -1
...@@ -1337,23 +1585,28 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1337,23 +1585,28 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding # sharding
if self.sharding_degree > 1: if self.sharding_degree > 1:
self.sharding_ring_id = 1 self.sharding_ring_id = 1
self.sharding_rank = (self.global_rank // self.sharding_rank = (
self.mp_degree) % self.sharding_degree self.global_rank // self.mp_degree
self.sharding_group_id = self.global_rank // (self.mp_degree * ) % self.sharding_degree
self.sharding_degree) self.sharding_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree
)
# mp + sharding + ... # mp + sharding + ...
if self.mp_degree > 1: if self.mp_degree > 1:
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
if (idx // (self.mp_degree * self.sharding_degree)) == self. for idx, ep in enumerate(self.global_endpoints)
sharding_group_id and idx % self.mp_degree == self.mp_rank if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
and idx % self.mp_degree == self.mp_rank
] ]
# sharding + ... # sharding + ...
else: else:
self.sharding_group_endpoints = [ self.sharding_group_endpoints = [
ep for idx, ep in enumerate(self.global_endpoints) ep
if (idx // (self.mp_degree * self.sharding_degree) for idx, ep in enumerate(self.global_endpoints)
) == self.sharding_group_id if (idx // (self.mp_degree * self.sharding_degree))
== self.sharding_group_id
] ]
assert self.current_endpoint in self.sharding_group_endpoints assert self.current_endpoint in self.sharding_group_endpoints
else: else:
...@@ -1368,20 +1621,28 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1368,20 +1621,28 @@ class ShardingOptimizer(MetaOptimizerBase):
self.pp_pair_ring_id = 20 self.pp_pair_ring_id = 20
# pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3 # pipeline global ring_id set to 4 for sharding0, mp1, dp2, global3
self.pp_ring_id = 4 self.pp_ring_id = 4
self.pp_rank = self.global_rank // (self.sharding_degree * self.pp_rank = (
self.mp_degree) % self.pp_degree self.global_rank
// (self.sharding_degree * self.mp_degree)
% self.pp_degree
)
# (NOTE): Already adjust for (outter-pure) dp # (NOTE): Already adjust for (outter-pure) dp
self.pp_group_id = self.global_rank // ( self.pp_group_id = self.global_rank // (
self.mp_degree * self.sharding_degree * self.pp_degree) self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_first_stage_idx = self.global_rank % ( pp_first_stage_idx = self.global_rank % (
self.sharding_degree * self.mp_degree) + self.pp_group_id * ( self.sharding_degree * self.mp_degree
self.mp_degree * self.sharding_degree * self.pp_degree) ) + self.pp_group_id * (
self.mp_degree * self.sharding_degree * self.pp_degree
)
pp_stage_offset = self.sharding_degree * self.mp_degree pp_stage_offset = self.sharding_degree * self.mp_degree
self.pp_group_endpoints = [] self.pp_group_endpoints = []
for i in range(self.pp_degree): for i in range(self.pp_degree):
self.pp_group_endpoints.append( self.pp_group_endpoints.append(
self.global_endpoints[pp_first_stage_idx + self.global_endpoints[
pp_stage_offset * i]) pp_first_stage_idx + pp_stage_offset * i
]
)
assert self.current_endpoint in self.pp_group_endpoints assert self.current_endpoint in self.pp_group_endpoints
else: else:
self.pp_ring_id = -1 self.pp_ring_id = -1
...@@ -1397,29 +1658,48 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1397,29 +1658,48 @@ class ShardingOptimizer(MetaOptimizerBase):
# sharding-hybrid-dp as one senario of outter-pure-dp # sharding-hybrid-dp as one senario of outter-pure-dp
local_pp_degree = self.pp_degree local_pp_degree = self.pp_degree
if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None): if os.getenv("PADDLE_MANUAL_PIPELINE_STAGE", None):
assert self.pp_degree == 2, ("For manually set pipeline, only " assert self.pp_degree == 2, (
"pp_degree = 2 is supported.") "For manually set pipeline, only " "pp_degree = 2 is supported."
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.dp_degree, \ )
"global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format( assert (
self.global_word_size, self.mp_degree, self.sharding_degree, self.dp_degree) self.global_word_size
== self.mp_degree * self.sharding_degree * self.dp_degree
), "global work size [{}], mp_degree [{}], sharding_degree [{}], dp_degree [{}].".format(
self.global_word_size,
self.mp_degree,
self.sharding_degree,
self.dp_degree,
)
local_pp_degree = 1 local_pp_degree = 1
else: else:
assert self.global_word_size == self.mp_degree * self.sharding_degree * self.pp_degree * self.dp_degree, "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format( assert (
self.mp_degree, self.sharding_degree, self.pp_degree, self.global_word_size
self.dp_degree, self.global_word_size) == self.mp_degree
* self.sharding_degree
* self.pp_degree
* self.dp_degree
), "mp_degree: [{}], sharding_degree: [{}], pp_degree: [{}], dp_degree: [{}]; BUT global nrank: [{}]".format(
self.mp_degree,
self.sharding_degree,
self.pp_degree,
self.dp_degree,
self.global_word_size,
)
if self.dp_degree > 1: if self.dp_degree > 1:
self.dp_ring_id = 2 self.dp_ring_id = 2
self.dp_rank = self.global_rank // ( self.dp_rank = self.global_rank // (
self.sharding_degree * self.mp_degree * local_pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree
)
dp_first_rank_idx = self.global_rank % ( dp_first_rank_idx = self.global_rank % (
self.sharding_degree * self.mp_degree * local_pp_degree) self.sharding_degree * self.mp_degree * local_pp_degree
dp_offset = (self.sharding_degree * self.mp_degree * )
local_pp_degree) dp_offset = self.sharding_degree * self.mp_degree * local_pp_degree
self.dp_group_endpoints = [] self.dp_group_endpoints = []
for i in range(self.dp_degree): for i in range(self.dp_degree):
self.dp_group_endpoints.append( self.dp_group_endpoints.append(
self.global_endpoints[dp_first_rank_idx + dp_offset * i]) self.global_endpoints[dp_first_rank_idx + dp_offset * i]
)
assert self.current_endpoint in self.dp_group_endpoints assert self.current_endpoint in self.dp_group_endpoints
logger.info("Hybrid DP mode turn on !") logger.info("Hybrid DP mode turn on !")
else: else:
...@@ -1448,8 +1728,9 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1448,8 +1728,9 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("sharding group size: {}".format(self.sharding_degree)) logger.info("sharding group size: {}".format(self.sharding_degree))
logger.info("sharding rank: {}".format(self.sharding_rank)) logger.info("sharding rank: {}".format(self.sharding_rank))
logger.info("sharding group id: {}".format(self.sharding_group_id)) logger.info("sharding group id: {}".format(self.sharding_group_id))
logger.info("sharding group endpoints: {}".format( logger.info(
self.sharding_group_endpoints)) "sharding group endpoints: {}".format(self.sharding_group_endpoints)
)
logger.info("sharding ring id: {}".format(self.sharding_ring_id)) logger.info("sharding ring id: {}".format(self.sharding_ring_id))
logger.info("#####" * 6) logger.info("#####" * 6)
...@@ -1462,15 +1743,15 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1462,15 +1743,15 @@ class ShardingOptimizer(MetaOptimizerBase):
logger.info("pure dp group size: {}".format(self.dp_degree)) logger.info("pure dp group size: {}".format(self.dp_degree))
logger.info("pure dp rank: {}".format(self.dp_rank)) logger.info("pure dp rank: {}".format(self.dp_rank))
logger.info("pure dp group endpoints: {}".format( logger.info(
self.dp_group_endpoints)) "pure dp group endpoints: {}".format(self.dp_group_endpoints)
)
logger.info("pure dp ring id: {}".format(self.dp_ring_id)) logger.info("pure dp ring id: {}".format(self.dp_ring_id))
logger.info("#####" * 6) logger.info("#####" * 6)
return return
def _recreate_not_persist_param_as_var(self): def _recreate_not_persist_param_as_var(self):
def recreate_not_persist_param_as_var(program): def recreate_not_persist_param_as_var(program):
block = program.global_block() block = program.global_block()
params = block.all_parameters() params = block.all_parameters()
...@@ -1494,14 +1775,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1494,14 +1775,16 @@ class ShardingOptimizer(MetaOptimizerBase):
is_distributed = param.is_distributed is_distributed = param.is_distributed
block._remove_var(name, sync=False) block._remove_var(name, sync=False)
var = block.create_var(name=name, var = block.create_var(
name=name,
shape=shape, shape=shape,
dtype=dtype, dtype=dtype,
type=type, type=type,
lod_level=lod_level, lod_level=lod_level,
stop_gradient=stop_gradient, stop_gradient=stop_gradient,
trainable=trainable, trainable=trainable,
persistable=False) persistable=False,
)
if have_dist_attr: if have_dist_attr:
var.is_distributed = is_distributed var.is_distributed = is_distributed
...@@ -1516,6 +1799,13 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1516,6 +1799,13 @@ class ShardingOptimizer(MetaOptimizerBase):
identical when hybrid-dp is used, and the initialization of identical when hybrid-dp is used, and the initialization of
not distributed param between mp group to be identical. not distributed param between mp group to be identical.
""" """
def _find_master_param(all_vars_name, param_name):
for var_name in all_vars_name:
if param_name in var_name and "fp32_master" in var_name:
return var_name
return None
if self.dp_degree <= 1 and self.mp_degree <= 1: if self.dp_degree <= 1 and self.mp_degree <= 1:
return return
...@@ -1536,8 +1826,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1536,8 +1826,10 @@ class ShardingOptimizer(MetaOptimizerBase):
if op.type == 'c_broadcast': if op.type == 'c_broadcast':
broadcast_params.add(op.desc.output_arg_names()[0]) broadcast_params.add(op.desc.output_arg_names()[0])
all_vars_name = startup_block.vars
for param in params_name: for param in params_name:
if param in broadcast_params: continue if param in broadcast_params:
continue
rings = [] rings = []
# need sync not distributed param in mp group # need sync not distributed param in mp group
...@@ -1547,30 +1839,51 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1547,30 +1839,51 @@ class ShardingOptimizer(MetaOptimizerBase):
rings.append(self.dp_ring_id) rings.append(self.dp_ring_id)
for ring in rings: for ring in rings:
startup_block.append_op(type='c_broadcast', startup_block.append_op(
type='c_broadcast',
inputs={'X': param}, inputs={'X': param},
outputs={'Out': param}, outputs={'Out': param},
attrs={ attrs={
'ring_id': ring, 'ring_id': ring,
'root': 0, 'root': 0,
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward OP_ROLE_KEY: OpRole.Forward,
}) },
)
# Broadcast the master weight at the same time for AMP-O2 training.
master_param = _find_master_param(all_vars_name, param)
if master_param is not None:
startup_block.append_op(
type='c_broadcast',
inputs={'X': master_param},
outputs={'Out': master_param},
attrs={
'ring_id': ring,
'root': 0,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Forward,
},
)
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
# sharding gradient merge # sharding gradient merge
def create_persistable_gradients_and_insert_merge_ops( def create_persistable_gradients_and_insert_merge_ops(
self, main_block, startup_block, insert_idx, grad_names, shard): self, main_block, startup_block, insert_idx, grad_names, shard
):
for grad_name in grad_names: for grad_name in grad_names:
assert get_grad_device( assert (
grad_name, shard get_grad_device(grad_name, shard) == shard.worker_idx
) == shard.worker_idx, "try to merge gradient not belong to current shard: [{}]".format( ), "try to merge gradient not belong to current shard: [{}]".format(
grad_name) grad_name
)
persistable_grad_name = grad_name + '@GradiantMerge' persistable_grad_name = grad_name + '@GradiantMerge'
assert grad_name not in self._grad2merged_grad, "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format( assert (
grad_name) grad_name not in self._grad2merged_grad
), "grad [{}] already in grad2merged_grad, maybe you meet sharing weight case !".format(
grad_name
)
self._grad2merged_grad[grad_name] = persistable_grad_name self._grad2merged_grad[grad_name] = persistable_grad_name
grad_var = main_block.var(grad_name) grad_var = main_block.var(grad_name)
# create var # create var
...@@ -1578,36 +1891,38 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1578,36 +1891,38 @@ class ShardingOptimizer(MetaOptimizerBase):
name=persistable_grad_name, name=persistable_grad_name,
shape=grad_var.shape, shape=grad_var.shape,
dtype=grad_var.dtype, dtype=grad_var.dtype,
persistable=True) persistable=True,
)
startup_gradient_merge_var = startup_block.create_var( startup_gradient_merge_var = startup_block.create_var(
name=persistable_grad_name, name=persistable_grad_name,
shape=grad_var.shape, shape=grad_var.shape,
dtype=grad_var.dtype, dtype=grad_var.dtype,
persistable=True) persistable=True,
)
# merge gradient # merge gradient
main_block._insert_op_without_sync( main_block._insert_op_without_sync(
insert_idx, insert_idx,
type="elementwise_add", type="elementwise_add",
inputs={ inputs={'X': grad_name, 'Y': gradient_merge_var},
'X': grad_name,
'Y': gradient_merge_var
},
outputs={'Out': gradient_merge_var}, outputs={'Out': gradient_merge_var},
attrs={ attrs={
'axis': -1, 'axis': -1,
'use_mkldnn': False, 'use_mkldnn': False,
OP_ROLE_KEY: OpRole.Backward OP_ROLE_KEY: OpRole.Backward,
}) },
)
# startup initialization # startup initialization
startup_block.append_op(type="fill_constant", startup_block.append_op(
type="fill_constant",
outputs={"Out": startup_gradient_merge_var}, outputs={"Out": startup_gradient_merge_var},
attrs={ attrs={
"shape": grad_var.shape, "shape": grad_var.shape,
"dtype": grad_var.dtype, "dtype": grad_var.dtype,
"value": float(0), "value": float(0),
}) },
)
main_block._sync_with_cpp() main_block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
...@@ -1620,14 +1935,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1620,14 +1935,17 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(self._gradient_merge_acc_step), value=int(self._gradient_merge_acc_step),
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True,
)
zero_var = layers.create_global_var(name="gradient_merge_zero", zero_var = layers.create_global_var(
name="gradient_merge_zero",
shape=[1], shape=[1],
value=int(0), value=int(0),
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True,
)
# Add step var & cond var # Add step var & cond var
current_step_var = layers.create_global_var( current_step_var = layers.create_global_var(
...@@ -1636,42 +1954,40 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1636,42 +1954,40 @@ class ShardingOptimizer(MetaOptimizerBase):
value=int(0), value=int(0),
dtype='int32', dtype='int32',
persistable=True, persistable=True,
force_cpu=True) force_cpu=True,
)
cond_var = main_block.create_var(name="gradient_merge_cond", cond_var = main_block.create_var(
shape=[1], name="gradient_merge_cond", shape=[1], dtype='bool'
dtype='bool') )
with device_guard("cpu"): with device_guard("cpu"):
# step_var = (step_var + 1) % k_step # step_var = (step_var + 1) % k_step
main_block.append_op(type='increment', main_block.append_op(
type='increment',
inputs={'X': [current_step_var]}, inputs={'X': [current_step_var]},
outputs={'Out': [current_step_var]}, outputs={'Out': [current_step_var]},
attrs={ attrs={'step': float(1), OP_ROLE_KEY: OpRole.Optimize},
'step': float(1), )
OP_ROLE_KEY: OpRole.Optimize
})
main_block.append_op(type='elementwise_mod', main_block.append_op(
inputs={ type='elementwise_mod',
'X': current_step_var, inputs={'X': current_step_var, 'Y': acc_step_var},
'Y': acc_step_var
},
outputs={'Out': current_step_var}, outputs={'Out': current_step_var},
attrs={ attrs={
'axis': -1, 'axis': -1,
OP_ROLE_KEY: OpRole.Optimize, OP_ROLE_KEY: OpRole.Optimize,
'use_mkldnn': False 'use_mkldnn': False,
}) },
)
# cond_var = (step_var == 0) # cond_var = (step_var == 0)
main_block.append_op(type='equal', main_block.append_op(
inputs={ type='equal',
'X': current_step_var, inputs={'X': current_step_var, 'Y': zero_var},
'Y': zero_var
},
outputs={'Out': cond_var}, outputs={'Out': cond_var},
attrs={OP_ROLE_KEY: OpRole.Optimize}) attrs={OP_ROLE_KEY: OpRole.Optimize},
)
# paddle.static.Print(current_step_var, message="in FWBW last conditional") # paddle.static.Print(current_step_var, message="in FWBW last conditional")
return cond_var return cond_var
...@@ -1698,35 +2014,37 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1698,35 +2014,37 @@ class ShardingOptimizer(MetaOptimizerBase):
# allreduce grad@gradientmerge # allreduce grad@gradientmerge
if self.hybrid_dp: if self.hybrid_dp:
assert self.dp_ring_id >= 0, "dp_ring_id should larger than 0 when in sharding&DP mode" assert (
self.dp_ring_id >= 0
), "dp_ring_id should larger than 0 when in sharding&DP mode"
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='c_allreduce_sum', cur_block.append_op(
type='c_allreduce_sum',
inputs={'X': merged_grad_var}, inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var}, outputs={'Out': merged_grad_var},
attrs={ attrs={
'ring_id': self.dp_ring_id, 'ring_id': self.dp_ring_id,
'use_calc_stream': True, 'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
# grad@gradientmerge / acc_step # grad@gradientmerge / acc_step
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
# grad /= k_steps # grad /= k_steps
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='scale', cur_block.append_op(
type='scale',
inputs={'X': merged_grad_var}, inputs={'X': merged_grad_var},
outputs={'Out': merged_grad_var}, outputs={'Out': merged_grad_var},
attrs={ attrs={
'scale': 'scale': 1.0 / float(self._gradient_merge_acc_step),
1.0 / float(self._gradient_merge_acc_step), 'bias': 0.0,
'bias': 'bias_after_scale': False,
0.0, OP_ROLE_KEY: OpRole.Optimize,
'bias_after_scale': },
False, )
OP_ROLE_KEY:
OpRole.Optimize
})
# re-create optimize ops # re-create optimize ops
already_moved_var_names = [] already_moved_var_names = []
...@@ -1737,15 +2055,19 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1737,15 +2055,19 @@ class ShardingOptimizer(MetaOptimizerBase):
for input_name in new_op_desc.input_arg_names(): for input_name in new_op_desc.input_arg_names():
if input_name in self._grad2merged_grad: if input_name in self._grad2merged_grad:
new_op_desc._rename_input( new_op_desc._rename_input(
input_name, self._grad2merged_grad[input_name]) input_name, self._grad2merged_grad[input_name]
)
for output_name in new_op_desc.output_arg_names(): for output_name in new_op_desc.output_arg_names():
if output_name in self._grad2merged_grad: if output_name in self._grad2merged_grad:
new_op_desc._rename_output( new_op_desc._rename_output(
output_name, self._grad2merged_grad[output_name]) output_name, self._grad2merged_grad[output_name]
)
# move non temp optimize vars from block0 to cond block # move non temp optimize vars from block0 to cond block
if output_name not in already_moved_var_names and output_name not in self._grad2merged_grad.keys( if (
output_name not in already_moved_var_names
and output_name not in self._grad2merged_grad.keys()
): ):
var_ = self._main_program.global_block().var(output_name) var_ = self._main_program.global_block().var(output_name)
if not var_.persistable: if not var_.persistable:
...@@ -1754,11 +2076,14 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1754,11 +2076,14 @@ class ShardingOptimizer(MetaOptimizerBase):
shape_ = var_.shape shape_ = var_.shape
type_ = var_.dtype type_ = var_.dtype
self._main_program.global_block()._remove_var( self._main_program.global_block()._remove_var(
var_.name, sync=False) var_.name, sync=False
self.cond_block.create_var(name=name_, )
self.cond_block.create_var(
name=name_,
shape=shape_, shape=shape_,
dtype=type_, dtype=type_,
persistable=False) persistable=False,
)
already_moved_var_names.append(name_) already_moved_var_names.append(name_)
self._main_program.global_block()._sync_with_cpp() self._main_program.global_block()._sync_with_cpp()
...@@ -1767,14 +2092,16 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1767,14 +2092,16 @@ class ShardingOptimizer(MetaOptimizerBase):
# fill zero to grad@gradientmerge # fill zero to grad@gradientmerge
for grad, merged_grad in self._grad2merged_grad.items(): for grad, merged_grad in self._grad2merged_grad.items():
merged_grad_var = main_block.var(merged_grad) merged_grad_var = main_block.var(merged_grad)
cur_block.append_op(type='fill_constant', cur_block.append_op(
type='fill_constant',
outputs={'Out': merged_grad_var}, outputs={'Out': merged_grad_var},
attrs={ attrs={
"shape": merged_grad_var.shape, "shape": merged_grad_var.shape,
"dtype": merged_grad_var.dtype, "dtype": merged_grad_var.dtype,
"value": float(0), "value": float(0),
OP_ROLE_KEY: OpRole.Optimize OP_ROLE_KEY: OpRole.Optimize,
}) },
)
# lr_var = main_block.var("gradient_merge_current_step") # lr_var = main_block.var("gradient_merge_current_step")
# paddle.static.Print(lr_var, message="in OPTIMIZE last conditional") # paddle.static.Print(lr_var, message="in OPTIMIZE last conditional")
...@@ -1786,7 +2113,10 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1786,7 +2113,10 @@ class ShardingOptimizer(MetaOptimizerBase):
create cond block create cond block
""" """
if self.gradient_merge_mode != "sharding_gm" or self._gradient_merge_acc_step <= 1: if (
self.gradient_merge_mode != "sharding_gm"
or self._gradient_merge_acc_step <= 1
):
return return
main_block = self._main_program.global_block() main_block = self._main_program.global_block()
...@@ -1805,7 +2135,8 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1805,7 +2135,8 @@ class ShardingOptimizer(MetaOptimizerBase):
main_block._remove_op(op_idx, sync=False) main_block._remove_op(op_idx, sync=False)
tmp_copy_block._sync_with_cpp() tmp_copy_block._sync_with_cpp()
self.original_optimize_ops_desc = list( self.original_optimize_ops_desc = list(
reversed(self.original_optimize_ops_desc)) reversed(self.original_optimize_ops_desc)
)
# back to block 0 # back to block 0
self._main_program._rollback() self._main_program._rollback()
...@@ -1822,18 +2153,17 @@ class ShardingOptimizer(MetaOptimizerBase): ...@@ -1822,18 +2153,17 @@ class ShardingOptimizer(MetaOptimizerBase):
# cond op # cond op
step_scope = self._main_program.global_block().create_var( step_scope = self._main_program.global_block().create_var(
type=core.VarDesc.VarType.STEP_SCOPES) type=core.VarDesc.VarType.STEP_SCOPES
)
conditional_block_op = self._main_program.global_block().append_op( conditional_block_op = self._main_program.global_block().append_op(
type='conditional_block', type='conditional_block',
inputs={ inputs={
'Cond': cond, 'Cond': cond,
'Input': [], 'Input': [],
}, },
outputs={ outputs={'Out': [], 'Scope': [step_scope]},
'Out': [],
'Scope': [step_scope]
},
attrs={ attrs={
'sub_block': cond_block, 'sub_block': cond_block,
'is_scalar_condition': True, 'is_scalar_condition': True,
}) },
)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册