未验证 提交 74c9b57b 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto parallel] Bug fixed for GPT3 benchmark (#43793)

* fixed bug for pass & engine

* fixed bug for benchmark GPT-3
上级 ccfde2da
...@@ -83,6 +83,7 @@ class Engine: ...@@ -83,6 +83,7 @@ class Engine:
self._dist_startup_progs = defaultdict(dict) # dist startup programs self._dist_startup_progs = defaultdict(dict) # dist startup programs
self._feed_vars = {} self._feed_vars = {}
self._fetch_vars = {} self._fetch_vars = {}
self._planners = {}
def prepare(self, def prepare(self,
optimizer=None, optimizer=None,
...@@ -116,13 +117,13 @@ class Engine: ...@@ -116,13 +117,13 @@ class Engine:
self._planned_mode = None self._planned_mode = None
self._modes = ['train', 'eval', 'predict'] self._modes = ['train', 'eval', 'predict']
# Build forward program
self._build() self._build()
# Do auto parallel process # Do auto parallel process
for mode in self._modes: for mode in self._modes:
# Do the planning process # Do the planning process
self._plan(mode) self._plan(mode)
for mode in self._modes:
# Do the parallel process # Do the parallel process
self._parallel(mode, all_ranks) self._parallel(mode, all_ranks)
# Init comm and startup program # Init comm and startup program
...@@ -185,14 +186,14 @@ class Engine: ...@@ -185,14 +186,14 @@ class Engine:
else: else:
self._init_dist_context(mode) self._init_dist_context(mode)
self.planner = Planner(mode, self._dist_contexts[mode]) self._planners[mode] = Planner(mode, self._dist_contexts[mode])
self.planner.plan() self._planners[mode].plan()
def _parallel(self, mode, all_ranks): def _parallel(self, mode, all_ranks):
# Parallelize program based on the planner's results # Parallelize program based on the planner's results
# For now, the completer has to be passed to the planner, # For now, the completer has to be passed to the planner,
# because we may use it to complete the annotation of the backwarkward and update. # because we may use it to complete the annotation of the backwarkward and update.
parallelizer = Parallelizer(mode, self.planner.completer, parallelizer = Parallelizer(mode, self._planners[mode].completer,
self._dist_contexts[mode]) self._dist_contexts[mode])
if not all_ranks: if not all_ranks:
parallelizer.parallel(self._cur_rank) parallelizer.parallel(self._cur_rank)
......
...@@ -18,7 +18,8 @@ from ..dist_attribute import OperatorDistributedAttribute ...@@ -18,7 +18,8 @@ from ..dist_attribute import OperatorDistributedAttribute
_g_distributed_operator_impl_containers = {} _g_distributed_operator_impl_containers = {}
_g_elementwise_ops = [ _g_elementwise_ops = [
"elementwise", "gelu", "dropout", "cast", "gather", "concat" "elementwise", "gelu", "dropout", "cast", "gather", "concat",
"fused_softmax_mask_upper_triangle"
] ]
BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'} BACKWARD_ONLY_DIST_OPS = {'check_finite_and_unscale', 'update_loss_scaling'}
......
...@@ -80,9 +80,9 @@ class Parallelizer: ...@@ -80,9 +80,9 @@ class Parallelizer:
rank, dist_params_grads) rank, dist_params_grads)
else: else:
# Apply pre optimization passes # Apply pre optimization passes
self._apply_pre_optimization(serial_main_program, # self._apply_pre_optimization(serial_main_program,
serial_startup_program, None, None, # serial_startup_program, None, None,
None) # None)
# Do logical partition # Do logical partition
partitioner = Partitioner(self._dist_context, rank) partitioner = Partitioner(self._dist_context, rank)
dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition( dist_main_prog, dist_startup_prog, dist_params_grads = partitioner.partition(
...@@ -121,7 +121,9 @@ class Parallelizer: ...@@ -121,7 +121,9 @@ class Parallelizer:
if self._strategy is None: if self._strategy is None:
return return
# apply amp pass # apply amp pass
if self._strategy.amp: # FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
if self._mode == 'train' and self._strategy.amp:
config = copy.deepcopy(self._strategy.amp_configs) config = copy.deepcopy(self._strategy.amp_configs)
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
...@@ -139,7 +141,8 @@ class Parallelizer: ...@@ -139,7 +141,8 @@ class Parallelizer:
self._pass_context) self._pass_context)
# apply recompute pass # apply recompute pass
if self._strategy.recompute: # recompute is then train-only optimization
if self._mode == "train" and self._strategy.recompute:
config = copy.deepcopy(self._strategy.recompute_configs) config = copy.deepcopy(self._strategy.recompute_configs)
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["no_grad_set"] = None config["no_grad_set"] = None
...@@ -164,7 +167,8 @@ class Parallelizer: ...@@ -164,7 +167,8 @@ class Parallelizer:
auto_parallel_sharding_pass.apply([main_program], [startup_program], auto_parallel_sharding_pass.apply([main_program], [startup_program],
self._pass_context) self._pass_context)
if self._strategy.gradient_merge: # recompute is then train-only optimization
if self._mode == "train" and self._strategy.gradient_merge:
config = copy.deepcopy(self._strategy.gradient_merge_configs) config = copy.deepcopy(self._strategy.gradient_merge_configs)
config["dist_context"] = self._dist_context config["dist_context"] = self._dist_context
config["params_grads"] = params_grads config["params_grads"] = params_grads
......
...@@ -1057,13 +1057,15 @@ def set_grad_var_shape(program, dist_context): ...@@ -1057,13 +1057,15 @@ def set_grad_var_shape(program, dist_context):
"transpose2_grad", "softmax_grad", "cross_entropy_grad2", "transpose2_grad", "softmax_grad", "cross_entropy_grad2",
"dropout_grad", "tanh_grad", "slice", "assign", "dropout_grad", "tanh_grad", "slice", "assign",
"matmul_v2_triple_grad", "elementwise_add_triple_grad", "matmul_v2_triple_grad", "elementwise_add_triple_grad",
"fill_constant", "sqrt_grad" "fill_constant", "sqrt_grad",
"fused_softmax_mask_upper_triangle_grad"
] ]
forward_list = [ forward_list = [
"reshape2", "softmax_with_cross_entropy", "transpose2", "reshape2", "softmax_with_cross_entropy", "transpose2",
"softmax", "cross_entropy2", "dropout", "tanh", "softmax", "cross_entropy2", "dropout", "tanh",
["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad", ["slice_grad", "c_allgather"], "assign", "matmul_v2_grad_grad",
"elementwise_add_grad_grad", "shape", "sqrt" "elementwise_add_grad_grad", "shape", "sqrt",
"fused_softmax_mask_upper_triangle_grad"
] ]
if op.type in need_set_shape_list: if op.type in need_set_shape_list:
for forward_op in block.ops: for forward_op in block.ops:
......
...@@ -143,8 +143,8 @@ class AMPState(object): ...@@ -143,8 +143,8 @@ class AMPState(object):
""" """
num_cast_ops = 0 num_cast_ops = 0
for in_name in op.input_names:
var_name_dict = {} var_name_dict = {}
for in_name in op.input_names:
if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input( if src_dtype == core.VarDesc.VarType.FP32 and _keep_fp32_input(
op, in_name): op, in_name):
continue continue
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册