提交 18bff529 编写于 作者: T tensor-tang

extract fused_emb_seq_pool forward function

test=develop
上级 c1e18b13
......@@ -31,38 +31,54 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename T>
void emb_seqpool(const framework::ExecutionContext &context, const T *table,
const int64_t *idx, T *out, int64_t table_height,
int64_t table_width, int64_t idx_height, int64_t idx_width,
int64_t out_width) { // pool type == sum
PADDLE_ENFORCE_EQ(table_width * idx_width, out_width);
auto check_idx_value_valid = [&](int i) {
PADDLE_ENFORCE_LT(idx[i], table_height, "idx value: %d, i: %d", idx[i], i);
PADDLE_ENFORCE_GE(idx[i], 0, "idx value: %d, i: %d", idx[i], i);
};
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int w = 0; w != idx_width; ++w) {
check_idx_value_valid(w);
blas.VCOPY(table_width, table + idx[w] * table_width,
out + w * table_width);
}
for (int h = 1; h < idx_height; ++h) {
for (int w = 0; w < idx_width; ++w) {
int i = h * idx_width + w;
check_idx_value_valid(i);
blas.AXPY(table_width, static_cast<T>(1), table + idx[i] * table_width,
out + w * table_width);
}
}
}
template <typename T>
struct EmbeddingVSumFunctor {
void operator()(const framework::ExecutionContext &context,
const LoDTensor *table_t, const LoDTensor *ids_t,
LoDTensor *output_t) {
auto *table = table_t->data<T>();
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int64_t last_dim = output_t->dims()[1];
int64_t table_height = table_t->dims()[0];
int64_t table_width = table_t->dims()[1];
int64_t out_width = output_t->dims()[1];
const int64_t *ids = ids_t->data<int64_t>();
auto ids_lod = ids_t->lod()[0];
int64_t ids_count = ids_t->numel() / ids_lod.back();
int64_t idx_width = ids_t->numel() / ids_lod.back();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
PADDLE_ENFORCE_LE(table_width * idx_width, out_width);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
size_t begin = ids_lod[i] * ids_count;
for (int64_t j = 0; j != ids_count; ++j) {
PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin + j] * row_width,
output + i * last_dim + j * row_width);
}
for (int64_t r = (ids_lod[i] + 1) * ids_count;
r < ids_lod[i + 1] * ids_count; ++r) {
PADDLE_ENFORCE_LT(ids[r], row_number);
PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
blas.AXPY(row_width, 1., table + ids[r] * row_width,
output + i * last_dim + (r % ids_count) * row_width);
}
emb_seqpool(context, table, ids + ids_lod[i] * idx_width,
output + i * out_width, table_height, table_width,
ids_lod[i + 1] - ids_lod[i], idx_width, out_width);
}
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册