提交 47280ef8 编写于 作者: Q Qiao Longfei

lookup table op support prefetch

上级 4ad5fd8f
...@@ -37,7 +37,7 @@ if (WITH_GPU) ...@@ -37,7 +37,7 @@ if (WITH_GPU)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub) SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif() endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS}) register_operators(EXCLUDES warpctc_op conv_fusion_op lookup_table_op DEPS ${OP_HEADER_DEPS})
# warpctc_op needs cudnn 7 above # warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32) if (WITH_GPU AND NOT WIN32)
...@@ -55,6 +55,8 @@ else() ...@@ -55,6 +55,8 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif() endif()
op_library(lookup_table_op DEPS parameter_prefetch)
set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
......
...@@ -23,8 +23,12 @@ limitations under the License. */ ...@@ -23,8 +23,12 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h" #include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -43,44 +47,64 @@ class LookupTableKernel : public framework::OpKernel<T> { ...@@ -43,44 +47,64 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W"); auto *table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx"); auto id_name = context.Inputs("Ids").front();
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>()); auto out_name = context.Outputs("Out").front();
int64_t ids_numel = ids_t->numel(); auto table_name = context.Inputs("W").front();
auto epmap = context.Attr<std::vector<std::string>>("epmap");
if (table_var->IsType<LoDTensor>()) { auto height_sections =
auto *table_t = context.Input<LoDTensor>("W"); context.Attr<std::vector<int64_t>>("height_sections");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1]; if (!epmap.empty()) {
// if emap is not empty, then the paramter will be fetched from remote parameter
auto *table = table_t->data<T>(); // server
auto *output = output_t->mutable_data<T>(context.GetPlace()); #ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap,
for (int64_t i = 0; i < ids_numel; ++i) { height_sections, context);
if (padding_idx != kNoPadding && ids[i] == padding_idx) { #else
memset(output + i * row_width, 0, row_width * sizeof(T)); PADDLE_THROW(
} else { "paddle is not compiled with distribute support, can not do "
PADDLE_ENFORCE_LT(ids[i], row_number); "parameter prefetch!");
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); #endif
memcpy(output + i * row_width, table + ids[i] * row_width, } else {
row_width * sizeof(T)); int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
} }
} } else if (table_var->IsType<SelectedRows>()) {
} else if (table_var->IsType<SelectedRows>()) { const auto &table_t = table_var->Get<SelectedRows>();
const auto &table_t = table_var->Get<SelectedRows>(); int64_t row_width = table_t.value().dims()[1];
int64_t row_width = table_t.value().dims()[1]; const auto *table = table_t.value().data<T>();
const auto *table = table_t.value().data<T>(); auto *output = output_t->mutable_data<T>(context.GetPlace());
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); for (int64_t i = 0; i < ids_numel; ++i) {
for (int64_t i = 0; i < ids_numel; ++i) { if (padding_idx != kNoPadding && ids[i] == padding_idx) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) { memset(output + i * row_width, 0, row_width * sizeof(T));
memset(output + i * row_width, 0, row_width * sizeof(T)); } else {
} else { PADDLE_ENFORCE_GE(ids[i], 0);
PADDLE_ENFORCE_GE(ids[i], 0); auto id_index = table_t.Index(ids[i]);
auto id_index = table_t.Index(ids[i]); PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); blas.VCOPY(row_width, table + id_index * row_width,
blas.VCOPY(row_width, table + id_index * row_width, output + i * row_width);
output + i * row_width); }
} }
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册