提交 c6756ed2 编写于 作者: Y yaoxuefeng 提交者: Jiawei Wang

fix instag op (#19591)

* fix instag op

* fix instag bug: Some tiny logical error, occurring when ins_tag (2nd input) is multiple. test=develop
上级 6c2bc29c
...@@ -73,14 +73,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -73,14 +73,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
} else { } else {
x1_lods = context.Input<LoDTensor>("Ins")->lod()[0]; x1_lods = context.Input<LoDTensor>("Ins")->lod()[0];
} }
std::unordered_map<int64_t, int64_t> mmap_aux; std::unordered_map<int64_t, int64_t> mmap_aux;
std::vector<size_t> ins_after_filter;
Vector<size_t> out_lods(1, 0); Vector<size_t> out_lods(1, 0);
for (size_t i = 0; i < x2_lods.size() - 1; i++) { for (size_t i = 0; i < x2_lods.size() - 1; i++) {
for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) { for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) {
if (filter_tag.find(x2_data[j]) != filter_tag.end()) { if (filter_tag.find(x2_data[j]) != filter_tag.end()) {
ins_after_filter.push_back(x2_lods[i]);
size_t batch_len = x1_lods[i + 1] - x1_lods[i]; size_t batch_len = x1_lods[i + 1] - x1_lods[i];
mmap_aux[out_lods.back()] = x1_lods[i]; mmap_aux[out_lods.back()] = x1_lods[i];
out_lods.push_back(out_lods.back() + batch_len); out_lods.push_back(out_lods.back() + batch_len);
...@@ -88,7 +85,6 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -88,7 +85,6 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
} }
} }
} }
// set output value // set output value
// for those whose ins been dropout, set 0 for whole lines. // for those whose ins been dropout, set 0 for whole lines.
// otherwise, copy whole line // otherwise, copy whole line
...@@ -100,12 +96,12 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -100,12 +96,12 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
auto* x1_data = x1->data<T>(); auto* x1_data = x1->data<T>();
// expected auto = T // expected auto = T
size_t x1_embed_size = x1->dims()[1]; size_t x1_embed_size = x1->dims()[1];
if (ins_after_filter.size() > 0) { if (out_lods.size() - 1 > 0) {
out->Resize(framework::make_ddim( out->Resize(framework::make_ddim(
{(int64_t)out_lods.back(), (int64_t)x1_embed_size})); {(int64_t)out_lods.back(), (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({(int64_t)ins_after_filter.size(), 3})); map->Resize(framework::make_ddim({(int64_t)out_lods.size() - 1, 3}));
loss_weight->Resize( loss_weight->Resize(
framework::make_ddim({(int64_t)ins_after_filter.size(), 1})); framework::make_ddim({(int64_t)out_lods.size() - 1, 1}));
} else { } else {
out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size})); out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({1, 3})); map->Resize(framework::make_ddim({1, 3}));
...@@ -115,15 +111,15 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -115,15 +111,15 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
auto* map_data = map->mutable_data<int64_t>(context.GetPlace()); auto* map_data = map->mutable_data<int64_t>(context.GetPlace());
auto* loss_weight_data = auto* loss_weight_data =
loss_weight->mutable_data<float>(context.GetPlace()); loss_weight->mutable_data<float>(context.GetPlace());
if (ins_after_filter.size() > 0) { if (out_lods.size() - 1 > 0) {
Vector<size_t> map_lods; Vector<size_t> map_lods;
for (size_t i = 0; i < ins_after_filter.size(); i++) { for (size_t i = 0; i < out_lods.size() - 1; i++) {
map_data[i * 3] = (int64_t)out_lods[i]; map_data[i * 3] = (int64_t)out_lods[i];
map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]]; map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]];
map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i]; map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i];
map_lods.push_back(i); map_lods.push_back(i);
} }
map_lods.push_back(ins_after_filter.size()); map_lods.push_back(out_lods.size() - 1);
std::vector<Vector<size_t>> map_lod_info; std::vector<Vector<size_t>> map_lod_info;
map_lod_info.push_back(map_lods); map_lod_info.push_back(map_lods);
...@@ -136,10 +132,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -136,10 +132,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < loss_weight->numel(); i++) { for (size_t i = 0; i < loss_weight->numel(); i++) {
loss_weight_data[i] = 1; loss_weight_data[i] = 1;
} }
for (size_t i = 0; i < ins_after_filter.size(); i++) {
for (size_t i = 0; i < out_lods.size() - 1; i++) {
size_t pos = out_lods[i]; size_t pos = out_lods[i];
for (size_t k = x1_lods[ins_after_filter[i]]; for (size_t k = map_data[i * 3 + 1];
k < x1_lods[ins_after_filter[i] + 1]; k++) { k < map_data[i * 3 + 1] + map_data[i * 3 + 2]; k++) {
memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size, memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size,
x1_embed_size * sizeof(T)); x1_embed_size * sizeof(T));
++pos; ++pos;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册