未验证 提交 8f0b3c05 编写于 作者: T tangwei12 提交者: GitHub

the integrated communicator (#19849)

* add a base class for the Communicator
* add AsyncCommunicator Impl for async distributed training
上级 8d92b36d
...@@ -86,13 +86,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -86,13 +86,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
if (send_varname_to_ctx.size() > 0) { if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator"; VLOG(3) << "this is distribute mode, will use communicator";
if (operators::distributed::Communicator::GetInstance() == nullptr) { auto *instance = operators::distributed::Communicator::InitInstance<
operators::distributed::Communicator::Init(send_varname_to_ctx, operators::distributed::AsyncCommunicator>(send_varname_to_ctx,
recv_varname_to_ctx, scope); recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start(); if (!instance->IsRunning()) instance->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
} }
#endif #endif
} }
......
...@@ -52,14 +52,16 @@ inline double GetCurrentUS() { ...@@ -52,14 +52,16 @@ inline double GetCurrentUS() {
return 1e+6 * time.tv_sec + time.tv_usec; return 1e+6 * time.tv_sec + time.tv_usec;
} }
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr); std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx, const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope) Scope *recv_scope) {
: send_varname_to_ctx_(send_varname_to_ctx), send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_(recv_varname_to_ctx), recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_(recv_scope) { recv_scope_ = std::move(recv_scope);
// get all send information from graph, build vars_to_send // get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: " VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread; << FLAGS_communicator_independent_recv_thread;
...@@ -98,7 +100,51 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx, ...@@ -98,7 +100,51 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
} }
} }
Communicator::~Communicator() { void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
Scope *param_scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto *op : program.Block(0).AllOps()) {
VLOG(3) << "node name " << op->Type();
if (op->Type() == "send") {
auto send_var_name = op->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("send_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") {
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
auto recv_var_name = op->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("recv_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
}
}
// init communicator here
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) {
LOG(WARNING) << "no var need to send and recv!!";
}
operators::distributed::AsyncCommunicator::InitImpl(
send_varname_to_ctx, recv_varname_to_ctx, param_scope);
}
AsyncCommunicator::~AsyncCommunicator() {
if (FLAGS_v >= 3) { if (FLAGS_v >= 3) {
std::string msg("~Communicator"); std::string msg("~Communicator");
fwrite(msg.c_str(), msg.length(), 1, stdout); fwrite(msg.c_str(), msg.length(), 1, stdout);
...@@ -112,7 +158,7 @@ Communicator::~Communicator() { ...@@ -112,7 +158,7 @@ Communicator::~Communicator() {
} }
} }
void Communicator::SendThread() { void AsyncCommunicator::SendThread() {
VLOG(3) << "SendThread start!"; VLOG(3) << "SendThread start!";
while (running_) { while (running_) {
std::vector<std::future<void>> task_futures; std::vector<std::future<void>> task_futures;
...@@ -175,50 +221,12 @@ void Communicator::SendThread() { ...@@ -175,50 +221,12 @@ void Communicator::SendThread() {
VLOG(3) << "run send graph use time " VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph; << after_run_send_graph - before_run_send_graph;
RecvNonIndependent(); Recv();
} }
VLOG(0) << "communicator stopped, send thread exit"; VLOG(0) << "communicator stopped, send thread exit";
} }
void Communicator::RecvNonIndependent() { void AsyncCommunicator::RecvThread() {
if (FLAGS_communicator_independent_recv_thread) {
return;
}
auto grad_num = grad_num_.load();
if (grad_num > 0) {
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
if (!running_) return;
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::RecvThread() {
VLOG(3) << "RecvThread start!"; VLOG(3) << "RecvThread start!";
while (running_) { while (running_) {
auto grad_num = grad_num_.load(); auto grad_num = grad_num_.load();
...@@ -233,7 +241,7 @@ void Communicator::RecvThread() { ...@@ -233,7 +241,7 @@ void Communicator::RecvThread() {
VLOG(0) << "communicator stopped, recv thread exit"; VLOG(0) << "communicator stopped, recv thread exit";
} }
void Communicator::Send(const std::string &var_name, void AsyncCommunicator::Send(const std::string &var_name,
const framework::Scope &scope) { const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name; VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name // push var into send queue by var_name
...@@ -255,56 +263,45 @@ void Communicator::Send(const std::string &var_name, ...@@ -255,56 +263,45 @@ void Communicator::Send(const std::string &var_name,
} }
} }
void Communicator::Init(const paddle::framework::ProgramDesc &program, void AsyncCommunicator::Recv() {
Scope *param_scope) { if (FLAGS_communicator_independent_recv_thread) {
using RpcCtxMap = operators::distributed::RpcCtxMap; return;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto *op : program.Block(0).AllOps()) {
VLOG(3) << "node name " << op->Type();
if (op->Type() == "send") {
auto send_var_name = op->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("send_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") {
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
auto recv_var_name = op->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("recv_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
}
} }
// init communicator here auto grad_num = grad_num_.load();
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) { if (grad_num > 0) {
LOG(WARNING) << "no var need to send and recv!!"; RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
} }
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, param_scope);
} }
Communicator *Communicator::GetInstance() { return communicator_.get(); } void AsyncCommunicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
std::shared_ptr<Communicator> Communicator::GetInstantcePtr() { if (!running_) return;
return communicator_; auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
} }
void Communicator::Start() { void AsyncCommunicator::Start() {
VLOG(0) << "Communicator start"; VLOG(0) << "Communicator start";
if (!communicator_) { if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing"; VLOG(0) << "Communicator is not inited, do nothing";
...@@ -313,15 +310,15 @@ void Communicator::Start() { ...@@ -313,15 +310,15 @@ void Communicator::Start() {
running_ = true; running_ = true;
// start send and recv thread // start send and recv thread
send_thread_.reset( send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this))); new std::thread(std::bind(&AsyncCommunicator::SendThread, this)));
if (FLAGS_communicator_independent_recv_thread) { if (FLAGS_communicator_independent_recv_thread) {
recv_thread_.reset( recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this))); new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
} }
} }
} }
void Communicator::Stop() { void AsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop"; VLOG(0) << "Communicator stop";
running_ = false; running_ = false;
if (!communicator_) { if (!communicator_) {
......
...@@ -161,27 +161,97 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>; ...@@ -161,27 +161,97 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator { class Communicator {
public: public:
Communicator(const RpcCtxMap& send_varname_to_ctx, Communicator() {}
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope); virtual ~Communicator() {}
~Communicator(); virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
void Start(); virtual void Send(const std::string& var_name,
void Stop(); const framework::Scope& scope) = 0;
virtual void Recv() = 0;
bool IsRunning() { return running_; } virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) = 0;
// send grad virtual void InitImpl(const paddle::framework::ProgramDesc& program,
void Send(const std::string& var_name, const framework::Scope& scope); Scope* recv_scope) = 0;
private: static Communicator* GetInstance() { return communicator_.get(); }
// recv all parameter
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator* InitInstance(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>,
send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(send_varname_to_ctx, recv_varname_to_ctx,
recv_scope);
}
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
return communicator_.get();
}
template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(program, recv_scope);
}
}
protected:
bool running_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() {}
~AsyncCommunicator();
void Start() override;
void Stop() override;
void Send(const std::string& var_name,
const framework::Scope& scope) override;
void Recv() override;
void RecvAll(); void RecvAll();
void RecvNonIndependent();
void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) override;
void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override;
void SendThread(); void SendThread();
void RecvThread(); void RecvThread();
bool running_ = false; private:
std::unordered_map<std::string, std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>> std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_; send_varname_to_queue_;
...@@ -194,26 +264,6 @@ class Communicator { ...@@ -194,26 +264,6 @@ class Communicator {
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr}; std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
static void Init(const paddle::framework::ProgramDesc& program,
Scope* param_scope);
static Communicator* GetInstance();
static std::shared_ptr<Communicator> GetInstantcePtr();
private:
static std::shared_ptr<Communicator> communicator_;
}; };
} // namespace distributed } // namespace distributed
......
...@@ -48,14 +48,7 @@ class SendOp : public framework::OperatorBase { ...@@ -48,14 +48,7 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) { if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, ""); PADDLE_ENFORCE_EQ(ins.size(), 1, "");
if (distributed::Communicator::GetInstance() == nullptr) {
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections, trainer_id);
send_functor(rpc_ctx, scope, true);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope); distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
} else { } else {
platform::DeviceContextPool& pool = platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance(); platform::DeviceContextPool::Instance();
......
...@@ -26,6 +26,7 @@ namespace py = pybind11; ...@@ -26,6 +26,7 @@ namespace py = pybind11;
using paddle::framework::ProgramDesc; using paddle::framework::ProgramDesc;
using paddle::operators::distributed::Communicator; using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::AsyncCommunicator;
using paddle::framework::Scope; using paddle::framework::Scope;
namespace paddle { namespace paddle {
...@@ -36,7 +37,7 @@ void BindCommunicator(py::module* m) { ...@@ -36,7 +37,7 @@ void BindCommunicator(py::module* m) {
py::class_<Communicator, std::shared_ptr<Communicator>>(*m, py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator") "DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope) { .def(py::init([](const ProgramDesc& program, Scope* param_scope) {
Communicator::Init(program, param_scope); Communicator::InitInstance<AsyncCommunicator>(program, param_scope);
return Communicator::GetInstantcePtr(); return Communicator::GetInstantcePtr();
})) }))
.def("stop", &Communicator::Stop) .def("stop", &Communicator::Stop)
......
...@@ -75,11 +75,14 @@ class TestDistRunnerBase(object): ...@@ -75,11 +75,14 @@ class TestDistRunnerBase(object):
sync_mode, sync_mode,
dc_asgd=False, dc_asgd=False,
current_endpoint=None, current_endpoint=None,
nccl_comm_num=1): nccl_comm_num=1,
hogwild_mode=False):
# NOTE: import fluid until runtime, or else forking processes will cause error. # NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig() config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd config.enable_dc_asgd = dc_asgd
config.sync_mode = sync_mode config.sync_mode = sync_mode
config.runtime_split_send_recv = hogwild_mode
if nccl_comm_num > 1: if nccl_comm_num > 1:
config.nccl_comm_num = nccl_comm_num config.nccl_comm_num = nccl_comm_num
# config.runtime_split_send_recv = True # config.runtime_split_send_recv = True
...@@ -89,6 +92,7 @@ class TestDistRunnerBase(object): ...@@ -89,6 +92,7 @@ class TestDistRunnerBase(object):
program=main_program, program=main_program,
pservers=pserver_endpoints, pservers=pserver_endpoints,
trainers=trainers, trainers=trainers,
sync_mode=sync_mode,
current_endpoint=current_endpoint) current_endpoint=current_endpoint)
return t return t
...@@ -96,9 +100,15 @@ class TestDistRunnerBase(object): ...@@ -96,9 +100,15 @@ class TestDistRunnerBase(object):
self.lr = args.lr self.lr = args.lr
self.get_model(batch_size=args.batch_size) self.get_model(batch_size=args.batch_size)
# NOTE: pserver should not call memory optimize # NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints, t = self.get_transpiler(
args.trainers, args.sync_mode, args.dc_asgd) trainer_id=args.trainer_id,
main_program=fluid.default_main_program(),
pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
pserver_prog = t.get_pserver_program(args.current_endpoint) pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint, startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog) pserver_prog)
...@@ -120,7 +130,7 @@ class TestDistRunnerBase(object): ...@@ -120,7 +130,7 @@ class TestDistRunnerBase(object):
dist_strategy = DistributedStrategy() dist_strategy = DistributedStrategy()
dist_strategy.exec_strategy = exec_strategy dist_strategy.exec_strategy = exec_strategy
dist_strategy.fuse_memory_size = 1 #MB dist_strategy.fuse_memory_size = 1 # MB
dist_strategy.fuse_laryer_size = 1 dist_strategy.fuse_laryer_size = 1
if args.use_local_sgd: if args.use_local_sgd:
dist_strategy.use_local_sgd = True dist_strategy.use_local_sgd = True
...@@ -130,8 +140,8 @@ class TestDistRunnerBase(object): ...@@ -130,8 +140,8 @@ class TestDistRunnerBase(object):
role = role_maker.PaddleCloudRoleMaker(is_collective=True) role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role) fleet.init(role)
print_to_err("gpu_fleet", "fleet.node_num:") print_to_err("gpu_fleet", "fleet.node_num:")
#"fleet.node_id:", fleet.node_id(), # "fleet.node_id:", fleet.node_id(),
#"fleet.trainer_num:", fleet.worker_num()) # "fleet.trainer_num:", fleet.worker_num())
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \ test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy) self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)
...@@ -196,10 +206,15 @@ class TestDistRunnerBase(object): ...@@ -196,10 +206,15 @@ class TestDistRunnerBase(object):
print_to_err( print_to_err(
type(self).__name__, type(self).__name__,
"begin to run transpile on trainer with pserver mode") "begin to run transpile on trainer with pserver mode")
t = self.get_transpiler(args.trainer_id, t = self.get_transpiler(
fluid.default_main_program(), trainer_id=args.trainer_id,
args.endpoints, args.trainers, main_program=fluid.default_main_program(),
args.sync_mode, args.dc_asgd) pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
trainer_prog = t.get_trainer_program() trainer_prog = t.get_trainer_program()
print_to_err( print_to_err(
type(self).__name__, type(self).__name__,
...@@ -251,6 +266,9 @@ class TestDistRunnerBase(object): ...@@ -251,6 +266,9 @@ class TestDistRunnerBase(object):
build_stra.enable_inplace = False build_stra.enable_inplace = False
build_stra.memory_optimize = False build_stra.memory_optimize = False
if args.hogwild:
build_stra.async_mode = True
if args.enable_backward_deps: if args.enable_backward_deps:
build_stra.enable_backward_optimizer_op_deps = True build_stra.enable_backward_optimizer_op_deps = True
...@@ -411,6 +429,7 @@ def runtime_main(test_class): ...@@ -411,6 +429,7 @@ def runtime_main(test_class):
parser.add_argument('--use_dgc', action='store_true') parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_reduce', action='store_true') parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true') parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument('--hogwild', action='store_true')
parser.add_argument( parser.add_argument(
'--use_reader_alloc', action='store_true', required=False) '--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2) parser.add_argument('--batch_size', required=False, type=int, default=2)
...@@ -467,6 +486,7 @@ class TestDistBase(unittest.TestCase): ...@@ -467,6 +486,7 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(), self._find_free_port()) self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable self._python_interp = sys.executable
self._sync_mode = True self._sync_mode = True
self._hogwild_mode = False
self._enforce_place = None self._enforce_place = None
self._use_reduce = False self._use_reduce = False
self._dc_asgd = False # must use with async mode self._dc_asgd = False # must use with async mode
...@@ -630,6 +650,9 @@ class TestDistBase(unittest.TestCase): ...@@ -630,6 +650,9 @@ class TestDistBase(unittest.TestCase):
if self._sync_mode: if self._sync_mode:
tr0_cmd += " --sync_mode" tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode" tr1_cmd += " --sync_mode"
if self._hogwild_mode:
tr0_cmd += " --hogwild"
tr1_cmd += " --hogwild"
if self._use_reduce: if self._use_reduce:
tr0_cmd += " --use_reduce" tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce" tr1_cmd += " --use_reduce"
...@@ -825,7 +848,7 @@ class TestDistBase(unittest.TestCase): ...@@ -825,7 +848,7 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "10" required_envs["GLOG_v"] = "10"
required_envs["GLOG_logtostderr"] = "1" required_envs["GLOG_logtostderr"] = "1"
local_losses\ local_losses \
= self._run_local(model_file, required_envs, = self._run_local(model_file, required_envs,
check_error_log) check_error_log)
if self._nccl2_mode: if self._nccl2_mode:
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import os import os
...@@ -29,14 +30,13 @@ def skip_ci(func): ...@@ -29,14 +30,13 @@ def skip_ci(func):
return __func__ return __func__
@skip_ci
class TestDistCTR2x2(TestDistBase): class TestDistCTR2x2(TestDistBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = True self._sync_mode = True
self._enforce_place = "CPU" self._enforce_place = "CPU"
def test_dist_ctr(self): def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False) self.check_with_place("dist_ctr.py", delta=1e-2, check_error_log=False)
@skip_ci @skip_ci
...@@ -54,5 +54,40 @@ class TestDistCTRWithL2Decay2x2(TestDistBase): ...@@ -54,5 +54,40 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
need_envs=need_envs) need_envs=need_envs)
class TestDistCTR2x2_ASYNC(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"FLAGS_communicator_independent_recv_thread": "0"
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
...@@ -30,7 +30,6 @@ def skip_ci(func): ...@@ -30,7 +30,6 @@ def skip_ci(func):
return __func__ return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase): class TestDistMnist2x2(TestFleetBase):
def _setup_config(self): def _setup_config(self):
self._sync_mode = False self._sync_mode = False
......
...@@ -33,7 +33,7 @@ class TestDistSimnetBowDense2x2(TestDistBase): ...@@ -33,7 +33,7 @@ class TestDistSimnetBowDense2x2(TestDistBase):
self.check_with_place( self.check_with_place(
"dist_simnet_bow.py", "dist_simnet_bow.py",
delta=1e-5, delta=1e-5,
check_error_log=False, check_error_log=True,
need_envs=need_envs) need_envs=need_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册