未验证 提交 7aeec4ed 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Data Parallel Optimization Pass 1 (#44882)

* bugfix

* remove scaling

* support rescale_grad opt
上级 cf17ae8a
......@@ -435,3 +435,13 @@ def gradient_synchronization(dist_ctx, op, act_grad_names, out_grad_names,
return
sync_and_scale_gradients(dist_ctx, op, dp_group, out_grad_names)
def is_data_parallel_scale_op(op):
return op.type == "scale" and op.desc.has_attr("op_namescope") \
and ParallelMode.DataParallel in op.desc.attr("op_namescope")
def is_data_parallel_reduce_op(op):
return op.type in ["c_reduce_sum", "c_allreduce_sum"] and op.desc.has_attr("op_namescope") \
and ParallelMode.DataParallel in op.desc.attr("op_namescope")
......@@ -195,6 +195,14 @@ class Parallelizer:
params_grads):
if self._strategy is None:
return
# data parallel optimization
config = {}
config["dist_context"] = self._dist_context
config["global_rank"] = rank
dp_pass = new_pass("auto_parallel_data_parallel_optimization", config)
dp_pass.apply([main_program], [startup_program], self._pass_context)
if self._strategy.sharding:
config = copy.deepcopy(self._strategy.sharding_configs)
config["dist_context"] = self._dist_context
......
......@@ -160,21 +160,24 @@ class ProcessGroup:
def is_member(self):
return True
# def __eq__(self, other):
# if not isinstance(other, ProcessGroup):
# return False
# if self.id != other.id:
# return False
# return True
def __eq__(self, other):
if not isinstance(other, ProcessGroup):
return False
if self.id != other.id:
return False
return True
# def __ne__(self, other):
# return not self.__eq__(other)
def __ne__(self, other):
return not self.__eq__(other)
def __str__(self):
string = "id: {}, nranks: {}, ranks: {}.".format(
self.id, self.nranks, ", ".join(map(str, self.ranks)))
return string
def __hash__(self):
return hash(self.__str__())
# Note that Process group 0 is reserved for representing all ranks.
# At the beginning, group 0 is empty and new ranks will be added automatically.
......
......@@ -266,7 +266,7 @@ class OptimizationTuner:
config["input_data"] = self._baseline_dist_context.serial_feed_vars["inputs"] \
+ self._baseline_dist_context.serial_feed_vars["labels"]
if config["use_pure_fp16"]:
config["base_opt"] = dist_context.optimizer
config["base_opt"] = dist_context.serial_optimizer
auto_parallel_fp16_pass = new_pass("auto_parallel_fp16", config)
auto_parallel_fp16_pass.apply([main_program], [startup_program],
pass_context)
......@@ -363,11 +363,11 @@ class OptimizationTuner:
profile_args = " ".join([
"--rank",
str(self.rank),
"--device_id",
str(self.device_id),
"--ctx_filename",
ctx_path,
str(self.rank), "--device_id",
str(self.device_id), "--ctx_filename", ctx_path,
"--profile_start_step",
str(self._config.profile_start_step), "--profile_end_step",
str(self._config.profile_end_step)
])
cmd_args = "-m paddle.distributed.auto_parallel.tuner.profiler" + " " + profile_args
cmd = [sys.executable, "-u"] + coverage_args + shlex.split(cmd_args)
......
......@@ -23,6 +23,7 @@ from functools import reduce
import paddle.fluid.core as core
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import get_all_process_groups
from paddle.fluid.io import is_parameter, is_belong_to_optimizer
from paddle.distributed.auto_parallel.dist_attribute import TensorDistributedAttribute, OperatorDistributedAttribute
......@@ -1123,6 +1124,13 @@ def is_loss_op(op):
int(op.all_attrs()[OP_ROLE_KEY]) == (int(OpRole.Forward) | int(OpRole.Loss))
def is_loss_grad_op(op):
if OP_ROLE_KEY not in op.attr_names:
return False
op_role = int(op.all_attrs()[OP_ROLE_KEY])
return op_role & int(OpRole.Backward) and op_role & int(OpRole.Loss)
def is_prim_op(op):
return op.type.endswith("_p")
......@@ -1481,3 +1489,10 @@ def debug_program(program, path, name):
path, name + '_program' + ".%d" % (paddle.distributed.get_rank()))
with open(filename, 'w') as f:
f.write(str(program))
def ring_id_to_process_group(ring_id):
for g in get_all_process_groups():
if g.id == ring_id:
return g
return None
......@@ -19,6 +19,7 @@ from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_fp16 import *
from .auto_parallel_recompute import *
from .auto_parallel_data_parallel_optimization import *
from .cpp_pass import *
import os
from .ps_trainer_pass import *
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import OrderedDict
import paddle
from paddle.fluid.framework import default_main_program
from paddle.distributed.auto_parallel.operators.common import is_data_parallel_scale_op, is_data_parallel_reduce_op
from paddle.distributed.auto_parallel.utils import is_loss_grad_op, is_optimize_op, ring_id_to_process_group
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
'lars_momentum', 'sparse_momentum', 'dgc_momentum', 'momentum',
'merge_momentum'
]
@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
"""
Apply Optimizations that specialized for data parallelism in Auto Parallel.
1. prune grad scaling
2. overlap comm and calc
3. fuse allreduce
"""
def __init__(self):
super(DataParallelOptimizationPass, self).__init__()
# NOTE not use depence on loss and param_grads
self.set_attr("dist_context", None)
self.set_attr("global_rank", -1)
# {grad1: group1, grad2: group1, grad3: group2}
# record the order for fuse grad data memory
self._grad_name_to_group_map = OrderedDict()
# {group1:[grad1, grad2] , group2:[grad3]}
self._group_to_grad_name_map = OrderedDict()
self._support_rescale_grad = False
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if (not isinstance(self.get_attr("global_rank"),
int)) or self.get_attr("global_rank") < 0:
return False
return True
def _check_conflict(self, other_pass):
return True
def _type(self):
return PassType.COMM_OPT
def _apply_single_impl(self, main_program, startup_program, context):
self.dist_context = self.get_attr("dist_context")
self.global_rank = int(self.get_attr("global_rank"))
with paddle.static.program_guard(main_program, startup_program):
self._analyze_program()
self._prune_grad_scaling()
self._overlap_comm()
self._fuse_allreduce()
def _prune_grad_scaling(self):
if not self._could_be_prune():
return
if self._all_dp_groups_same_degree():
self._scale_backward_initial_grad()
else:
self._update_opt_rescale_grad()
self._remove_grad_scaling()
def _overlap_comm(self):
pass
def _fuse_allreduce(self):
pass
def _analyze_program(self):
"""
{param_grad_name: data_parallel_group}
{pdata_parallel_group: aram_grad_name}
"""
block = default_main_program().global_block()
ops = block.ops
scaled_grads = []
for op in ops:
if is_data_parallel_reduce_op(op):
grad_name = op.output_arg_names[0]
if grad_name in self._grad_name_to_group_map:
continue
assert op.has_attr(
"ring_id"
), "Unexception: comm op [{}] has NOT ring id.".format(str(op))
group = ring_id_to_process_group(op.attr("ring_id"))
assert group is not None, "Unexception: data parallel group of [{}] from op [{}] is None".format(
grad_name, str(op))
self._grad_name_to_group_map[grad_name] = group
if group not in self._group_to_grad_name_map:
self._group_to_grad_name_map[group] = [grad_name]
else:
self._group_to_grad_name_map[group].append(grad_name)
elif is_data_parallel_scale_op(op):
grad_name = op.output_arg_names[0]
scaled_grads.append(grad_name)
# TODO support multiple optimizers in on network in future.
# here we assume that the optimizer is unique in network.
elif is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
self._support_rescale_grad = True
not_synchronized_grads = []
for grad_name in scaled_grads:
if grad_name not in self._grad_name_to_group_map:
not_synchronized_grads.append(grad_name)
assert len(
not_synchronized_grads
) == 0, "Unexception: gradients [{}] is scaled BUT NOT synchronized.".format(
not_synchronized_grads)
def _could_be_prune(self):
return self._support_rescale_grad or self._all_dp_groups_same_degree()
def _all_dp_groups_same_degree(self):
return len(
set([
len(group.ranks)
for group in self._group_to_grad_name_map.keys()
])) == 1
def _scale_backward_initial_grad(self):
block = default_main_program().global_block()
dp_degree = len(list(self._group_to_grad_name_map.keys())[0].ranks)
for idx, op in reversed(list(enumerate(block.ops))):
if is_loss_grad_op(op):
assert op.type == 'fill_constant', \
"loss_grad_op must be fill_constant op, " \
"but this op is {}".format(op.type)
assert op.has_attr('value')
loss_scale = float(op.attr('value'))
loss_scale = loss_scale / dp_degree
op._set_attr('value', loss_scale)
break
def _remove_grad_scaling(self):
block = default_main_program().global_block()
for op_idx, op in reversed(list(enumerate(block.ops))):
if is_data_parallel_scale_op(op):
block._remove_op(op_idx, False)
block._sync_with_cpp()
def _update_opt_rescale_grad(self):
block = default_main_program().global_block()
scaled_grads = set()
for idx, op in reversed(list(enumerate(block.ops))):
if is_optimize_op(
op) and op.type in __rescale_grad_supported_opts__:
assert op.has_attr(
'rescale_grad'
), "Unexception: op [{}] is supported to have [rescale_grad] attribute.".format(
str(op))
assert len(
op.input("Grad")
) == 1, "Unexception: op [{}] is supported to have only one input grad var.".format(
str(op))
grad_name = op.input("Grad")[0]
dp_degree = len(
list(self._grad_name_to_group_map[grad_name].ranks))
scaled_grads.add(grad_name)
rescale_grad = float(op.attr('rescale_grad')) / dp_degree
op._set_attr('rescale_grad', rescale_grad)
assert scaled_grads == set(self._grad_name_to_group_map.keys(
)), "Unexception: gradients [{}] are unscaled.".format(
set(self._grad_name_to_group_map.keys()) - scaled_grads)
......@@ -20,6 +20,8 @@ if((NOT WITH_GPU)
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_sharding_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_fp16_pass")
list(REMOVE_ITEM TEST_OPS "test_auto_parallel_gradient_merge_pass")
list(REMOVE_ITEM TEST_OPS
"test_auto_parallel_data_parallel_optimization_pass")
endif()
foreach(TEST_OP ${TEST_OPS})
......
......@@ -108,7 +108,7 @@ class AutoPallelPassTestBase(DistPassTestBase):
pickle.dump(all_fetch_values, f)
def get_gpt_model(self, strategy, place, batch_size, sequence_len,
vocab_size):
vocab_size, **kwargs):
modeling.init_global()
if strategy == "dp":
modeling._global_parallel_strategy = "dp"
......@@ -179,11 +179,17 @@ class AutoPallelPassTestBase(DistPassTestBase):
criterion = GPTPretrainingCriterion()
loss = criterion(preds, labels, loss_mask)
clip = paddle.nn.ClipGradByNorm(clip_norm=1.0)
optimizer = paddle.fluid.optimizer.AdamOptimizer(learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=clip)
if kwargs.get('optimizer', None) == "LarsMomentum":
optimizer = paddle.fluid.optimizer.LarsMomentumOptimizer(
learning_rate=0.001, momentum=0.9)
else:
optimizer = paddle.fluid.optimizer.AdamOptimizer(
learning_rate=0.00001,
beta1=0.9,
beta2=0.999,
epsilon=1e-08,
grad_clip=clip)
optimizer = fleet.distributed_optimizer(optimizer)
startup_program = paddle.static.default_startup_program()
_, _, dist_startup_prog, dist_main_prog = optimizer.minimize(
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import random
import numpy as np
import unittest
import paddle
import paddle.nn as nn
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.dist_context import get_default_distributed_context
from paddle.distributed.passes import new_pass, PassManager, PassContext
from auto_parallel_pass_test_base import AutoPallelPassTestBase
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
class TestDataParallelPassWithScale1(AutoPallelPassTestBase):
def init(self):
if paddle.is_compiled_with_cuda():
paddle.set_flags({'FLAGS_cudnn_deterministic': 1})
self.rtol = 1e-5
self.atol = 1e-8
# NOTE a hack to compare pass apply or not, since there is no
# setting of this pass in dist_strategy
self._apply_pass = False
rank = paddle.distributed.get_rank()
paddle.seed(rank + 2021)
random.seed(rank + 2021)
np.random.seed(rank + 2021)
def apply_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
self._apply_pass = True
def apply_no_passes(self):
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
fleet.init(is_collective=True, strategy=dist_strategy)
self._apply_pass = False
def test_bs_8(self):
self.check_main(gpus=[0, 1],
batch_size=8,
sequence_len=512,
vocab_size=1000)
# test scaling with fillconstant
def get_model(self, place, batch_size, sequence_len, vocab_size):
dist_main_prog, dist_startup_prog, data_holder, [
loss
], gen_data = self.get_gpt_model('dp', place, batch_size, sequence_len,
vocab_size)
if self._apply_pass:
config = {}
config["dist_context"] = get_default_distributed_context()
config["global_rank"] = paddle.distributed.get_rank()
dp_pass = new_pass("auto_parallel_data_parallel_optimization",
config)
dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext())
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
class TestDataParallelPassWithScale2(TestDataParallelPassWithScale1):
# test scaling with optimizer rescale_grad
def get_model(self, place, batch_size, sequence_len, vocab_size):
dist_main_prog, dist_startup_prog, data_holder, [
loss
], gen_data = self.get_gpt_model('dp',
place,
batch_size,
sequence_len,
vocab_size,
optimizer='LarsMomentum')
if self._apply_pass:
config = {}
config["dist_context"] = get_default_distributed_context()
config["global_rank"] = paddle.distributed.get_rank()
dp_pass = new_pass("auto_parallel_data_parallel_optimization",
config)
dp_pass.apply([dist_main_prog], [dist_startup_prog], PassContext())
return dist_main_prog, dist_startup_prog, data_holder, [loss], gen_data
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册