未验证 提交 a0dffd39 编写于 作者: L LiYuRio 提交者: GitHub

Move group and all reduce from collective to communication (#45848)

上级 45b93325
...@@ -293,6 +293,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce( ...@@ -293,6 +293,14 @@ std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs, std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts) { const AllreduceOptions& opts) {
return AllReduce(inputs, outputs, opts, true);
}
std::shared_ptr<ProcessGroup::Task> ProcessGroupGloo::AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) {
auto tag = next_tag(); auto tag = next_tag();
std::shared_ptr<GlooTask> task; std::shared_ptr<GlooTask> task;
auto context = get_context(); auto context = get_context();
......
...@@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup { ...@@ -120,6 +120,12 @@ class ProcessGroupGloo : public ProcessGroup {
std::vector<phi::DenseTensor>& outputs, std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts = AllreduceOptions()) override; const AllreduceOptions& opts = AllreduceOptions()) override;
std::shared_ptr<ProcessGroup::Task> AllReduce(
std::vector<phi::DenseTensor>& inputs,
std::vector<phi::DenseTensor>& outputs,
const AllreduceOptions& opts,
bool sync_op) override;
std::shared_ptr<ProcessGroup::Task> Barrier( std::shared_ptr<ProcessGroup::Task> Barrier(
const BarrierOptions& = BarrierOptions()) override; const BarrierOptions& = BarrierOptions()) override;
......
...@@ -52,54 +52,12 @@ from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy ...@@ -52,54 +52,12 @@ from .fleet.layers.mpu.mp_ops import _c_softmax_with_cross_entropy
from .fleet.layers.mpu.mp_ops import _linear from .fleet.layers.mpu.mp_ops import _linear
from .fleet.layers.mpu.mp_ops import _parallel_linear from .fleet.layers.mpu.mp_ops import _parallel_linear
from .fleet.layers.mpu.mp_ops import _parallel_embedding from .fleet.layers.mpu.mp_ops import _parallel_embedding
from .communication.comm_utils import ReduceOp from .communication.group import Group, _add_new_group
from .communication.all_reduce import all_reduce
from .communication.reduce import _get_reduce_op, ReduceOp
__all__ = [] __all__ = []
class Group():
"""
The abstract representation of group.
"""
def __init__(self, rank, rank_num, id=0, ranks=[], pg=None, name=None):
self.rank = rank
self.nranks = rank_num
self.id = id
self.ranks = ranks
self.pg = pg
self.name = name
def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True
def get_group_rank(self, rank):
if self.is_member() and rank in self.ranks:
return self.ranks.index(rank)
else:
return -1
@property
def process_group(self):
return self.pg
@property
def world_size(self):
return self.nranks if self.rank >= 0 else -1
def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str
_global_env = None _global_env = None
...@@ -147,9 +105,8 @@ def _get_group_map(): ...@@ -147,9 +105,8 @@ def _get_group_map():
global _group_map global _group_map
if _global_env_gid not in _group_map: if _global_env_gid not in _group_map:
genv = _get_global_env() genv = _get_global_env()
_group_map[_global_env_gid] = Group(genv.rank, _group_map[_global_env_gid] = Group(genv.rank, 0,
genv.world_size, list(range(genv.world_size)))
ranks=list(range(genv.world_size)))
return _group_map return _group_map
...@@ -197,19 +154,6 @@ def _new_ring_id(): ...@@ -197,19 +154,6 @@ def _new_ring_id():
return len(_get_group_map()) + max(_get_global_env().nrings, 9) return len(_get_group_map()) + max(_get_global_env().nrings, 9)
def _get_reduce_op(reduce_op, func_name):
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
raise ValueError("Unknown reduce_op type for {}.".format(func_name))
def get_group(id=0): def get_group(id=0):
""" """
...@@ -451,10 +395,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ...@@ -451,10 +395,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
else: else:
rank = -1 rank = -1
pg = None pg = None
group = Group(rank, size, id=gid, ranks=ranks, pg=pg, name=group_name) group = Group(rank, gid, ranks, pg=pg, name=group_name)
_group_map_by_name[group_name] = group _group_map_by_name[group_name] = group
_group_map[gid] = group _group_map[gid] = group
_group_map_backend[group] = backend _group_map_backend[group] = backend
#TODO: The method below is a new method for group management, will replace the previous
# three in the future.
_add_new_group(group)
# TODO(shenliang03): This is a temporary solution to solve the problem of # TODO(shenliang03): This is a temporary solution to solve the problem of
# hang caused by tcp # hang caused by tcp
...@@ -476,13 +423,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout): ...@@ -476,13 +423,13 @@ def new_group(ranks=None, backend=None, timeout=_default_timeout):
ring_id = _new_ring_id() ring_id = _new_ring_id()
if global_rank not in ranks: if global_rank not in ranks:
gp = Group(-1, -1, ring_id, ranks) gp = Group(-1, ring_id, ranks)
_group_map[ring_id] = gp _group_map[ring_id] = gp
else: else:
ranks = sorted(ranks) ranks = sorted(ranks)
group_rank = ranks.index(global_rank) group_rank = ranks.index(global_rank)
group_size = len(ranks) group_size = len(ranks)
gp = Group(group_rank, group_size, ring_id, ranks) gp = Group(group_rank, ring_id, ranks)
_group_map[ring_id] = gp _group_map[ring_id] = gp
if group_size >= 2: if group_size >= 2:
...@@ -748,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True): ...@@ -748,104 +695,6 @@ def broadcast(tensor, src, group=None, sync_op=True):
}) })
def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
None.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if group is not None and not group.is_member():
return
if in_dygraph_mode():
op_type = _get_reduce_op(op, "all_reduce")
group = _get_default_group() if group is None else group
task = group.process_group.allreduce(tensor, op_type)
if sync_op:
task.wait()
return None
else:
return task
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if _non_static_mode():
if op == ReduceOp.SUM:
return _legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return _legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return _legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return _legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))
check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
], 'all_reduce')
if op == ReduceOp.SUM:
op_type = 'c_allreduce_sum'
elif op == ReduceOp.MAX:
op_type = 'c_allreduce_max'
elif op == ReduceOp.MIN:
op_type = 'c_allreduce_min'
elif op == ReduceOp.PROD:
op_type = 'c_allreduce_prod'
if not isinstance(ring_id, int):
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
helper = LayerHelper(op_type, **locals())
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': ring_id,
'use_calc_stream': use_calc_stream
})
def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True): def reduce(tensor, dst, op=ReduceOp.SUM, group=None, sync_op=True):
""" """
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
import paddle.fluid.framework as framework
from paddle.distributed.communication import stream as stream
from paddle.distributed.communication.reduce import ReduceOp
def all_reduce(tensor, op=ReduceOp.SUM, group=None, sync_op=True):
"""
Reduce a tensor over all ranks so that all get the result.
As shown below, one process is started with a GPU and the data of this process is represented
by its group rank. The reduce operator is sum. Through all_reduce operator,
each GPU will have the sum of the data from all GPUs.
.. image:: https://githubraw.cdn.bcebos.com/PaddlePaddle/docs/develop/docs/api/paddle/distributed/img/allreduce.png
:width: 800
:alt: all_reduce
:align: center
Args:
tensor (Tensor): The input Tensor. It also works as the output Tensor. Its data type
should be float16, float32, float64, int32, int64, int8, uint8 or bool.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The operation used. Default value is ReduceOp.SUM.
group (Group, optional): The group instance return by new_group or None for global default group.
sync_op (bool, optional): Wether this op is a sync op. Default value is True.
Returns:
Return a task object.
Examples:
.. code-block:: python
# required: distributed
import paddle
import paddle.distributed as dist
dist.init_parallel_env()
if dist.get_rank() == 0:
data = paddle.to_tensor([[4, 5, 6], [4, 5, 6]])
else:
data = paddle.to_tensor([[1, 2, 3], [1, 2, 3]])
dist.all_reduce(data)
print(data)
# [[5, 7, 9], [5, 7, 9]] (2 GPUs)
"""
if not framework._in_legacy_dygraph():
return stream.all_reduce(tensor,
op=op,
group=group,
sync_op=sync_op,
use_calc_stream=False)
# code below will be removed after we remove the old dygraph
use_calc_stream = sync_op
ring_id = 0 if group is None else group.id
if op == ReduceOp.SUM:
return paddle._legacy_C_ops.c_allreduce_sum_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MAX:
return paddle._legacy_C_ops.c_allreduce_max_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.MIN:
return paddle._legacy_C_ops.c_allreduce_min_(tensor, 'use_calc_stream',
use_calc_stream, 'ring_id',
ring_id)
elif op == ReduceOp.PROD:
return paddle._legacy_C_ops.c_allreduce_prod_(tensor, 'use_calc_stream',
use_calc_stream,
'ring_id', ring_id)
else:
raise ValueError("Unknown parameter: {}.".format(op))
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
class Group():
"""
The abstract representation of group.
"""
def __init__(self, rank_in_group, id, ranks, pg=None, name=None):
self._rank_in_group = rank_in_group
self._world_size = len(ranks) if rank_in_group >= 0 else -1
self._id = id
self._ranks = ranks
self._pg = pg
self._name = name
@property
def rank(self):
return self._rank_in_group
@property
def ranks(self):
return self._ranks
@property
def nranks(self):
return len(self._ranks)
@property
def name(self):
return self._name
@property
def process_group(self):
return self._pg
@property
def world_size(self):
return self._world_size
@property
def id(self):
return self._id
def is_member(self):
if self.rank < 0:
return False
if self.nranks < 2:
return False
return True
def get_group_rank(self, rank):
if self.is_member():
return self.ranks.index(rank)
else:
return -1
def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
debug_str += ", ".join(map(str, self.ranks))
debug_str += "; name: "
debug_str += self.name if self.name else "None"
return debug_str
class _GroupManager():
global_group_id = 0
group_map_by_id = {}
def _get_global_group():
if _GroupManager.global_group_id not in _GroupManager.group_map_by_id:
raise RuntimeError("The global group is not initialized.")
return _GroupManager.group_map_by_id[_GroupManager.global_group_id]
def _add_new_group(group):
if group.id in _GroupManager.group_map_by_id:
raise RuntimeError("The group with id {} already exist.".format(
group.id))
_GroupManager.group_map_by_id[group.id] = group
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import paddle.fluid.framework as framework
import paddle.fluid.core as core
class ReduceOp: class ReduceOp:
""" """
...@@ -48,3 +51,26 @@ class ReduceOp: ...@@ -48,3 +51,26 @@ class ReduceOp:
MIN = 2 MIN = 2
PROD = 3 PROD = 3
AVG = 4 AVG = 4
def _get_reduce_op(reduce_op, func_name):
if framework.in_dygraph_mode():
if reduce_op == ReduceOp.SUM:
return core.ReduceOp.SUM
elif reduce_op == ReduceOp.MAX:
return core.ReduceOp.MAX
elif reduce_op == ReduceOp.MIN:
return core.ReduceOp.MIN
elif reduce_op == ReduceOp.PROD:
return core.ReduceOp.PRODUCT
else:
if reduce_op == ReduceOp.SUM:
return 'c_allreduce_sum'
elif reduce_op == ReduceOp.MAX:
return 'c_allreduce_max'
elif reduce_op == ReduceOp.MIN:
return 'c_allreduce_min'
elif reduce_op == ReduceOp.PROD:
return 'c_allreduce_prod'
raise ValueError("Unknown reduce_op type for {}.".format(func_name))
...@@ -13,12 +13,16 @@ ...@@ -13,12 +13,16 @@
# limitations under the License. # limitations under the License.
import paddle.fluid.framework as framework import paddle.fluid.framework as framework
from paddle.distributed import collective import paddle.fluid.data_feeder as data_feeder
import paddle.fluid.layer_helper as layer_helper
from paddle.distributed.communication.reduce import _get_reduce_op, ReduceOp
from paddle.distributed.communication.group import _get_global_group
def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
op_type = collective._get_reduce_op(op, "all_reduce") op_type = _get_reduce_op(op, "all_reduce")
group = collective._get_default_group() if group is None else group
group = _get_global_group() if group is None else group
if use_calc_stream: if use_calc_stream:
return group.process_group.allreduce_on_calc_stream(tensor, op_type) return group.process_group.allreduce_on_calc_stream(tensor, op_type)
...@@ -29,8 +33,34 @@ def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream): ...@@ -29,8 +33,34 @@ def _all_reduce_in_dygraph(tensor, op, group, sync_op, use_calc_stream):
return task return task
def _all_reduce_in_static_mode(tensor, op, group, sync_op, use_calc_stream):
data_feeder.check_variable_and_dtype(tensor, 'tensor', [
'float16', 'float32', 'float64', 'int32', 'int64', 'int8', 'uint8',
'bool'
], 'all_reduce')
op_type = _get_reduce_op(op, "all_reduce")
ring_id = 0 if group is None else group.id
if not isinstance(ring_id, int):
raise ValueError("The type of 'ring_id' for all_reduce should be int.")
# TODO: Support task and use task.wait in static mode
# Use use_calc_stream rather than sync_op
helper = layer_helper.LayerHelper(op_type, **locals())
helper.append_op(type=op_type,
inputs={'X': [tensor]},
outputs={'Out': [tensor]},
attrs={
'ring_id': ring_id,
'use_calc_stream': sync_op
})
return None
def all_reduce(tensor, def all_reduce(tensor,
op=collective.ReduceOp.SUM, op=ReduceOp.SUM,
group=None, group=None,
sync_op=True, sync_op=True,
use_calc_stream=False): use_calc_stream=False):
...@@ -41,7 +71,7 @@ def all_reduce(tensor, ...@@ -41,7 +71,7 @@ def all_reduce(tensor,
Args: Args:
tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support tensor (Tensor): The input tensor on each rank. The result will overwrite this tenor after communication. Support
float16, float32, float64, int32 or int64 as the input data type. float16, float32, float64, int32 or int64 as the input data type.
op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.Min|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default. op (ReduceOp.SUM|ReduceOp.MAX|ReduceOp.MIN|ReduceOp.PROD, optional): The reduction used. If none is given, use ReduceOp.SUM as default.
group (Group, optional): Communicate in which group. If none is given, use the global group as default. group (Group, optional): Communicate in which group. If none is given, use the global group as default.
sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default. sync_op (bool, optional): Indicate whether the communication is sync or not. If none is given, use true as default.
use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This use_calc_stream (bool, optional): Indicate whether the communication is done on calculation stream. If none is given, use false as default. This
...@@ -50,9 +80,6 @@ def all_reduce(tensor, ...@@ -50,9 +80,6 @@ def all_reduce(tensor,
Returns: Returns:
Return a task object. Return a task object.
Warning:
This API only supports the dygraph mode now.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -84,7 +111,6 @@ def all_reduce(tensor, ...@@ -84,7 +111,6 @@ def all_reduce(tensor,
if framework.in_dygraph_mode(): if framework.in_dygraph_mode():
return _all_reduce_in_dygraph(tensor, op, group, sync_op, return _all_reduce_in_dygraph(tensor, op, group, sync_op,
use_calc_stream) use_calc_stream)
else:
raise RuntimeError( return _all_reduce_in_static_mode(tensor, op, group, sync_op,
"paddle.distributed.stream.all_reduce is only supported in dygraph mode now." use_calc_stream)
)
...@@ -377,8 +377,8 @@ class _CommunicateGroup(object): ...@@ -377,8 +377,8 @@ class _CommunicateGroup(object):
def set_comm_group(self, group_name, group_rank, group_size, ring_id, def set_comm_group(self, group_name, group_rank, group_size, ring_id,
group_ranks): group_ranks):
group = paddle.distributed.collective.Group(group_rank, group_size, group = paddle.distributed.collective.Group(group_rank, ring_id,
ring_id, group_ranks) group_ranks)
self.groups[group_name] = group self.groups[group_name] = group
def get_group(self, group_name): def get_group(self, group_name):
......
...@@ -22,7 +22,7 @@ from paddle.fluid.layer_helper import LayerHelper ...@@ -22,7 +22,7 @@ from paddle.fluid.layer_helper import LayerHelper
from paddle.fluid.data_feeder import check_variable_and_dtype from paddle.fluid.data_feeder import check_variable_and_dtype
from paddle.fluid.dygraph import layers from paddle.fluid.dygraph import layers
from paddle.distributed import collective from paddle.distributed import collective
from ....communication.comm_utils import ReduceOp from ....communication.reduce import ReduceOp
from paddle.fluid.data_feeder import check_dtype from paddle.fluid.data_feeder import check_dtype
import paddle.fluid.dygraph_utils as dygraph_utils import paddle.fluid.dygraph_utils as dygraph_utils
......
...@@ -43,6 +43,7 @@ from paddle.distributed.collective import _set_default_store ...@@ -43,6 +43,7 @@ from paddle.distributed.collective import _set_default_store
from paddle.distributed.collective import _new_process_group_impl from paddle.distributed.collective import _new_process_group_impl
from paddle.distributed.collective import Group from paddle.distributed.collective import Group
from paddle.distributed.collective import _set_group_map_backend from paddle.distributed.collective import _set_group_map_backend
from paddle.distributed.communication.group import _add_new_group
__all__ = [] __all__ = []
...@@ -258,15 +259,11 @@ def init_parallel_env(): ...@@ -258,15 +259,11 @@ def init_parallel_env():
_default_group_name, _default_group_name,
pg_options=None) pg_options=None)
ranks = list(range(world_size)) ranks = list(range(world_size))
group = Group(rank, group = Group(rank, 0, ranks, pg=pg, name=_default_group_name)
world_size,
id=0,
ranks=ranks,
pg=pg,
name=_default_group_name)
_set_group_map_by_name(_default_group_name, group) _set_group_map_by_name(_default_group_name, group)
_set_group_map(0, group) _set_group_map(0, group)
_set_group_map_backend(group, backend) _set_group_map_backend(group, backend)
_add_new_group(group)
parallel_helper._set_parallel_ctx(True) parallel_helper._set_parallel_ctx(True)
paddle.distributed.barrier(group=group) paddle.distributed.barrier(group=group)
......
...@@ -265,7 +265,6 @@ class MoELayer(nn.Layer): ...@@ -265,7 +265,6 @@ class MoELayer(nn.Layer):
from paddle.distributed import fleet from paddle.distributed import fleet
moe_group = Group(fleet.worker_index(), moe_group = Group(fleet.worker_index(),
fleet.worker_num(),
0, 0,
list(range(fleet.worker_num()))) list(range(fleet.worker_num())))
mp_group = None mp_group = None
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册