提交 47280ef8 编写于 作者: Q Qiao Longfei

lookup table op support prefetch

上级 4ad5fd8f
......@@ -37,7 +37,7 @@ if (WITH_GPU)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS})
register_operators(EXCLUDES warpctc_op conv_fusion_op lookup_table_op DEPS ${OP_HEADER_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
......@@ -55,6 +55,8 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
op_library(lookup_table_op DEPS parameter_prefetch)
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
......
......@@ -23,8 +23,12 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
......@@ -43,6 +47,25 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
auto table_name = context.Inputs("W").front();
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
if (!epmap.empty()) {
// if emap is not empty, then the paramter will be fetched from remote parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
......@@ -85,6 +108,7 @@ class LookupTableKernel : public framework::OpKernel<T> {
}
}
}
}
};
template <typename T>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册