提交 7bb4a4e8 编写于 作者: S seiriosPlus

rectification init_worker and exe.run startup program

上级 fef6f6f9
......@@ -37,12 +37,6 @@ class RecvOp : public framework::OperatorBase {
void RunImpl(const framework::Scope &scope,
const platform::Place &place) const override {
int do_not_run = Attr<int>("do_not_run");
if (do_not_run) {
VLOG(3) << "recv do not run!";
return;
}
std::vector<std::string> epmap = Attr<std::vector<std::string>>("epmap");
std::vector<std::string> varnames =
Attr<std::vector<std::string>>("varnames");
......@@ -63,11 +57,10 @@ class RecvOp : public framework::OperatorBase {
if (recv_varnames.size() > 0) {
auto *communicator = distributed::Communicator::GetInstance();
if (communicator == nullptr) {
if (communicator != nullptr) {
PADDLE_THROW(platform::errors::InvalidArgument(
"need run fleet.init_worker first"));
"execute startup program must before fleet.init_worker"));
}
communicator->RecvNoBarrier();
} else {
std::vector<distributed::VarHandlePtr> rets;
if (with_barrier) {
......
......@@ -216,12 +216,12 @@ class ParameterServerRuntime(RuntimeBase):
else:
model_dirname = None
if self.role_maker._is_heter_worker():
self._init_worker()
executor = self._get_executor()
executor.run(fluid.default_startup_program())
if self.role_maker._is_heter_worker():
self._init_worker()
if self.role_maker._is_heter_worker():
return
......
......@@ -191,12 +191,14 @@ class FleetTranspiler(Fleet):
self._communicator = Communicator(
trainer_config.mode, kwargs,
trainer_config.get_communicator_flags())
self._communicator.init_with_ctx(send_ctx, recv_ctx)
if not self._communicator.is_running():
self._communicator.start()
else:
warnings.warn("communicator has been initialized, skip")
raise ValueError(
"Communicator can only be inited once, please check")
def init_worker(self):
"""
......
......@@ -222,22 +222,22 @@ def append_send_ops_pass(program, config):
def init_from_server_pass(program, config):
fetch_barrier_out = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
#
# recv_ctx = config.get_communicator_recv_context(recv_type=1)
# recv_varnames = []
#
# for name, ctxs in recv_ctx.items():
# recv_varnames.extend(ctxs.origin_varnames())
#
# program.global_block().append_op(
# type="recv",
# inputs={"X": []},
# outputs={"Out": []},
# attrs={
# "recv_varnames": recv_varnames,
# "trainer_id": config.get_role_id(),
# RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
# })
recv_ctx = config.get_communicator_recv_context(recv_type=1)
recv_varnames = []
for name, ctxs in recv_ctx.items():
recv_varnames.extend(ctxs.origin_varnames())
program.global_block().append_op(
type="recv",
inputs={"X": []},
outputs={"Out": []},
attrs={
"recv_varnames": recv_varnames,
"trainer_id": config.get_role_id(),
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
program.global_block().append_op(
type="fetch_barrier",
......
......@@ -164,8 +164,8 @@ def train(args):
elif fleet.is_worker():
logger.info("run trainer")
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2
filelist = []
......
......@@ -161,8 +161,10 @@ class TestDistCTR2x2(FleetDistRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
......@@ -201,8 +203,8 @@ class TestDistCTR2x2(FleetDistRunnerBase):
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = 2
batch_size = 128
......
......@@ -60,8 +60,9 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
device_id = int(os.getenv("FLAGS_selected_gpus", "0"))
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
......@@ -104,8 +105,8 @@ class TestDistGpuPsCTR2x2(TestDistCTR2x2):
place = fluid.CUDAPlace(device_id)
exe = fluid.Executor(place)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
thread_num = 2
batch_size = 128
......
......@@ -150,8 +150,9 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
train_reader = paddle.batch(fake_ctr_reader(), batch_size=batch_size)
self.reader.decorate_sample_list_generator(train_reader)
......@@ -174,8 +175,8 @@ class TestHeterPsCTR2x2(FleetDistHeterRunnerBase):
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
thread_num = 1
batch_size = 128
......
......@@ -151,8 +151,9 @@ class TestDistCTR2x2(FleetDistRunnerBase):
"""
exe = fluid.Executor(fluid.CPUPlace())
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
batch_size = 4
......
......@@ -81,8 +81,8 @@ class TestCommunicatorGeoEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fluid.default_startup_program())
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, z, y])
......
......@@ -69,8 +69,8 @@ class TestCommunicatorHalfAsyncEnd2End(unittest.TestCase):
optimizer = fleet.distributed_optimizer(optimizer, strategy)
optimizer.minimize(avg_cost)
fleet.init_worker()
exe.run(fleet.startup_program)
fleet.init_worker()
train_reader = paddle.batch(self.fake_reader(), batch_size=24)
feeder = fluid.DataFeeder(place=place, feed_list=[x, y])
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册