提交 ec773f90 编写于 作者: Y yi.wu

fix ut merge error

上级 1b79974a
......@@ -21,7 +21,7 @@ import sys
import six
import signal
import subprocess
import six
import argparse
class TestDistRunnerBase(object):
......@@ -30,7 +30,7 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes.")
def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
trainers):
trainers, sync_mode):
# NOTE: import fluid until runtime, or else forking processes will cause error.
import paddle
import paddle.fluid as fluid
......@@ -39,33 +39,35 @@ class TestDistRunnerBase(object):
trainer_id=trainer_id,
program=main_program,
pservers=pserver_endpoints,
trainers=trainers)
trainers=trainers,
sync_mode=sync_mode)
return t
def run_pserver(self, pserver_endpoints, trainers, current_endpoint,
trainer_id):
def run_pserver(self, args):
import paddle
import paddle.fluid as fluid
self.get_model(batch_size=2)
t = self.get_transpiler(trainer_id,
fluid.default_main_program(), pserver_endpoints,
trainers)
pserver_prog = t.get_pserver_program(current_endpoint)
startup_prog = t.get_startup_program(current_endpoint, pserver_prog)
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
place = fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup_prog)
exe.run(pserver_prog)
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True):
def run_trainer(self, place, args):
import paddle
import paddle.fluid as fluid
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=2)
if is_dist:
t = self.get_transpiler(trainer_id,
fluid.default_main_program(), endpoints,
trainers)
if args.is_dist:
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
args.sync_mode)
trainer_prog = t.get_trainer_program()
else:
trainer_prog = fluid.default_main_program()
......@@ -132,18 +134,21 @@ def runtime_main(test_class):
args = parser.parse_args()
model = test_class()
if role == "pserver":
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id)
if args.role == "pserver":
model.run_pserver(args)
else:
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace()
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist)
model.run_trainer(p, args)
import paddle.compat as cpt
class TestDistBase(unittest.TestCase):
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
def setUp(self):
self._trainers = 2
self._pservers = 2
......@@ -221,9 +226,7 @@ class TestDistBase(unittest.TestCase):
# Run local to get a base line
env_local = {"CUDA_VISIBLE_DEVICES": "0"}
env_local.update(required_envs)
local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \
(self._python_interp, model_file,
"127.0.0.1:1234", "127.0.0.1:1234", 1)
local_cmd = "%s %s --role trainer" % (self._python_interp, model_file)
if not check_error_log:
local_proc = subprocess.Popen(
local_cmd.split(" "),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册