提交 a1821a04 编写于 作者: Q Qiao Longfei

remote remote_prefetch in embedding layer test=develop

上级 61912e87
...@@ -81,6 +81,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) { ...@@ -81,6 +81,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
nodes_to_delete.push_back(node); nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: " VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name]; << recv_varname_to_ctx[recv_var_name];
} else if (node->Name() == "lookup_table") {
VLOG(0) << "set lookup_table op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
} }
} }
} }
......
...@@ -50,11 +50,12 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -50,11 +50,12 @@ class LookupTableKernel : public framework::OpKernel<T> {
// for remote prefetch // for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap"); auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections = auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections"); context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) { if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote // if epmap is not empty, then the parameter will be fetched from remote
// parameter // parameter
// server // server
......
...@@ -341,8 +341,7 @@ def embedding(input, ...@@ -341,8 +341,7 @@ def embedding(input,
is_distributed=False, is_distributed=False,
padding_idx=None, padding_idx=None,
param_attr=None, param_attr=None,
dtype='float32', dtype='float32'):
remote_prefetch=False):
""" """
**Embedding Layer** **Embedding Layer**
...@@ -381,7 +380,7 @@ def embedding(input, ...@@ -381,7 +380,7 @@ def embedding(input,
""" """
helper = LayerHelper('embedding', **locals()) helper = LayerHelper('embedding', **locals())
remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch: if remote_prefetch:
assert is_sparse is True and is_distributed is False assert is_sparse is True and is_distributed is False
w = helper.create_parameter( w = helper.create_parameter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册