未验证 提交 bb725e3a 编写于 作者: L LiYuRio 提交者: GitHub

add group argument (#44758)

上级 2cda4e21
......@@ -110,6 +110,10 @@ class Group():
def process_group(self):
return self.pg
@property
def world_size(self):
return self.nranks if self.rank >= 0 else -1
def __repr__(self):
debug_str = "rank: {}, nranks: {}, id: {}, ranks: ".format(
self.rank, self.nranks, self.id)
......
......@@ -359,47 +359,65 @@ def init_parallel_env():
return group
def get_rank():
def get_rank(group=None):
"""
Returns the rank of current trainer.
Returns the rank of current trainer in the given group, ranks are consecutive integers in [0, ``world_size``).
If none of the group is given, the global group will be used as default.
Its value is equal to the value of the environment variable ``PADDLE_TRAINER_ID`` .
The default value is 0.
Args:
group (Group, optional): The communication group you want to get rank of current trainer, use global group as default if group is None.
Returns:
(int) The rank of current trainer.
(int) The rank of current trainer in the given group. Return -1 if the process is not part of the given group.
Warning:
Argument ``group`` only supports in dygraph mode.
Examples:
.. code-block:: python
# Execute this script using distributed launch with one card configs.
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINER_ID=0
dist.init_parallel_env()
print("The rank is %d" % dist.get_rank())
# The rank is 0
"""
if in_dygraph_mode() and group:
return group.rank
assert group is None, "Only support group argument in eager mode."
return _get_global_parallel_env().rank
def get_world_size():
def get_world_size(group=None):
"""
Returns the number of trainers (number of processes participating in current job).
Returns the number of trainers (number of processes participating in current job) in the given group.
If none of the group is given, the global group will be used as default.
Its value is equal to the value of the environment variable ``PADDLE_TRAINERS_NUM`` .
The default value is 1.
Args:
group (Group, optional): The communication group you want to check world size, use global group as default if group is None.
Returns:
(int) The number of trainers.
(int) The number of trainers in the given group. Return -1 if the process if not part of the given group.
Warning:
Argument ``group`` only supports in dygraph mode.
Examples:
.. code-block:: python
# Execute this script using distributed launch with one card configs.
import paddle
import paddle.distributed as dist
# execute this command in terminal: export PADDLE_TRAINERS_NUM=4
dist.init_parallel_env()
print("The world_size is %d" % dist.get_world_size())
# The world_size is 4
# The world_size is 1
"""
if in_dygraph_mode() and group:
return group.world_size
assert group is None, "Only support group argument in eager mode."
return _get_global_parallel_env().world_size
......@@ -304,5 +304,16 @@ if((WITH_GPU OR WITH_ROCM) AND (LINUX))
set_tests_properties(test_communication_stream_allreduce_api
PROPERTIES TIMEOUT "120" LABELS "RUN_TYPE=DIST")
endif()
if((WITH_ROCM OR WITH_GPU) AND (LINUX))
bash_test_modules(
test_world_size_and_rank
START_BASH
test_world_size_and_rank.sh
LABELS
"RUN_TYPE=DIST"
ENVS
"PADDLE_DIST_UT_PORT=21532;http_proxy=;https_proxy=")
set_tests_properties(test_world_size_and_rank PROPERTIES TIMEOUT "120")
endif()
add_subdirectory(fleet)
add_subdirectory(multinode)
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
set -e
CUDA_VISIBLE_DEVICES=0,1 python -m paddle.distributed.launch --gpus=0,1 world_size_and_rank.py
......@@ -36,3 +36,4 @@ test_eager_dist_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_
test_new_group_api,linux,gpu;rocm,120,DIST,test_runner.py,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_gen_nccl_id_op,,gpu;rocm;ASCEND;ASCEND_CL,,DIST,../dist_test.sh,2,,http_proxy=;https_proxy=;PYTHONPATH=..,
test_communication_stream_allreduce_api,linux,gpu;rocm,120,DIST,,2,,PYTHONPATH=..;http_proxy=;https_proxy=,
test_world_size_and_rank,linux,rocm;gpu,120,DIST,test_world_size_and_rank.sh,2,,http_proxy=;https_proxy=,
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import paddle.distributed as dist
class TestWorldSizeAndRankAPI(unittest.TestCase):
def setUp(self):
self._num_of_ranks = 2
self._subgroup_ranks = [0, 1]
dist.init_parallel_env()
self._subgroup = dist.new_group(self._subgroup_ranks)
self._global_rank = dist.get_rank()
def test_default_env_world_size(self):
self.assertEqual(dist.get_world_size(), self._num_of_ranks)
def test_given_group_world_size(self):
world_size = 2 if self._global_rank in self._subgroup_ranks else -1
self.assertEqual(dist.get_world_size(self._subgroup), world_size)
def test_given_group_rank(self):
rank = self._subgroup_ranks.index(
self._global_rank
) if self._global_rank in self._subgroup_ranks else -1
self.assertEqual(dist.get_rank(self._subgroup), rank)
if __name__ == '__main__':
unittest.main()
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class DemoNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class BatchNormActNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class BatchNormAddActNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class DemoNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class ReluDepthwiseConvNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class DemoNet(nn.Layer):
......
......@@ -21,6 +21,8 @@ from paddle.distributed.passes import new_pass, PassManager
import unittest
from dist_pass_test_base import DistPassTestBase
paddle.enable_static()
class DemoNet(nn.Layer):
......
......@@ -627,7 +627,8 @@ class TestParallelDyGraphRunnerBase(object):
np.random.seed(seed)
random.seed(seed)
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
paddle.distributed.parallel._get_global_parallel_env()
args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# 3. init parallel env
if args.update_method in ["nccl2", "gloo"]:
......@@ -666,7 +667,8 @@ class TestParallelDyGraphRunnerBase(object):
np.random.seed(seed)
random.seed(seed)
# get trainer id
args.trainer_id = paddle.distributed.get_rank()
paddle.distributed.parallel._get_global_parallel_env()
args.trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
# set strategy
strategy = fleet.DistributedStrategy()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册