diff --git a/paddle/fluid/framework/details/async_ssa_graph_executor.cc b/paddle/fluid/framework/details/async_ssa_graph_executor.cc index 8fe4cdc7099bbf39fdfdb6dd969a42d4f2bb525a..52641260a6c14337873e732cc1d4b3f397c07cdf 100644 --- a/paddle/fluid/framework/details/async_ssa_graph_executor.cc +++ b/paddle/fluid/framework/details/async_ssa_graph_executor.cc @@ -81,8 +81,9 @@ void ProcessGraph(std::vector graphs, Scope *scope) { nodes_to_delete.push_back(node); VLOG(3) << "find and remove an recv op: " << recv_varname_to_ctx[recv_var_name]; - } else if (node->Name() == "lookup_table") { - VLOG(0) << "set lookup_table op remote_prefetch to false"; + } else if (node->Name() == "lookup_table" || node->Name() == "nce" || + node->Name() == "hierarchical_sigmoid") { + VLOG(0) << "set " << node->Name() << " op remote_prefetch to false"; node->Op()->SetAttr("remote_prefetch", false); } } diff --git a/paddle/fluid/operators/hierarchical_sigmoid_op.h b/paddle/fluid/operators/hierarchical_sigmoid_op.h index ed978782402829207c5bbbe325ddcbae679e825b..82c8171ca52ffb128df103f27bafbdba1e72e52f 100644 --- a/paddle/fluid/operators/hierarchical_sigmoid_op.h +++ b/paddle/fluid/operators/hierarchical_sigmoid_op.h @@ -68,8 +68,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel { size_t num_classes = static_cast(ctx.Attr("num_classes")); // for remote prefetch + auto remote_prefetch = ctx.Attr("remote_prefetch"); auto epmap = ctx.Attr>("epmap"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server diff --git a/paddle/fluid/operators/nce_op.h b/paddle/fluid/operators/nce_op.h index 25b6ed851bc5149dcf6d25edc80544c99dd95d34..12f3118ec775dfce13d1f7ff836d82e1d999c65b 100644 --- a/paddle/fluid/operators/nce_op.h +++ b/paddle/fluid/operators/nce_op.h @@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel { auto input_mat = EigenMatrix::From(*(context.Input("Input"))); // for remote prefetch + auto remote_prefetch = context.Attr("remote_prefetch"); auto epmap = context.Attr>("epmap"); - if (!epmap.empty()) { + if (remote_prefetch && !epmap.empty()) { // if epmap is not empty, then the parameter will be fetched from remote // parameter // server