未验证 提交 b48545ee 编写于 作者: D danleifeng 提交者: GitHub

fix filter_by_instag op for lod_level=0 without lod;test=develop (#37834)

上级 b154110a
......@@ -65,19 +65,26 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
// expected auto = const int64_t
auto* x2_data = x2->data<int64_t>();
// e.g get [0, 1, 2, 3, ...]
auto x2_lods = x2->lod()[0];
size_t x2_lods_size = x2->dims()[0];
Vector<size_t> x1_lods(1, 0);
if (!is_x1_lod) {
for (int i = 0; i < x1->dims()[0]; i++) {
x1_lods.push_back(i + 1);
}
} else {
x1_lods = context.Input<LoDTensor>("Ins")->lod()[0];
// new: lod_level=0 => lod() return {}
if (x1->lod().size() != 0) {
x1_lods = x1->lod()[0];
} else {
for (int i = 0; i < x1->dims()[0]; i++) {
x1_lods.push_back(i + 1);
}
}
}
std::unordered_map<int64_t, int64_t> mmap_aux;
Vector<size_t> out_lods(1, 0);
for (size_t i = 0; i < x2_lods.size() - 1; i++) {
for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) {
for (size_t i = 0; i < x2_lods_size; i++) {
for (size_t j = i; j < i + 1; j++) {
if (filter_tag.find(x2_data[j]) != filter_tag.end()) {
size_t batch_len = x1_lods[i + 1] - x1_lods[i];
mmap_aux[out_lods.back()] = x1_lods[i];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册