未验证 提交 4fa8a676 编写于 作者: H Haohongxiang 提交者: GitHub

[Dygraph] Fix NCCL_VERSION Check Tools (#53990)

上级 dc3c0de1
......@@ -12,51 +12,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import subprocess
from paddle.fluid import core
def get_nccl_version_str():
nccl_version_str = subprocess.check_output(
r"ldconfig -v | grep 'libnccl.so' | tail -n1 | sed -r 's/^.*\.so\.//'",
stderr=subprocess.DEVNULL,
shell=True,
).decode('utf-8')
# NOTE: This is a hacking method to get nccl version, but it will return None
# if current platform is not Linux. So we only check nccl version for Linux
# platform while training with pipeline parallelism.
if nccl_version_str:
nccl_version_str = nccl_version_str.replace("\n", "")
def get_nccl_version_str(ver):
if ver >= 10000:
NCCL_MAJOR_VERSION = int(ver // 10000)
ver = ver % 10000
else:
NCCL_MAJOR_VERSION = int(ver // 1000)
ver = ver % 1000
NCCL_MINOR_VERSION = int(ver // 100)
NCCL_PATCH_VERSION = int(ver % 100)
return nccl_version_str
return "{}.{}.{}".format(
NCCL_MAJOR_VERSION, NCCL_MINOR_VERSION, NCCL_PATCH_VERSION
)
def check_nccl_version_for_p2p():
nccl_version_str = get_nccl_version_str()
if nccl_version_str:
nccl_version_str = nccl_version_str.replace("\n", "")
nccl_version_int = [int(s) for s in nccl_version_str.split(".")]
nccl_version_baseline = [2, 8, 4]
assert nccl_version_int >= nccl_version_baseline, (
"The version of NCCL is required to be at least v2.8.4 while training with "
"pipeline/MoE parallelism, but we found v{}. The previous version of NCCL has "
"some bugs in p2p communication, and you can see more detailed description "
"about this issue from ReleaseNotes of NCCL v2.8.4 "
"(https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4).".format(
nccl_version_str
)
nccl_version = core.nccl_version()
nccl_version_str = get_nccl_version_str(nccl_version)
nccl_version_baseline = 2804
assert nccl_version >= nccl_version_baseline, (
"The version of NCCL is required to be at least v2.8.4 while training with "
"pipeline/MoE parallelism, but we found v{}. The previous version of NCCL has "
"some bugs in p2p communication, and you can see more detailed description "
"about this issue from ReleaseNotes of NCCL v2.8.4 "
"(https://docs.nvidia.com/deeplearning/nccl/release-notes/rel_2-8-4.html#rel_2-8-4).".format(
nccl_version_str
)
else:
logging.warning("No version for NCCL library found!")
)
def check_nccl_version_for_bf16():
nccl_version_str = get_nccl_version_str()
if nccl_version_str:
nccl_version_str = nccl_version_str.replace("\n", "")
nccl_version_int = [int(s) for s in nccl_version_str.split(".")]
nccl_version_baseline = [2, 10, 0]
return nccl_version_int >= nccl_version_baseline
return False
nccl_version = core.nccl_version()
nccl_version_baseline = 21000
return nccl_version >= nccl_version_baseline
......@@ -33,7 +33,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
from paddle.distributed.utils.nccl_utils import get_nccl_version_str
from paddle.fluid import core
from paddle.nn import Linear
epoch = 10
......@@ -119,7 +119,7 @@ class RandomDataset(paddle.io.Dataset):
def optimizer_setting(model, use_pure_fp16, opt_group=False):
clip = paddle.nn.ClipGradByGlobalNorm(clip_norm=1.0)
optimizer = paddle.optimizer.Momentum(
optimizer = paddle.optimizer.AdamW(
parameters=[{"params": list(model.parameters())}]
if opt_group
else list(model.parameters()),
......@@ -364,14 +364,9 @@ def test_stage2_stage3():
)
# bfp16
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = get_nccl_version_str()
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
nccl_version = core.nccl_version()
if nccl_version >= 2100:
if nccl_version >= 21000:
stage2_params = train_mlp(
mlp11,
sharding_stage=2,
......@@ -388,8 +383,8 @@ def test_stage2_stage3():
)
for i in range(len(stage2_params)):
np.testing.assert_allclose(
stage2_params[i].numpy(),
stage3_params[i].numpy(),
stage2_params[i].astype("float32").numpy(),
stage3_params[i].astype("float32").numpy(),
rtol=1e-4,
atol=1e-3,
)
......
......@@ -24,7 +24,7 @@ from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_stage3 import
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_utils import (
GroupShardedScaler,
)
from paddle.distributed.utils.nccl_utils import get_nccl_version_str
from paddle.fluid import core
from paddle.nn import Linear
epoch = 10
......@@ -214,22 +214,16 @@ def test_stage3_offload():
)
# bfp16 offload
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = get_nccl_version_str()
nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
if nccl_version >= 2100:
nccl_version = core.nccl_version()
if nccl_version >= 21000:
stage3_params = train_mlp(mlp7, use_pure_fp16=True, use_bfp16=True)
stage3_params_offload = train_mlp(
mlp8, use_pure_fp16=True, offload=True, use_bfp16=True
)
for i in range(len(stage3_params)):
np.testing.assert_allclose(
stage3_params[i].numpy(),
stage3_params_offload[i].numpy(),
stage3_params[i].astype("float32").numpy(),
stage3_params_offload[i].astype("float32").numpy(),
rtol=1e-2,
atol=1e-2,
)
......
......@@ -55,7 +55,7 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......@@ -118,7 +118,7 @@ class TestCollectiveAllgatherAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -46,7 +46,7 @@ class TestCollectiveAllreduceAPI(TestDistBase):
red_types_to_test = [
dist.ReduceOp.SUM,
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
for red_type in red_types_to_test:
......@@ -107,7 +107,7 @@ class TestCollectiveAllreduceAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -39,7 +39,7 @@ class TestCollectiveAllToAllAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -32,7 +32,7 @@ class TestCollectiveAllToAllSingleAPI(test_base.TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -43,7 +43,7 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......@@ -92,7 +92,7 @@ class TestCollectiveBroadcastAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -36,7 +36,7 @@ class TestCollectiveGatherAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -32,7 +32,7 @@ class TestCollectiveIsendIrecvAPI(test_base.TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -44,7 +44,7 @@ class TestCollectiveReduceAPI(TestDistBase):
red_types_to_test = [
dist.ReduceOp.SUM,
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
if paddle.fluid.core.is_compiled_with_cuda():
......@@ -102,7 +102,7 @@ class TestCollectiveReduceAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -32,7 +32,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......@@ -54,7 +54,7 @@ class TestCollectiveReduceScatterAPI(test_base.TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -44,7 +44,7 @@ class TestCollectiveScatterAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -41,7 +41,7 @@ class TestCollectiveSendRecvAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
if paddle.fluid.core.is_compiled_with_cuda():
......@@ -64,7 +64,7 @@ class TestCollectiveSendRecvAPI(TestDistBase):
"uint8",
"bool",
]
if self._nccl_version >= 2100:
if self._nccl_version >= 21000:
dtypes_to_test.append("bfloat16")
for dtype in dtypes_to_test:
self.check_with_place(
......
......@@ -27,7 +27,6 @@ from eager_op_test import convert_float_to_uint16, convert_uint16_to_float
import paddle
import paddle.distributed as dist
from paddle import fluid
from paddle.distributed.utils.nccl_utils import get_nccl_version_str
from paddle.fluid import core
......@@ -194,10 +193,7 @@ class TestDistBase(unittest.TestCase):
# NOTE: this is a hack to get int format nccl version, like 2134
# if current platform is not linux, version number will be 0
nccl_version_str = get_nccl_version_str()
self._nccl_version = (
int("".join(nccl_version_str.split("."))) if nccl_version_str else 0
)
self._nccl_version = core.nccl_version()
def tearDown(self):
self.temp_dir.cleanup()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册