未验证 提交 981fc2bd 编写于 作者: T tangwei12 提交者: GitHub

fix bug in merge_ids (#15503)

* fix mistakes in merge_ids, test=develop
上级 a7ba07d7
...@@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -43,9 +43,9 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ(ids.size(), outs.size(), PADDLE_ENFORCE_EQ(ids.size(), outs.size(),
"the number of Ids and Out should be the same"); "the number of Ids and Out should be the same");
size_t row_ids_size = 0; int64_t row_ids_size = 0;
int row_size = 0; int64_t row_size = 0;
int embedding_size = 0; int64_t embedding_size = 0;
for (size_t i = 0; i < x_tensors.size(); ++i) { for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *x_tensor = x_tensors[i]; const auto *x_tensor = x_tensors[i];
...@@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -69,7 +69,7 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < x_tensors.size(); ++i) { for (size_t i = 0; i < x_tensors.size(); ++i) {
const auto *row_id = row_ids[i]; const auto *row_id = row_ids[i];
for (int j = 0; j < row_id->numel(); ++j) { for (auto j = 0; j < row_id->numel(); ++j) {
int64_t key = row_id->data<int64_t>()[j]; int64_t key = row_id->data<int64_t>()[j];
std::tuple<int64_t, int64_t> val = std::make_tuple(i, j); std::tuple<int64_t, int64_t> val = std::make_tuple(i, j);
selected_rows_idx_map.insert(std::make_pair(key, val)); selected_rows_idx_map.insert(std::make_pair(key, val));
...@@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> { ...@@ -84,13 +84,13 @@ class MergeIdsOpKernel : public framework::OpKernel<T> {
out->set_lod(out_ids->lod()); out->set_lod(out_ids->lod());
int nums = static_cast<int>(out_ids->dims()[0]); auto nums = out_ids->dims()[0];
auto *out_data = out->mutable_data<T>( auto *out_data = out->mutable_data<T>(
framework::make_ddim({nums, embedding_size}), place); framework::make_ddim({nums, embedding_size}), place);
for (int j = 0; j < nums; ++j) { for (auto j = 0; j < nums; ++j) {
int id = out_ids->data<int64_t>()[j]; auto id = out_ids->data<int64_t>()[j];
auto row_tuple = selected_rows_idx_map[id]; auto row_tuple = selected_rows_idx_map.at(id);
int64_t row_idx = std::get<1>(row_tuple); auto row_idx = std::get<1>(row_tuple);
const auto *x_tensor = x_tensors[std::get<0>(row_tuple)]; const auto *x_tensor = x_tensors[std::get<0>(row_tuple)];
memcpy(out_data + embedding_size * j, memcpy(out_data + embedding_size * j,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册