未验证 提交 c809530e 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Fix precision problem of model parallel (#32897)

* fix precision of mp

* fix bug of seed

* fix dp

* print group
上级 906db719
......@@ -141,6 +141,7 @@ message PipelineConfig {
message TensorParallelConfig {
optional int32 tensor_parallel_degree = 1 [ default = 1 ];
optional int32 tensor_init_seed = 2 [ default = -1 ];
}
message DistributedStrategy {
......
......@@ -99,6 +99,13 @@ class Group():
else:
return -1
def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += ". "
return debug_str
_global_env = None
......
......@@ -949,6 +949,8 @@ class DistributedStrategy(object):
**Notes**:
**Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed
Examples:
......@@ -957,7 +959,8 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4}
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}
"""
return get_msg_dict(self.strategy.tensor_parallel_configs)
......
......@@ -17,6 +17,7 @@ import copy
import warnings
import paddle
import os
import numpy as np
from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
......@@ -28,7 +29,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper
from . import topology as tp
from .topology import ParallelMode
from ..meta_parallel import ModelParallel
from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel
from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler
......@@ -279,6 +280,14 @@ class Fleet(object):
self._hcg = tp.HybridCommunicateGroup(self._topology)
if self.mp_degree > 1:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
if tensor_init_seed == -1:
model_parallel_random_seed()
else:
model_parallel_random_seed(tensor_init_seed)
def get_hybrid_communicate_group(self):
assert self._hcg is not None
return self._hcg
......@@ -829,8 +838,8 @@ class Fleet(object):
last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy.
find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL:
distributed_model = ModelParallel(
elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel(
......
......@@ -28,7 +28,7 @@ _HYBRID_PARALLEL_GROUP = None
class ParallelMode(object):
DATA_PARALLEL = 0
MODEL_PARALLEL = 1
TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2
......@@ -155,12 +155,12 @@ class HybridCommunicateGroup(object):
_HYBRID_PARALLEL_GROUP = self
def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel
# there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed
return ParallelMode.MODEL_PARALLEL
return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL
......
......@@ -31,7 +31,7 @@ class HybridParallelGradScaler:
self._scaler = scaler
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
def scale(self, var):
return self._scaler.scale(var)
......
......@@ -90,12 +90,12 @@ class HybridParallelOptimizer:
self._strategy = strategy
self._hcg = hcg
self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL)
self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and self._is_mp:
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \
logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg)
......
......@@ -20,7 +20,7 @@ from .parallel_layers import PipelineLayer # noqa: F401
from .parallel_layers import RNGStatesTracker # noqa: F401
from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401
from .model_parallel import ModelParallel # noqa: F401
from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401
__all__ = []
......@@ -41,6 +41,7 @@ class VocabParallelEmbedding(Layer):
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
per_part_size = (
num_embeddings + self.world_size - 1) // self.world_size
......@@ -50,16 +51,36 @@ class VocabParallelEmbedding(Layer):
per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size
self.embedding = paddle.nn.Embedding(
per_part_size,
embedding_dim,
padding_idx=per_part_size - 1,
sparse=False,
weight_attr=weight_attr,
name=name)
self.embedding.weight.is_distributed = True
self._dtype = self._helper.get_default_dtype()
self._size = [per_part_size, embedding_dim]
self._weight_attr = weight_attr
self._name = name
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight[per_part_size - 1] = 0.0
self.weight.is_distributed = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=[num_embeddings, embedding_dim],
dtype=self._dtype,
is_bias=False)
def forward(self, x):
if not self.is_mp:
return F.embedding(
x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)
origin_input_shape = x.shape
if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1)
......@@ -72,13 +93,18 @@ class VocabParallelEmbedding(Layer):
if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1)
emb_out = self.embedding(x_shard)
if self.world_size > 1:
emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
emb_out = F.embedding(
x_shard,
weight=self.weight,
padding_idx=self.per_part_size - 1,
sparse=False,
name=self._name)
emb_out = paddle.distributed.collective._mp_allreduce(
emb_out,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
return emb_out
......@@ -96,8 +122,9 @@ class ColumnParallelLinear(Layer):
)
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self._name = name
self.is_mp = (self.world_size > 1)
self.name = name
self.gather_output = gather_output
assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be"
......@@ -108,10 +135,20 @@ class ColumnParallelLinear(Layer):
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
if has_bias:
......@@ -119,18 +156,24 @@ class ColumnParallelLinear(Layer):
self.bias = self.create_parameter(
shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True
else:
self.bias = None
def forward(self, x):
# use inner api to process identity
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group)
else:
input_parallel = x
output_parallel = F.linear(
input_parallel, self.weight, self.bias, name=self.name)
if self.gather_output:
input_parallel, self.weight, self.bias, name=self._name)
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat(
output_parallel,
nranks=self.world_size,
......@@ -155,7 +198,7 @@ class RowParallelLinear(Layer):
self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype()
self.name = name
self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
......@@ -163,6 +206,7 @@ class RowParallelLinear(Layer):
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(in_features,
......@@ -170,22 +214,33 @@ class RowParallelLinear(Layer):
self.input_size_per_partition = in_features // self.world_size
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype)
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True
if has_bias:
self.bias = self.create_parameter(
shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype)
dtype=self._dtype,
is_bias=True)
else:
self.bias = None
def forward(self, x):
if self.input_is_parallel:
if self.input_is_parallel or (not self.is_mp):
input_parallel = x
else:
# split last dim
......@@ -195,12 +250,16 @@ class RowParallelLinear(Layer):
nranks=self.world_size,
group=self.model_parallel_group)
output_parallel = F.linear(input_parallel, self.weight, name=self.name)
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
output_parallel = F.linear(input_parallel, self.weight, name=self._name)
if self.is_mp:
output_ = paddle.distributed.collective._mp_allreduce(
output_parallel,
group=self.model_parallel_group,
use_calc_stream=True,
use_model_parallel=True)
else:
output_ = output_parallel
output = output_ + self.bias if self.bias is not None else output_
return output
......@@ -14,6 +14,7 @@
import paddle
import contextlib
import numpy as np
__all__ = []
......@@ -65,14 +66,18 @@ def get_rng_state_tracker():
return RNG_STATE_TRACKER
def model_parallel_random_seed(seed=2048):
def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank()
local_seed = seed + 1024 + rank
global_seed = seed
if seed:
global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
RNG_STATE_TRACKER.reset()
paddle.seed(global_seed)
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
......@@ -22,15 +22,15 @@ from ..utils.log_util import logger
__all__ = []
class ModelParallel(MetaParallelBase):
class TensorParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs):
super(ModelParallel, self).__init__(layers, hcg, **kwargs)
super(TensorParallel, self).__init__(layers, hcg, **kwargs)
def _prepare_for_model(self):
logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg)
logger.info("start broadcast mp parameters")
logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg)
logger.info("mp's parameters is ready")
......
......@@ -44,7 +44,15 @@ def _apply_collective_grads(parameters, comm_group):
for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks
coalesced_grad = coalesced_grad / comm_group.nranks
div_factor = paddle.to_tensor(
comm_group.nranks, dtype=coalesced_grad.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad,
'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1})
paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
_split_tensors(coalesced_grads_and_vars)
......
......@@ -231,7 +231,7 @@ class TestDistTraning(unittest.TestCase):
# model_b
check_group = dist.new_group(list(range(self.model_parallel_size)))
integral_w = []
partial_w = model_a.embedding.embedding.weight.clone().detach()
partial_w = model_a.embedding.weight.clone().detach()
paddle.distributed.all_gather(integral_w, partial_w, group=check_group)
result_w = []
for idx in range(len(integral_w)):
......
......@@ -27,6 +27,7 @@ class TestNewGroupAPI(object):
def test_all(self):
gp = paddle.distributed.new_group([0, 1])
print("gp info:", gp)
print("test new group api ok")
tmp = np.array([0, 0, 0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册