提交 47280ef8 编写于 作者: Q Qiao Longfei

lookup table op support prefetch

上级 4ad5fd8f
......@@ -37,7 +37,7 @@ if (WITH_GPU)
SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub)
endif()
register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS})
register_operators(EXCLUDES warpctc_op conv_fusion_op lookup_table_op DEPS ${OP_HEADER_DEPS})
# warpctc_op needs cudnn 7 above
if (WITH_GPU AND NOT WIN32)
......@@ -55,6 +55,8 @@ else()
op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale)
endif()
op_library(lookup_table_op DEPS parameter_prefetch)
set(COMMON_OP_DEPS ${OP_HEADER_DEPS})
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor)
......
......@@ -23,8 +23,12 @@ limitations under the License. */
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/operators/math/blas.h"
#ifdef PADDLE_WITH_DISTRIBUTE
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
#endif
namespace paddle {
namespace operators {
......@@ -43,44 +47,64 @@ class LookupTableKernel : public framework::OpKernel<T> {
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
auto id_name = context.Inputs("Ids").front();
auto out_name = context.Outputs("Out").front();
auto table_name = context.Inputs("W").front();
auto epmap = context.Attr<std::vector<std::string>>("epmap");
auto height_sections =
context.Attr<std::vector<int64_t>>("height_sections");
if (!epmap.empty()) {
// if emap is not empty, then the paramter will be fetched from remote parameter
// server
#ifdef PADDLE_WITH_DISTRIBUTE
operators::distributed::prefetch(id_name, out_name, table_name, epmap,
height_sections, context);
#else
PADDLE_THROW(
"paddle is not compiled with distribute support, can not do "
"parameter prefetch!");
#endif
} else {
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
int64_t ids_numel = ids_t->numel();
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册