未验证 提交 c0061ff5 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】geo send sparse optimize (#27719) (#27979)

* test=develop, fix geo sgd communicator and gloo http_init for ps
上级 51dd268c
...@@ -466,41 +466,34 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -466,41 +466,34 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables, const std::vector<std::string> &var_tables,
const framework::Scope &scope) { const framework::Scope &scope) {
waiting_ = false; waiting_ = false;
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (table_name == STEP_COUNTER) return;
auto before_send = GetCurrentUS(); auto before_send = GetCurrentUS();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table; size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
for (size_t i = 0; i < var_tables.size(); i++) { std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
auto table_name = var_tables[i];
if (table_name == STEP_COUNTER) {
continue;
} else {
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
for (size_t j = 0; j < splited_var_nums; j++) {
if (ids_table.find(
send_varname_to_ctx_[table_name].splited_varnames[j]) ==
ids_table.end()) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
}
auto *var = scope.FindVar(var_names[i]); for (size_t j = 0; j < splited_var_nums; j++) {
auto var_tensor = var->Get<framework::LoDTensor>(); ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
int element_number = var_tensor.numel(); send_varname_to_ctx_[table_name].splited_varnames[j],
const int64_t *var_mutable_data = var_tensor.data<int64_t>(); std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(var_names[0]);
auto &rows = var->Get<framework::SelectedRows>().rows();
// insert ids which has not been record // insert ids which has not been record
for (int j = 0; j < element_number; j++) { for (size_t j = 0; j < rows.size(); j++) {
auto ep_idx = var_mutable_data[j] % splited_var_nums; auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx]) ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(var_mutable_data[j]); .insert(rows[j]);
}
}
} }
auto before_push = GetCurrentUS(); auto before_push = GetCurrentUS();
for (auto &iter : ids_table) { for (auto &iter : ids_table) {
auto &key = iter.first; auto &key = iter.first;
...@@ -512,8 +505,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names, ...@@ -512,8 +505,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
<< "'s queue"; << "'s queue";
} }
auto after_send = GetCurrentUS(); auto after_send = GetCurrentUS();
VLOG(3) << "run send_op finish. using " << (before_push - before_send) << "; " VLOG(3) << "run send " << table_name << " op finish. using "
<< (after_send - before_push); << (before_push - before_send) << "; " << (after_send - before_push);
} }
void GeoCommunicator::MainThread() { void GeoCommunicator::MainThread() {
......
...@@ -826,7 +826,7 @@ class PaddleCloudRoleMaker(RoleMakerBase): ...@@ -826,7 +826,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
start_http_server = True start_http_server = True
else: else:
ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "") ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
if self._server_index() == 0: if self._is_server() and self._server_index() == 0:
start_http_server = True start_http_server = True
ip, port = ep_rank_0.split(':') ip, port = ep_rank_0.split(':')
kwargs = { kwargs = {
......
...@@ -895,7 +895,7 @@ class ParameterServerLauncher(object): ...@@ -895,7 +895,7 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINERS_NUM": str(self.worker_num), "PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_server.endpoint.split(":")[0], "POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port "PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
} }
...@@ -959,7 +959,7 @@ class ParameterServerLauncher(object): ...@@ -959,7 +959,7 @@ class ParameterServerLauncher(object):
"TRAINING_ROLE": "TRAINER", "TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank), "PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0", "FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0", "FLAGS_selected_xpus": "0",
...@@ -1028,7 +1028,7 @@ class ParameterServerLauncher(object): ...@@ -1028,7 +1028,7 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINERS_NUM": str(self.worker_num), "PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_heter_worker.endpoint.split(":")[0], "POD_IP": cur_heter_worker.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1", "PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2", "PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir, "PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0", "FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0", "FLAGS_selected_xpus": "0",
......
...@@ -169,7 +169,7 @@ def append_send_ops_pass(program, config): ...@@ -169,7 +169,7 @@ def append_send_ops_pass(program, config):
trainer_id = config.get_role_id() trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints() pserver_endpoints = config.get_ps_endpoints()
def _append_grad_send_op(union_vars, queue): def _append_send_op(union_vars, queue):
if queue == STEP_COUNTER: if queue == STEP_COUNTER:
send_input_vars = [] send_input_vars = []
...@@ -198,43 +198,6 @@ def append_send_ops_pass(program, config): ...@@ -198,43 +198,6 @@ def append_send_ops_pass(program, config):
return dummy_output return dummy_output
def _append_sparse_ids_send_op():
sparse_var = []
sparse_tables = []
unique_sparse_var = {}
for op in program.global_block().ops:
if "is_sparse" in op.all_attrs():
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if input_var_name in unique_sparse_var:
if unique_sparse_var[input_var_name] == sparse_var_name:
continue
input_var = program.global_block().var(input_var_name)
sparse_var.append(input_var)
sparse_tables.append(sparse_var_name)
unique_sparse_var[input_var_name] = sparse_var_name
dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": sparse_var},
outputs={"Out": dummy_output},
attrs={
"send_varnames": sparse_tables,
"merge_add": True,
"use_send_handler": False,
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return dummy_output
def _append_barrier_op(dummys): def _append_barrier_op(dummys):
program.global_block().append_op( program.global_block().append_op(
type="send_barrier", type="send_barrier",
...@@ -251,12 +214,8 @@ def append_send_ops_pass(program, config): ...@@ -251,12 +214,8 @@ def append_send_ops_pass(program, config):
sends = config.get_trainer_send_context() sends = config.get_trainer_send_context()
if mode == DistributedMode.GEO: for merged_name, send in sends.items():
dummys.append(_append_sparse_ids_send_op()) dummys.append(_append_send_op(send.origin_varnames(), merged_name))
else:
for merged_name, send in sends.items():
dummys.append(
_append_grad_send_op(send.origin_varnames(), merged_name))
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]: if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys) _append_barrier_op(dummys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册