提交 a1821a04 编写于 作者: Q Qiao Longfei

remote remote_prefetch in embedding layer test=develop

上级 61912e87
......@@ -81,6 +81,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} else if (node->Name() == "lookup_table") {
VLOG(0) << "set lookup_table op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
}
}
}
......
......@@ -50,11 +50,12 @@ class LookupTableKernel : public framework::OpKernel<T> {
// for remote prefetch
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
if (!epmap.empty()) {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
......
......@@ -341,8 +341,7 @@ def embedding(input,
is_distributed=False,
padding_idx=None,
param_attr=None,
dtype='float32',
remote_prefetch=False):
dtype='float32'):
"""
**Embedding Layer**
......@@ -381,7 +380,7 @@ def embedding(input,
"""
helper = LayerHelper('embedding', **locals())
remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch
remote_prefetch = is_sparse and (not is_distributed)
if remote_prefetch:
assert is_sparse is True and is_distributed is False
w = helper.create_parameter(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册