diff --git a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h index 758432fd9e4197302e0bd8f76a1ca7c524026a70..744e83541d3d5bc02ee5570460ac960d40408e7b 100644 --- a/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h +++ b/paddle/fluid/operators/fused/fused_embedding_seq_pool_op.h @@ -31,38 +31,54 @@ using LoDTensor = framework::LoDTensor; using SelectedRows = framework::SelectedRows; using DDim = framework::DDim; +template +void emb_seqpool(const framework::ExecutionContext &context, const T *table, + const int64_t *idx, T *out, int64_t table_height, + int64_t table_width, int64_t idx_height, int64_t idx_width, + int64_t out_width) { // pool type == sum + PADDLE_ENFORCE_EQ(table_width * idx_width, out_width); + + auto check_idx_value_valid = [&](int i) { + PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i); + PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i); + }; + auto blas = math::GetBlas(context); + + for (int w = 0; w != idx_width; ++w) { + check_idx_value_valid(w); + blas.VCOPY(table_width, table + idx[w] * table_width, + out + w * table_width); + } + + for (int h = 1; h < idx_height; ++h) { + for (int w = 0; w < idx_width; ++w) { + int i = h * idx_width + w; + check_idx_value_valid(i); + blas.AXPY(table_width, static_cast(1), table + idx[i] * table_width, + out + w * table_width); + } + } +} + template struct EmbeddingVSumFunctor { void operator()(const framework::ExecutionContext &context, const LoDTensor *table_t, const LoDTensor *ids_t, LoDTensor *output_t) { auto *table = table_t->data(); - int64_t row_number = table_t->dims()[0]; - int64_t row_width = table_t->dims()[1]; - int64_t last_dim = output_t->dims()[1]; + int64_t table_height = table_t->dims()[0]; + int64_t table_width = table_t->dims()[1]; + int64_t out_width = output_t->dims()[1]; const int64_t *ids = ids_t->data(); auto ids_lod = ids_t->lod()[0]; - int64_t ids_count = ids_t->numel() / ids_lod.back(); - + int64_t idx_width = ids_t->numel() / ids_lod.back(); auto *output = output_t->mutable_data(context.GetPlace()); - auto blas = math::GetBlas(context); + PADDLE_ENFORCE_LE(table_width * idx_width, out_width); for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { - size_t begin = ids_lod[i] * ids_count; - for (int64_t j = 0; j != ids_count; ++j) { - PADDLE_ENFORCE_LT(ids[begin], row_number); - PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); - blas.VCOPY(row_width, table + ids[begin + j] * row_width, - output + i * last_dim + j * row_width); - } - - for (int64_t r = (ids_lod[i] + 1) * ids_count; - r < ids_lod[i + 1] * ids_count; ++r) { - PADDLE_ENFORCE_LT(ids[r], row_number); - PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i); - blas.AXPY(row_width, 1., table + ids[r] * row_width, - output + i * last_dim + (r % ids_count) * row_width); - } + emb_seqpool(context, table, ids + ids_lod[i] * idx_width, + output + i * out_width, table_height, table_width, + ids_lod[i + 1] - ids_lod[i], idx_width, out_width); } } };