From df45c8c538bddc1d43f933438413d4143c588fce Mon Sep 17 00:00:00 2001 From: Qiao Longfei Date: Sat, 30 Mar 2019 23:00:17 +0800 Subject: [PATCH] update nce and hierarchical_sigmoid remote_prefetch test=develop --- paddle/fluid/framework/details/async_ssa_graph_executor.cc | 5 +++-- paddle/fluid/operators/hierarchical_sigmoid_op.h | 3 ++- paddle/fluid/operators/nce_op.h | 3 ++- 3 files changed, 7 insertions(+), 4 deletions(-) diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 8fe4cdc7099..52641260a6c 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -81,8 +81,9 @@ void ProcessGraph(std::vector graphs, Scope *scope) { nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; - } else if (node->Name() == "lookup_table") { - VLOG(0) << "set lookup_table op remote_prefetch to false"; + } else if (node->Name() == "lookup_table" || node->Name() == "nce" || + node->Name() == "hierarchical_sigmoid") { + VLOG(0) << "set " << node->Name() << " op remote_prefetch to false"; node->Op()->SetAttr("remote_prefetch", false); } } diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index ed978782402..82c8171ca52 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -68,8 +68,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { size_t num_classes = static_cast(ctx.Attr("num_classes")); // for remote prefetch + auto remote_prefetch = ctx.Attr("remote_prefetch"); auto epmap = ctx.Attr>("epmap"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 25b6ed851bc..12f3118ec77 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel { auto input_mat = EigenMatrix::From(*(context.Input("Input"))); // for remote prefetch + auto remote_prefetch = context.Attr("remote_prefetch"); auto epmap = context.Attr>("epmap"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server -- GitLab