提交 32ebee9f 编写于 作者: M minqiyang

Polish code

上级 b0afdc4e
...@@ -48,9 +48,8 @@ struct EmbeddingVSumFunctor { ...@@ -48,9 +48,8 @@ struct EmbeddingVSumFunctor {
auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context); auto blas = math::GetBlas<platform::CPUDeviceContext, T>(context);
for (int64_t i = 0; i != ids_lod.size() - 1; ++i) { for (int64_t i = 0; i != ids_lod.size() - 1; ++i) {
size_t begin = ids_lod[i] * ids_count;
for (int64_t j = 0; j != ids_count; ++j) { for (int64_t j = 0; j != ids_count; ++j) {
size_t begin = ids_lod[i] * ids_count;
PADDLE_ENFORCE_LT(ids[begin], row_number); PADDLE_ENFORCE_LT(ids[begin], row_number);
PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i); PADDLE_ENFORCE_GE(ids[begin], 0, "ids %d", i);
blas.VCOPY(row_width, table + ids[begin + j] * row_width, blas.VCOPY(row_width, table + ids[begin + j] * row_width,
...@@ -114,10 +113,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> { ...@@ -114,10 +113,9 @@ class FusedEmbeddingSeqPoolGradKernel : public framework::OpKernel<T> {
auto lod = ids->lod()[0]; auto lod = ids->lod()[0];
int64_t row_width = d_output->dims()[1]; int64_t row_width = d_output->dims()[1];
framework::Vector<int64_t> new_rows; framework::Vector<int64_t> *new_rows = d_table->mutable_rows();
new_rows.resize(ids_num); new_rows->resize(ids_num);
std::memcpy(&new_rows[0], ids_data, ids_num * sizeof(int64_t)); std::memcpy(&(*new_rows)[0], ids_data, ids_num * sizeof(int64_t));
d_table->set_rows(new_rows);
auto *d_table_value = d_table->mutable_value(); auto *d_table_value = d_table->mutable_value();
d_table_value->Resize({ids_num, table_dim[1]}); d_table_value->Resize({ids_num, table_dim[1]});
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册