提交 17c8014f 编写于 作者: M minqiyang

Complete implementation

test=develop
上级 7939f835
...@@ -93,6 +93,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -93,6 +93,12 @@ class FusedEmbeddingSeqPoolOpMaker : public framework::OpProtoAndCheckerMaker {
"are supported, sum computes the weighted sum of the " "are supported, sum computes the weighted sum of the "
"embedding results for each row.") "embedding results for each row.")
.SetDefault("sum"); .SetDefault("sum");
// NOTE(minqiyang): grad_inplace is an temporal attribute,
// please do NOT set this attribute in python layer.
AddAttr<bool>("grad_inplace",
"(boolean, default false) "
"If the grad op reuse the input's variable.")
.SetDefault(false);
AddAttr<bool>("is_sparse", AddAttr<bool>("is_sparse",
"(boolean, default false) " "(boolean, default false) "
"Sparse update.") "Sparse update.")
......
...@@ -31,62 +31,54 @@ using LoDTensor = framework::LoDTensor; ...@@ -31,62 +31,54 @@ using LoDTensor = framework::LoDTensor;
using SelectedRows = framework::SelectedRows; using SelectedRows = framework::SelectedRows;
using DDim = framework::DDim; using DDim = framework::DDim;
template <typename DeviceContext, typename T>
struct EmbeddingVSumFunctor {
void operator()(const DeviceContext &context, LoDTensor *table_t,
LoDTensor *ids_t, LoDTensor *output_t) {
auto *table = table_t->data<T>();
int64_t row_number = table->dims()[0];
int64_t row_width = table->dims()[1];
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>());
auto ids_lod = ids_t->LoD()[0];
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<DeviceContext, T>(context);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
size_t begin = ids_lod[i];
PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin] * row_width,
output + i * row_width);
for (int64_t r = ids_lod[i] + 1; r < ids_lod[i + 1]; ++r) {
PADDLE_ENFORCE_LT(ids[r], row_number);
PADDLE_ENFORCE_GE(ids[r], 0, "ids %d", i);
blas.AXPY(row_width, 1., table + ids[r] * row_width,
output + i * row_width);
}
}
}
};
template <typename T> template <typename T>
class LookupTableKernel : public framework::OpKernel<T> { class FusedEmbeddingSeqPoolKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *ids_t = context.Input<LoDTensor>("Ids"); // int tensor LoDTensor *ids_t = context.Input<LoDTensor>("Ids"); // int tensor
auto *output_t = context.Output<LoDTensor>("Out"); // float tensor LoDTensor *output_t = context.Output<LoDTensor>("Out"); // float tensor
auto *table_var = context.InputVar("W"); LoDTensor *table_var = context.Input<LoDTensor>("W");
const std::string &combiner_type = context.Attr<std::string>("combiner");
int64_t padding_idx = context.Attr<int64_t>("padding_idx");
int64_t *ids = const_cast<int64_t *>(ids_t->data<int64_t>()); if (combiner_type == "sum") {
int64_t ids_numel = ids_t->numel(); EmbeddingVSumFunctor<T> functor;
functor(context.template device_context(), ids_t, output_t, table_var);
if (table_var->IsType<LoDTensor>()) {
auto *table_t = context.Input<LoDTensor>("W");
int64_t row_number = table_t->dims()[0];
int64_t row_width = table_t->dims()[1];
auto *table = table_t->data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_LT(ids[i], row_number);
PADDLE_ENFORCE_GE(ids[i], 0, "ids %d", i);
memcpy(output + i * row_width, table + ids[i] * row_width,
row_width * sizeof(T));
}
}
} else if (table_var->IsType<SelectedRows>()) {
const auto &table_t = table_var->Get<SelectedRows>();
int64_t row_width = table_t.value().dims()[1];
const auto *table = table_t.value().data<T>();
auto *output = output_t->mutable_data<T>(context.GetPlace());
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i < ids_numel; ++i) {
if (padding_idx != kNoPadding && ids[i] == padding_idx) {
memset(output + i * row_width, 0, row_width * sizeof(T));
} else {
PADDLE_ENFORCE_GE(ids[i], 0);
auto id_index = table_t.Index(ids[i]);
PADDLE_ENFORCE_GE(id_index, 0, "the input key should be exists.");
// memcpy(output + i * row_width, table + id_index * row_width,
// row_width * sizeof(T));
blas.VCOPY(row_width, table + id_index * row_width,
output + i * row_width);
}
}
} }
} }
}; };
template <typename T> template <typename T>
class LookupTableGradKernel : public framework::OpKernel<T> { class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *table_var = context.InputVar("W"); auto *table_var = context.InputVar("W");
...@@ -106,97 +98,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> { ...@@ -106,97 +98,37 @@ class LookupTableGradKernel : public framework::OpKernel<T> {
// Since paddings are not trainable and fixed in forward, the gradient of // Since paddings are not trainable and fixed in forward, the gradient of
// paddings makes no sense and we don't deal with it in backward. // paddings makes no sense and we don't deal with it in backward.
if (is_sparse) { if (is_sparse) {
// auto start = std::chrono::system_clock::now();
auto *ids = context.Input<LoDTensor>("Ids"); auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out")); auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W")); auto *d_table = context.Output<SelectedRows>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>(); auto *ids_data = ids->data<int64_t>();
int64_t ids_num = ids->numel(); int64_t ids_num = ids->numel();
// auto end = std::chrono::system_clock::now(); auto lod = ids->lod()[0];
// std::chrono::duration<double> diff = end - start; int64_t row_width = table_dim[1];
// auto copy_start = std::chrono::system_clock::now(); framework::Vector<int64_t> new_rows;
std::vector<int64_t> new_rows;
new_rows.resize(ids_num); new_rows.resize(ids_num);
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t));
// for (int64_t i = 0; i < ids_num; i++) {
// new_rows.push_back(ids_data[i]);
// }
// auto copy_end = std::chrono::system_clock::now();
// std::chrono::duration<double> copy_diff = copy_end - copy_start;
// diff += copy_diff;
// LOG(ERROR) << "run emb_grad copy end, cost: " << copy_diff.count() << "
// " << ids_num;
// copy_start = std::chrono::system_clock::now();
d_table->set_rows(new_rows); d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, row_width});
d_table_value->ShareDataWith(*d_output); T *d_table_data = d_table_value->mutable_data<T>(context.GetPlace());
// d_table_value->mutable_data<T>(context.GetPlace()); const T *d_output_data = d_output->data<T>();
// // copy_end = std::chrono::system_clock::now(); auto blas = math::GetBlas<T>(context);
// // copy_diff = copy_end - copy_start; for (int i = 0; i < static_cast<int>(lod.size()) - 1; ++i) {
// // diff += copy_diff; int64_t h = static_cast<int64_t>(lod[i + 1] - lod[i]);
// // LOG(ERROR) << "run emb_grad resize table end, cost: " << int64_t in_offset = lod[i] * row_width;
// // copy_diff.count() << " " << ids_num; const T *out_pos = d_output_data + i * row_width;
T *in_pos = d_table_data + in_offset;
// // copy_start = std::chrono::system_clock::now(); for (int r = 0; r != h; ++r) {
// d_table->set_height(table_dim[0]); blas.VCOPY(row_width, out_pos, in_pos + r * row_width);
// auto *d_output_data = d_output->data<T>();
// auto *d_table_data = d_table_value->data<T>();
// // copy_end = std::chrono::system_clock::now();
// // copy_diff = copy_end - copy_start;
// // diff += copy_diff;
// // LOG(ERROR) << "run emb_grad set height end, cost: " <<
// // copy_diff.count() << " " << ids_num;
// auto d_output_dims = d_output->dims();
// PADDLE_ENFORCE_EQ(
// d_table_value->dims(),
// framework::flatten_to_2d(d_output_dims, d_output_dims.size() - 1));
// // copy_start = std::chrono::system_clock::now();
// auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
// blas.VCOPY(d_output->numel(), d_output_data, d_table_data);
// cblas_scopy(d_output->numel(), d_output_data, 1, d_table_data, 1);
// // for (int i = 0; i != d_output->numel(), ++i) {
// // *(d_table_data++) = *(d_output_data++);
// // }
// // memcpy(d_table_data, d_output_data, sizeof(T) * d_output->numel());
// // copy_end = std::chrono::system_clock::now();
// // copy_diff = copy_end - copy_start;
// // diff += copy_diff;
// // LOG(ERROR) << "run emb_grad core end, cost: " << copy_diff.count()
// << "
// // " << ids_num << " " << d_output->numel();
// // LOG(ERROR) << "run emb_grad end, cost: " << diff.count();
} else {
auto *ids = context.Input<LoDTensor>("Ids");
auto *d_output = context.Input<LoDTensor>(framework::GradVarName("Out"));
auto *d_table = context.Output<LoDTensor>(framework::GradVarName("W"));
auto *ids_data = ids->data<int64_t>();
int N = table_dim[0];
int D = table_dim[1];
auto *d_output_data = d_output->data<T>();
auto *d_table_data = d_table->mutable_data<T>(context.GetPlace());
memset(d_table_data, 0, d_table->numel() * sizeof(T));
for (int64_t i = 0; i < ids->numel(); ++i) {
PADDLE_ENFORCE_LT(ids_data[i], N);
PADDLE_ENFORCE_GE(ids_data[i], 0);
for (int j = 0; j < D; ++j) {
d_table_data[ids_data[i] * D + j] += d_output_data[i * D + j];
} }
} }
} else {
LOG(ERROR) << "Dense is not supported in fused_embedding_seq_pool_op now";
} }
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册