未验证 提交 74c9b57b 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Bug fixed for GPT3 benchmark (#43793)

* fixed bug for pass & engine

* fixed bug for benchmark GPT-3
上级 ccfde2da
......@@ -83,6 +83,7 @@ class Engine:
self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._feed_vars = {}
self._fetch_vars = {}
self._planners = {}
def prepare(self,
optimizer=None,
......@@ -116,13 +117,13 @@ class Engine:
self._planned_mode = None
self._modes = ['train', 'eval', 'predict']
# Build forward program
self._build()
# Do auto parallel process
for mode in self._modes:
# Do the planning process
self._plan(mode)
for mode in self._modes:
# Do the parallel process
self._parallel(mode, all_ranks)
# Init comm and startup program
......@@ -185,14 +186,14 @@ class Engine:
else:
self._init_dist_context(mode)
self.planner = Planner(mode, self._dist_contexts[mode])
self.planner.plan()
self._planners[mode] = Planner(mode, self._dist_contexts[mode])
self._planners[mode].plan()
def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, self.planner.completer,
parallelizer = Parallelizer(mode, self._planners[mode].completer,
self._dist_contexts[mode])
if not all_ranks:
parallelizer.parallel(self._cur_rank)
......
......@@ -18,7 +18,8 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [
"elementwise", "gelu", "dropout", "cast", "gather", "concat"
"elementwise", "gelu", "dropout", "cast", "gather", "concat",
"fused_softmax_mask_upper_triangle"
]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
......
......@@ -80,9 +80,9 @@ class Parallelizer:
rank, dist_params_grads)
else:
# Apply pre optimization passes
self._apply_pre_optimization(serial_main_program,
serial_startup_program, None, None,
None)
# self._apply_pre_optimization(serial_main_program,
# serial_startup_program, None, None,
# None)
# Do logical partition
partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
......@@ -121,7 +121,9 @@ class Parallelizer:
if self._strategy is None:
return
# apply amp pass
if self._strategy.amp:
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp:
config = copy.deepcopy(self._strategy.amp_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
......@@ -139,7 +141,8 @@ class Parallelizer:
self._pass_context)
# apply recompute pass
if self._strategy.recompute:
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute:
config = copy.deepcopy(self._strategy.recompute_configs)
config["dist_context"] = self._dist_context
config["no_grad_set"] = None
......@@ -164,7 +167,8 @@ class Parallelizer:
auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context)
if self._strategy.gradient_merge:
# recompute is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
......
......@@ -1057,13 +1057,15 @@ def set_grad_var_shape(program, dist_context):
"transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad"
"fill_constant", "sqrt_grad",
"fused_softmax_mask_upper_triangle_grad"
]
forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt"
"elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle_grad"
]
if op.type in need_set_shape_list:
for forward_op in block.ops:
......
......@@ -143,8 +143,8 @@ class AMPState(object):
"""
num_cast_ops = 0
var_name_dict = {}
for in_name in op.input_names:
var_name_dict = {}
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name):
continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册