未验证 提交 19ef4bab 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #12418 from panyx0718/fix_dist_seed

Fix dist seed
......@@ -19,6 +19,7 @@ import math
import unittest
import os
import sys
import signal
import subprocess
......@@ -56,7 +57,7 @@ class TestDistSeResneXt2x2(unittest.TestCase):
except os.error:
retry_times -= 1
def no_test_with_place(self):
def test_with_place(self):
# *ATTENTION* THIS TEST NEEDS AT LEAST 2GPUS TO RUN
required_envs = {
"PATH": os.getenv("PATH"),
......@@ -70,9 +71,15 @@ class TestDistSeResneXt2x2(unittest.TestCase):
local_cmd = "%s dist_se_resnext.py trainer %s 0 %s %d FLASE" % \
(self._python_interp, "127.0.0.1:1234", "127.0.0.1:1234", 1)
local_proc = subprocess.Popen(
local_cmd.split(" "), stdout=subprocess.PIPE, env=env_local)
local_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env_local)
local_proc.wait()
local_ret = local_proc.stdout.read()
out, err = local_proc.communicate()
local_ret = out
sys.stderr.write('local_loss: %s\n' % local_ret)
sys.stderr.write('local_stderr: %s\n' % err)
# Run dist train to compare with local results
ps0, ps1 = self.start_pserver()
......@@ -92,13 +99,22 @@ class TestDistSeResneXt2x2(unittest.TestCase):
FNULL = open(os.devnull, 'w')
tr0_proc = subprocess.Popen(
tr0_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env0)
tr0_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env0)
tr1_proc = subprocess.Popen(
tr1_cmd.split(" "), stdout=subprocess.PIPE, stderr=FNULL, env=env1)
tr1_cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=env1)
tr0_proc.wait()
tr1_proc.wait()
loss_data0 = tr0_proc.stdout.read()
out, err = tr0_proc.communicate()
sys.stderr.write('dist_stderr: %s\n' % err)
loss_data0 = out
sys.stderr.write('dist_loss: %s\n' % loss_data0)
lines = loss_data0.split("\n")
dist_first_loss = eval(lines[0].replace(" ", ","))[0]
dist_last_loss = eval(lines[1].replace(" ", ","))[0]
......
......@@ -347,6 +347,7 @@ class DistributeTranspiler(object):
# step1
pserver_program = Program()
pserver_program.random_seed = self.origin_program.random_seed
# step2: Create vars to receive vars at parameter servers.
recv_inputs = []
for v in self.param_grad_ep_mapping[endpoint]["params"]:
......@@ -544,6 +545,7 @@ class DistributeTranspiler(object):
"""
s_prog = Program()
orig_s_prog = default_startup_program()
s_prog.random_seed = orig_s_prog.random_seed
params = self.param_grad_ep_mapping[endpoint]["params"]
def _get_splited_name_and_shape(varname):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册