未验证 提交 c17e6af8 编写于 作者: H Haohongxiang 提交者: GitHub

[Fleet] Reconstruct of Fleet API in Dygraph Mode (#44922)

* reconstruct_of_fleet_api

* update
上级 6452ab3b
......@@ -41,14 +41,16 @@ class VocabParallelEmbedding(Layer):
num_embeddings,
embedding_dim,
weight_attr=None,
mp_group=None,
name=None):
super(VocabParallelEmbedding, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.origin_num_embeddings = num_embeddings
self.is_mp = (self.world_size > 1)
......@@ -108,14 +110,15 @@ class ColumnParallelLinear(Layer):
weight_attr=None,
has_bias=None,
gather_output=True,
name=None,
fuse_matmul_bias=False):
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(ColumnParallelLinear, self).__init__()
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
) if mp_group is None else mp_group.nranks
self._name = name
self.is_mp = (self.world_size > 1)
......@@ -197,8 +200,9 @@ class RowParallelLinear(Layer):
weight_attr=None,
has_bias=True,
input_is_parallel=False,
name=None,
fuse_matmul_bias=False):
fuse_matmul_bias=False,
mp_group=None,
name=None):
super(RowParallelLinear, self).__init__()
self.in_features = in_features
......@@ -209,10 +213,11 @@ class RowParallelLinear(Layer):
self._name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
self.is_mp = (self.world_size > 1)
assert in_features % self.world_size == 0, (
......@@ -288,14 +293,15 @@ class RowParallelLinear(Layer):
class ParallelCrossEntropy(Layer):
def __init__(self, name=None):
def __init__(self, mp_group=None, name=None):
super(ParallelCrossEntropy, self).__init__()
self.name = name
self.model_parallel_group = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_group(
)
) if mp_group is None else mp_group
self.world_size = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_world_size(
)
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank()
) if mp_group is None else mp_group.nranks
self.rank = tp._HYBRID_PARALLEL_GROUP.get_model_parallel_rank(
) if mp_group is None else mp_group.rank
def forward(self, input, label):
loss = paddle.distributed.collective._c_softmax_with_cross_entropy(
......
......@@ -109,16 +109,19 @@ def _broadcast_data_help(data, shape, dtype, hcg):
def broadcast_input_data(hcg, *inputs, **kwargs):
cur_device = paddle.get_device()
for v in inputs:
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
_broadcast_data_help(v, v.shape, v.dtype, hcg)
else:
logger.error("it doesn't support data type {}".format(type(v)))
for k, v in kwargs.items():
if isinstance(v, core.VarBase):
if isinstance(v, (core.VarBase, core.eager.Tensor)):
with framework.no_grad():
v = v.cuda() if "gpu" in cur_device else v
_broadcast_data_help(v, v.shape, v.dtype, hcg)
kwargs[k] = v
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册