未验证 提交 8f0b3c05 编写于 作者: T tangwei12 提交者: GitHub

the integrated communicator (#19849)

* add a base class for the Communicator
* add AsyncCommunicator Impl for async distributed training
上级 8d92b36d
......@@ -86,13 +86,10 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
if (send_varname_to_ctx.size() > 0) {
VLOG(3) << "this is distribute mode, will use communicator";
if (operators::distributed::Communicator::GetInstance() == nullptr) {
operators::distributed::Communicator::Init(send_varname_to_ctx,
auto *instance = operators::distributed::Communicator::InitInstance<
operators::distributed::AsyncCommunicator>(send_varname_to_ctx,
recv_varname_to_ctx, scope);
operators::distributed::Communicator::GetInstance()->Start();
} else {
VLOG(3) << "communicator has been initialized, skip";
}
if (!instance->IsRunning()) instance->Start();
}
#endif
}
......
......@@ -52,14 +52,16 @@ inline double GetCurrentUS() {
return 1e+6 * time.tv_sec + time.tv_usec;
}
std::once_flag Communicator::init_flag_;
std::shared_ptr<Communicator> Communicator::communicator_(nullptr);
Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
void AsyncCommunicator::InitImpl(const RpcCtxMap &send_varname_to_ctx,
const RpcCtxMap &recv_varname_to_ctx,
Scope *recv_scope)
: send_varname_to_ctx_(send_varname_to_ctx),
recv_varname_to_ctx_(recv_varname_to_ctx),
recv_scope_(recv_scope) {
Scope *recv_scope) {
send_varname_to_ctx_ = std::move(send_varname_to_ctx);
recv_varname_to_ctx_ = std::move(recv_varname_to_ctx);
recv_scope_ = std::move(recv_scope);
// get all send information from graph, build vars_to_send
VLOG(0) << "communicator_independent_recv_thread: "
<< FLAGS_communicator_independent_recv_thread;
......@@ -98,7 +100,51 @@ Communicator::Communicator(const RpcCtxMap &send_varname_to_ctx,
}
}
Communicator::~Communicator() {
void AsyncCommunicator::InitImpl(const paddle::framework::ProgramDesc &program,
Scope *param_scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto *op : program.Block(0).AllOps()) {
VLOG(3) << "node name " << op->Type();
if (op->Type() == "send") {
auto send_var_name = op->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("send_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") {
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
auto recv_var_name = op->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("recv_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
}
}
// init communicator here
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) {
LOG(WARNING) << "no var need to send and recv!!";
}
operators::distributed::AsyncCommunicator::InitImpl(
send_varname_to_ctx, recv_varname_to_ctx, param_scope);
}
AsyncCommunicator::~AsyncCommunicator() {
if (FLAGS_v >= 3) {
std::string msg("~Communicator");
fwrite(msg.c_str(), msg.length(), 1, stdout);
......@@ -112,7 +158,7 @@ Communicator::~Communicator() {
}
}
void Communicator::SendThread() {
void AsyncCommunicator::SendThread() {
VLOG(3) << "SendThread start!";
while (running_) {
std::vector<std::future<void>> task_futures;
......@@ -175,50 +221,12 @@ void Communicator::SendThread() {
VLOG(3) << "run send graph use time "
<< after_run_send_graph - before_run_send_graph;
RecvNonIndependent();
Recv();
}
VLOG(0) << "communicator stopped, send thread exit";
}
void Communicator::RecvNonIndependent() {
if (FLAGS_communicator_independent_recv_thread) {
return;
}
auto grad_num = grad_num_.load();
if (grad_num > 0) {
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
}
void Communicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
if (!running_) return;
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::RecvThread() {
void AsyncCommunicator::RecvThread() {
VLOG(3) << "RecvThread start!";
while (running_) {
auto grad_num = grad_num_.load();
......@@ -233,7 +241,7 @@ void Communicator::RecvThread() {
VLOG(0) << "communicator stopped, recv thread exit";
}
void Communicator::Send(const std::string &var_name,
void AsyncCommunicator::Send(const std::string &var_name,
const framework::Scope &scope) {
VLOG(3) << "communicator send " << var_name;
// push var into send queue by var_name
......@@ -255,56 +263,45 @@ void Communicator::Send(const std::string &var_name,
}
}
void Communicator::Init(const paddle::framework::ProgramDesc &program,
Scope *param_scope) {
using RpcCtxMap = operators::distributed::RpcCtxMap;
VLOG(3) << "ProcessGraph";
RpcCtxMap send_varname_to_ctx;
RpcCtxMap recv_varname_to_ctx;
for (auto *op : program.Block(0).AllOps()) {
VLOG(3) << "node name " << op->Type();
if (op->Type() == "send") {
auto send_var_name = op->Input("X")[0];
auto send_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("send_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto height_section =
boost::get<std::vector<int64_t>>(op->GetNullableAttr("sections"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
send_varname_to_ctx[send_var_name] = operators::distributed::RpcContext(
send_var_name, send_varnames, epmap, height_section, trainer_id);
VLOG(3) << "find and init an send op: "
<< send_varname_to_ctx[send_var_name];
} else if (op->Type() == "recv") {
auto do_not_run = boost::get<int>(op->GetNullableAttr("do_not_run"));
PADDLE_ENFORCE_GT(do_not_run, 0, "recv should not run!");
auto recv_var_name = op->Output("Out")[0];
auto recv_varnames = boost::get<std::vector<std::string>>(
op->GetNullableAttr("recv_varnames"));
auto epmap =
boost::get<std::vector<std::string>>(op->GetNullableAttr("epmap"));
auto trainer_id = boost::get<int>(op->GetNullableAttr("trainer_id"));
recv_varname_to_ctx[recv_var_name] = operators::distributed::RpcContext(
recv_var_name, recv_varnames, epmap, {}, trainer_id);
}
void AsyncCommunicator::Recv() {
if (FLAGS_communicator_independent_recv_thread) {
return;
}
// init communicator here
if (send_varname_to_ctx.size() == 0 && recv_varname_to_ctx.size() == 0) {
LOG(WARNING) << "no var need to send and recv!!";
auto grad_num = grad_num_.load();
if (grad_num > 0) {
RecvAll();
grad_num_.store(0);
} else {
std::this_thread::sleep_for(std::chrono::milliseconds(10));
}
operators::distributed::Communicator::Init(send_varname_to_ctx,
recv_varname_to_ctx, param_scope);
}
Communicator *Communicator::GetInstance() { return communicator_.get(); }
std::shared_ptr<Communicator> Communicator::GetInstantcePtr() {
return communicator_;
void AsyncCommunicator::RecvAll() {
VLOG(3) << "parallel run recv graph";
if (!running_) return;
auto before_send = GetCurrentUS();
std::vector<std::future<void>> task_futures;
task_futures.reserve(recv_varname_to_ctx_.size());
for (auto &iter : recv_varname_to_ctx_) {
auto recv_task = [this, &iter] {
auto &var_name = iter.first;
VLOG(4) << "recv var " << var_name;
auto recv_functor = distributed::ParameterRecv<float>();
if (!FLAGS_communicator_fake_rpc) {
recv_functor(iter.second, *recv_scope_);
}
};
task_futures.emplace_back(recv_threadpool_->enqueue(std::move(recv_task)));
}
for (auto &task : task_futures) {
task.wait();
}
auto after_recv = GetCurrentUS();
VLOG(1) << "run recv graph use time " << after_recv - before_send;
}
void Communicator::Start() {
void AsyncCommunicator::Start() {
VLOG(0) << "Communicator start";
if (!communicator_) {
VLOG(0) << "Communicator is not inited, do nothing";
......@@ -313,15 +310,15 @@ void Communicator::Start() {
running_ = true;
// start send and recv thread
send_thread_.reset(
new std::thread(std::bind(&Communicator::SendThread, this)));
new std::thread(std::bind(&AsyncCommunicator::SendThread, this)));
if (FLAGS_communicator_independent_recv_thread) {
recv_thread_.reset(
new std::thread(std::bind(&Communicator::RecvThread, this)));
new std::thread(std::bind(&AsyncCommunicator::RecvThread, this)));
}
}
}
void Communicator::Stop() {
void AsyncCommunicator::Stop() {
VLOG(0) << "Communicator stop";
running_ = false;
if (!communicator_) {
......
......@@ -161,27 +161,97 @@ using RpcCtxMap = std::unordered_map<std::string, RpcContext>;
class Communicator {
public:
Communicator(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope);
Communicator() {}
virtual ~Communicator() {}
~Communicator();
virtual void Start() = 0;
virtual void Stop() = 0;
virtual bool IsRunning() { return running_; }
void Start();
void Stop();
virtual void Send(const std::string& var_name,
const framework::Scope& scope) = 0;
virtual void Recv() = 0;
bool IsRunning() { return running_; }
virtual void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) = 0;
// send grad
void Send(const std::string& var_name, const framework::Scope& scope);
virtual void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) = 0;
private:
// recv all parameter
static Communicator* GetInstance() { return communicator_.get(); }
static std::shared_ptr<Communicator> GetInstantcePtr() {
return communicator_;
}
template <typename T>
static Communicator* InitInstance(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithRpcCtx<T>,
send_varname_to_ctx, recv_varname_to_ctx, recv_scope);
return communicator_.get();
}
// Init is called by InitInstance.
template <typename T>
static void InitWithRpcCtx(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(send_varname_to_ctx, recv_varname_to_ctx,
recv_scope);
}
}
template <typename T>
static Communicator* InitInstance(
const paddle::framework::ProgramDesc& program, Scope* recv_scope) {
std::call_once(init_flag_, &Communicator::InitWithProgram<T>, program,
recv_scope);
return communicator_.get();
}
template <typename T>
static void InitWithProgram(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) {
if (communicator_.get() == nullptr) {
communicator_.reset(new T());
communicator_->InitImpl(program, recv_scope);
}
}
protected:
bool running_ = false;
static std::shared_ptr<Communicator> communicator_;
static std::once_flag init_flag_;
};
class AsyncCommunicator : public Communicator {
public:
AsyncCommunicator() {}
~AsyncCommunicator();
void Start() override;
void Stop() override;
void Send(const std::string& var_name,
const framework::Scope& scope) override;
void Recv() override;
void RecvAll();
void RecvNonIndependent();
void InitImpl(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx,
Scope* recv_scope) override;
void InitImpl(const paddle::framework::ProgramDesc& program,
Scope* recv_scope) override;
void SendThread();
void RecvThread();
bool running_ = false;
private:
std::unordered_map<std::string,
std::shared_ptr<BlockingQueue<std::shared_ptr<Variable>>>>
send_varname_to_queue_;
......@@ -194,26 +264,6 @@ class Communicator {
std::unique_ptr<::ThreadPool> send_threadpool_{nullptr};
std::unique_ptr<::ThreadPool> recv_threadpool_{nullptr};
std::atomic_uint grad_num_{0}; // the num of gradient sent since last recv
// the following code is for initialize the commnunicator
public:
static void Init(const RpcCtxMap& send_varname_to_ctx,
const RpcCtxMap& recv_varname_to_ctx, Scope* recv_scope) {
if (communicator_ == nullptr) {
communicator_.reset(new Communicator(send_varname_to_ctx,
recv_varname_to_ctx, recv_scope));
}
}
static void Init(const paddle::framework::ProgramDesc& program,
Scope* param_scope);
static Communicator* GetInstance();
static std::shared_ptr<Communicator> GetInstantcePtr();
private:
static std::shared_ptr<Communicator> communicator_;
};
} // namespace distributed
......
......@@ -48,14 +48,7 @@ class SendOp : public framework::OperatorBase {
if (send_varnames.size() > 0) {
PADDLE_ENFORCE_EQ(ins.size(), 1, "");
if (distributed::Communicator::GetInstance() == nullptr) {
auto send_functor = distributed::ParameterSend<float>();
auto rpc_ctx = distributed::RpcContext(ins[0], send_varnames, epmap,
height_sections, trainer_id);
send_functor(rpc_ctx, scope, true);
} else {
distributed::Communicator::GetInstance()->Send(ins[0], scope);
}
} else {
platform::DeviceContextPool& pool =
platform::DeviceContextPool::Instance();
......
......@@ -26,6 +26,7 @@ namespace py = pybind11;
using paddle::framework::ProgramDesc;
using paddle::operators::distributed::Communicator;
using paddle::operators::distributed::AsyncCommunicator;
using paddle::framework::Scope;
namespace paddle {
......@@ -36,7 +37,7 @@ void BindCommunicator(py::module* m) {
py::class_<Communicator, std::shared_ptr<Communicator>>(*m,
"DistCommunicator")
.def(py::init([](const ProgramDesc& program, Scope* param_scope) {
Communicator::Init(program, param_scope);
Communicator::InitInstance<AsyncCommunicator>(program, param_scope);
return Communicator::GetInstantcePtr();
}))
.def("stop", &Communicator::Stop)
......
......@@ -75,11 +75,14 @@ class TestDistRunnerBase(object):
sync_mode,
dc_asgd=False,
current_endpoint=None,
nccl_comm_num=1):
nccl_comm_num=1,
hogwild_mode=False):
# NOTE: import fluid until runtime, or else forking processes will cause error.
config = fluid.DistributeTranspilerConfig()
config.enable_dc_asgd = dc_asgd
config.sync_mode = sync_mode
config.runtime_split_send_recv = hogwild_mode
if nccl_comm_num > 1:
config.nccl_comm_num = nccl_comm_num
# config.runtime_split_send_recv = True
......@@ -89,6 +92,7 @@ class TestDistRunnerBase(object):
program=main_program,
pservers=pserver_endpoints,
trainers=trainers,
sync_mode=sync_mode,
current_endpoint=current_endpoint)
return t
......@@ -96,9 +100,15 @@ class TestDistRunnerBase(object):
self.lr = args.lr
self.get_model(batch_size=args.batch_size)
# NOTE: pserver should not call memory optimize
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(), args.endpoints,
args.trainers, args.sync_mode, args.dc_asgd)
t = self.get_transpiler(
trainer_id=args.trainer_id,
main_program=fluid.default_main_program(),
pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
pserver_prog = t.get_pserver_program(args.current_endpoint)
startup_prog = t.get_startup_program(args.current_endpoint,
pserver_prog)
......@@ -120,7 +130,7 @@ class TestDistRunnerBase(object):
dist_strategy = DistributedStrategy()
dist_strategy.exec_strategy = exec_strategy
dist_strategy.fuse_memory_size = 1 #MB
dist_strategy.fuse_memory_size = 1 # MB
dist_strategy.fuse_laryer_size = 1
if args.use_local_sgd:
dist_strategy.use_local_sgd = True
......@@ -130,8 +140,8 @@ class TestDistRunnerBase(object):
role = role_maker.PaddleCloudRoleMaker(is_collective=True)
fleet.init(role)
print_to_err("gpu_fleet", "fleet.node_num:")
#"fleet.node_id:", fleet.node_id(),
#"fleet.trainer_num:", fleet.worker_num())
# "fleet.node_id:", fleet.node_id(),
# "fleet.trainer_num:", fleet.worker_num())
test_program, avg_cost, train_reader, test_reader, batch_acc, predict = \
self.get_model(batch_size=args.batch_size, dist_strategy=dist_strategy)
......@@ -196,10 +206,15 @@ class TestDistRunnerBase(object):
print_to_err(
type(self).__name__,
"begin to run transpile on trainer with pserver mode")
t = self.get_transpiler(args.trainer_id,
fluid.default_main_program(),
args.endpoints, args.trainers,
args.sync_mode, args.dc_asgd)
t = self.get_transpiler(
trainer_id=args.trainer_id,
main_program=fluid.default_main_program(),
pserver_endpoints=args.endpoints,
trainers=args.trainers,
sync_mode=args.sync_mode,
dc_asgd=args.dc_asgd,
hogwild_mode=args.hogwild)
trainer_prog = t.get_trainer_program()
print_to_err(
type(self).__name__,
......@@ -251,6 +266,9 @@ class TestDistRunnerBase(object):
build_stra.enable_inplace = False
build_stra.memory_optimize = False
if args.hogwild:
build_stra.async_mode = True
if args.enable_backward_deps:
build_stra.enable_backward_optimizer_op_deps = True
......@@ -411,6 +429,7 @@ def runtime_main(test_class):
parser.add_argument('--use_dgc', action='store_true')
parser.add_argument('--use_reduce', action='store_true')
parser.add_argument('--dc_asgd', action='store_true')
parser.add_argument('--hogwild', action='store_true')
parser.add_argument(
'--use_reader_alloc', action='store_true', required=False)
parser.add_argument('--batch_size', required=False, type=int, default=2)
......@@ -467,6 +486,7 @@ class TestDistBase(unittest.TestCase):
self._find_free_port(), self._find_free_port())
self._python_interp = sys.executable
self._sync_mode = True
self._hogwild_mode = False
self._enforce_place = None
self._use_reduce = False
self._dc_asgd = False # must use with async mode
......@@ -630,6 +650,9 @@ class TestDistBase(unittest.TestCase):
if self._sync_mode:
tr0_cmd += " --sync_mode"
tr1_cmd += " --sync_mode"
if self._hogwild_mode:
tr0_cmd += " --hogwild"
tr1_cmd += " --hogwild"
if self._use_reduce:
tr0_cmd += " --use_reduce"
tr1_cmd += " --use_reduce"
......@@ -825,7 +848,7 @@ class TestDistBase(unittest.TestCase):
required_envs["GLOG_v"] = "10"
required_envs["GLOG_logtostderr"] = "1"
local_losses\
local_losses \
= self._run_local(model_file, required_envs,
check_error_log)
if self._nccl2_mode:
......
......@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
......@@ -29,14 +30,13 @@ def skip_ci(func):
return __func__
@skip_ci
class TestDistCTR2x2(TestDistBase):
def _setup_config(self):
self._sync_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
self.check_with_place("dist_ctr.py", delta=1e-7, check_error_log=False)
self.check_with_place("dist_ctr.py", delta=1e-2, check_error_log=False)
@skip_ci
......@@ -54,5 +54,40 @@ class TestDistCTRWithL2Decay2x2(TestDistBase):
need_envs=need_envs)
class TestDistCTR2x2_ASYNC(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
class TestDistCTR2x2_ASYNC2(TestDistBase):
def _setup_config(self):
self._sync_mode = False
self._hogwild_mode = True
self._enforce_place = "CPU"
def test_dist_ctr(self):
need_envs = {
"FLAGS_communicator_send_queue_size": "2",
"FLAGS_communicator_max_merge_var_num": "2",
"FLAGS_communicator_max_send_grad_num_before_recv": "2",
"FLAGS_communicator_independent_recv_thread": "0"
}
self.check_with_place(
"dist_ctr.py", delta=100, check_error_log=True, need_envs=need_envs)
if __name__ == "__main__":
unittest.main()
......@@ -30,7 +30,6 @@ def skip_ci(func):
return __func__
@skip_ci
class TestDistMnist2x2(TestFleetBase):
def _setup_config(self):
self._sync_mode = False
......
......@@ -33,7 +33,7 @@ class TestDistSimnetBowDense2x2(TestDistBase):
self.check_with_place(
"dist_simnet_bow.py",
delta=1e-5,
check_error_log=False,
check_error_log=True,
need_envs=need_envs)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册