未验证 提交 aa3b4ed7 编写于 作者: 1 123malin 提交者: GitHub

【paddle.fleet】geo send sparse optimize (#27719)

* test=develop, fix geo sgd communicator

* test=develop, gloo_init_method

* test=develop, bug fix for gloo http_init
上级 2ac6c6c3
......@@ -466,41 +466,34 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
const std::vector<std::string> &var_tables,
const framework::Scope &scope) {
waiting_ = false;
PADDLE_ENFORCE_EQ(
var_tables.size(), 1,
platform::errors::InvalidArgument("var_tables.size() == 1 is permitted"));
auto table_name = var_tables[0];
if (table_name == STEP_COUNTER) return;
auto before_send = GetCurrentUS();
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
for (size_t i = 0; i < var_tables.size(); i++) {
auto table_name = var_tables[i];
if (table_name == STEP_COUNTER) {
continue;
} else {
size_t splited_var_nums =
send_varname_to_ctx_[table_name].splited_varnames.size();
for (size_t j = 0; j < splited_var_nums; j++) {
if (ids_table.find(
send_varname_to_ctx_[table_name].splited_varnames[j]) ==
ids_table.end()) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
}
std::unordered_map<std::string, std::unordered_set<int64_t>> ids_table;
auto *var = scope.FindVar(var_names[i]);
auto var_tensor = var->Get<framework::LoDTensor>();
int element_number = var_tensor.numel();
const int64_t *var_mutable_data = var_tensor.data<int64_t>();
for (size_t j = 0; j < splited_var_nums; j++) {
ids_table.insert(std::pair<std::string, std::unordered_set<int64_t>>(
send_varname_to_ctx_[table_name].splited_varnames[j],
std::unordered_set<int64_t>()));
}
auto *var = scope.FindVar(var_names[0]);
auto &rows = var->Get<framework::SelectedRows>().rows();
// insert ids which has not been record
for (int j = 0; j < element_number; j++) {
auto ep_idx = var_mutable_data[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(var_mutable_data[j]);
}
}
// insert ids which has not been record
for (size_t j = 0; j < rows.size(); j++) {
auto ep_idx = rows[j] % splited_var_nums;
ids_table.at(send_varname_to_ctx_[table_name].splited_varnames[ep_idx])
.insert(rows[j]);
}
auto before_push = GetCurrentUS();
for (auto &iter : ids_table) {
auto &key = iter.first;
......@@ -512,8 +505,8 @@ void GeoCommunicator::Send(const std::vector<std::string> &var_names,
<< "'s queue";
}
auto after_send = GetCurrentUS();
VLOG(3) << "run send_op finish. using " << (before_push - before_send) << "; "
<< (after_send - before_push);
VLOG(3) << "run send " << table_name << " op finish. using "
<< (before_push - before_send) << "; " << (after_send - before_push);
}
void GeoCommunicator::MainThread() {
......
......@@ -826,7 +826,7 @@ class PaddleCloudRoleMaker(RoleMakerBase):
start_http_server = True
else:
ep_rank_0 = os.getenv("PADDLE_GLOO_HTTP_ENDPOINT", "")
if self._server_index() == 0:
if self._is_server() and self._server_index() == 0:
start_http_server = True
ip, port = ep_rank_0.split(':')
kwargs = {
......
......@@ -895,7 +895,7 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_server.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"PADDLE_GLOO_HTTP_ENDPOINT": self.http_port
}
......@@ -959,7 +959,7 @@ class ParameterServerLauncher(object):
"TRAINING_ROLE": "TRAINER",
"PADDLE_TRAINER_ID": str(cur_worker.rank),
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
......@@ -1028,7 +1028,7 @@ class ParameterServerLauncher(object):
"PADDLE_TRAINERS_NUM": str(self.worker_num),
"POD_IP": cur_heter_worker.endpoint.split(":")[0],
"PADDLE_WITH_GLOO": "1",
"PADDLE_GLOO_RENDEZVOUS": "2",
"PADDLE_GLOO_RENDEZVOUS": "3",
"PADDLE_GLOO_FS_PATH": self.gloo_rendezvous_dir,
"FLAGS_selected_gpus": "0",
"FLAGS_selected_xpus": "0",
......
......@@ -169,7 +169,7 @@ def append_send_ops_pass(program, config):
trainer_id = config.get_role_id()
pserver_endpoints = config.get_ps_endpoints()
def _append_grad_send_op(union_vars, queue):
def _append_send_op(union_vars, queue):
if queue == STEP_COUNTER:
send_input_vars = []
......@@ -198,43 +198,6 @@ def append_send_ops_pass(program, config):
return dummy_output
def _append_sparse_ids_send_op():
sparse_var = []
sparse_tables = []
unique_sparse_var = {}
for op in program.global_block().ops:
if "is_sparse" in op.all_attrs():
if op.type == "lookup_table":
op._set_attr('remote_prefetch', False)
for input_var_name, sparse_var_name in zip(
op.input("Ids"), op.input("W")):
if input_var_name in unique_sparse_var:
if unique_sparse_var[input_var_name] == sparse_var_name:
continue
input_var = program.global_block().var(input_var_name)
sparse_var.append(input_var)
sparse_tables.append(sparse_var_name)
unique_sparse_var[input_var_name] = sparse_var_name
dummy_output = []
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
dummy_output = program.global_block().create_var(
name=framework.generate_control_dev_var_name())
program.global_block().append_op(
type="send",
inputs={"X": sparse_var},
outputs={"Out": dummy_output},
attrs={
"send_varnames": sparse_tables,
"merge_add": True,
"use_send_handler": False,
"endpoints": pserver_endpoints,
RPC_OP_ROLE_ATTR_NAME: RPC_OP_ROLE_ATTR_VALUE
})
return dummy_output
def _append_barrier_op(dummys):
program.global_block().append_op(
type="send_barrier",
......@@ -251,12 +214,8 @@ def append_send_ops_pass(program, config):
sends = config.get_trainer_send_context()
if mode == DistributedMode.GEO:
dummys.append(_append_sparse_ids_send_op())
else:
for merged_name, send in sends.items():
dummys.append(
_append_grad_send_op(send.origin_varnames(), merged_name))
for merged_name, send in sends.items():
dummys.append(_append_send_op(send.origin_varnames(), merged_name))
if mode in [DistributedMode.SYNC, DistributedMode.HALF_ASYNC]:
_append_barrier_op(dummys)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册