提交 7bb4a4e8 编写于 作者: S seiriosPlus

rectification init_worker and exe.run startup program

上级 fef6f6f9
...@@ -37,12 +37,6 @@ class RecvOp : public framework::OperatorBase { ...@@ -37,12 +37,6 @@ class RecvOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope, void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override { const platform::Place &place) const override {
int do_not_run = Attr<int>("do_not_run");
if (do_not_run) {
VLOG(3) << "recv do not run!";
return;
}
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap"); std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames = std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames"); Attr<std::vector<std::string>>("varnames");
...@@ -63,11 +57,10 @@ class RecvOp : public framework::OperatorBase { ...@@ -63,11 +57,10 @@ class RecvOp : public framework::OperatorBase {
if (recv_varnames.size() > 0) { if (recv_varnames.size() > 0) {
auto *communicator = distributed::Communicator::GetInstance(); auto *communicator = distributed::Communicator::GetInstance();
if (communicator == nullptr) { if (communicator != nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"need run fleet.init_worker first")); "execute startup program must before fleet.init_worker"));
} }
communicator->RecvNoBarrier();
} else { } else {
std::vector<distributed::VarHandlePtr> rets; std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) { if (with_barrier) {
......
...@@ -216,12 +216,12 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -216,12 +216,12 @@ class ParameterServerRuntime(RuntimeBase):
else: else:
model_dirname = None model_dirname = None
if self.role_maker._is_heter_worker():
self._init_worker()
executor = self._get_executor() executor = self._get_executor()
executor.run(fluid.default_startup_program()) executor.run(fluid.default_startup_program())
if self.role_maker._is_heter_worker():
self._init_worker()
if self.role_maker._is_heter_worker(): if self.role_maker._is_heter_worker():
return return
......
...@@ -191,12 +191,14 @@ class FleetTranspiler(Fleet): ...@@ -191,12 +191,14 @@ class FleetTranspiler(Fleet):
self._communicator = Communicator( self._communicator = Communicator(
trainer_config.mode, kwargs, trainer_config.mode, kwargs,
trainer_config.get_communicator_flags()) trainer_config.get_communicator_flags())
self._communicator.init_with_ctx(send_ctx, recv_ctx) self._communicator.init_with_ctx(send_ctx, recv_ctx)
if not self._communicator.is_running(): if not self._communicator.is_running():
self._communicator.start() self._communicator.start()
else: else:
warnings.warn("communicator has been initialized, skip") raise ValueError(
"Communicator can only be inited once, please check")
def init_worker(self): def init_worker(self):
""" """
......
...@@ -222,22 +222,22 @@ def append_send_ops_pass(program, config): ...@@ -222,22 +222,22 @@ def append_send_ops_pass(program, config):
def init_from_server_pass(program, config): def init_from_server_pass(program, config):
fetch_barrier_out = program.global_block().create_var( fetch_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name()) name=framework.generate_control_dev_var_name())
#
# recv_ctx = config.get_communicator_recv_context(recv_type=1) recv_ctx = config.get_communicator_recv_context(recv_type=1)
# recv_varnames = [] recv_varnames = []
#
# for name, ctxs in recv_ctx.items(): for name, ctxs in recv_ctx.items():
# recv_varnames.extend(ctxs.origin_varnames()) recv_varnames.extend(ctxs.origin_varnames())
#
# program.global_block().append_op( program.global_block().append_op(
# type="recv", type="recv",
# inputs={"X": []}, inputs={"X": []},
# outputs={"Out": []}, outputs={"Out": []},
# attrs={ attrs={
# "recv_varnames": recv_varnames, "recv_varnames": recv_varnames,
# "trainer_id": config.get_role_id(), "trainer_id": config.get_role_id(),
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
# }) })
program.global_block().append_op( program.global_block().append_op(
type="fetch_barrier", type="fetch_barrier",
......
...@@ -164,8 +164,8 @@ def train(args): ...@@ -164,8 +164,8 @@ def train(args):
elif fleet.is_worker(): elif fleet.is_worker():
logger.info("run trainer") logger.info("run trainer")
fleet.init_worker()
exe.run(fleet.startup_program) exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2 thread_num = 2
filelist = [] filelist = []
......
...@@ -161,8 +161,10 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -161,8 +161,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
""" """
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4 batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader) self.reader.decorate_sample_list_generator(train_reader)
...@@ -201,8 +203,8 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -201,8 +203,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = 2 thread_num = 2
batch_size = 128 batch_size = 128
......
...@@ -60,8 +60,9 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -60,8 +60,9 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
device_id = int(os.getenv("FLAGS_selected_gpus", "0")) device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place) exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program) exe.run(fleet.startup_program)
fleet.init_worker()
batch_size = 4 batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
...@@ -104,8 +105,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2): ...@@ -104,8 +105,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
place = fluid.CUDAPlace(device_id) place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place) exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program) exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2 thread_num = 2
batch_size = 128 batch_size = 128
......
...@@ -150,8 +150,9 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -150,8 +150,9 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
""" """
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4 batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size) train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader) self.reader.decorate_sample_list_generator(train_reader)
...@@ -174,8 +175,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase): ...@@ -174,8 +175,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = 1 thread_num = 1
batch_size = 128 batch_size = 128
......
...@@ -151,8 +151,9 @@ class TestDistCTR2x2(FleetDistRunnerBase): ...@@ -151,8 +151,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
""" """
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4 batch_size = 4
......
...@@ -81,8 +81,8 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase): ...@@ -81,8 +81,8 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24) train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, z, y]) feeder = fluid.DataFeeder(place=place, feed_list=[x, z, y])
......
...@@ -69,8 +69,8 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase): ...@@ -69,8 +69,8 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy) optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fleet.startup_program) exe.run(fleet.startup_program)
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24) train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y]) feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册