提交 627a6b8b 编写于 作者: T tangwei12

add prefetch in nce

上级 4cb0100c
......@@ -26,6 +26,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
......@@ -166,8 +170,8 @@ class NCEKernel : public framework::OpKernel<T> {
auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names");
framework::Variable *ids = local_scope.Var("Ids");
framework::Variable *weight = local_scope.Var("Weight");
local_scope.Var("Ids");
local_scope.Var("Weight");
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch("Ids", "Weight", table_names, epmap,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册