未验证 提交 f33f2444 编写于 作者: J JZ-LIANG 提交者: GitHub

Dygraph/sharding (#33633)

* dygraph sharding

* update unitest hybrid_parallel_communicate_group
上级 3e82a794
......@@ -47,6 +47,7 @@ message HybridConfig {
optional int32 dp_degree = 1 [ default = -1 ];
optional int32 mp_degree = 2 [ default = 1 ];
optional int32 pp_degree = 3 [ default = 1 ];
optional int32 sharding_degree = 4 [ default = 1 ];
}
message AMPConfig {
......
......@@ -30,7 +30,7 @@ from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel
from ..meta_parallel import PipelineParallel, ShardingParallel
from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler
......@@ -295,9 +295,11 @@ class Fleet(object):
self.dp_degree = self.hybrid_configs["dp_degree"]
self.mp_degree = self.hybrid_configs["mp_degree"]
self.pp_degree = self.hybrid_configs["pp_degree"]
self.sharding_degree = self.hybrid_configs["sharding_degree"]
assert self.mp_degree >= 0, "mp_degree should be greater or equal to 0"
assert self.pp_degree >= 0, "pp_degree should be greater or equal to 0"
assert self.sharding_degree >= 0, "sharding_degree should be greater or equal to 0"
self.mp_degree = max(self.mp_degree, 1)
self.pp_degree = max(self.pp_degree, 1)
......@@ -309,8 +311,11 @@ class Fleet(object):
self.dp_degree = max(self.dp_degree, 1)
self._topology = tp.CommunicateTopology(
hybrid_group_names=["data", "pipe", "model"],
dims=[self.dp_degree, self.pp_degree, self.mp_degree])
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[
self.dp_degree, self.pp_degree, self.sharding_degree,
self.mp_degree
])
self._hcg = tp.HybridCommunicateGroup(self._topology)
......@@ -886,7 +891,11 @@ class Fleet(object):
assert model is not None, "model should not be None"
if self.worker_num() <= 1:
return model
if self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
if self._hcg.get_parallel_mode() == ParallelMode.SHARDING_PARALLEL:
distributed_model = ShardingParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.DATA_PARALLEL:
distributed_model = paddle.DataParallel(
model,
comm_buffer_size=self._user_defined_strategy.
......@@ -901,6 +910,7 @@ class Fleet(object):
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
model, self._hcg, strategy=self._user_defined_strategy)
return distributed_model
@dygraph_only
......
......@@ -30,12 +30,13 @@ class ParallelMode(object):
DATA_PARALLEL = 0
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2
SHARDING_PARALLEL = 3
class CommunicateTopology(object):
def __init__(self,
hybrid_group_names=["data", "pipe", "model"],
dims=[1, 1, 1]):
hybrid_group_names=["data", "pipe", "sharding", "model"],
dims=[1, 1, 1, 1]):
self._parallel_names = hybrid_group_names
self._dims = dims
self.coordinate = collections.namedtuple('Coordinate',
......@@ -122,15 +123,17 @@ class HybridCommunicateGroup(object):
self._dp_degree = self._topo.get_dim('data')
self._mp_degree = self._topo.get_dim('model')
self._pp_degree = self._topo.get_dim('pipe')
self._sharding_degree = self._topo.get_dim('sharding')
self._data_parallel_id = self._get_data_parallel_id()
self._model_parallel_id = self._get_model_parallel_id()
self._sharding_parallel_id = self._get_sharding_parallel_id()
self.stage_id = self._get_pipe_parallel_id()
assert self._check_vaild_topo(
), "Here is an unreasonable topogy setting. world_size: {}, but" \
"dp_num: {}, mp_num: {}, pp_num: {}".format(self.nranks, self._dp_degree,
self._mp_degree, self._pp_degree)
"mp_num: {}, sharding_num: {}, pp_num: {}, dp_num: {}".format(self.nranks,
self._mp_degree, self._sharding_degree, self._pp_degree, self._dp_degree)
# create comm group for data parallel
self._dp_group, self._dp_comm_group = self._set_comm_group("data")
......@@ -141,6 +144,10 @@ class HybridCommunicateGroup(object):
# create comm group for pipe parallel
self._pp_group, self._pp_comm_group = self._set_comm_group("pipe")
# create comm group for sharding parallel
self._sharding_group, self._sharding_comm_group = self._set_comm_group(
"sharding")
# create global group for check inf_nan / clip global norm
self._check_group, self._check_comm_group = self._set_check_group(
"data")
......@@ -149,19 +156,26 @@ class HybridCommunicateGroup(object):
self.is_first_stage = (self.stage_id == 0)
self.is_last_stage = (self.stage_id == (self._pp_degree - 1))
debug_str = "HybridParallelInfo: rank_id: %d, dp_degree: %d, " \
"mp_degree: %d, pp_degree: %d" % (self.global_rank, self._dp_degree,
self._mp_degree,self._pp_degree)
debug_str += ", dp_group: %s, mp_group: %s, pp_group: %s, check/clip group: %s" % (
self._dp_group, self._mp_group, self._pp_group, self._check_group)
debug_str = "HybridParallelInfo: rank_id: %d, mp_degree: %d, " \
"sharding_degree: %d, pp_degree: %d, dp_degree: %d" % (self.global_rank, self._mp_degree,
self._sharding_degree, self._pp_degree, self._dp_degree)
debug_str += ", mp_group: %s, sharding_group: %s, pp_group: %s, dp_group: %s, check/clip group: %s" % (
self._mp_group, self._sharding_group, self._pp_group,
self._dp_group, self._check_group)
logger.info(debug_str)
global _HYBRID_PARALLEL_GROUP
_HYBRID_PARALLEL_GROUP = self
def get_parallel_mode(self):
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
# there are four modes : DataParallel / TensorParallel / PipelineParallel / ShardingParallel
# NOTE when sharding conjugates with other parallel, sharding should act like a optimizer and
# adding its parallel logic within that parallelism
# when use sharding alone, it should have its own parallelism for its parallel logic
# TODO modify 3 others parallel to support sharding
if self._mp_degree == 1 and self._pp_degree == 1 and self._dp_degree == 1 and self._sharding_degree > 1:
return ParallelMode.SHARDING_PARALLEL
elif self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
......@@ -170,7 +184,7 @@ class HybridCommunicateGroup(object):
return ParallelMode.PIPELINE_PARALLEL
def _check_vaild_topo(self):
return self._dp_degree * self._mp_degree * self._pp_degree == self.nranks
return self._dp_degree * self._mp_degree * self._pp_degree * self._sharding_degree == self.nranks
def _set_comm_group(self, parallel_method="data"):
parallel_group = []
......@@ -255,6 +269,23 @@ class HybridCommunicateGroup(object):
def get_pipe_parallel_group(self):
return self._pp_comm_group
# sharding parallel message:
def _get_sharding_parallel_id(self):
return self._topo.get_coord(self.global_rank).sharding
def get_sharding_parallel_rank(self):
return self._sharding_parallel_id
def get_sharding_parallel_world_size(self):
return self._sharding_degree
def get_sharding_parallel_group(self):
return self._sharding_comm_group
def get_sharding_parallel_group_src_rank(self):
# TODO should the src rank related to the shard rank for each parameter ?
return self._sharding_comm_group.ranks[0]
# check parallel group
def get_check_parallel_group(self):
return self._check_comm_group
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
######
from functools import reduce
import paddle
from paddle import framework
from ...utils.log_util import logger
def _is_trainable(param: paddle.Tensor) -> bool:
return not param.stop_gradient
class DygraphShardingOptimizer(object):
"""
A wrapper for Sharding Optimizer in Dygraph.
.. warning: DygraphShardingOptimizer is experimental and subject to change.
.. ZeRO: https://arxiv.org/abs/1910.02054
"""
# TODO (JZ-LIANG)
# TO support following featrues in future:
# 1. fused update parameter sync
# 2. parameters_groups
# 3. dynamic trainable params, which is the case bewteen pretraining and finetuning
# 4. option to choose fuse comm (more GPU MEM need) or un-fuse comm
def __init__(
self,
hcg,
user_defined_strategy,
params,
inner_optimizer_class,
**inner_optimizer_kargs, ):
if not isinstance(params, list):
raise TypeError(
"`parameters` argument given to the DygraphShardingOptimizer should be "
"an iterable of paddle Tensors, but got argument type is `{}`.".
format(type(params)))
self._parameter_list = params
self._reference_is_trainable_params = list(
map(_is_trainable, self._parameter_list))
self._inner_optimizer_class = inner_optimizer_class
self._inner_optimizer_kargs = inner_optimizer_kargs
# sharding parallel information
# TODO better way to get the hcg & user_defined_strategy
self._hcg = hcg
self._user_defined_strategy = user_defined_strategy
self._sharding_world_size = self._hcg.get_sharding_parallel_world_size()
self._sharding_rank = self._hcg.get_sharding_parallel_rank()
# logic partitioning
self._build_sharding_mapping()
# actually create opt ops
self._buid_inner_optimizer()
def clear_grad(self):
"""
should clear grad for all parameters in model
"""
for p in self._parameter_list:
if not p.stop_gradient:
p.clear_gradient()
def _build_sharding_mapping(self):
self._rank2params = self._partition_parameters()
self._param2rank = self._map_param_to_rank()
def _partition_parameters(self):
"""
Partitions parameters among sharding ranks.
Return:
Dict[int, List]
"""
# TODO(JZ-LIANG) support multiple partition methods
# method1: greedy even but unorder
# method2: roughly even with oreder
mapping = {}
for rank_ in range(self._sharding_world_size):
mapping[rank_] = []
sizes = [0] * self._sharding_world_size
for param in self._parameter_list:
rank = sizes.index(min(sizes))
mapping[rank].append(param)
numel = reduce(lambda x, y: x * y, param.shape)
assert numel > 0, "param [{}] should larger than 0, but it is [{}]".format(
param.name, numel)
sizes[rank] += numel
return mapping
def _map_param_to_rank(self):
"""
mapping parameters to the shard which holds it.
Return:
Dict[str, int]
"""
mapping = {}
for rank, params in self._rank2params.items():
for param in params:
mapping[param.name] = rank
return mapping
def _buid_inner_optimizer(self):
# we rely on the inner opt to determine whether a parameter is stop_gradient or not:
# create moment
# update related ops: clip, regular, opt
self._inner_optimizer = self._inner_optimizer_class(
parameters=self._rank2params[self._sharding_rank],
**self._inner_optimizer_kargs)
def _sharding_sync_parameters(self):
"""
sync parameter across sharding group
"""
# TODO speed up this functional
logger.debug("sharding start sync parameters")
with framework.no_grad():
# TODO detach not need (?)
for rank, params in self._rank2params.items():
for param in params:
paddle.distributed.broadcast(
param,
# the collective API need src rank to be the global rank id
# instead of the relative logic rank id within group
src=self._hcg.get_sharding_parallel_group().ranks[rank],
group=self._hcg.get_sharding_parallel_group(),
use_calc_stream=True)
def _update_trainable(self):
"""
allow user to update trainable parameters list during training
"""
raise NotImplementedError
def minimize(self,
loss,
startup_program=None,
parameters=None,
no_grad_set=None):
# NOTE in dygraph mode, the only different between step and minimize is that minimize
# allow user to customize the parameters for updating on each step
input_param_names = set([param.name for param in parameters])
parameters = list(
filter(lambda x: x.name in input_param_names, self._rank2params[
self._sharding_rank]))
result = self._inner_optimizer.minimize(loss, startup_program,
parameters, no_grad_set)
# sync parameters accross sharding ranks
self._sharding_sync_parameters()
return result
def step(self):
# TODO Check whether the model trainable param changed and update state accordingly
# actually updating
self._inner_optimizer.step()
# sync parameters accross sharding ranks
self._sharding_sync_parameters()
# TODO is it a good way to make _grad_clip a property
@property
def _grad_clip(self):
assert self._inner_optimizer is not None, "inner opt of sharding is not initiliazed."
return self._inner_optimizer._grad_clip
def __getattr__(self, item):
return getattr(self._inner_optimizer, item)
......@@ -17,7 +17,7 @@ import sys
import paddle
from paddle.optimizer import Optimizer
from paddle.fluid.clip import ClipGradByGlobalNorm
from ...utils.hybrid_parallel_util import fused_allreduce_gradients
from ...utils.hybrid_parallel_util import fused_allreduce_gradients, sharding_reduce_gradients
from ...base.topology import ParallelMode
from paddle.fluid.dygraph import base as imperative_base
from paddle.fluid import framework
......@@ -98,6 +98,9 @@ class HybridParallelOptimizer:
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
self._sharding_enable = (
self._hcg.get_sharding_parallel_world_size() > 1)
if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and not self._use_dp_mode:
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
......@@ -108,6 +111,11 @@ class HybridParallelOptimizer:
@imperative_base.no_grad
@framework.dygraph_only
def step(self):
# Here should use global parameter list
if self._sharding_enable:
sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
if not self._use_dp_mode and self._need_dp:
fused_allreduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
......@@ -119,15 +127,19 @@ class HybridParallelOptimizer:
startup_program=None,
parameters=None,
no_grad_set=None):
assert isinstance(loss, Variable), "The loss should be an Tensor."
parameter_list = parameters if parameters \
else self._parameter_list
else self._inner_opt._parameter_list
# Here should use global parameter list
if self._sharding_enable:
sharding_reduce_gradients(
list(self._inner_opt._parameter_list), self._hcg)
if not self._use_dp_mode and self._need_dp:
fused_allreduce_gradients(list(parameter_list), self._hcg)
return self._inner_opt.minimize(loss, startup_program, parameters,
return self._inner_opt.minimize(loss, startup_program, parameter_list,
no_grad_set)
def __getattr__(self, item):
......
......@@ -24,5 +24,6 @@ from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401
from .sharding_parallel import ShardingParallel # noqa: F401
__all__ = []
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.fluid.dygraph.layers import Layer
from .meta_parallel_base import MetaParallelBase
from ..utils.hybrid_parallel_util import broadcast_sharding_parameters
from ..utils.log_util import logger
__all__ = []
class ShardingParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super(ShardingParallel, self).__init__(layers, hcg, **kwargs)
def _prepare_for_model(self):
logger.info("start broadcast sharding parameters")
broadcast_sharding_parameters(self._layers, self._hcg)
# TODO (JZ-LIANG) to support Sharding-DP
logger.info("sharding's parameters is ready")
......@@ -119,3 +119,46 @@ def fused_allreduce_gradients(parameter_list, hcg):
logger.debug("dp start fuse allreduce gradients")
with framework.no_grad():
_apply_collective_grads(parameter_list, data_parallel_group)
def sharding_reduce_gradients(parameter_list, hcg):
# TODO allreduce --> reduce
# TODO merge grad / nrank with dp
logger.debug("sharding start gradients sync")
with framework.no_grad():
sharding_nrank = hcg.get_sharding_parallel_group().nranks
for param in parameter_list:
if param.trainable and (param._grad_ivar() is not None):
g_var = param._grad_ivar()
# need use trace_op to allreduce
# paddle.distributed.all_reduce(
# g_var, group=hcg.get_sharding_parallel_group(), use_calc_stream=True)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="c_allreduce_sum",
inputs={'X': g_var},
outputs={'Out': g_var},
attrs={
'ring_id': hcg.get_sharding_parallel_group().id,
'use_calc_stream': True
})
# grad / sharding_rank
div_factor = paddle.to_tensor(sharding_nrank, dtype=g_var.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': g_var,
'Y': div_factor},
outputs={'Out': g_var},
attrs={'axis': -1})
def broadcast_sharding_parameters(model, hcg):
# TODO TO save memory, use un-fused broadcast to avoid potentional OOM
logger.debug("sharding start init parameters sync")
sharding_parallel_group = hcg.get_sharding_parallel_group()
src_rank = hcg.get_sharding_parallel_group_src_rank()
sync_params_buffers(
model, sharding_parallel_group, src_rank, is_model_parallel=False)
......@@ -25,6 +25,7 @@ list(APPEND DIST_TEST_OPS test_parallel_dygraph_control_flow)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_dataparallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_tensor_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_sharding_parallel)
list(APPEND DIST_TEST_OPS test_parallel_dygraph_mp_layers)
set(MIXED_DIST_TEST_OPS ${DIST_TEST_OPS})
#remove distribute unittests.
......@@ -185,6 +186,7 @@ if ((NOT WITH_GPU) AND (NOT WITH_ROCM))
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_dataparallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_pipeline_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_tensor_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_sharding_parallel)
list(REMOVE_ITEM TEST_OPS test_parallel_dygraph_mp_layers)
LIST(REMOVE_ITEM TEST_OPS test_imperative_auto_mixed_precision)
LIST(REMOVE_ITEM TEST_OPS test_fleet_base_single)
......@@ -882,6 +884,7 @@ if(WITH_DISTRIBUTE AND WITH_GPU AND WITH_NCCL)
set_tests_properties(test_parallel_dygraph_control_flow PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_pipeline_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_tensor_parallel PROPERTIES TIMEOUT 200)
set_tests_properties(test_parallel_dygraph_sharding_parallel PROPERTIES TIMEOUT 120)
set_tests_properties(test_parallel_dygraph_mp_layers PROPERTIES TIMEOUT 120)
if(${NCCL_VERSION} VERSION_GREATER_EQUAL 2212)
set_tests_properties(test_parallel_dygraph_sparse_embedding PROPERTIES TIMEOUT 120)
......
......@@ -21,7 +21,8 @@ from paddle.distributed import fleet
class TestNewGroupAPI(object):
def __init__(self):
paddle.distributed.init_parallel_env()
topo = fleet.CommunicateTopology(["data", "model", "pipe"], [2, 1, 1])
topo = fleet.CommunicateTopology(["data", "model", "sharding", "pipe"],
[2, 1, 1, 1])
self.hcg = fleet.HybridCommunicateGroup(topo)
d1 = np.array([1, 2, 3])
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import division
from __future__ import print_function
import paddle
import numpy as np
import random
import paddle.distributed as dist
import paddle.fluid as fluid
import paddle.distributed.fleet as fleet
from paddle.io import DataLoader, Dataset
from paddle.distributed.fleet.meta_optimizers.dygraph_optimizer.dygraph_sharding_optimizer import DygraphShardingOptimizer
import unittest
vocab_size = 20
hidden_size = 10
inner_size = 8
output_size = 10
seq_length = 2
batch_size = 4
STEPS = 10
def parallel_matmul(lm_output, logit_weights, parallel_output):
hcg = fleet.get_hybrid_communicate_group()
model_parallel_group = hcg.get_model_parallel_group()
world_size = hcg.get_model_parallel_world_size()
rank = hcg.get_model_parallel_rank()
if world_size > 1:
input_parallel = paddle.distributed.collective._c_identity(
lm_output, group=model_parallel_group)
logits = paddle.matmul(input_parallel, logit_weights, transpose_y=True)
if parallel_output:
return logits
return paddle.distributed.collective._c_concat(
logits, group=model_parallel_group)
else:
logits = paddle.matmul(lm_output, logit_weights, transpose_y=True)
return logits
class SimpleMPNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2, mp_id):
super(SimpleMPNet, self).__init__()
if mp_id == 0:
init_fc1_data = np_fc1[:, :(inner_size // 2)]
init_fc2_data = np_fc2[:(inner_size // 2), :]
else:
init_fc1_data = np_fc1[:, (inner_size // 2):]
init_fc2_data = np_fc2[(inner_size // 2):, :]
self.linear1 = fleet.meta_parallel.ColumnParallelLinear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc1_data)),
gather_output=False,
has_bias=True)
self.linear2 = fleet.meta_parallel.RowParallelLinear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(init_fc2_data)),
input_is_parallel=True,
has_bias=True)
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.embedding = fleet.meta_parallel.VocabParallelEmbedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5))
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = parallel_matmul(x, self.embedding.weight, False)
return x
class SimpleDPNet(fluid.dygraph.Layer):
def __init__(self, vocab_size, hidden_size, inner_size, output_size, np_fc1,
np_fc2):
super(SimpleDPNet, self).__init__()
self.linear1 = paddle.nn.Linear(
hidden_size,
inner_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc1)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.linear2 = paddle.nn.Linear(
inner_size,
hidden_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Assign(np_fc2)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.linear3 = paddle.nn.Linear(
hidden_size,
output_size,
weight_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)),
bias_attr=paddle.framework.ParamAttr(
initializer=paddle.nn.initializer.Constant(0.0)))
self.embedding = paddle.nn.Embedding(
vocab_size,
hidden_size,
weight_attr=paddle.nn.initializer.Constant(value=0.5))
def forward(self, x):
x = self.embedding(x)
x = self.linear1(x)
x = self.linear2(x)
x = self.linear3(x)
x = paddle.matmul(x, self.embedding.weight, transpose_y=True)
return x
class TestDistMPTraning(unittest.TestCase):
def setUp(self):
random.seed(2021)
np.random.seed(2021)
paddle.seed(2021)
self.strategy = fleet.DistributedStrategy()
self.strategy.hybrid_configs = {
"sharding_degree": 2,
"dp_degree": 1,
"mp_degree": 1,
"pp_degree": 1,
}
fleet.init(is_collective=True, strategy=self.strategy)
self.data = [
np.random.randint(0, vocab_size, (
batch_size,
seq_length, )) for _ in range(STEPS)
]
def train_batch(self, batch, model, optimizer):
output = model(batch)
loss = output.mean()
loss.backward() # do backward
optimizer.step() # update parameters
optimizer.clear_grad()
return loss
def build_optimizer(self,
model,
strategy=None,
is_sharding=True,
Optimizer="adam"):
if Optimizer == "adam":
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Adam,
learning_rate=0.001,
weight_decay=0.00001, )
else:
optimizer = paddle.optimizer.Adam(
parameters=model.parameters(),
learning_rate=0.001,
weight_decay=0.00001, )
else:
if is_sharding:
optimizer = DygraphShardingOptimizer(
hcg=fleet.get_hybrid_communicate_group(),
user_defined_strategy=strategy,
params=model.parameters(),
inner_optimizer_class=paddle.optimizer.Momentum,
learning_rate=0.001, )
else:
optimizer = paddle.optimizer.Momentum(
learning_rate=0.001, parameters=model.parameters())
return optimizer
def build_model_optimizer(self, Optimizer="adam"):
hcg = fleet.get_hybrid_communicate_group()
word_size = hcg.get_model_parallel_world_size()
sharding_id = hcg.get_sharding_parallel_rank()
dp_id = hcg.get_data_parallel_rank()
rank_id = dist.get_rank()
np_fc1 = np.random.random_sample((hidden_size, inner_size))
np_fc2 = np.random.random_sample((inner_size, hidden_size))
model_a = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2)
optimizer_a = self.build_optimizer(
model_a,
strategy=self.strategy,
is_sharding=True,
Optimizer=Optimizer)
model_a = fleet.distributed_model(model_a)
optimizer_a = fleet.distributed_optimizer(optimizer_a)
model_b = SimpleDPNet(vocab_size, hidden_size, inner_size, output_size,
np_fc1, np_fc2)
optimizer_b = self.build_optimizer(
model_b,
strategy=self.strategy,
is_sharding=False,
Optimizer=Optimizer)
return model_a, optimizer_a, model_b, optimizer_b
def sharding_model(self, Optimizer, sharded_accumulators):
model_a, optimizer_a, model_b, optimizer_b = self.build_model_optimizer(
Optimizer=Optimizer)
self.assertTrue(
isinstance(optimizer_a._inner_opt, DygraphShardingOptimizer))
for idx in range(STEPS):
if idx == 2 and paddle.distributed.get_rank() == 0:
self.assertTrue(
set(optimizer_a._inner_opt._inner_optimizer.state_dict()
.keys()) == sharded_accumulators)
if paddle.distributed.get_rank() == 0:
batch_sharding = paddle.to_tensor(self.data[idx][:2])
else:
batch_sharding = paddle.to_tensor(self.data[idx][2:])
batch_single = paddle.to_tensor(self.data[idx])
loss_a = self.train_batch(batch_sharding, model_a, optimizer_a)
loss_b = self.train_batch(batch_single, model_b, optimizer_b)
for j in range(len(model_a.parameters())):
np.testing.assert_allclose(
model_a.parameters()[j].numpy(),
model_b.parameters()[j].numpy(),
rtol=1e-6)
def test_sharding_adam(self):
sharded_accumulators = set([
'linear_0.w_0_moment1_0', 'linear_1.b_0_moment1_0',
'linear_2.b_0_moment1_0', 'embedding_0.w_0_moment1_0',
'linear_0.w_0_moment2_0', 'linear_1.b_0_moment2_0',
'linear_2.b_0_moment2_0', 'embedding_0.w_0_moment2_0',
'linear_0.w_0_beta1_pow_acc_0', 'linear_1.b_0_beta1_pow_acc_0',
'linear_2.b_0_beta1_pow_acc_0', 'embedding_0.w_0_beta1_pow_acc_0',
'linear_0.w_0_beta2_pow_acc_0', 'linear_1.b_0_beta2_pow_acc_0',
'linear_2.b_0_beta2_pow_acc_0', 'embedding_0.w_0_beta2_pow_acc_0'
])
self.sharding_model(
Optimizer="adam", sharded_accumulators=sharded_accumulators)
def test_sharding_momentum(self):
sharded_accumulators = set([
'linear_6.w_0_velocity_0', 'linear_7.b_0_velocity_0',
'linear_8.b_0_velocity_0', 'embedding_2.w_0_velocity_0'
])
self.sharding_model(
Optimizer="Momentum", sharded_accumulators=sharded_accumulators)
if __name__ == "__main__":
unittest.main()
......@@ -79,6 +79,99 @@ class TestCommunicateTopology(unittest.TestCase):
self.assertEqual(topo.get_dim_size("mp"), 2)
self.assertEqual(topo.get_dim_size("pp"), 2)
def test_topology_4D(self):
topo = fleet.CommunicateTopology(["dp", "pp", "sharding", "mp"],
[2, 2, 2, 2])
# test get_comm_list
dp_comm_list = [[0, 8], [1, 9], [2, 10], [3, 11], [4, 12], [5, 13],
[6, 14], [7, 15]]
mp_comm_list = [[0, 1], [2, 3], [4, 5], [6, 7], [8, 9], [10, 11],
[12, 13], [14, 15]]
pp_comm_list = [[0, 4], [1, 5], [2, 6], [3, 7], [8, 12], [9, 13],
[10, 14], [11, 15]]
sharding_comm_list = [[0, 2], [1, 3], [4, 6], [5, 7], [8, 10], [9, 11],
[12, 14], [13, 15]]
np.testing.assert_array_equal(dp_comm_list, topo.get_comm_list("dp"))
np.testing.assert_array_equal(mp_comm_list, topo.get_comm_list("mp"))
np.testing.assert_array_equal(pp_comm_list, topo.get_comm_list("pp"))
np.testing.assert_array_equal(sharding_comm_list,
topo.get_comm_list("sharding"))
# test get_hybrid_group_names
parallel_names = ["dp", "pp", "sharding", "mp"]
np.testing.assert_array_equal(parallel_names,
topo.get_hybrid_group_names())
# test get_dims
np.testing.assert_array_equal(2, topo.get_dim("dp"))
np.testing.assert_array_equal(2, topo.get_dim("mp"))
np.testing.assert_array_equal(2, topo.get_dim("pp"))
np.testing.assert_array_equal(2, topo.get_dim("sharding"))
# test world size
self.assertEqual(topo.world_size(), 16)
# test get_rank
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, mp=0), 0)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=0, mp=1), 1)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, mp=0), 2)
self.assertEqual(topo.get_rank(dp=0, pp=0, sharding=1, mp=1), 3)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, mp=0), 4)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=0, mp=1), 5)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, mp=0), 6)
self.assertEqual(topo.get_rank(dp=0, pp=1, sharding=1, mp=1), 7)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, mp=0), 8)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=0, mp=1), 9)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, mp=0), 10)
self.assertEqual(topo.get_rank(dp=1, pp=0, sharding=1, mp=1), 11)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, mp=0), 12)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=0, mp=1), 13)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, mp=0), 14)
self.assertEqual(topo.get_rank(dp=1, pp=1, sharding=1, mp=1), 15)
# test get_coord
self.assertEqual(topo.get_coord(0), topo.coordinate(0, 0, 0, 0))
self.assertEqual(topo.get_coord(1), topo.coordinate(0, 0, 0, 1))
self.assertEqual(topo.get_coord(2), topo.coordinate(0, 0, 1, 0))
self.assertEqual(topo.get_coord(3), topo.coordinate(0, 0, 1, 1))
self.assertEqual(topo.get_coord(4), topo.coordinate(0, 1, 0, 0))
self.assertEqual(topo.get_coord(5), topo.coordinate(0, 1, 0, 1))
self.assertEqual(topo.get_coord(6), topo.coordinate(0, 1, 1, 0))
self.assertEqual(topo.get_coord(7), topo.coordinate(0, 1, 1, 1))
self.assertEqual(topo.get_coord(8), topo.coordinate(1, 0, 0, 0))
self.assertEqual(topo.get_coord(9), topo.coordinate(1, 0, 0, 1))
self.assertEqual(topo.get_coord(10), topo.coordinate(1, 0, 1, 0))
self.assertEqual(topo.get_coord(11), topo.coordinate(1, 0, 1, 1))
self.assertEqual(topo.get_coord(12), topo.coordinate(1, 1, 0, 0))
self.assertEqual(topo.get_coord(13), topo.coordinate(1, 1, 0, 1))
self.assertEqual(topo.get_coord(14), topo.coordinate(1, 1, 1, 0))
self.assertEqual(topo.get_coord(15), topo.coordinate(1, 1, 1, 1))
# test get_axis_list
self.assertEqual(topo.get_axis_list("dp", 0), [0, 1, 2, 3, 4, 5, 6, 7])
self.assertEqual(
topo.get_axis_list("dp", 1), [8, 9, 10, 11, 12, 13, 14, 15])
self.assertEqual(
topo.get_axis_list("mp", 0), [0, 2, 4, 6, 8, 10, 12, 14])
self.assertEqual(
topo.get_axis_list("mp", 1), [1, 3, 5, 7, 9, 11, 13, 15])
self.assertEqual(
topo.get_axis_list("pp", 0), [0, 1, 2, 3, 8, 9, 10, 11])
self.assertEqual(
topo.get_axis_list("pp", 1), [4, 5, 6, 7, 12, 13, 14, 15])
self.assertEqual(
topo.get_axis_list("sharding", 0), [0, 1, 4, 5, 8, 9, 12, 13])
self.assertEqual(
topo.get_axis_list("sharding", 1), [2, 3, 6, 7, 10, 11, 14, 15])
# test get_dim_size
self.assertEqual(topo.get_dim_size("dp"), 2)
self.assertEqual(topo.get_dim_size("mp"), 2)
self.assertEqual(topo.get_dim_size("pp"), 2)
self.assertEqual(topo.get_dim_size("sharding"), 2)
if __name__ == '__main__':
unittest.main()
......@@ -124,6 +124,8 @@ class TestMultipleGpus(unittest.TestCase):
break
time.sleep(3)
class TestDataParallelGradientCheck(TestMultipleGpus):
def test_multiple_gpus_dynamic(self):
self.run_mnist_2gpu('parallel_dygraph_gradient_check.py')
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
from test_parallel_dygraph_dataparallel import TestMultipleGpus
class TestHybridParallel(TestMultipleGpus):
# check sharding logic as well as the accuracy with single mode
def test_hybrid_parallel_sharding_logic(self):
self.run_mnist_2gpu('hybrid_parallel_sharding_model.py')
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册