未验证 提交 4fa8a676 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix NCCL_VERSION Check Tools (#53990)

上级 dc3c0de1
...@@ -12,51 +12,42 @@ ...@@ -12,51 +12,42 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import logging
import subprocess
from paddle.fluid import core
def get_nccl_version_str():
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
# NOTE: This is a hacking method to get nccl version, but it will return None def get_nccl_version_str(ver):
# if current platform is not Linux. So we only check nccl version for Linux if ver >= 10000:
# platform while training with pipeline parallelism. NCCL_MAJOR_VERSION = int(ver // 10000)
if nccl_version_str: ver = ver % 10000
nccl_version_str = nccl_version_str.replace("\n", "") else:
NCCL_MAJOR_VERSION = int(ver // 1000)
ver = ver % 1000
NCCL_MINOR_VERSION = int(ver // 100)
NCCL_PATCH_VERSION = int(ver % 100)
return nccl_version_str return "{}.{}.{}".format(
NCCL_MAJOR_VERSION, NCCL_MINOR_VERSION, NCCL_PATCH_VERSION
)
def check_nccl_version_for_p2p(): def check_nccl_version_for_p2p():
nccl_version_str = get_nccl_version_str() nccl_version = core.nccl_version()
if nccl_version_str: nccl_version_str = get_nccl_version_str(nccl_version)
nccl_version_str = nccl_version_str.replace("\n", "") nccl_version_baseline = 2804
nccl_version_int = [int(s) for s in nccl_version_str.split(".")] assert nccl_version >= nccl_version_baseline, (
nccl_version_baseline = [2, 8, 4] "The version of NCCL is required to be at least v2.8.4 while training with "
assert nccl_version_int >= nccl_version_baseline, ( "pipeline/MoE parallelism, but we found v{}. The previous version of NCCL has "
"The version of NCCL is required to be at least v2.8.4 while training with " "some bugs in p2p communication, and you can see more detailed description "
"pipeline/MoE parallelism, but we found v{}. The previous version of NCCL has " "about this issue from ReleaseNotes of NCCL v2.8.4 "
"some bugs in p2p communication, and you can see more detailed description " "(https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4).".format(
"about this issue from ReleaseNotes of NCCL v2.8.4 " nccl_version_str
"(https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4).".format(
nccl_version_str
)
) )
else: )
logging.warning("No version for NCCL library found!")
def check_nccl_version_for_bf16(): def check_nccl_version_for_bf16():
nccl_version_str = get_nccl_version_str() nccl_version = core.nccl_version()
if nccl_version_str: nccl_version_baseline = 21000
nccl_version_str = nccl_version_str.replace("\n", "") return nccl_version >= nccl_version_baseline
nccl_version_int = [int(s) for s in nccl_version_str.split(".")]
nccl_version_baseline = [2, 10, 0]
return nccl_version_int >= nccl_version_baseline
return False
...@@ -33,7 +33,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ...@@ -33,7 +33,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler, GroupShardedScaler,
) )
from paddle.distributed.utils.nccl_utils import get_nccl_version_str from paddle.fluid import core
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -119,7 +119,7 @@ class RandomDataset(paddle.io.Dataset): ...@@ -119,7 +119,7 @@ class RandomDataset(paddle.io.Dataset):
def optimizer_setting(model, use_pure_fp16, opt_group=False): def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0) clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum( optimizer = paddle.optimizer.AdamW(
parameters=[{"params": list(model.parameters())}] parameters=[{"params": list(model.parameters())}]
if opt_group if opt_group
else list(model.parameters()), else list(model.parameters()),
...@@ -364,14 +364,9 @@ def test_stage2_stage3(): ...@@ -364,14 +364,9 @@ def test_stage2_stage3():
) )
# bfp16 # bfp16
# NOTE: this is a hack to get int format nccl version, like 2134 nccl_version = core.nccl_version()
# if current platform is not linux, version number will be 0
nccl_version_str = get_nccl_version_str()
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100: if nccl_version >= 21000:
stage2_params = train_mlp( stage2_params = train_mlp(
mlp11, mlp11,
sharding_stage=2, sharding_stage=2,
...@@ -388,8 +383,8 @@ def test_stage2_stage3(): ...@@ -388,8 +383,8 @@ def test_stage2_stage3():
) )
for i in range(len(stage2_params)): for i in range(len(stage2_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
stage2_params[i].numpy(), stage2_params[i].astype("float32").numpy(),
stage3_params[i].numpy(), stage3_params[i].astype("float32").numpy(),
rtol=1e-4, rtol=1e-4,
atol=1e-3, atol=1e-3,
) )
......
...@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import ...@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import ( from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler, GroupShardedScaler,
) )
from paddle.distributed.utils.nccl_utils import get_nccl_version_str from paddle.fluid import core
from paddle.nn import Linear from paddle.nn import Linear
epoch = 10 epoch = 10
...@@ -214,22 +214,16 @@ def test_stage3_offload(): ...@@ -214,22 +214,16 @@ def test_stage3_offload():
) )
# bfp16 offload # bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134 nccl_version = core.nccl_version()
# if current platform is not linux, version number will be 0 if nccl_version >= 21000:
nccl_version_str = get_nccl_version_str()
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True) stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True)
stage3_params_offload = train_mlp( stage3_params_offload = train_mlp(
mlp8, use_pure_fp16=True, offload=True, use_bfp16=True mlp8, use_pure_fp16=True, offload=True, use_bfp16=True
) )
for i in range(len(stage3_params)): for i in range(len(stage3_params)):
np.testing.assert_allclose( np.testing.assert_allclose(
stage3_params[i].numpy(), stage3_params[i].astype("float32").numpy(),
stage3_params_offload[i].numpy(), stage3_params_offload[i].astype("float32").numpy(),
rtol=1e-2, rtol=1e-2,
atol=1e-2, atol=1e-2,
) )
......
...@@ -55,7 +55,7 @@ class TestCollectiveAllgatherAPI(TestDistBase): ...@@ -55,7 +55,7 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
...@@ -118,7 +118,7 @@ class TestCollectiveAllgatherAPI(TestDistBase): ...@@ -118,7 +118,7 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -46,7 +46,7 @@ class TestCollectiveAllreduceAPI(TestDistBase): ...@@ -46,7 +46,7 @@ class TestCollectiveAllreduceAPI(TestDistBase):
red_types_to_test = [ red_types_to_test = [
dist.ReduceOp.SUM, dist.ReduceOp.SUM,
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
for red_type in red_types_to_test: for red_type in red_types_to_test:
...@@ -107,7 +107,7 @@ class TestCollectiveAllreduceAPI(TestDistBase): ...@@ -107,7 +107,7 @@ class TestCollectiveAllreduceAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -39,7 +39,7 @@ class TestCollectiveAllToAllAPI(TestDistBase): ...@@ -39,7 +39,7 @@ class TestCollectiveAllToAllAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -32,7 +32,7 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase): ...@@ -32,7 +32,7 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -43,7 +43,7 @@ class TestCollectiveBroadcastAPI(TestDistBase): ...@@ -43,7 +43,7 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
...@@ -92,7 +92,7 @@ class TestCollectiveBroadcastAPI(TestDistBase): ...@@ -92,7 +92,7 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -36,7 +36,7 @@ class TestCollectiveGatherAPI(TestDistBase): ...@@ -36,7 +36,7 @@ class TestCollectiveGatherAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -32,7 +32,7 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase): ...@@ -32,7 +32,7 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -44,7 +44,7 @@ class TestCollectiveReduceAPI(TestDistBase): ...@@ -44,7 +44,7 @@ class TestCollectiveReduceAPI(TestDistBase):
red_types_to_test = [ red_types_to_test = [
dist.ReduceOp.SUM, dist.ReduceOp.SUM,
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
if paddle.fluid.core.is_compiled_with_cuda(): if paddle.fluid.core.is_compiled_with_cuda():
...@@ -102,7 +102,7 @@ class TestCollectiveReduceAPI(TestDistBase): ...@@ -102,7 +102,7 @@ class TestCollectiveReduceAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -32,7 +32,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase): ...@@ -32,7 +32,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
...@@ -54,7 +54,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase): ...@@ -54,7 +54,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -44,7 +44,7 @@ class TestCollectiveScatterAPI(TestDistBase): ...@@ -44,7 +44,7 @@ class TestCollectiveScatterAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -41,7 +41,7 @@ class TestCollectiveSendRecvAPI(TestDistBase): ...@@ -41,7 +41,7 @@ class TestCollectiveSendRecvAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
if paddle.fluid.core.is_compiled_with_cuda(): if paddle.fluid.core.is_compiled_with_cuda():
...@@ -64,7 +64,7 @@ class TestCollectiveSendRecvAPI(TestDistBase): ...@@ -64,7 +64,7 @@ class TestCollectiveSendRecvAPI(TestDistBase):
"uint8", "uint8",
"bool", "bool",
] ]
if self._nccl_version >= 2100: if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16") dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test: for dtype in dtypes_to_test:
self.check_with_place( self.check_with_place(
......
...@@ -27,7 +27,6 @@ from eager_op_test import convert_float_to_uint16, convert_uint16_to_float ...@@ -27,7 +27,6 @@ from eager_op_test import convert_float_to_uint16, convert_uint16_to_float
import paddle import paddle
import paddle.distributed as dist import paddle.distributed as dist
from paddle import fluid from paddle import fluid
from paddle.distributed.utils.nccl_utils import get_nccl_version_str
from paddle.fluid import core from paddle.fluid import core
...@@ -194,10 +193,7 @@ class TestDistBase(unittest.TestCase): ...@@ -194,10 +193,7 @@ class TestDistBase(unittest.TestCase):
# NOTE: this is a hack to get int format nccl version, like 2134 # NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0 # if current platform is not linux, version number will be 0
nccl_version_str = get_nccl_version_str() self._nccl_version = core.nccl_version()
self._nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
def tearDown(self): def tearDown(self):
self.temp_dir.cleanup() self.temp_dir.cleanup()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册