未验证 提交 d186e743 编写于 作者: W Wu Yi 提交者: GitHub

Refine dist ut (#14118)

* fix use_reader_alloc uts

* dist ut fixes test=develop

* update test=develop

* fix test for py3 test=develop
上级 06e508ab
...@@ -90,8 +90,10 @@ class TestDistMnist2x2(TestDistRunnerBase): ...@@ -90,8 +90,10 @@ class TestDistMnist2x2(TestDistRunnerBase):
inference_program = fluid.default_main_program().clone() inference_program = fluid.default_main_program().clone()
# Optimization # Optimization
opt = fluid.optimizer.AdamOptimizer( # TODO(typhoonzero): fix distributed adam optimizer
learning_rate=0.001, beta1=0.9, beta2=0.999) # opt = fluid.optimizer.AdamOptimizer(
# learning_rate=0.001, beta1=0.9, beta2=0.999)
opt = fluid.optimizer.Momentum(learning_rate=0.001, momentum=0.9)
# Reader # Reader
train_reader = paddle.batch( train_reader = paddle.batch(
......
...@@ -22,6 +22,8 @@ import signal ...@@ -22,6 +22,8 @@ import signal
import subprocess import subprocess
import six import six
import argparse import argparse
import pickle
import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -128,10 +130,15 @@ class TestDistRunnerBase(object): ...@@ -128,10 +130,15 @@ class TestDistRunnerBase(object):
else: else:
return origin_batch return origin_batch
out_losses = []
for _ in six.moves.xrange(RUN_STEP): for _ in six.moves.xrange(RUN_STEP):
loss, = exe.run(fetch_list=[avg_cost.name], loss, = exe.run(fetch_list=[avg_cost.name],
feed=feeder.feed(get_data())) feed=feeder.feed(get_data()))
print(loss) out_losses.append(loss[0])
if six.PY2:
print(pickle.dumps(out_losses))
else:
sys.stdout.buffer.write(pickle.dumps(out_losses))
def runtime_main(test_class): def runtime_main(test_class):
...@@ -149,7 +156,7 @@ def runtime_main(test_class): ...@@ -149,7 +156,7 @@ def runtime_main(test_class):
parser.add_argument('--use_cuda', action='store_true') parser.add_argument('--use_cuda', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
parser.add_argument( parser.add_argument(
'--use_reader_alloc', action='store_true', required=False, default=True) '--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2) parser.add_argument('--batch_size', required=False, type=int, default=2)
parser.add_argument( parser.add_argument(
'--batch_merge_repeat', required=False, type=int, default=1) '--batch_merge_repeat', required=False, type=int, default=1)
...@@ -237,21 +244,6 @@ class TestDistBase(unittest.TestCase): ...@@ -237,21 +244,6 @@ class TestDistBase(unittest.TestCase):
return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe return ps0_proc, ps1_proc, ps0_pipe, ps1_pipe
def _wait_ps_ready(self, pid):
retry_times = 50
while True:
assert retry_times >= 0, "wait ps ready failed"
time.sleep(3)
try:
# the listen_and_serv_op would touch a file which contains the listen port
# on the /tmp directory until it was ready to process all the RPC call.
os.stat("/tmp/paddle.%d.port" % pid)
return
except os.error as e:
sys.stderr.write('waiting for pserver: %s, left retry %d\n' %
(e, retry_times))
retry_times -= 1
def _run_local(self, def _run_local(self,
model, model,
envs, envs,
...@@ -288,23 +280,20 @@ class TestDistBase(unittest.TestCase): ...@@ -288,23 +280,20 @@ class TestDistBase(unittest.TestCase):
env=envs) env=envs)
local_out, local_err = local_proc.communicate() local_out, local_err = local_proc.communicate()
local_ret = cpt.to_text(local_out)
if check_error_log: if check_error_log:
err_log.close() err_log.close()
sys.stderr.write('local_stdout: %s\n' % local_ret) sys.stderr.write('local_stdout: %s\n' % pickle.loads(local_out))
sys.stderr.write('local_stderr: %s\n' % local_err) sys.stderr.write('local_stderr: %s\n' % local_err)
local_losses = local_ret.split("\n") return pickle.loads(local_out)
return local_losses
def _run_cluster(self, model, envs, check_error_log): def _run_cluster(self, model, envs, check_error_log):
# Run dist train to compare with local results # Run dist train to compare with local results
ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model, ps0, ps1, ps0_pipe, ps1_pipe = self.start_pserver(model,
check_error_log, envs) check_error_log, envs)
self._wait_ps_ready(ps0.pid)
self._wait_ps_ready(ps1.pid)
ps0_ep, ps1_ep = self._ps_endpoints.split(",") ps0_ep, ps1_ep = self._ps_endpoints.split(",")
tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist" tr_cmd = "%s %s --role trainer --endpoints %s --trainer_id %d --current_endpoint %s --trainers %d --is_dist"
...@@ -339,8 +328,8 @@ class TestDistBase(unittest.TestCase): ...@@ -339,8 +328,8 @@ class TestDistBase(unittest.TestCase):
env0.update(envs) env0.update(envs)
env1.update(envs) env1.update(envs)
print("tr0_cmd:{}, env0: {}".format(tr0_cmd, env0)) print("tr0_cmd:{}".format(tr0_cmd))
print("tr1_cmd:{}, env1: {}".format(tr1_cmd, env1)) print("tr1_cmd:{}".format(tr1_cmd))
tr0_pipe = open("/tmp/tr0_err.log", "wb") tr0_pipe = open("/tmp/tr0_err.log", "wb")
tr1_pipe = open("/tmp/tr1_err.log", "wb") tr1_pipe = open("/tmp/tr1_err.log", "wb")
...@@ -356,9 +345,7 @@ class TestDistBase(unittest.TestCase): ...@@ -356,9 +345,7 @@ class TestDistBase(unittest.TestCase):
env=env1) env=env1)
tr0_out, tr0_err = tr0_proc.communicate() tr0_out, tr0_err = tr0_proc.communicate()
tr0_loss_text = cpt.to_text(tr0_out)
tr1_out, tr1_err = tr1_proc.communicate() tr1_out, tr1_err = tr1_proc.communicate()
tr1_loss_text = cpt.to_text(tr1_out)
# close trainer file # close trainer file
tr0_pipe.close() tr0_pipe.close()
...@@ -373,15 +360,13 @@ class TestDistBase(unittest.TestCase): ...@@ -373,15 +360,13 @@ class TestDistBase(unittest.TestCase):
ps1.terminate() ps1.terminate()
# print log # print log
sys.stderr.write('trainer 0 stdout:\n %s\n' % tr0_loss_text) sys.stderr.write('trainer 0 stdout: %s\n' % pickle.loads(tr0_out))
sys.stderr.write('trainer 0 stderr:\n %s\n' % tr0_err) sys.stderr.write('trainer 0 stderr: %s\n' % tr0_err)
sys.stderr.write('trainer 1 stdout: %s\n' % tr1_loss_text) sys.stderr.write('trainer 1 stdout: %s\n' % pickle.loads(tr1_out))
sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err) sys.stderr.write('trainer 1 stderr: %s\n' % tr1_err)
tr0_losses = tr0_loss_text.split("\n") # return tr0_losses, tr1_losses
tr1_losses = tr1_loss_text.split("\n") return pickle.loads(tr0_out), pickle.loads(tr1_out)
return tr0_losses, tr1_losses
def check_with_place(self, def check_with_place(self,
model_file, model_file,
...@@ -411,9 +396,9 @@ class TestDistBase(unittest.TestCase): ...@@ -411,9 +396,9 @@ class TestDistBase(unittest.TestCase):
check_error_log) check_error_log)
for step_id in range(RUN_STEP): for step_id in range(RUN_STEP):
local_loss = eval(local_losses[step_id])[0] local_loss = local_losses[step_id]
tr0_loss = eval(tr0_losses[step_id])[0] tr0_loss = tr0_losses[step_id]
tr1_loss = eval(tr1_losses[step_id])[0] tr1_loss = tr1_losses[step_id]
dist_loss = (tr0_loss + tr1_loss) / 2 dist_loss = (np.array([tr0_loss]) + np.array([tr1_loss])) / 2
print(str(local_loss) + ":" + str(dist_loss)) print("=======", local_loss, ":", dist_loss[0], "=======")
self.assertAlmostEqual(local_loss, dist_loss, delta=delta) self.assertAlmostEqual(local_loss, dist_loss[0], delta=delta)
...@@ -23,16 +23,17 @@ class TestDistSeResneXt2x2(TestDistBase): ...@@ -23,16 +23,17 @@ class TestDistSeResneXt2x2(TestDistBase):
self._use_reader_alloc = False self._use_reader_alloc = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistseResnXt2x2WithMemopt(TestDistBase): class TestDistseResnXt2x2WithMemopt(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._mem_opt = True self._mem_opt = True
self._use_reader_alloc = False
def test_dist_train(self): def test_dist_train(self):
self.check_with_place("dist_se_resnext.py", delta=100) self.check_with_place("dist_se_resnext.py", delta=1e-7)
class TestDistSeResneXt2x2Async(TestDistBase): class TestDistSeResneXt2x2Async(TestDistBase):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册