提交 df45c8c5 编写于 作者: Q Qiao Longfei

update nce and hierarchical_sigmoid remote_prefetch

test=develop
上级 a1821a04
......@@ -81,8 +81,9 @@ void ProcessGraph(std::vector<ir::Graph *> graphs, Scope *scope) {
nodes_to_delete.push_back(node);
VLOG(3) << "find and remove an recv op: "
<< recv_varname_to_ctx[recv_var_name];
} else if (node->Name() == "lookup_table") {
VLOG(0) << "set lookup_table op remote_prefetch to false";
} else if (node->Name() == "lookup_table" || node->Name() == "nce" ||
node->Name() == "hierarchical_sigmoid") {
VLOG(0) << "set " << node->Name() << " op remote_prefetch to false";
node->Op()->SetAttr("remote_prefetch", false);
}
}
......
......@@ -68,8 +68,9 @@ class HierarchicalSigmoidOpKernel : public framework::OpKernel<T> {
size_t num_classes = static_cast<size_t>(ctx.Attr<int>("num_classes"));
// for remote prefetch
auto remote_prefetch = ctx.Attr<bool>("remote_prefetch");
auto epmap = ctx.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
......
......@@ -156,9 +156,10 @@ class NCEKernel : public framework::OpKernel<T> {
auto input_mat = EigenMatrix<T>::From(*(context.Input<Tensor>("Input")));
// for remote prefetch
auto remote_prefetch = context.Attr<bool>("remote_prefetch");
auto epmap = context.Attr<std::vector<std::string>>("epmap");
if (!epmap.empty()) {
if (remote_prefetch && !epmap.empty()) {
// if epmap is not empty, then the parameter will be fetched from remote
// parameter
// server
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册