提交 d98c59fd 编写于 作者: Q Qiao Longfei

support none sliced variable

上级 af2f5fc8
...@@ -154,7 +154,7 @@ inline void MergeMultipleVarsIntoOneBySection( ...@@ -154,7 +154,7 @@ inline void MergeMultipleVarsIntoOneBySection(
} }
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& table_name, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context) { const framework::ExecutionContext& context) {
...@@ -190,7 +190,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -190,7 +190,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
<< " to get " << out_var_names[i] << " back"; << " to get " << out_var_names[i] << " back";
rets.push_back(rpc_client->AsyncPrefetchVar( rets.push_back(rpc_client->AsyncPrefetchVar(
epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i], epmap[i], ctx, local_scope, in_var_names[i], out_var_names[i],
table_name + ".block" + std::to_string(i))); table_names[i]));
} else { } else {
VLOG(30) << "don't send no-initialied variable: " << out_var_names[i]; VLOG(30) << "don't send no-initialied variable: " << out_var_names[i];
} }
......
...@@ -24,7 +24,7 @@ namespace operators { ...@@ -24,7 +24,7 @@ namespace operators {
namespace distributed { namespace distributed {
void prefetch(const std::string& id_name, const std::string& out_name, void prefetch(const std::string& id_name, const std::string& out_name,
const std::string& table_name, const std::vector<std::string>& table_names,
const std::vector<std::string>& epmap, const std::vector<std::string>& epmap,
const std::vector<int64_t>& height_sections, const std::vector<int64_t>& height_sections,
const framework::ExecutionContext& context); const framework::ExecutionContext& context);
......
...@@ -99,6 +99,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -99,6 +99,12 @@ class LookupTableOpMaker : public framework::OpProtoAndCheckerMaker {
"(string vector, default 127.0.0.1:6164)" "(string vector, default 127.0.0.1:6164)"
"Server endpoints in the order of input variables for mapping") "Server endpoints in the order of input variables for mapping")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<std::string>>(
"table_names",
"(string vector, the splited table names that will be fetched from "
"parameter server)"
"in the order of input variables for mapping")
.SetDefault({});
AddComment(R"DOC( AddComment(R"DOC(
Lookup Table Operator. Lookup Table Operator.
......
...@@ -54,13 +54,14 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -54,13 +54,14 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto remote_prefetch = context.Attr<bool>("remote_prefetch"); auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections = auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections"); context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (remote_prefetch) { if (remote_prefetch) {
// if emap is not empty, then the parameter will be fetched from remote // if emap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap, operators::distributed::prefetch(id_name, out_name, table_names, epmap,
height_sections, context); height_sections, context);
#else #else
PADDLE_THROW( PADDLE_THROW(
......
...@@ -247,10 +247,11 @@ class DistributeTranspiler(object): ...@@ -247,10 +247,11 @@ class DistributeTranspiler(object):
return sparse_update_ops return sparse_update_ops
def _update_remote_sparse_update_op(self, param_varname, height_sections, def _update_remote_sparse_update_op(self, param_varname, height_sections,
endpint_map): endpint_map, table_names):
for op in self.sparse_update_ops: for op in self.sparse_update_ops:
if param_varname in op.input_arg_names: if param_varname in op.input_arg_names:
op._set_attr('epmap', endpint_map) op._set_attr('epmap', endpint_map)
op._set_attr('table_names', table_names)
op._set_attr('height_sections', height_sections) op._set_attr('height_sections', height_sections)
op._set_attr('trainer_id', self.trainer_id) op._set_attr('trainer_id', self.trainer_id)
...@@ -326,6 +327,7 @@ class DistributeTranspiler(object): ...@@ -326,6 +327,7 @@ class DistributeTranspiler(object):
# get all sparse update ops # get all sparse update ops
self.sparse_update_ops = self._get_all_remote_sparse_update_op( self.sparse_update_ops = self._get_all_remote_sparse_update_op(
self.origin_program) self.origin_program)
# use_sparse_update_param_name -> split_height_section
self.sparse_param_to_height_sections = dict() self.sparse_param_to_height_sections = dict()
# add distributed attrs to program # add distributed attrs to program
...@@ -365,6 +367,13 @@ class DistributeTranspiler(object): ...@@ -365,6 +367,13 @@ class DistributeTranspiler(object):
splited_grad_varname = splited_vars[0].name splited_grad_varname = splited_vars[0].name
index = find_op_by_output_arg( index = find_op_by_output_arg(
program.global_block(), splited_grad_varname, reverse=True) program.global_block(), splited_grad_varname, reverse=True)
if splited_vars[0].type == core.VarDesc.VarType.SELECTED_ROWS:
sparse_param_name = self.grad_name_to_param_name[
splited_grad_varname]
if self._is_input_of_remote_sparse_update_op(
sparse_param_name):
self.sparse_param_to_height_sections[
sparse_param_name] = [splited_vars[0].shape[0]]
elif len(splited_vars) > 1: elif len(splited_vars) > 1:
orig_var = program.global_block().vars[splited_grad_varname] orig_var = program.global_block().vars[splited_grad_varname]
index = find_op_by_output_arg( index = find_op_by_output_arg(
...@@ -435,9 +444,11 @@ class DistributeTranspiler(object): ...@@ -435,9 +444,11 @@ class DistributeTranspiler(object):
all_recv_outputs = [] all_recv_outputs = []
for param_varname, splited_var in six.iteritems(self.param_var_mapping): for param_varname, splited_var in six.iteritems(self.param_var_mapping):
eps = [] eps = []
table_names = []
for var in splited_var: for var in splited_var:
index = [v.name for v in recv_vars].index(var.name) index = [v.name for v in recv_vars].index(var.name)
eps.append(eplist[index]) eps.append(eplist[index])
table_names.append(var.name)
if self.sync_mode: if self.sync_mode:
recv_dep_in = send_barrier_out recv_dep_in = send_barrier_out
else: else:
...@@ -457,8 +468,8 @@ class DistributeTranspiler(object): ...@@ -457,8 +468,8 @@ class DistributeTranspiler(object):
if param_varname in self.sparse_param_to_height_sections: if param_varname in self.sparse_param_to_height_sections:
height_sections = self.sparse_param_to_height_sections[ height_sections = self.sparse_param_to_height_sections[
param_varname] param_varname]
self._update_remote_sparse_update_op(param_varname, self._update_remote_sparse_update_op(
height_sections, eps) param_varname, height_sections, eps, table_names)
else: else:
all_recv_outputs.extend(splited_var) all_recv_outputs.extend(splited_var)
program.global_block().append_op( program.global_block().append_op(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册