diff --git a/paddle/fluid/operators/filter_by_instag_op.h b/paddle/fluid/operators/filter_by_instag_op.h index 41bbbeac11e7ef81633e3d4f5a08ff59448eff66..f082d0dfc1273cba4ef9b400022c2ba15f164cec 100644 --- a/paddle/fluid/operators/filter_by_instag_op.h +++ b/paddle/fluid/operators/filter_by_instag_op.h @@ -73,14 +73,11 @@ class FilterByInstagKernel : public framework::OpKernel { } else { x1_lods = context.Input("Ins")->lod()[0]; } - std::unordered_map mmap_aux; - std::vector ins_after_filter; Vector out_lods(1, 0); for (size_t i = 0; i < x2_lods.size() - 1; i++) { for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) { if (filter_tag.find(x2_data[j]) != filter_tag.end()) { - ins_after_filter.push_back(x2_lods[i]); size_t batch_len = x1_lods[i + 1] - x1_lods[i]; mmap_aux[out_lods.back()] = x1_lods[i]; out_lods.push_back(out_lods.back() + batch_len); @@ -88,7 +85,6 @@ class FilterByInstagKernel : public framework::OpKernel { } } } - // set output value // for those whose ins been dropout, set 0 for whole lines. // otherwise, copy whole line @@ -100,12 +96,12 @@ class FilterByInstagKernel : public framework::OpKernel { auto* x1_data = x1->data(); // expected auto = T size_t x1_embed_size = x1->dims()[1]; - if (ins_after_filter.size() > 0) { + if (out_lods.size() - 1 > 0) { out->Resize(framework::make_ddim( {(int64_t)out_lods.back(), (int64_t)x1_embed_size})); - map->Resize(framework::make_ddim({(int64_t)ins_after_filter.size(), 3})); + map->Resize(framework::make_ddim({(int64_t)out_lods.size() - 1, 3})); loss_weight->Resize( - framework::make_ddim({(int64_t)ins_after_filter.size(), 1})); + framework::make_ddim({(int64_t)out_lods.size() - 1, 1})); } else { out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size})); map->Resize(framework::make_ddim({1, 3})); @@ -115,15 +111,15 @@ class FilterByInstagKernel : public framework::OpKernel { auto* map_data = map->mutable_data(context.GetPlace()); auto* loss_weight_data = loss_weight->mutable_data(context.GetPlace()); - if (ins_after_filter.size() > 0) { + if (out_lods.size() - 1 > 0) { Vector map_lods; - for (size_t i = 0; i < ins_after_filter.size(); i++) { + for (size_t i = 0; i < out_lods.size() - 1; i++) { map_data[i * 3] = (int64_t)out_lods[i]; map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]]; map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i]; map_lods.push_back(i); } - map_lods.push_back(ins_after_filter.size()); + map_lods.push_back(out_lods.size() - 1); std::vector> map_lod_info; map_lod_info.push_back(map_lods); @@ -136,10 +132,11 @@ class FilterByInstagKernel : public framework::OpKernel { for (size_t i = 0; i < loss_weight->numel(); i++) { loss_weight_data[i] = 1; } - for (size_t i = 0; i < ins_after_filter.size(); i++) { + + for (size_t i = 0; i < out_lods.size() - 1; i++) { size_t pos = out_lods[i]; - for (size_t k = x1_lods[ins_after_filter[i]]; - k < x1_lods[ins_after_filter[i] + 1]; k++) { + for (size_t k = map_data[i * 3 + 1]; + k < map_data[i * 3 + 1] + map_data[i * 3 + 2]; k++) { memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size, x1_embed_size * sizeof(T)); ++pos;