提交 d98c59fd 编写于 作者: Q Qiao Longfei

support none sliced variable

上级 af2f5fc8
......@@ -154,7 +154,7 @@ inline void MergeMultipleVarsIntoOneBySection(
}
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& table_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context) {
......@@ -190,7 +190,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
<< " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i],
table_name + ".block" + std::to_string(i)));
table_names[i]));
} else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
}
......
......@@ -24,7 +24,7 @@ namespace operators {
namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& table_name,
const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context);
......
......@@ -99,6 +99,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping")
.SetDefault({});
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
AddComment(R"DOC(
Lookup Table Operator.
......
......@@ -54,13 +54,14 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (remote_prefetch) {
// if emap is not empty, then the parameter will be fetched from remote
// parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap,
operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context);
#else
PADDLE_THROW(
......
......@@ -247,10 +247,11 @@ class DistributeTranspiler(object):
return sparse_update_ops
def _update_remote_sparse_update_op(self, param_varname, height_sections,
endpint_map):
endpint_map, table_names):
for op in self.sparse_update_ops:
if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map)
op._set_attr('table_names', table_names)
op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id)
......@@ -326,6 +327,7 @@ class DistributeTranspiler(object):
# get all sparse update ops
self.sparse_update_ops = self._get_all_remote_sparse_update_op(
self.origin_program)
# use_sparse_update_param_name -> split_height_section
self.sparse_param_to_height_sections = dict()
# add distributed attrs to program
......@@ -365,6 +367,13 @@ class DistributeTranspiler(object):
splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[
splited_grad_varname]
if self._is_input_of_remote_sparse_update_op(
sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = [splited_vars[0].shape[0]]
elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg(
......@@ -435,9 +444,11 @@ class DistributeTranspiler(object):
all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping):
eps = []
table_names = []
for var in splited_var:
index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index])
table_names.append(var.name)
if self.sync_mode:
recv_dep_in = send_barrier_out
else:
......@@ -457,8 +468,8 @@ class DistributeTranspiler(object):
if param_varname in self.sparse_param_to_height_sections:
height_sections = self.sparse_param_to_height_sections[
param_varname]
self._update_remote_sparse_update_op(param_varname,
height_sections, eps)
self._update_remote_sparse_update_op(
param_varname, height_sections, eps, table_names)
else:
all_recv_outputs.extend(splited_var)
program.global_block().append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册