From a1821a04493152facc8ff63a2bcd6b339028d7a5 Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 30 Mar 2019 22:52:19 +0800 Subject: [PATCH] remote remote_prefetch in embedding layer test=develop --- paddle/fluid/framework/details/async_ssa_graph_executor.cc | 3 +++ paddle/fluid/operators/lookup_table_op.h | 3 ++- python/paddle/fluid/layers/nn.py | 5 ++--- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index e9aad5d26..8fe4cdc70 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -81,6 +81,9 @@ void ProcessGraph(std::vector graphs, Scope *scope) { nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; + } else if (node->Name() == "lookup_table") { + VLOG(0) << "set lookup_table op remote_prefetch to false"; + node->Op()->SetAttr("remote_prefetch", false); } } } diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 524565a43..62e298e06 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -50,11 +50,12 @@ class LookupTableKernel : public framework::OpKernel { // for remote prefetch auto epmap = context.Attr>("epmap"); + auto remote_prefetch = context.Attr("remote_prefetch"); auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9743cfa72..f2413f603 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -341,8 +341,7 @@ def embedding(input, is_distributed=False, padding_idx=None, param_attr=None, - dtype='float32', - remote_prefetch=False): + dtype='float32'): """ **Embedding Layer** @@ -381,7 +380,7 @@ def embedding(input, """ helper = LayerHelper('embedding', **locals()) - remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch + remote_prefetch = is_sparse and (not is_distributed) if remote_prefetch: assert is_sparse is True and is_distributed is False w = helper.create_parameter( -- GitLab