提交 ec773f90 编写于 作者: Y yi.wu

fix ut merge error

上级 1b79974a
...@@ -21,7 +21,7 @@ import sys ...@@ -21,7 +21,7 @@ import sys
import six import six
import signal import signal
import subprocess import subprocess
import six import argparse
class TestDistRunnerBase(object): class TestDistRunnerBase(object):
...@@ -30,7 +30,7 @@ class TestDistRunnerBase(object): ...@@ -30,7 +30,7 @@ class TestDistRunnerBase(object):
"get_model should be implemented by child classes.") "get_model should be implemented by child classes.")
def get_transpiler(self, trainer_id, main_program, pserver_endpoints, def get_transpiler(self, trainer_id, main_program, pserver_endpoints,
trainers): trainers, sync_mode):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -39,33 +39,35 @@ class TestDistRunnerBase(object): ...@@ -39,33 +39,35 @@ class TestDistRunnerBase(object):
trainer_id=trainer_id, trainer_id=trainer_id,
program=main_program, program=main_program,
pservers=pserver_endpoints, pservers=pserver_endpoints,
trainers=trainers) trainers=trainers,
sync_mode=sync_mode)
return t return t
def run_pserver(self, pserver_endpoints, trainers, current_endpoint, def run_pserver(self, args):
trainer_id):
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
self.get_model(batch_size=2) self.get_model(batch_size=2)
t = self.get_transpiler(trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), pserver_endpoints, fluid.default_main_program(), args.endpoints,
trainers) args.trainers, args.sync_mode)
pserver_prog = t.get_pserver_program(current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(current_endpoint, pserver_prog) startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
place = fluid.CPUPlace() place = fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
exe.run(pserver_prog) exe.run(pserver_prog)
def run_trainer(self, place, endpoints, trainer_id, trainers, is_dist=True): def run_trainer(self, place, args):
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=2) self.get_model(batch_size=2)
if is_dist: if args.is_dist:
t = self.get_transpiler(trainer_id, t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), endpoints, fluid.default_main_program(),
trainers) args.endpoints, args.trainers,
args.sync_mode)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
else: else:
trainer_prog = fluid.default_main_program() trainer_prog = fluid.default_main_program()
...@@ -132,18 +134,21 @@ def runtime_main(test_class): ...@@ -132,18 +134,21 @@ def runtime_main(test_class):
args = parser.parse_args() args = parser.parse_args()
model = test_class() model = test_class()
if role == "pserver": if args.role == "pserver":
model.run_pserver(endpoints, trainers, current_endpoint, trainer_id) model.run_pserver(args)
else: else:
p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda( p = fluid.CUDAPlace(0) if core.is_compiled_with_cuda(
) else fluid.CPUPlace() ) else fluid.CPUPlace()
model.run_trainer(p, endpoints, trainer_id, trainers, is_dist) model.run_trainer(p, args)
import paddle.compat as cpt import paddle.compat as cpt
class TestDistBase(unittest.TestCase): class TestDistBase(unittest.TestCase):
def _setup_config(self):
raise NotImplementedError("tests should have _setup_config implemented")
def setUp(self): def setUp(self):
self._trainers = 2 self._trainers = 2
self._pservers = 2 self._pservers = 2
...@@ -221,9 +226,7 @@ class TestDistBase(unittest.TestCase): ...@@ -221,9 +226,7 @@ class TestDistBase(unittest.TestCase):
# Run local to get a base line # Run local to get a base line
env_local = {"CUDA_VISIBLE_DEVICES": "0"} env_local = {"CUDA_VISIBLE_DEVICES": "0"}
env_local.update(required_envs) env_local.update(required_envs)
local_cmd = "%s %s trainer %s 0 %s %d FLASE" % \ local_cmd = "%s %s --role trainer" % (self._python_interp, model_file)
(self._python_interp, model_file,
"127.0.0.1:1234", "127.0.0.1:1234", 1)
if not check_error_log: if not check_error_log:
local_proc = subprocess.Popen( local_proc = subprocess.Popen(
local_cmd.split(" "), local_cmd.split(" "),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册