未验证 提交 c17e6af8 编写于 作者: H Haohongxiang 提交者: GitHub

[Fleet] Reconstruct of Fleet API in Dygraph Mode (#44922)

* reconstruct_of_fleet_api

* update
上级 6452ab3b
...@@ -41,14 +41,16 @@ class VocabParallelEmbedding(Layer): ...@@ -41,14 +41,16 @@ class VocabParallelEmbedding(Layer):
num_embeddings, num_embeddings,
embedding_dim, embedding_dim,
weight_attr=None, weight_attr=None,
mp_group=None,
name=None): name=None):
super(VocabParallelEmbedding, self).__init__() super(VocabParallelEmbedding, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) ) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) ) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.origin_num_embeddings = num_embeddings self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1) self.is_mp = (self.world_size > 1)
...@@ -108,14 +110,15 @@ class ColumnParallelLinear(Layer): ...@@ -108,14 +110,15 @@ class ColumnParallelLinear(Layer):
weight_attr=None, weight_attr=None,
has_bias=None, has_bias=None,
gather_output=True, gather_output=True,
name=None, fuse_matmul_bias=False,
fuse_matmul_bias=False): mp_group=None,
name=None):
super(ColumnParallelLinear, self).__init__() super(ColumnParallelLinear, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) ) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) ) if mp_group is None else mp_group.nranks
self._name = name self._name = name
self.is_mp = (self.world_size > 1) self.is_mp = (self.world_size > 1)
...@@ -197,8 +200,9 @@ class RowParallelLinear(Layer): ...@@ -197,8 +200,9 @@ class RowParallelLinear(Layer):
weight_attr=None, weight_attr=None,
has_bias=True, has_bias=True,
input_is_parallel=False, input_is_parallel=False,
name=None, fuse_matmul_bias=False,
fuse_matmul_bias=False): mp_group=None,
name=None):
super(RowParallelLinear, self).__init__() super(RowParallelLinear, self).__init__()
self.in_features = in_features self.in_features = in_features
...@@ -209,10 +213,11 @@ class RowParallelLinear(Layer): ...@@ -209,10 +213,11 @@ class RowParallelLinear(Layer):
self._name = name self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) ) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) ) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.is_mp = (self.world_size > 1) self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, ( assert in_features % self.world_size == 0, (
...@@ -288,14 +293,15 @@ class RowParallelLinear(Layer): ...@@ -288,14 +293,15 @@ class RowParallelLinear(Layer):
class ParallelCrossEntropy(Layer): class ParallelCrossEntropy(Layer):
def __init__(self, name=None): def __init__(self, mp_group=None, name=None):
super(ParallelCrossEntropy, self).__init__() super(ParallelCrossEntropy, self).__init__()
self.name = name self.name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group( self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
) ) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size( self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
) ) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank() self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
def forward(self, input, label): def forward(self, input, label):
loss = paddle.distributed.collective._c_softmax_with_cross_entropy( loss = paddle.distributed.collective._c_softmax_with_cross_entropy(
......
...@@ -109,16 +109,19 @@ def _broadcast_data_help(data, shape, dtype, hcg): ...@@ -109,16 +109,19 @@ def _broadcast_data_help(data, shape, dtype, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs): def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
for v in inputs: for v in inputs:
if isinstance(v, core.VarBase): if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad(): with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
_broadcast_data_help(v, v.shape, v.dtype, hcg) _broadcast_data_help(v, v.shape, v.dtype, hcg)
else: else:
logger.error("it doesn't support data type {}".format(type(v))) logger.error("it doesn't support data type {}".format(type(v)))
for k, v in kwargs.items(): for k, v in kwargs.items():
if isinstance(v, core.VarBase): if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad(): with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
_broadcast_data_help(v, v.shape, v.dtype, hcg) _broadcast_data_help(v, v.shape, v.dtype, hcg)
kwargs[k] = v kwargs[k] = v
else: else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册