未验证 提交 c809530e 编写于 作者: S ShenLiang 提交者: GitHub

[HybridParallel]Fix precision problem of model parallel (#32897)

* fix precision of mp

* fix bug of seed

* fix dp

* print group
上级 906db719
...@@ -141,6 +141,7 @@ message PipelineConfig { ...@@ -141,6 +141,7 @@ message PipelineConfig {
message TensorParallelConfig { message TensorParallelConfig {
optional int32 tensor_parallel_degree = 1 [ default = 1 ]; optional int32 tensor_parallel_degree = 1 [ default = 1 ];
optional int32 tensor_init_seed = 2 [ default = -1 ];
} }
message DistributedStrategy { message DistributedStrategy {
......
...@@ -99,6 +99,13 @@ class Group(): ...@@ -99,6 +99,13 @@ class Group():
else: else:
return -1 return -1
def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += ". "
return debug_str
_global_env = None _global_env = None
......
...@@ -949,6 +949,8 @@ class DistributedStrategy(object): ...@@ -949,6 +949,8 @@ class DistributedStrategy(object):
**Notes**: **Notes**:
**Detailed arguments for tensor_parallel_configs** **Detailed arguments for tensor_parallel_configs**
**tensor_parallel_degree**: degree of tensor parallel **tensor_parallel_degree**: degree of tensor parallel
**tensor_init_seed**: parameter initialization random seed
Examples: Examples:
...@@ -957,7 +959,8 @@ class DistributedStrategy(object): ...@@ -957,7 +959,8 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy() strategy = fleet.DistributedStrategy()
strategy.tensor_parallel = True strategy.tensor_parallel = True
strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4} strategy.tensor_parallel_configs = {"tensor_parallel_degree": 4,
"tensor_init_seed": 123}
""" """
return get_msg_dict(self.strategy.tensor_parallel_configs) return get_msg_dict(self.strategy.tensor_parallel_configs)
......
...@@ -17,6 +17,7 @@ import copy ...@@ -17,6 +17,7 @@ import copy
import warnings import warnings
import paddle import paddle
import os import os
import numpy as np
from paddle.fluid.framework import dygraph_only from paddle.fluid.framework import dygraph_only
from paddle.fluid import compiler from paddle.fluid import compiler
from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase from .role_maker import UserDefinedRoleMaker, PaddleCloudRoleMaker, RoleMakerBase
...@@ -28,7 +29,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator ...@@ -28,7 +29,7 @@ from paddle.fluid.wrapped_decorator import wrap_decorator
from paddle.fluid.dygraph import parallel_helper from paddle.fluid.dygraph import parallel_helper
from . import topology as tp from . import topology as tp
from .topology import ParallelMode from .topology import ParallelMode
from ..meta_parallel import ModelParallel from ..meta_parallel import TensorParallel, model_parallel_random_seed
from ..meta_parallel import PipelineParallel from ..meta_parallel import PipelineParallel
from ..meta_optimizers import HybridParallelOptimizer from ..meta_optimizers import HybridParallelOptimizer
from ..meta_optimizers import HybridParallelGradScaler from ..meta_optimizers import HybridParallelGradScaler
...@@ -279,6 +280,14 @@ class Fleet(object): ...@@ -279,6 +280,14 @@ class Fleet(object):
self._hcg = tp.HybridCommunicateGroup(self._topology) self._hcg = tp.HybridCommunicateGroup(self._topology)
if self.mp_degree > 1:
tensor_parallel_configs = self._user_defined_strategy.tensor_parallel_configs
tensor_init_seed = tensor_parallel_configs["tensor_init_seed"]
if tensor_init_seed == -1:
model_parallel_random_seed()
else:
model_parallel_random_seed(tensor_init_seed)
def get_hybrid_communicate_group(self): def get_hybrid_communicate_group(self):
assert self._hcg is not None assert self._hcg is not None
return self._hcg return self._hcg
...@@ -829,8 +838,8 @@ class Fleet(object): ...@@ -829,8 +838,8 @@ class Fleet(object):
last_comm_group_size_MB, last_comm_group_size_MB,
find_unused_parameters=self._user_defined_strategy. find_unused_parameters=self._user_defined_strategy.
find_unused_parameters) find_unused_parameters)
elif self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL: elif self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL:
distributed_model = ModelParallel( distributed_model = TensorParallel(
model, self._hcg, strategy=self._user_defined_strategy) model, self._hcg, strategy=self._user_defined_strategy)
elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL: elif self._hcg.get_parallel_mode() == ParallelMode.PIPELINE_PARALLEL:
distributed_model = PipelineParallel( distributed_model = PipelineParallel(
......
...@@ -28,7 +28,7 @@ _HYBRID_PARALLEL_GROUP = None ...@@ -28,7 +28,7 @@ _HYBRID_PARALLEL_GROUP = None
class ParallelMode(object): class ParallelMode(object):
DATA_PARALLEL = 0 DATA_PARALLEL = 0
MODEL_PARALLEL = 1 TENSOR_PARALLEL = 1
PIPELINE_PARALLEL = 2 PIPELINE_PARALLEL = 2
...@@ -155,12 +155,12 @@ class HybridCommunicateGroup(object): ...@@ -155,12 +155,12 @@ class HybridCommunicateGroup(object):
_HYBRID_PARALLEL_GROUP = self _HYBRID_PARALLEL_GROUP = self
def get_parallel_mode(self): def get_parallel_mode(self):
# there are three modes : DataParallel / ModelParallel / PipelineParallel # there are three modes : DataParallel / TensorParallel / PipelineParallel
if self._mp_degree == 1 and self._pp_degree == 1: if self._mp_degree == 1 and self._pp_degree == 1:
return ParallelMode.DATA_PARALLEL return ParallelMode.DATA_PARALLEL
elif self._mp_degree > 1 and self._pp_degree == 1: elif self._mp_degree > 1 and self._pp_degree == 1:
# initialize the seed # initialize the seed
return ParallelMode.MODEL_PARALLEL return ParallelMode.TENSOR_PARALLEL
elif self._pp_degree > 1: elif self._pp_degree > 1:
return ParallelMode.PIPELINE_PARALLEL return ParallelMode.PIPELINE_PARALLEL
......
...@@ -31,7 +31,7 @@ class HybridParallelGradScaler: ...@@ -31,7 +31,7 @@ class HybridParallelGradScaler:
self._scaler = scaler self._scaler = scaler
self._hcg = hcg self._hcg = hcg
self._is_mp = ( self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
def scale(self, var): def scale(self, var):
return self._scaler.scale(var) return self._scaler.scale(var)
......
...@@ -90,12 +90,12 @@ class HybridParallelOptimizer: ...@@ -90,12 +90,12 @@ class HybridParallelOptimizer:
self._strategy = strategy self._strategy = strategy
self._hcg = hcg self._hcg = hcg
self._is_mp = ( self._is_mp = (
self._hcg.get_parallel_mode() == ParallelMode.MODEL_PARALLEL) self._hcg.get_parallel_mode() == ParallelMode.TENSOR_PARALLEL)
self._need_dp = (self._hcg.get_data_parallel_world_size() > 1) self._need_dp = (self._hcg.get_data_parallel_world_size() > 1)
if isinstance(self._inner_opt._grad_clip, if isinstance(self._inner_opt._grad_clip,
ClipGradByGlobalNorm) and self._is_mp: ClipGradByGlobalNorm) and self._is_mp:
logger.warning("using ClipGradByGlobalNorm in ModelParallel, the origin " \ logger.warning("using ClipGradByGlobalNorm in TensorParallel, the origin " \
"optmizer'grad clip will be changed.") "optmizer'grad clip will be changed.")
self._inner_opt._grad_clip = HybridParallelClipGrad( self._inner_opt._grad_clip = HybridParallelClipGrad(
self._inner_opt._grad_clip, hcg) self._inner_opt._grad_clip, hcg)
......
...@@ -20,7 +20,7 @@ from .parallel_layers import PipelineLayer # noqa: F401 ...@@ -20,7 +20,7 @@ from .parallel_layers import PipelineLayer # noqa: F401
from .parallel_layers import RNGStatesTracker # noqa: F401 from .parallel_layers import RNGStatesTracker # noqa: F401
from .parallel_layers import model_parallel_random_seed # noqa: F401 from .parallel_layers import model_parallel_random_seed # noqa: F401
from .parallel_layers import get_rng_state_tracker # noqa: F401 from .parallel_layers import get_rng_state_tracker # noqa: F401
from .model_parallel import ModelParallel # noqa: F401 from .tensor_parallel import TensorParallel # noqa: F401
from .pipeline_parallel import PipelineParallel # noqa: F401 from .pipeline_parallel import PipelineParallel # noqa: F401
__all__ = [] __all__ = []
...@@ -41,6 +41,7 @@ class VocabParallelEmbedding(Layer): ...@@ -41,6 +41,7 @@ class VocabParallelEmbedding(Layer):
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
self.origin_num_embeddings = num_embeddings self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
per_part_size = ( per_part_size = (
num_embeddings + self.world_size - 1) // self.world_size num_embeddings + self.world_size - 1) // self.world_size
...@@ -50,16 +51,36 @@ class VocabParallelEmbedding(Layer): ...@@ -50,16 +51,36 @@ class VocabParallelEmbedding(Layer):
per_part_size += 1 # make the last row as the padding index per_part_size += 1 # make the last row as the padding index
self.per_part_size = per_part_size self.per_part_size = per_part_size
self.embedding = paddle.nn.Embedding( self._dtype = self._helper.get_default_dtype()
per_part_size, self._size = [per_part_size, embedding_dim]
embedding_dim, self._weight_attr = weight_attr
padding_idx=per_part_size - 1, self._name = name
sparse=False,
weight_attr=weight_attr, if self.is_mp:
name=name) with get_rng_state_tracker().rng_state():
self.embedding.weight.is_distributed = True self.weight = self.create_parameter(
attr=self._weight_attr,
shape=self._size,
dtype=self._dtype,
is_bias=False)
self.weight[per_part_size - 1] = 0.0
self.weight.is_distributed = True
else:
self.weight = self.create_parameter(
attr=self._weight_attr,
shape=[num_embeddings, embedding_dim],
dtype=self._dtype,
is_bias=False)
def forward(self, x): def forward(self, x):
if not self.is_mp:
return F.embedding(
x,
weight=self.weight,
padding_idx=None,
sparse=False,
name=self._name)
origin_input_shape = x.shape origin_input_shape = x.shape
if len(origin_input_shape) == 2: if len(origin_input_shape) == 2:
x = paddle.unsqueeze(x, axis=-1) x = paddle.unsqueeze(x, axis=-1)
...@@ -72,8 +93,13 @@ class VocabParallelEmbedding(Layer): ...@@ -72,8 +93,13 @@ class VocabParallelEmbedding(Layer):
if len(origin_input_shape) == 2: if len(origin_input_shape) == 2:
x_shard = paddle.squeeze(x_shard, axis=-1) x_shard = paddle.squeeze(x_shard, axis=-1)
emb_out = self.embedding(x_shard) emb_out = F.embedding(
if self.world_size > 1: x_shard,
weight=self.weight,
padding_idx=self.per_part_size - 1,
sparse=False,
name=self._name)
emb_out = paddle.distributed.collective._mp_allreduce( emb_out = paddle.distributed.collective._mp_allreduce(
emb_out, emb_out,
group=self.model_parallel_group, group=self.model_parallel_group,
...@@ -96,8 +122,9 @@ class ColumnParallelLinear(Layer): ...@@ -96,8 +122,9 @@ class ColumnParallelLinear(Layer):
) )
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) )
self._name = name
self.is_mp = (self.world_size > 1)
self.name = name
self.gather_output = gather_output self.gather_output = gather_output
assert out_features % self.world_size == 0, ( assert out_features % self.world_size == 0, (
"Number of column of the weight for linear ({}) must be" "Number of column of the weight for linear ({}) must be"
...@@ -108,10 +135,20 @@ class ColumnParallelLinear(Layer): ...@@ -108,10 +135,20 @@ class ColumnParallelLinear(Layer):
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition], shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr, attr=self._weight_attr,
dtype=self._dtype) dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter(
shape=[in_features, self.output_size_per_partition],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True self.weight.is_distributed = True
if has_bias: if has_bias:
...@@ -119,18 +156,24 @@ class ColumnParallelLinear(Layer): ...@@ -119,18 +156,24 @@ class ColumnParallelLinear(Layer):
self.bias = self.create_parameter( self.bias = self.create_parameter(
shape=[self.output_size_per_partition], shape=[self.output_size_per_partition],
attr=paddle.nn.initializer.Constant(value=0.0), attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype) dtype=self._dtype,
is_bias=True)
self.bias.is_distributed = True self.bias.is_distributed = True
else: else:
self.bias = None self.bias = None
def forward(self, x): def forward(self, x):
# use inner api to process identity # use inner api to process identity
if self.is_mp:
input_parallel = paddle.distributed.collective._c_identity( input_parallel = paddle.distributed.collective._c_identity(
x, group=self.model_parallel_group) x, group=self.model_parallel_group)
else:
input_parallel = x
output_parallel = F.linear( output_parallel = F.linear(
input_parallel, self.weight, self.bias, name=self.name) input_parallel, self.weight, self.bias, name=self._name)
if self.gather_output:
if self.gather_output and self.is_mp:
output = paddle.distributed.collective._c_concat( output = paddle.distributed.collective._c_concat(
output_parallel, output_parallel,
nranks=self.world_size, nranks=self.world_size,
...@@ -155,7 +198,7 @@ class RowParallelLinear(Layer): ...@@ -155,7 +198,7 @@ class RowParallelLinear(Layer):
self.input_is_parallel = input_is_parallel self.input_is_parallel = input_is_parallel
self._weight_attr = weight_attr self._weight_attr = weight_attr
self._dtype = self._helper.get_default_dtype() self._dtype = self._helper.get_default_dtype()
self.name = name self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) )
...@@ -163,6 +206,7 @@ class RowParallelLinear(Layer): ...@@ -163,6 +206,7 @@ class RowParallelLinear(Layer):
) )
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, ( assert in_features % self.world_size == 0, (
"Number of row of the weight for linear ({}) must be" "Number of row of the weight for linear ({}) must be"
" divisible by model parallel size ({})".format(in_features, " divisible by model parallel size ({})".format(in_features,
...@@ -170,22 +214,33 @@ class RowParallelLinear(Layer): ...@@ -170,22 +214,33 @@ class RowParallelLinear(Layer):
self.input_size_per_partition = in_features // self.world_size self.input_size_per_partition = in_features // self.world_size
if self.is_mp:
with get_rng_state_tracker().rng_state():
self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr,
dtype=self._dtype,
is_bias=False)
else:
self.weight = self.create_parameter( self.weight = self.create_parameter(
shape=[self.input_size_per_partition, self.out_features], shape=[self.input_size_per_partition, self.out_features],
attr=self._weight_attr, attr=self._weight_attr,
dtype=self._dtype) dtype=self._dtype,
is_bias=False)
self.weight.is_distributed = True self.weight.is_distributed = True
if has_bias: if has_bias:
self.bias = self.create_parameter( self.bias = self.create_parameter(
shape=[self.out_features], shape=[self.out_features],
attr=paddle.nn.initializer.Constant(value=0.0), attr=paddle.nn.initializer.Constant(value=0.0),
dtype=self._dtype) dtype=self._dtype,
is_bias=True)
else: else:
self.bias = None self.bias = None
def forward(self, x): def forward(self, x):
if self.input_is_parallel: if self.input_is_parallel or (not self.is_mp):
input_parallel = x input_parallel = x
else: else:
# split last dim # split last dim
...@@ -195,12 +250,16 @@ class RowParallelLinear(Layer): ...@@ -195,12 +250,16 @@ class RowParallelLinear(Layer):
nranks=self.world_size, nranks=self.world_size,
group=self.model_parallel_group) group=self.model_parallel_group)
output_parallel = F.linear(input_parallel, self.weight, name=self.name) output_parallel = F.linear(input_parallel, self.weight, name=self._name)
if self.is_mp:
output_ = paddle.distributed.collective._mp_allreduce( output_ = paddle.distributed.collective._mp_allreduce(
output_parallel, output_parallel,
group=self.model_parallel_group, group=self.model_parallel_group,
use_calc_stream=True, use_calc_stream=True,
use_model_parallel=True) use_model_parallel=True)
else:
output_ = output_parallel
output = output_ + self.bias if self.bias is not None else output_ output = output_ + self.bias if self.bias is not None else output_
return output return output
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
import paddle import paddle
import contextlib import contextlib
import numpy as np
__all__ = [] __all__ = []
...@@ -65,14 +66,18 @@ def get_rng_state_tracker(): ...@@ -65,14 +66,18 @@ def get_rng_state_tracker():
return RNG_STATE_TRACKER return RNG_STATE_TRACKER
def model_parallel_random_seed(seed=2048): def model_parallel_random_seed(seed=None):
import paddle.distributed.fleet as fleet import paddle.distributed.fleet as fleet
hcg = fleet.get_hybrid_communicate_group() hcg = fleet.get_hybrid_communicate_group()
rank = hcg.get_model_parallel_rank() rank = hcg.get_model_parallel_rank()
local_seed = seed + 1024 + rank if seed:
global_seed = seed global_seed = seed
local_seed = seed * 1024 + rank * 100
else:
global_seed = np.random.randint(0, 655350)
local_seed = np.random.randint(rank * 10000, (rank + 1) * 10000 - 1)
RNG_STATE_TRACKER.reset() RNG_STATE_TRACKER.reset()
paddle.seed(global_seed)
RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed) RNG_STATE_TRACKER.add(MODEL_PARALLEL_RNG, local_seed)
paddle.seed(global_seed)
...@@ -22,15 +22,15 @@ from ..utils.log_util import logger ...@@ -22,15 +22,15 @@ from ..utils.log_util import logger
__all__ = [] __all__ = []
class ModelParallel(MetaParallelBase): class TensorParallel(MetaParallelBase):
def __init__(self, layers, hcg, **kwargs): def __init__(self, layers, hcg, **kwargs):
super(ModelParallel, self).__init__(layers, hcg, **kwargs) super(TensorParallel, self).__init__(layers, hcg, **kwargs)
def _prepare_for_model(self): def _prepare_for_model(self):
logger.info("start broadcast mp parameters") logger.info("start broadcast mp parameters")
broadcast_mp_parameters(self._layers, self._hcg) broadcast_mp_parameters(self._layers, self._hcg)
logger.info("start broadcast mp parameters") logger.info("start broadcast dp parameters")
broadcast_dp_parameters(self._layers, self._hcg) broadcast_dp_parameters(self._layers, self._hcg)
logger.info("mp's parameters is ready") logger.info("mp's parameters is ready")
......
...@@ -44,7 +44,15 @@ def _apply_collective_grads(parameters, comm_group): ...@@ -44,7 +44,15 @@ def _apply_collective_grads(parameters, comm_group):
for coalesced_grad, _, _ in coalesced_grads_and_vars: for coalesced_grad, _, _ in coalesced_grads_and_vars:
# need to div nranks # need to div nranks
coalesced_grad = coalesced_grad / comm_group.nranks div_factor = paddle.to_tensor(
comm_group.nranks, dtype=coalesced_grad.dtype)
paddle.fluid.framework._dygraph_tracer().trace_op(
type="elementwise_div",
inputs={'X': coalesced_grad,
'Y': div_factor},
outputs={'Out': coalesced_grad},
attrs={'axis': -1})
paddle.distributed.all_reduce(coalesced_grad, group=comm_group) paddle.distributed.all_reduce(coalesced_grad, group=comm_group)
_split_tensors(coalesced_grads_and_vars) _split_tensors(coalesced_grads_and_vars)
......
...@@ -231,7 +231,7 @@ class TestDistTraning(unittest.TestCase): ...@@ -231,7 +231,7 @@ class TestDistTraning(unittest.TestCase):
# model_b # model_b
check_group = dist.new_group(list(range(self.model_parallel_size))) check_group = dist.new_group(list(range(self.model_parallel_size)))
integral_w = [] integral_w = []
partial_w = model_a.embedding.embedding.weight.clone().detach() partial_w = model_a.embedding.weight.clone().detach()
paddle.distributed.all_gather(integral_w, partial_w, group=check_group) paddle.distributed.all_gather(integral_w, partial_w, group=check_group)
result_w = [] result_w = []
for idx in range(len(integral_w)): for idx in range(len(integral_w)):
......
...@@ -27,6 +27,7 @@ class TestNewGroupAPI(object): ...@@ -27,6 +27,7 @@ class TestNewGroupAPI(object):
def test_all(self): def test_all(self):
gp = paddle.distributed.new_group([0, 1]) gp = paddle.distributed.new_group([0, 1])
print("gp info:", gp)
print("test new group api ok") print("test new group api ok")
tmp = np.array([0, 0, 0]) tmp = np.array([0, 0, 0])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册