提交 412ad816 编写于 作者: X Xin Pan

keep seed in dist transpiler

better test
上级 02c31458
...@@ -19,6 +19,7 @@ import math ...@@ -19,6 +19,7 @@ import math
import unittest import unittest
import os import os
import sys
import signal import signal
import subprocess import subprocess
...@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except os.error: except os.error:
retry_times -= 1 retry_times -= 1
def no_test_with_place(self): def test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN # *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = { required_envs = {
"PATH": os.getenv("PATH"), "PATH": os.getenv("PATH"),
...@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \ local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1) (self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
local_proc = subprocess.Popen( local_proc = subprocess.Popen(
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local) local_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local)
local_proc.wait() local_proc.wait()
local_ret = local_proc.stdout.read() out, err = local_proc.communicate()
local_ret = out
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1 = self.start_pserver() ps0, ps1 = self.start_pserver()
...@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase): ...@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
FNULL = open(os.devnull, 'w') FNULL = open(os.devnull, 'w')
tr0_proc = subprocess.Popen( tr0_proc = subprocess.Popen(
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0) tr0_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env0)
tr1_proc = subprocess.Popen( tr1_proc = subprocess.Popen(
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1) tr1_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env1)
tr0_proc.wait() tr0_proc.wait()
tr1_proc.wait() tr1_proc.wait()
loss_data0 = tr0_proc.stdout.read() out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = out
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n") lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0] dist_first_loss = eval(lines[0].replace(" ", ","))[0]
dist_last_loss = eval(lines[1].replace(" ", ","))[0] dist_last_loss = eval(lines[1].replace(" ", ","))[0]
......
...@@ -347,6 +347,7 @@ class DistributeTranspiler(object): ...@@ -347,6 +347,7 @@ class DistributeTranspiler(object):
# step1 # step1
pserver_program = Program() pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
# step2: Create vars to receive vars at parameter servers. # step2: Create vars to receive vars at parameter servers.
recv_inputs = [] recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]: for v in self.param_grad_ep_mapping[endpoint]["params"]:
...@@ -544,6 +545,7 @@ class DistributeTranspiler(object): ...@@ -544,6 +545,7 @@ class DistributeTranspiler(object):
""" """
s_prog = Program() s_prog = Program()
orig_s_prog = default_startup_program() orig_s_prog = default_startup_program()
s_prog.random_seed = orig_s_prog.random_seed
params = self.param_grad_ep_mapping[endpoint]["params"] params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname): def _get_splited_name_and_shape(varname):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册