未验证 提交 e5eb3f55 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Sharding Optimization:Partition Algorithm & Stage2 Parameter...

[Auto Parallel] Sharding Optimization:Partition Algorithm & Stage2 Parameter Bucket communication  (#47180)

* partition param by order

* add logging

* reorder opt

* config

* stage2 bucket

* update unitest
上级 6934ae2b
......@@ -82,7 +82,9 @@ SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0)
set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", [])
......
......@@ -22,6 +22,7 @@ import logging
from functools import reduce
import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups,
......@@ -1790,6 +1791,18 @@ def find_higher_order_backward_op(program):
return False
def get_var_numel(var):
"""
input:
- var: variable
return:
number of elemnet in var
"""
assert isinstance(var, Variable)
assert -1 not in var.shape
return reduce(lambda x, y: x * y, var.shape)
def get_lr(optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr()
......
......@@ -13,12 +13,12 @@
# limitations under the License.
from collections import OrderedDict
import numpy as np
import paddle
from paddle.fluid import unique_name
from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op,
is_data_parallel_reduce_op,
......@@ -28,8 +28,8 @@ from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op,
is_optimize_op,
ring_id_to_process_group,
get_var_numel,
)
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [
......@@ -44,10 +44,6 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase):
"""
......@@ -430,7 +426,7 @@ class DataParallelOptimizationPass(PassBase):
ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name)
grad_numel = numel(grad_var)
grad_numel = get_var_numel(grad_var)
if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names
......@@ -594,7 +590,7 @@ class GradientsGroup:
return True
if ring_id != self.ring_id:
return False
if numel(grad_var) + self.numel > self.max_group_size:
if get_var_numel(grad_var) + self.numel > self.max_group_size:
return False
if grad_var.dtype != self.dtype:
return False
......@@ -605,7 +601,7 @@ class GradientsGroup:
self.gradients.append(grad_var)
self.ring_id = ring_id
self.dtype = grad_var.dtype
self.numel += numel(grad_var)
self.numel += get_var_numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i)
......
......@@ -13,15 +13,20 @@
# limitations under the License.
from functools import reduce
import logging
import paddle
from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size
from paddle.distributed.fleet.meta_optimizers.common import (
is_backward_op,
is_optimizer_op,
)
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import (
is_parameter_related,
is_data_parallel_reduce_op,
......@@ -30,6 +35,8 @@ from paddle.distributed.auto_parallel.utils import (
_get_comm_group,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr,
get_var_numel,
get_logger,
)
OpRole = core.op_proto_and_checker_maker.OpRole
......@@ -57,6 +64,8 @@ _supported_optimizer_type = [
"sgd",
]
_logger = get_logger(logging.INFO)
def _is_reshard_op(op):
return op.desc.has_attr(
......@@ -76,6 +85,9 @@ class ShardingPass(PassBase):
self.set_attr("stage", None)
self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("overlap_grad_comm", None)
self.set_attr("bucket_size_numel", None)
self.set_attr("partition_algor", None)
self.set_attr("params_grads", [])
self.set_attr("global_rank", -1)
self.dp_groups = set()
......@@ -109,6 +121,12 @@ class ShardingPass(PassBase):
"global_rank"
) < 0:
return False
if self.get_attr("overlap_grad_comm") is None:
return False
if self.get_attr("bucket_size_numel") is None:
return False
if self.get_attr("partition_algor") is None:
return False
return True
......@@ -122,22 +140,35 @@ class ShardingPass(PassBase):
)
self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank"))
self.overlap_grad_comm = self.get_attr("overlap_grad_comm")
self.bucket_size_numel = int(self.get_attr("bucket_size_numel"))
self.partition_algor = self.get_attr("partition_algor")
params_grads = self.get_attr("params_grads")
main_block, startup_block = (
main_program.global_block(),
startup_program.global_block(),
)
# NOTE Multi / Sub-Block Support
# we assume that only parameter are present and partitioned in main_block,
# there is NO new param in sub_block, and all params in sub_block follows the same
# partition as main_block. the above contraint fullfill the 3 most common use-cases in Paddle sub_block:
# 1. subblock for lr scheduler
# 2. sub-block uses the same or partial network of main-block, e.g. GPT3 generation model
# 3. sub-block used for double backward
self._build_sharding_groups(main_block, params_grads)
self._shard_optimizer(main_block, startup_block, params_grads, context)
self._shard_gradient_synchronization(main_block)
self._shard_parameter(main_block, startup_block)
for block in main_program.blocks:
self._shard_optimizer(block, startup_block, params_grads, context)
self._shard_gradient_synchronization(block)
self._shard_parameter(block, startup_block)
context.set_attr("params_grads", self.shared_params_grads)
self._optimization_pass(main_program, startup_program)
def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads)
self._build_sharding_infos(main_block, params_grads)
def _collective_data_parallel_groups(self, main_block):
for op in main_block.ops:
......@@ -162,8 +193,14 @@ class ShardingPass(PassBase):
)
)
def _build_sharding_infos(self, params_grads):
def _build_sharding_infos(self, main_block, params_grads):
# order params
params_grads = re_order_program(
main_block, params_grads, self._dist_context
)
# partition
for dp_group in self.dp_groups:
assert (
......@@ -204,7 +241,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
sharding_info = ShardingInfo(
sharding_group, self.global_rank, params_grads
sharding_group,
self.global_rank,
params_grads,
self.partition_algor,
)
self.sharding_infos.append(sharding_info)
for param in sharding_info.params:
......@@ -317,7 +357,7 @@ class ShardingPass(PassBase):
reserved_vars.append(input_name)
op.desc.set_input("X", reserved_vars)
sum_op_output = op.desc.output_arg_names()[0]
sum_op_output = op.output_arg_names[0]
for i, sharding_info in enumerate(self.sharding_infos):
new_op = main_block._insert_op(
idx + i + 1,
......@@ -401,7 +441,7 @@ class ShardingPass(PassBase):
def _insert_optimizer_broadcasts(self, main_block, startup_block):
if self.stage > 2:
if self.stage > 2 or self.bucket_size_numel > 1:
return
for sharding_info in self.sharding_infos:
......@@ -508,7 +548,7 @@ class ShardingPass(PassBase):
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
for input_name in op.input_arg_names:
# NOTE hack for embedding op when AMP 02-3
# paddle amp force embedding (lookup table) to be run on fp32
if _is_param_fp16_cast_op(
......@@ -601,6 +641,24 @@ class ShardingPass(PassBase):
main_block._sync_with_cpp()
startup_block._sync_with_cpp()
def _optimization_pass(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
if self.overlap_grad_comm:
_fuse_overlap_gradient_comm()
# TODO support multiple sub_blocks
if self.bucket_size_numel > 1:
if self.stage == 2:
_fuse_overlap_parameter_comm_stage_two(
self.sharding_infos,
self._dist_context,
fuse_size=self.bucket_size_numel,
)
elif self.stage == 3:
_fuse_overlap_parameter_comm_stage_three(
self.sharding_infos, fuse_size=self.bucket_size_numel
)
def _insert_init_and_broadcast_op(
block,
......@@ -723,7 +781,7 @@ def _is_param_grad_fp32_cast_op(block, op):
block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32
):
return False
output_name = op.desc.output_arg_names()[0]
output_name = op.output_arg_names[0]
base_name = output_name[: output_name.find("@")]
if not block.has_var(base_name):
return False
......@@ -736,7 +794,7 @@ def _is_param_fp16_cast_op(block, op, params):
return False
if not _is_desired_cast_op(block, op):
return False
input_name = op.desc.input_arg_names()[0]
input_name = op.input_arg_names[0]
if input_name not in params:
return False
return True
......@@ -750,10 +808,10 @@ def _is_desired_cast_op(
):
if op.type != "cast":
return False
assert len(op.desc.input_arg_names()) == 1
assert len(op.desc.output_arg_names()) == 1
input_var = block.var(op.desc.input_arg_names()[0])
output_var = block.var(op.desc.output_arg_names()[0])
assert len(op.input_arg_names) == 1
assert len(op.output_arg_names) == 1
input_var = block.var(op.input_arg_names[0])
output_var = block.var(op.output_arg_names[0])
if input_var.dtype != src_var_type or output_var.dtype != dst_var_type:
return False
......@@ -828,10 +886,36 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
return dp_group
def shard_parameters(params, group_size):
# TODO(JZ-LIANG) support multiple partition methods
# method1: greedy even but unorder
# method2: roughly even with oreder
def partition_by_use_order(params, group_size):
"""
shard the continouse param into same rank and divide the forward&backward computation into segement,
which will favor the fuse pass in later.
we assume that the params is already sorted by utilization order.
"""
mapping = {}
total_param_mem = 0.0
param2mem = []
for param in params:
mem = get_var_size(param)
total_param_mem += mem
param2mem.append((param, mem))
mapping = {x: [] for x in range(group_size)}
cur_rank = 0
mem_accu = 0.0
for param, mem in param2mem:
if mem_accu > total_param_mem * 1.0 * (cur_rank + 1) / group_size:
cur_rank += 1
mapping[cur_rank].append(param)
mem_accu += mem
return mapping
def partition_by_greedy_even(params, group_size):
"""
use greedy alogrithm to partition parameter as even as possible.
"""
mapping = {}
for rank_ in range(group_size):
mapping[rank_] = []
......@@ -850,8 +934,212 @@ def shard_parameters(params, group_size):
return mapping
class ShardingInfo:
def __init__(self, group, rank, params_grads):
def partition_parameters(params, group_size, algor="greedy_even"):
if algor == "greedy_even":
rank_to_params = partition_by_greedy_even(params, group_size)
else:
rank_to_params = partition_by_use_order(params, group_size)
_logger.info("Sharding Parameter Partition:")
for k, v in rank_to_params.items():
_logger.info(
"Rank:{}, Parameter Size:{} MB.".format(
k, sum([get_var_size(var) for var in v])
)
)
_logger.info("Params in this rank: {}.".format([var.name for var in v]))
return rank_to_params
def re_order_program(block, param_grads, dist_context):
# record order
pname_to_pg_pairs = {}
for p, g in param_grads:
pname_to_pg_pairs[p.name] = (p, g)
use_order = []
for op in block.ops:
for input_name in op.input_arg_names:
if (input_name in pname_to_pg_pairs) and (
input_name not in use_order
):
use_order.append(input_name)
if len(use_order) == len(pname_to_pg_pairs):
break
# reorder optimzier
last_op = block.ops[-1]
pname_to_op = {}
num_ops = len(block.ops)
remove_op_indices = []
# TODO support case when optimizer is not the last op
if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type:
# record optimizer
for idx, op in reversed(list(enumerate(block.ops))):
if op.type not in _supported_optimizer_type:
break
assert len(op.input("Param")) == 1
pname_to_op[op.input("Param")[0]] = op
remove_op_indices.append(idx)
assert len(use_order) == len(pname_to_op)
# append new opts
for pname in use_order:
new_op = block.append_op(type='nop')
new_op.desc.copy_from(pname_to_op[pname].desc)
dist_context.set_op_dist_attr_for_program(
new_op,
dist_context.get_op_dist_attr_for_program(pname_to_op[pname]),
)
# remove old opts
for idx in remove_op_indices:
block._remove_op(idx, sync=False)
block._sync_with_cpp()
assert len(block.ops) == num_ops
# TODO reorder gradient clip order
_logger.info(
"Sharding the Order of param being used: {}.".format(use_order)
)
return [pname_to_pg_pairs[p] for p in use_order]
def group_param(sharding_info, fuse_size):
"""
param are group by:
rank id
fuse_size
dtype
"""
group_to_param_map = {}
param_to_group_map = {}
bucket = []
cur_group = ParameterGroup(fuse_size)
for param in sharding_info.params:
rank = sharding_info.get_var_rank(param.name)
if cur_group.acceptable(param, rank):
cur_group.collect(param, rank)
else:
cur_group = ParameterGroup(fuse_size)
cur_group.collect(param, rank)
if cur_group in group_to_param_map:
group_to_param_map[cur_group].append(param.name)
else:
group_to_param_map[cur_group] = [param.name]
param_to_group_map[param.name] = cur_group
return group_to_param_map, param_to_group_map
def _fuse_overlap_gradient_comm():
pass
def _fuse_overlap_parameter_comm_stage_two(
sharding_infos, dist_context, fuse_size
):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
main_block = default_main_program().global_block()
startup_block = default_startup_program().global_block()
group_to_param_map, param_to_group_map = group_param(
sharding_info, fuse_size
)
_logger.info("Sharding Stage2 Optimization:")
_logger.info(
"Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format(
fuse_size,
len(param_to_group_map.keys()),
len(group_to_param_map.keys()),
)
)
for i, group in enumerate(group_to_param_map.keys()):
assert len(group) >= 1
if len(group) > 1:
coalesce_var_name = unique_name.generate(
'coalecse_param_{}'.format(i)
)
startup_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
group.coalesce_var = main_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
startup_block.append_op(
type="coalesce_tensor",
inputs={"Input": group.params},
outputs={
"Output": group.params,
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": True,
"use_align": True,
"dtype": group.dtype,
OP_ROLE_KEY: OpRole.Forward,
},
)
else:
group.coalesce_var = group.params[0]
_logger.info(
"Bucket[{}] size [{}]MB : {}".format(
i,
sum([get_var_size(p) for p in group.params]),
[p.name for p in group.params],
)
)
# TODO Overlap broadcast with opt and next forward
new_op = main_block.append_op(
type='c_broadcast',
inputs={'X': group.coalesce_var},
outputs={'Out': group.coalesce_var},
attrs={
'ring_id': sharding_info.group.id,
'root': group.rank,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
},
)
# NOTE the current dist context lack the presentation for bucket tensor which
# composes many tensor with different dims_mapping. we assign a fake dist attr
# for it currently.
def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
class ShardingInfo(object):
def __init__(self, group, rank, params_grads, partition_algor):
self.group = group
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len(
......@@ -863,8 +1151,11 @@ class ShardingInfo:
self.group_size = group.nranks
self.global_rank = rank
self.local_rank = group.ranks.index(self.global_rank)
self.partition_algor = partition_algor
# rank in below mapping are local rank in this sharding group
self.rank_to_params = shard_parameters(self.params, self.group_size)
self.rank_to_params = partition_parameters(
self.params, self.group_size, self.partition_algor
)
# include fp32 and fp16 param
self.param_to_rank = dict()
self._map_param_to_rank()
......@@ -899,7 +1190,7 @@ class ShardingInfo:
for op in block.ops:
if is_optimizer_op(op):
continue
for input_name in op.desc.input_arg_names():
for input_name in op.input_arg_names:
if input_name in self.param_names:
param_usage[input_name] += 1
......@@ -927,3 +1218,34 @@ class ShardingInfo:
if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None)
class ParameterGroup(object):
def __init__(self, max_size):
self.max_siez = max_size
self.dtype = None
self.rank = -1
self.numel = 0
self.params = []
self.coalesce_var = None
def acceptable(self, param, rank):
if self.numel == 0:
return True
else:
if param.dtype != self.dtype:
return False
if rank != self.rank:
return False
if self.numel + get_var_numel(param) > self.max_siez:
return False
return True
def collect(self, param, rank):
self.dtype = param.dtype
self.rank = rank
self.numel += get_var_numel(param)
self.params.append(param)
def __len__(self):
return len(self.params)
......@@ -44,7 +44,9 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0)
self.assertAlmostEqual(sharding.overlap_grad_comm, False)
self.assertAlmostEqual(sharding.bucket_size_numel, -1)
self.assertAlmostEqual(sharding.partition_algor, "greedy_even")
self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册