未验证 提交 e5eb3f55 编写于 作者: J JZ-LIANG 提交者: GitHub

[Auto Parallel] Sharding Optimization:Partition Algorithm & Stage2 Parameter...

[Auto Parallel] Sharding Optimization:Partition Algorithm & Stage2 Parameter Bucket communication  (#47180)

* partition param by order

* add logging

* reorder opt

* config

* stage2 bucket

* update unitest
上级 6934ae2b
...@@ -82,7 +82,9 @@ SHARDING = "sharding" ...@@ -82,7 +82,9 @@ SHARDING = "sharding"
set_field_default_config(SHARDING, "enable", False) set_field_default_config(SHARDING, "enable", False)
set_field_default_config(SHARDING, "stage", 1) set_field_default_config(SHARDING, "stage", 1)
set_field_default_config(SHARDING, "degree", 8) set_field_default_config(SHARDING, "degree", 8)
set_field_default_config(SHARDING, "segment_broadcast_MB", 32.0) set_field_default_config(SHARDING, "overlap_grad_comm", False)
set_field_default_config(SHARDING, "bucket_size_numel", -1)
set_field_default_config(SHARDING, "partition_algor", "greedy_even")
set_field_default_config(SHARDING, "enable_tuning", False) set_field_default_config(SHARDING, "enable_tuning", False)
set_field_default_config(SHARDING, "tuning_range", []) set_field_default_config(SHARDING, "tuning_range", [])
......
...@@ -22,6 +22,7 @@ import logging ...@@ -22,6 +22,7 @@ import logging
from functools import reduce from functools import reduce
import paddle.fluid.core as core import paddle.fluid.core as core
from paddle.fluid.framework import Variable
from paddle.distributed.fleet.meta_optimizers.common import OpRole from paddle.distributed.fleet.meta_optimizers.common import OpRole
from paddle.distributed.auto_parallel.process_group import ( from paddle.distributed.auto_parallel.process_group import (
get_all_process_groups, get_all_process_groups,
...@@ -1790,6 +1791,18 @@ def find_higher_order_backward_op(program): ...@@ -1790,6 +1791,18 @@ def find_higher_order_backward_op(program):
return False return False
def get_var_numel(var):
"""
input:
- var: variable
return:
number of elemnet in var
"""
assert isinstance(var, Variable)
assert -1 not in var.shape
return reduce(lambda x, y: x * y, var.shape)
def get_lr(optimizer): def get_lr(optimizer):
if isinstance(optimizer, paddle.optimizer.Optimizer): if isinstance(optimizer, paddle.optimizer.Optimizer):
return optimizer.get_lr() return optimizer.get_lr()
......
...@@ -13,12 +13,12 @@ ...@@ -13,12 +13,12 @@
# limitations under the License. # limitations under the License.
from collections import OrderedDict from collections import OrderedDict
import numpy as np
import paddle import paddle
from paddle.fluid import unique_name from paddle.fluid import unique_name
from paddle.fluid.framework import default_main_program from paddle.fluid.framework import default_main_program
from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole from paddle.distributed.fleet.meta_optimizers.common import OP_ROLE_KEY, OpRole
from .pass_base import PassBase, PassType, register_pass
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.operators.common import (
is_data_parallel_scale_op, is_data_parallel_scale_op,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
...@@ -28,8 +28,8 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -28,8 +28,8 @@ from paddle.distributed.auto_parallel.utils import (
is_loss_grad_op, is_loss_grad_op,
is_optimize_op, is_optimize_op,
ring_id_to_process_group, ring_id_to_process_group,
get_var_numel,
) )
from .pass_base import PassBase, PassType, register_pass
# add new optimizers supporting rescale_grad here # add new optimizers supporting rescale_grad here
__rescale_grad_supported_opts__ = [ __rescale_grad_supported_opts__ = [
...@@ -44,10 +44,6 @@ __rescale_grad_supported_opts__ = [ ...@@ -44,10 +44,6 @@ __rescale_grad_supported_opts__ = [
__max_stream_num_allow__ = 16 __max_stream_num_allow__ = 16
def numel(var):
return np.prod(list(var.shape))
@register_pass("auto_parallel_data_parallel_optimization") @register_pass("auto_parallel_data_parallel_optimization")
class DataParallelOptimizationPass(PassBase): class DataParallelOptimizationPass(PassBase):
""" """
...@@ -430,7 +426,7 @@ class DataParallelOptimizationPass(PassBase): ...@@ -430,7 +426,7 @@ class DataParallelOptimizationPass(PassBase):
ring_id = op.attr("ring_id") ring_id = op.attr("ring_id")
grad_name = op.output_arg_names[0] grad_name = op.output_arg_names[0]
grad_var = block.var(grad_name) grad_var = block.var(grad_name)
grad_numel = numel(grad_var) grad_numel = get_var_numel(grad_var)
if cur_group.acceptable(grad_var, ring_id): if cur_group.acceptable(grad_var, ring_id):
assert grad_name not in grouped_grad_names assert grad_name not in grouped_grad_names
...@@ -594,7 +590,7 @@ class GradientsGroup: ...@@ -594,7 +590,7 @@ class GradientsGroup:
return True return True
if ring_id != self.ring_id: if ring_id != self.ring_id:
return False return False
if numel(grad_var) + self.numel > self.max_group_size: if get_var_numel(grad_var) + self.numel > self.max_group_size:
return False return False
if grad_var.dtype != self.dtype: if grad_var.dtype != self.dtype:
return False return False
...@@ -605,7 +601,7 @@ class GradientsGroup: ...@@ -605,7 +601,7 @@ class GradientsGroup:
self.gradients.append(grad_var) self.gradients.append(grad_var)
self.ring_id = ring_id self.ring_id = ring_id
self.dtype = grad_var.dtype self.dtype = grad_var.dtype
self.numel += numel(grad_var) self.numel += get_var_numel(grad_var)
# remove auxiliary ops in non-fuse dp allreduce # remove auxiliary ops in non-fuse dp allreduce
self.remove_allreduce_op_indices.append(i) self.remove_allreduce_op_indices.append(i)
......
...@@ -13,15 +13,20 @@ ...@@ -13,15 +13,20 @@
# limitations under the License. # limitations under the License.
from functools import reduce from functools import reduce
import logging
import paddle
from paddle.framework import core from paddle.framework import core
from paddle.fluid.framework import default_main_program, default_startup_program
from paddle.fluid import unique_name from paddle.fluid import unique_name
from .pass_base import PassBase, register_pass from .pass_base import PassBase, register_pass
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.fleet.meta_optimizers.sharding.utils import get_var_size
from paddle.distributed.fleet.meta_optimizers.common import ( from paddle.distributed.fleet.meta_optimizers.common import (
is_backward_op, is_backward_op,
is_optimizer_op, is_optimizer_op,
) )
from paddle.distributed.auto_parallel.process_group import new_process_group
from paddle.distributed.auto_parallel.operators.common import ( from paddle.distributed.auto_parallel.operators.common import (
is_parameter_related, is_parameter_related,
is_data_parallel_reduce_op, is_data_parallel_reduce_op,
...@@ -30,6 +35,8 @@ from paddle.distributed.auto_parallel.utils import ( ...@@ -30,6 +35,8 @@ from paddle.distributed.auto_parallel.utils import (
_get_comm_group, _get_comm_group,
naive_set_dist_op_attr_for_program_by_mesh_and_mapping, naive_set_dist_op_attr_for_program_by_mesh_and_mapping,
set_var_dist_attr, set_var_dist_attr,
get_var_numel,
get_logger,
) )
OpRole = core.op_proto_and_checker_maker.OpRole OpRole = core.op_proto_and_checker_maker.OpRole
...@@ -57,6 +64,8 @@ _supported_optimizer_type = [ ...@@ -57,6 +64,8 @@ _supported_optimizer_type = [
"sgd", "sgd",
] ]
_logger = get_logger(logging.INFO)
def _is_reshard_op(op): def _is_reshard_op(op):
return op.desc.has_attr( return op.desc.has_attr(
...@@ -76,6 +85,9 @@ class ShardingPass(PassBase): ...@@ -76,6 +85,9 @@ class ShardingPass(PassBase):
self.set_attr("stage", None) self.set_attr("stage", None)
self.set_attr("sharding_degree", None) # for parallelizer self.set_attr("sharding_degree", None) # for parallelizer
self.set_attr("degree", None) # for parallelizer_v2 self.set_attr("degree", None) # for parallelizer_v2
self.set_attr("overlap_grad_comm", None)
self.set_attr("bucket_size_numel", None)
self.set_attr("partition_algor", None)
self.set_attr("params_grads", []) self.set_attr("params_grads", [])
self.set_attr("global_rank", -1) self.set_attr("global_rank", -1)
self.dp_groups = set() self.dp_groups = set()
...@@ -109,6 +121,12 @@ class ShardingPass(PassBase): ...@@ -109,6 +121,12 @@ class ShardingPass(PassBase):
"global_rank" "global_rank"
) < 0: ) < 0:
return False return False
if self.get_attr("overlap_grad_comm") is None:
return False
if self.get_attr("bucket_size_numel") is None:
return False
if self.get_attr("partition_algor") is None:
return False
return True return True
...@@ -122,22 +140,35 @@ class ShardingPass(PassBase): ...@@ -122,22 +140,35 @@ class ShardingPass(PassBase):
) )
self.stage = int(self.get_attr("stage")) self.stage = int(self.get_attr("stage"))
self.global_rank = int(self.get_attr("global_rank")) self.global_rank = int(self.get_attr("global_rank"))
self.overlap_grad_comm = self.get_attr("overlap_grad_comm")
self.bucket_size_numel = int(self.get_attr("bucket_size_numel"))
self.partition_algor = self.get_attr("partition_algor")
params_grads = self.get_attr("params_grads") params_grads = self.get_attr("params_grads")
main_block, startup_block = ( main_block, startup_block = (
main_program.global_block(), main_program.global_block(),
startup_program.global_block(), startup_program.global_block(),
) )
# NOTE Multi / Sub-Block Support
# we assume that only parameter are present and partitioned in main_block,
# there is NO new param in sub_block, and all params in sub_block follows the same
# partition as main_block. the above contraint fullfill the 3 most common use-cases in Paddle sub_block:
# 1. subblock for lr scheduler
# 2. sub-block uses the same or partial network of main-block, e.g. GPT3 generation model
# 3. sub-block used for double backward
self._build_sharding_groups(main_block, params_grads) self._build_sharding_groups(main_block, params_grads)
self._shard_optimizer(main_block, startup_block, params_grads, context) for block in main_program.blocks:
self._shard_gradient_synchronization(main_block) self._shard_optimizer(block, startup_block, params_grads, context)
self._shard_parameter(main_block, startup_block) self._shard_gradient_synchronization(block)
self._shard_parameter(block, startup_block)
context.set_attr("params_grads", self.shared_params_grads) context.set_attr("params_grads", self.shared_params_grads)
self._optimization_pass(main_program, startup_program)
def _build_sharding_groups(self, main_block, params_grads): def _build_sharding_groups(self, main_block, params_grads):
self._collective_data_parallel_groups(main_block) self._collective_data_parallel_groups(main_block)
self._build_sharding_infos(params_grads) self._build_sharding_infos(main_block, params_grads)
def _collective_data_parallel_groups(self, main_block): def _collective_data_parallel_groups(self, main_block):
for op in main_block.ops: for op in main_block.ops:
...@@ -162,8 +193,14 @@ class ShardingPass(PassBase): ...@@ -162,8 +193,14 @@ class ShardingPass(PassBase):
) )
) )
def _build_sharding_infos(self, params_grads): def _build_sharding_infos(self, main_block, params_grads):
# order params
params_grads = re_order_program(
main_block, params_grads, self._dist_context
)
# partition
for dp_group in self.dp_groups: for dp_group in self.dp_groups:
assert ( assert (
...@@ -204,7 +241,10 @@ class ShardingPass(PassBase): ...@@ -204,7 +241,10 @@ class ShardingPass(PassBase):
self._dist_context._sharding_group = sharding_group self._dist_context._sharding_group = sharding_group
# TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group # TODO(JZ-LIANG) when support multiple dp groups in future, should group param and bind them to corresponding dp group
sharding_info = ShardingInfo( sharding_info = ShardingInfo(
sharding_group, self.global_rank, params_grads sharding_group,
self.global_rank,
params_grads,
self.partition_algor,
) )
self.sharding_infos.append(sharding_info) self.sharding_infos.append(sharding_info)
for param in sharding_info.params: for param in sharding_info.params:
...@@ -317,7 +357,7 @@ class ShardingPass(PassBase): ...@@ -317,7 +357,7 @@ class ShardingPass(PassBase):
reserved_vars.append(input_name) reserved_vars.append(input_name)
op.desc.set_input("X", reserved_vars) op.desc.set_input("X", reserved_vars)
sum_op_output = op.desc.output_arg_names()[0] sum_op_output = op.output_arg_names[0]
for i, sharding_info in enumerate(self.sharding_infos): for i, sharding_info in enumerate(self.sharding_infos):
new_op = main_block._insert_op( new_op = main_block._insert_op(
idx + i + 1, idx + i + 1,
...@@ -401,7 +441,7 @@ class ShardingPass(PassBase): ...@@ -401,7 +441,7 @@ class ShardingPass(PassBase):
def _insert_optimizer_broadcasts(self, main_block, startup_block): def _insert_optimizer_broadcasts(self, main_block, startup_block):
if self.stage > 2: if self.stage > 2 or self.bucket_size_numel > 1:
return return
for sharding_info in self.sharding_infos: for sharding_info in self.sharding_infos:
...@@ -508,7 +548,7 @@ class ShardingPass(PassBase): ...@@ -508,7 +548,7 @@ class ShardingPass(PassBase):
if is_optimizer_op(op): if is_optimizer_op(op):
continue continue
for input_name in op.desc.input_arg_names(): for input_name in op.input_arg_names:
# NOTE hack for embedding op when AMP 02-3 # NOTE hack for embedding op when AMP 02-3
# paddle amp force embedding (lookup table) to be run on fp32 # paddle amp force embedding (lookup table) to be run on fp32
if _is_param_fp16_cast_op( if _is_param_fp16_cast_op(
...@@ -601,6 +641,24 @@ class ShardingPass(PassBase): ...@@ -601,6 +641,24 @@ class ShardingPass(PassBase):
main_block._sync_with_cpp() main_block._sync_with_cpp()
startup_block._sync_with_cpp() startup_block._sync_with_cpp()
def _optimization_pass(self, main_program, startup_program):
with paddle.static.program_guard(main_program, startup_program):
if self.overlap_grad_comm:
_fuse_overlap_gradient_comm()
# TODO support multiple sub_blocks
if self.bucket_size_numel > 1:
if self.stage == 2:
_fuse_overlap_parameter_comm_stage_two(
self.sharding_infos,
self._dist_context,
fuse_size=self.bucket_size_numel,
)
elif self.stage == 3:
_fuse_overlap_parameter_comm_stage_three(
self.sharding_infos, fuse_size=self.bucket_size_numel
)
def _insert_init_and_broadcast_op( def _insert_init_and_broadcast_op(
block, block,
...@@ -723,7 +781,7 @@ def _is_param_grad_fp32_cast_op(block, op): ...@@ -723,7 +781,7 @@ def _is_param_grad_fp32_cast_op(block, op):
block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32 block, op, core.VarDesc.VarType.FP16, core.VarDesc.VarType.FP32
): ):
return False return False
output_name = op.desc.output_arg_names()[0] output_name = op.output_arg_names[0]
base_name = output_name[: output_name.find("@")] base_name = output_name[: output_name.find("@")]
if not block.has_var(base_name): if not block.has_var(base_name):
return False return False
...@@ -736,7 +794,7 @@ def _is_param_fp16_cast_op(block, op, params): ...@@ -736,7 +794,7 @@ def _is_param_fp16_cast_op(block, op, params):
return False return False
if not _is_desired_cast_op(block, op): if not _is_desired_cast_op(block, op):
return False return False
input_name = op.desc.input_arg_names()[0] input_name = op.input_arg_names[0]
if input_name not in params: if input_name not in params:
return False return False
return True return True
...@@ -750,10 +808,10 @@ def _is_desired_cast_op( ...@@ -750,10 +808,10 @@ def _is_desired_cast_op(
): ):
if op.type != "cast": if op.type != "cast":
return False return False
assert len(op.desc.input_arg_names()) == 1 assert len(op.input_arg_names) == 1
assert len(op.desc.output_arg_names()) == 1 assert len(op.output_arg_names) == 1
input_var = block.var(op.desc.input_arg_names()[0]) input_var = block.var(op.input_arg_names[0])
output_var = block.var(op.desc.output_arg_names()[0]) output_var = block.var(op.output_arg_names[0])
if input_var.dtype != src_var_type or output_var.dtype != dst_var_type: if input_var.dtype != src_var_type or output_var.dtype != dst_var_type:
return False return False
...@@ -828,10 +886,36 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context): ...@@ -828,10 +886,36 @@ def _inference_data_parallel_group_for_operator(rank_id, op, dist_context):
return dp_group return dp_group
def shard_parameters(params, group_size): def partition_by_use_order(params, group_size):
# TODO(JZ-LIANG) support multiple partition methods """
# method1: greedy even but unorder shard the continouse param into same rank and divide the forward&backward computation into segement,
# method2: roughly even with oreder which will favor the fuse pass in later.
we assume that the params is already sorted by utilization order.
"""
mapping = {}
total_param_mem = 0.0
param2mem = []
for param in params:
mem = get_var_size(param)
total_param_mem += mem
param2mem.append((param, mem))
mapping = {x: [] for x in range(group_size)}
cur_rank = 0
mem_accu = 0.0
for param, mem in param2mem:
if mem_accu > total_param_mem * 1.0 * (cur_rank + 1) / group_size:
cur_rank += 1
mapping[cur_rank].append(param)
mem_accu += mem
return mapping
def partition_by_greedy_even(params, group_size):
"""
use greedy alogrithm to partition parameter as even as possible.
"""
mapping = {} mapping = {}
for rank_ in range(group_size): for rank_ in range(group_size):
mapping[rank_] = [] mapping[rank_] = []
...@@ -850,8 +934,212 @@ def shard_parameters(params, group_size): ...@@ -850,8 +934,212 @@ def shard_parameters(params, group_size):
return mapping return mapping
class ShardingInfo: def partition_parameters(params, group_size, algor="greedy_even"):
def __init__(self, group, rank, params_grads): if algor == "greedy_even":
rank_to_params = partition_by_greedy_even(params, group_size)
else:
rank_to_params = partition_by_use_order(params, group_size)
_logger.info("Sharding Parameter Partition:")
for k, v in rank_to_params.items():
_logger.info(
"Rank:{}, Parameter Size:{} MB.".format(
k, sum([get_var_size(var) for var in v])
)
)
_logger.info("Params in this rank: {}.".format([var.name for var in v]))
return rank_to_params
def re_order_program(block, param_grads, dist_context):
# record order
pname_to_pg_pairs = {}
for p, g in param_grads:
pname_to_pg_pairs[p.name] = (p, g)
use_order = []
for op in block.ops:
for input_name in op.input_arg_names:
if (input_name in pname_to_pg_pairs) and (
input_name not in use_order
):
use_order.append(input_name)
if len(use_order) == len(pname_to_pg_pairs):
break
# reorder optimzier
last_op = block.ops[-1]
pname_to_op = {}
num_ops = len(block.ops)
remove_op_indices = []
# TODO support case when optimizer is not the last op
if is_optimizer_op(last_op) and last_op.type in _supported_optimizer_type:
# record optimizer
for idx, op in reversed(list(enumerate(block.ops))):
if op.type not in _supported_optimizer_type:
break
assert len(op.input("Param")) == 1
pname_to_op[op.input("Param")[0]] = op
remove_op_indices.append(idx)
assert len(use_order) == len(pname_to_op)
# append new opts
for pname in use_order:
new_op = block.append_op(type='nop')
new_op.desc.copy_from(pname_to_op[pname].desc)
dist_context.set_op_dist_attr_for_program(
new_op,
dist_context.get_op_dist_attr_for_program(pname_to_op[pname]),
)
# remove old opts
for idx in remove_op_indices:
block._remove_op(idx, sync=False)
block._sync_with_cpp()
assert len(block.ops) == num_ops
# TODO reorder gradient clip order
_logger.info(
"Sharding the Order of param being used: {}.".format(use_order)
)
return [pname_to_pg_pairs[p] for p in use_order]
def group_param(sharding_info, fuse_size):
"""
param are group by:
rank id
fuse_size
dtype
"""
group_to_param_map = {}
param_to_group_map = {}
bucket = []
cur_group = ParameterGroup(fuse_size)
for param in sharding_info.params:
rank = sharding_info.get_var_rank(param.name)
if cur_group.acceptable(param, rank):
cur_group.collect(param, rank)
else:
cur_group = ParameterGroup(fuse_size)
cur_group.collect(param, rank)
if cur_group in group_to_param_map:
group_to_param_map[cur_group].append(param.name)
else:
group_to_param_map[cur_group] = [param.name]
param_to_group_map[param.name] = cur_group
return group_to_param_map, param_to_group_map
def _fuse_overlap_gradient_comm():
pass
def _fuse_overlap_parameter_comm_stage_two(
sharding_infos, dist_context, fuse_size
):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
main_block = default_main_program().global_block()
startup_block = default_startup_program().global_block()
group_to_param_map, param_to_group_map = group_param(
sharding_info, fuse_size
)
_logger.info("Sharding Stage2 Optimization:")
_logger.info(
"Bucket size is [{}], [{}] Parameters are fused into [{}] Buckets".format(
fuse_size,
len(param_to_group_map.keys()),
len(group_to_param_map.keys()),
)
)
for i, group in enumerate(group_to_param_map.keys()):
assert len(group) >= 1
if len(group) > 1:
coalesce_var_name = unique_name.generate(
'coalecse_param_{}'.format(i)
)
startup_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
group.coalesce_var = main_block.create_var(
name=coalesce_var_name,
dtype=group.dtype,
persistable=True,
stop_gradient=True,
)
startup_block.append_op(
type="coalesce_tensor",
inputs={"Input": group.params},
outputs={
"Output": group.params,
"FusedOutput": group.coalesce_var,
},
attrs={
"copy_data": True,
"use_align": True,
"dtype": group.dtype,
OP_ROLE_KEY: OpRole.Forward,
},
)
else:
group.coalesce_var = group.params[0]
_logger.info(
"Bucket[{}] size [{}]MB : {}".format(
i,
sum([get_var_size(p) for p in group.params]),
[p.name for p in group.params],
)
)
# TODO Overlap broadcast with opt and next forward
new_op = main_block.append_op(
type='c_broadcast',
inputs={'X': group.coalesce_var},
outputs={'Out': group.coalesce_var},
attrs={
'ring_id': sharding_info.group.id,
'root': group.rank,
'use_calc_stream': True,
OP_ROLE_KEY: OpRole.Optimize,
},
)
# NOTE the current dist context lack the presentation for bucket tensor which
# composes many tensor with different dims_mapping. we assign a fake dist attr
# for it currently.
def _fuse_overlap_parameter_comm_stage_three(sharding_infos, fuse_size):
assert (
len(sharding_infos) == 1
), "fuse overlap optimization only support one sharding group right now, but got [{}].".format(
len(sharding_infos)
)
sharding_info = sharding_infos[0]
class ShardingInfo(object):
def __init__(self, group, rank, params_grads, partition_algor):
self.group = group self.group = group
self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads]) self.params_grads = dict([(p.name, (p, g)) for p, g in params_grads])
assert len(self.params_grads) == len( assert len(self.params_grads) == len(
...@@ -863,8 +1151,11 @@ class ShardingInfo: ...@@ -863,8 +1151,11 @@ class ShardingInfo:
self.group_size = group.nranks self.group_size = group.nranks
self.global_rank = rank self.global_rank = rank
self.local_rank = group.ranks.index(self.global_rank) self.local_rank = group.ranks.index(self.global_rank)
self.partition_algor = partition_algor
# rank in below mapping are local rank in this sharding group # rank in below mapping are local rank in this sharding group
self.rank_to_params = shard_parameters(self.params, self.group_size) self.rank_to_params = partition_parameters(
self.params, self.group_size, self.partition_algor
)
# include fp32 and fp16 param # include fp32 and fp16 param
self.param_to_rank = dict() self.param_to_rank = dict()
self._map_param_to_rank() self._map_param_to_rank()
...@@ -899,7 +1190,7 @@ class ShardingInfo: ...@@ -899,7 +1190,7 @@ class ShardingInfo:
for op in block.ops: for op in block.ops:
if is_optimizer_op(op): if is_optimizer_op(op):
continue continue
for input_name in op.desc.input_arg_names(): for input_name in op.input_arg_names:
if input_name in self.param_names: if input_name in self.param_names:
param_usage[input_name] += 1 param_usage[input_name] += 1
...@@ -927,3 +1218,34 @@ class ShardingInfo: ...@@ -927,3 +1218,34 @@ class ShardingInfo:
if param_name not in self.params_grads: if param_name not in self.params_grads:
raise ValueError('param[{}] not in params_grads'.format(param_name)) raise ValueError('param[{}] not in params_grads'.format(param_name))
return self.params_grads.get(param_name, None) return self.params_grads.get(param_name, None)
class ParameterGroup(object):
def __init__(self, max_size):
self.max_siez = max_size
self.dtype = None
self.rank = -1
self.numel = 0
self.params = []
self.coalesce_var = None
def acceptable(self, param, rank):
if self.numel == 0:
return True
else:
if param.dtype != self.dtype:
return False
if rank != self.rank:
return False
if self.numel + get_var_numel(param) > self.max_siez:
return False
return True
def collect(self, param, rank):
self.dtype = param.dtype
self.rank = rank
self.numel += get_var_numel(param)
self.params.append(param)
def __len__(self):
return len(self.params)
...@@ -44,7 +44,9 @@ class TestStrategy(unittest.TestCase): ...@@ -44,7 +44,9 @@ class TestStrategy(unittest.TestCase):
self.assertEqual(sharding.enable, False) self.assertEqual(sharding.enable, False)
self.assertEqual(sharding.stage, 1) self.assertEqual(sharding.stage, 1)
self.assertEqual(sharding.degree, 8) self.assertEqual(sharding.degree, 8)
self.assertAlmostEqual(sharding.segment_broadcast_MB, 32.0) self.assertAlmostEqual(sharding.overlap_grad_comm, False)
self.assertAlmostEqual(sharding.bucket_size_numel, -1)
self.assertAlmostEqual(sharding.partition_algor, "greedy_even")
self.assertEqual(sharding.enable_tuning, False) self.assertEqual(sharding.enable_tuning, False)
self.assertEqual(sharding.tuning_range, []) self.assertEqual(sharding.tuning_range, [])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册