提交 8a412c0d 编写于 作者: M minqiyang

Complete impl

上级 17c8014f
......@@ -42,8 +42,14 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
// we only support sum now
PADDLE_ENFORCE_EQ(combiner, "sum");
int64_t last_dim = table_dims[1];
for (int i = 1; i != ids_dims.size(); ++i) {
last_dim *= ids_dims[i];
}
if (ctx->IsRuntime()) {
Variable* ids_var = boost::get<Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
framework::Variable* ids_var =
boost::get<framework::Variable*>(ctx->GetInputVarPtrs("Ids")[0]);
const auto& ids_lod = ids_var->Get<LoDTensor>().lod();
// in run time, the LoD of ids must be 1
......@@ -51,20 +57,20 @@ class FusedEmbeddingSeqPoolOp : public framework::OperatorWithKernel {
"The LoD level of Input(Ids) must be 1");
PADDLE_ENFORCE_GE(ids_lod[0].size(), 1u, "The LoD could NOT be empty");
size_t batch_size = ids_lod[0].size() - 1;
int64_t batch_size = ids_lod[0].size() - 1;
// in run time, the shape from Ids -> output
// should be [seq_length, 1] -> [batch_size, embedding_size]
ctx->SetOutputDim("Out",
framework::make_ddim({batch_size, table_dims[1]}));
ctx->SetOutputDim("Out", framework::make_ddim({batch_size, last_dim}));
} else {
// in compile time, the lod level of ids must be 1
VarDesc* ids_desc = boost::get<VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
framework::VarDesc* ids_desc =
boost::get<framework::VarDesc*>(ctx->GetInputVarPtrs("Ids")[0]);
PADDLE_ENFORCE_EQ(ids_desc->GetLoDLevel(), 1);
// in compile time, the shape from Ids -> output
// should be [-1, 1] -> [-1, embedding_size]
ctx->SetOutputDim("Out", framework::make_ddim({-1, table_dims[1]}));
ctx->SetOutputDim("Out", framework::make_ddim({-1, last_dim}));
}
}
......
......@@ -31,31 +31,38 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim;
template <typename DeviceContext, typename T>
template <typename T>
struct EmbeddingVSumFunctor {
void operator()(const DeviceContext &context, LoDTensor *table_t,
LoDTensor *ids_t, LoDTensor *output_t) {
void operator()(const framework::ExecutionContext &context,
const LoDTensor *table_t, const LoDTensor *ids_t,
LoDTensor *output_t) {
auto *table = table_t->data<T>();
int64_t row_number = table->dims()[0];
int64_t row_width = table->dims()[1];
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
int64_t last_dim = output_t->dims()[1];
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
auto ids_lod = ids_t->LoD()[0];
auto ids_lod = ids_t->lod()[0];
int64_t ids_count = ids_t->numel() / ids_lod.back();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
size_t begin = ids_lod[i];
for (int64_t j = 0; j != ids_count; ++j) {
size_t begin = ids_lod[i] * ids_count;
PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin] * row_width,
output + i * row_width);
PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin] * row_width,
output + i * last_dim + j * row_width);
}
for (int64_t r = ids_lod[i] + 1; r < ids_lod[i + 1]; ++r) {
for (int64_t r = (ids_lod[i] + 1) * ids_count;
r < ids_lod[i + 1] * ids_count; ++r) {
PADDLE_ENFORCE_LT(ids[r], row_number);
PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
blas.AXPY(row_width, 1., table + ids[r] * row_width,
output + i * row_width);
output + i * row_width + (r % ids_count) * row_width);
}
}
}
......@@ -65,14 +72,14 @@ template <typename T>
class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
LoDTensor *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
LoDTensor *output_t = context.Output<LoDTensor>("Out"); // float tensor
LoDTensor *table_var = context.Input<LoDTensor>("W");
const LoDTensor *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
LoDTensor *output_t = context.Output<LoDTensor>("Out"); // float tensor
const LoDTensor *table_var = context.Input<LoDTensor>("W");
const std::string &combiner_type = context.Attr<std::string>("combiner");
if (combiner_type == "sum") {
EmbeddingVSumFunctor<T> functor;
functor(context.template device_context(), ids_t, output_t, table_var);
functor(context, table_var, ids_t, output_t);
}
}
};
......@@ -105,7 +112,7 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel();
auto lod = ids->lod()[0];
int64_t row_width = table_dim[1];
int64_t row_width = d_output->dims()[1];
framework::Vector<int64_t> new_rows;
new_rows.resize(ids_num);
......@@ -113,11 +120,11 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, row_width});
d_table_value->Resize({ids_num, table_dim[1]});
T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
const T *d_output_data = d_output->data<T>();
auto blas = math::GetBlas<T>(context);
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
int64_t in_offset = lod[i] * row_width;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册