diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index de4f23515d8591f28b80ad00322365f8cdce768b..a824fec1e4d2a02c5d7a2d5d375cc0d2e6992b80 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -37,7 +37,7 @@ if (WITH_GPU) SET(OP_HEADER_DEPS ${OP_HEADER_DEPS} cub) endif() -register_operators(EXCLUDES warpctc_op conv_fusion_op DEPS ${OP_HEADER_DEPS}) +register_operators(EXCLUDES warpctc_op conv_fusion_op lookup_table_op DEPS ${OP_HEADER_DEPS}) # warpctc_op needs cudnn 7 above if (WITH_GPU AND NOT WIN32) @@ -55,6 +55,8 @@ else() op_library(warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale) endif() +op_library(lookup_table_op DEPS parameter_prefetch) + set(COMMON_OP_DEPS ${OP_HEADER_DEPS}) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} selected_rows_functor selected_rows lod_tensor maxouting unpooling pooling lod_rank_table context_project sequence_pooling executor) diff --git a/paddle/fluid/operators/lookup_table_op.h b/paddle/fluid/operators/lookup_table_op.h index 69cae78b7059a7c34fd07bb34883ee34a2010297..335e4adafa89ccfd3fc2d5fd9adf57524a4cf4e4 100644 --- a/paddle/fluid/operators/lookup_table_op.h +++ b/paddle/fluid/operators/lookup_table_op.h @@ -23,8 +23,12 @@ limitations under the License. */ #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/math/blas.h" +#ifdef PADDLE_WITH_DISTRIBUTE + #include "paddle/fluid/operators/distributed/parameter_prefetch.h" +#endif + namespace paddle { namespace operators { @@ -43,44 +47,64 @@ class LookupTableKernel : public framework::OpKernel { auto *output_t = context.Output("Out"); // float tensor auto *table_var = context.InputVar("W"); - int64_t padding_idx = context.Attr("padding_idx"); - int64_t *ids = const_cast(ids_t->data()); - int64_t ids_numel = ids_t->numel(); - - if (table_var->IsType()) { - auto *table_t = context.Input("W"); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; - - auto *table = table_t->data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_LT(ids[i], row_number); - PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); - memcpy(output + i * row_width, table + ids[i] * row_width, - row_width * sizeof(T)); + auto id_name = context.Inputs("Ids").front(); + auto out_name = context.Outputs("Out").front(); + auto table_name = context.Inputs("W").front(); + auto epmap = context.Attr>("epmap"); + auto height_sections = + context.Attr>("height_sections"); + + if (!epmap.empty()) { +// if emap is not empty, then the paramter will be fetched from remote parameter +// server +#ifdef PADDLE_WITH_DISTRIBUTE + operators::distributed::prefetch(id_name, out_name, table_name, epmap, + height_sections, context); +#else + PADDLE_THROW( + "paddle is not compiled with distribute support, can not do " + "parameter prefetch!"); +#endif + } else { + int64_t padding_idx = context.Attr("padding_idx"); + int64_t *ids = const_cast(ids_t->data()); + int64_t ids_numel = ids_t->numel(); + + if (table_var->IsType()) { + auto *table_t = context.Input("W"); + int64_t row_number = table_t->dims()[0]; + int64_t row_width = table_t->dims()[1]; + + auto *table = table_t->data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_LT(ids[i], row_number); + PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i); + memcpy(output + i * row_width, table + ids[i] * row_width, + row_width * sizeof(T)); + } } - } - } else if (table_var->IsType()) { - const auto &table_t = table_var->Get(); - int64_t row_width = table_t.value().dims()[1]; - const auto *table = table_t.value().data(); - auto *output = output_t->mutable_data(context.GetPlace()); - - auto blas = math::GetBlas(context); - for (int64_t i = 0; i < ids_numel; ++i) { - if (padding_idx != kNoPadding && ids[i] == padding_idx) { - memset(output + i * row_width, 0, row_width * sizeof(T)); - } else { - PADDLE_ENFORCE_GE(ids[i], 0); - auto id_index = table_t.Index(ids[i]); - PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); - blas.VCOPY(row_width, table + id_index * row_width, - output + i * row_width); + } else if (table_var->IsType()) { + const auto &table_t = table_var->Get(); + int64_t row_width = table_t.value().dims()[1]; + const auto *table = table_t.value().data(); + auto *output = output_t->mutable_data(context.GetPlace()); + + auto blas = math::GetBlas(context); + for (int64_t i = 0; i < ids_numel; ++i) { + if (padding_idx != kNoPadding && ids[i] == padding_idx) { + memset(output + i * row_width, 0, row_width * sizeof(T)); + } else { + PADDLE_ENFORCE_GE(ids[i], 0); + auto id_index = table_t.Index(ids[i]); + PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists."); + blas.VCOPY(row_width, table + id_index * row_width, + output + i * row_width); + } } } }