diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index e9aad5d264d1745662848d1ba313b573d0974cb7..8fe4cdc7099bbf39fdfdb6dd969a42d4f2bb525a 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -81,6 +81,9 @@ void ProcessGraph(std::vector graphs, Scope *scope) { nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; + } else if (node->Name() == "lookup_table") { + VLOG(0) << "set lookup_table op remote_prefetch to false"; + node->Op()->SetAttr("remote_prefetch", false); } } } diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 524565a439fec16e026f4b6414e78d0aff57ea29..62e298e066948c93a84a131a0dffc0a1d53f2a5b 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -50,11 +50,12 @@ class LookupTableKernel : public framework::OpKernel { // for remote prefetch auto epmap = context.Attr>("epmap"); + auto remote_prefetch = context.Attr("remote_prefetch"); auto height_sections = context.Attr>("height_sections"); auto table_names = context.Attr>("table_names"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index 9743cfa7272a90cef3fa0e95b18ae5fc49272060..f2413f603304f8262476ca3ae2b820c89d009c3d 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -341,8 +341,7 @@ def embedding(input, is_distributed=False, padding_idx=None, param_attr=None, - dtype='float32', - remote_prefetch=False): + dtype='float32'): """ **Embedding Layer** @@ -381,7 +380,7 @@ def embedding(input, """ helper = LayerHelper('embedding', **locals()) - remote_prefetch = is_sparse and (not is_distributed) and remote_prefetch + remote_prefetch = is_sparse and (not is_distributed) if remote_prefetch: assert is_sparse is True and is_distributed is False w = helper.create_parameter(