未验证 提交 e7547ca7 编写于 作者: Y Yuang Liu 提交者: GitHub

Pass NVIDIA_TF32_OVERRIDE to internal (#43646) (#44796)

Co-authored-by: Ngongweibao <gongweibao@baidu.com>
上级 6de20581
......@@ -32,6 +32,7 @@ from paddle.fluid import core
class TestCollectiveAPIRunnerBase(object):
def get_model(self, train_prog, startup_prog, rank, indata=None):
raise NotImplementedError(
"get model should be implemented by child class.")
......@@ -91,6 +92,7 @@ from contextlib import closing
class TestDistBase(unittest.TestCase):
def setUp(self):
self._port_set = set()
self._trainers = 2
......@@ -104,6 +106,7 @@ class TestDistBase(unittest.TestCase):
self.temp_dir.cleanup()
def _find_free_port(self):
def __free_port():
with closing(socket.socket(socket.AF_INET,
socket.SOCK_STREAM)) as s:
......@@ -168,17 +171,15 @@ class TestDistBase(unittest.TestCase):
tr0_pipe = open(path0, "w")
tr1_pipe = open(path1, "w")
#print(tr0_cmd)
tr0_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(
tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1)
tr0_proc = subprocess.Popen(tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr0_pipe,
env=env0)
tr1_proc = subprocess.Popen(tr0_cmd.strip().split(),
stdout=subprocess.PIPE,
stderr=tr1_pipe,
env=env1)
tr0_out, tr0_err = tr0_proc.communicate()
tr1_out, tr1_err = tr1_proc.communicate()
......@@ -220,8 +221,14 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "3"
required_envs["GLOG_logtostderr"] = "1"
required_envs["GLOO_LOG_LEVEL"] = "TRACE"
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(model_file,
required_envs)
if os.getenv('NVIDIA_TF32_OVERRIDE', '') is not None:
required_envs['NVIDIA_TF32_OVERRIDE'] = os.getenv(
'NVIDIA_TF32_OVERRIDE', '')
tr0_out, tr1_out, pid0, pid1 = self._run_cluster(
model_file, required_envs)
np.random.seed(pid0)
input1 = np.random.random((10, 1000))
np.random.seed(pid1)
......@@ -248,11 +255,9 @@ class TestDistBase(unittest.TestCase):
elif col_type == "allreduce":
need_result = input1 + input2
self.assertTrue(
np.allclose(
tr0_out, need_result, rtol=1e-05, atol=1e-05))
np.allclose(tr0_out, need_result, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result, rtol=1e-05, atol=1e-05))
np.allclose(tr1_out, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "parallel_embedding":
result_data = tr0_out[0]
np.random.seed(2020)
......@@ -260,24 +265,23 @@ class TestDistBase(unittest.TestCase):
for i in range(result_data.shape[0]):
for j in range(result_data.shape[1]):
data = result_data[i][j]
assert np.allclose(
tr0_out[1][i][j], need_result[data], atol=1e-08)
assert np.allclose(tr0_out[1][i][j],
need_result[data],
atol=1e-08)
elif col_type == "row_parallel_linear":
result_data = tr0_out[0]
np.random.seed(2020)
weight = np.random.rand(1000, 16)
need_result = np.matmul(input1, weight)
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
np.allclose(result_data, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "column_parallel_linear":
result_data = tr0_out[0]
np.random.seed(2020)
weight = np.random.rand(1000, 16)
need_result = np.matmul(input1, weight)
self.assertTrue(
np.allclose(
result_data, need_result, rtol=1e-05, atol=1e-05))
np.allclose(result_data, need_result, rtol=1e-05, atol=1e-05))
elif col_type == "alltoall":
need_result1 = np.vstack((input1[0:input1.shape[0] // 2, :],
input2[0:input2.shape[0] // 2, :]))
......@@ -286,16 +290,13 @@ class TestDistBase(unittest.TestCase):
tr0_out = np.vstack(tr0_out)
tr1_out = np.vstack(tr1_out)
self.assertTrue(
np.allclose(
tr0_out, need_result1, rtol=1e-05, atol=1e-05))
np.allclose(tr0_out, need_result1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out, need_result2, rtol=1e-05, atol=1e-05))
np.allclose(tr1_out, need_result2, rtol=1e-05, atol=1e-05))
elif col_type == "sendrecv":
result_data = tr1_out[0]
self.assertTrue(
np.allclose(
input1, result_data, rtol=1e-05, atol=1e-05))
np.allclose(input1, result_data, rtol=1e-05, atol=1e-05))
elif col_type == "global_gather":
in_feat = 2
n_expert = 2
......@@ -372,15 +373,13 @@ class TestDistBase(unittest.TestCase):
if result1 == []:
output1 = np.array([])
else:
output1 = np.concatenate(
result1, axis=0).reshape(
sum(local_expert_count1), in_feat)
output1 = np.concatenate(result1, axis=0).reshape(
sum(local_expert_count1), in_feat)
if result2 == []:
output2 = np.array([])
else:
output2 = np.concatenate(
result2, axis=0).reshape(
sum(local_expert_count2), in_feat)
output2 = np.concatenate(result2, axis=0).reshape(
sum(local_expert_count2), in_feat)
if tr0_out[0] is None or tr0_out[0].shape[0] == 0:
tr0_out[0] = np.array([])
......@@ -389,24 +388,20 @@ class TestDistBase(unittest.TestCase):
tr1_out[0] = np.array([])
self.assertTrue(
np.allclose(
tr0_out[0], output1, rtol=1e-05, atol=1e-05))
np.allclose(tr0_out[0], output1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[0], output2, rtol=1e-05, atol=1e-05))
np.allclose(tr1_out[0], output2, rtol=1e-05, atol=1e-05))
if static_mode == 0:
self.assertTrue(
np.allclose(
tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
np.allclose(tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))
np.allclose(tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))
elif col_type == "global_scatter":
np.random.seed(pid0)
......@@ -460,23 +455,19 @@ class TestDistBase(unittest.TestCase):
tr1_out[0] = np.array([])
self.assertTrue(
np.allclose(
tr0_out[0], output1, rtol=1e-05, atol=1e-05))
np.allclose(tr0_out[0], output1, rtol=1e-05, atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[0], output2, rtol=1e-05, atol=1e-05))
np.allclose(tr1_out[0], output2, rtol=1e-05, atol=1e-05))
if static_mode == 0:
self.assertTrue(
np.allclose(
tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
np.allclose(tr0_out[1],
2 * local_input_buf1,
rtol=1e-05,
atol=1e-05))
self.assertTrue(
np.allclose(
tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))
np.allclose(tr1_out[1],
2 * local_input_buf2,
rtol=1e-05,
atol=1e-05))
else:
pass
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册