提交 627a6b8b 编写于 作者: T tangwei12

add prefetch in nce

上级 4cb0100c
...@@ -26,6 +26,10 @@ limitations under the License. */ ...@@ -26,6 +26,10 @@ limitations under the License. */
#include "paddle/fluid/operators/math/sampler.h" #include "paddle/fluid/operators/math/sampler.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -166,8 +170,8 @@ class NCEKernel : public framework::OpKernel<T> { ...@@ -166,8 +170,8 @@ class NCEKernel : public framework::OpKernel<T> {
auto height_sections = context.Attr<std::vector<int>>("height_sections"); auto height_sections = context.Attr<std::vector<int>>("height_sections");
auto table_names = context.Attr<std::vector<std::string>>("table_names"); auto table_names = context.Attr<std::vector<std::string>>("table_names");
framework::Variable *ids = local_scope.Var("Ids"); local_scope.Var("Ids");
framework::Variable *weight = local_scope.Var("Weight"); local_scope.Var("Weight");
#ifdef PADDLE_WITH_DISTRIBUTE #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch("Ids", "Weight", table_names, epmap, operators::distributed::prefetch("Ids", "Weight", table_names, epmap,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册