提交 c6756ed2 编写于 作者: Y yaoxuefeng 提交者: Jiawei Wang

fix instag op (#19591)

* fix instag op

* fix instag bug: Some tiny logical error, occurring when ins_tag (2nd input) is multiple. test=develop
上级 6c2bc29c
......@@ -73,14 +73,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
} else {
x1_lods = context.Input<LoDTensor>("Ins")->lod()[0];
}
std::unordered_map<int64_t, int64_t> mmap_aux;
std::vector<size_t> ins_after_filter;
Vector<size_t> out_lods(1, 0);
for (size_t i = 0; i < x2_lods.size() - 1; i++) {
for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) {
if (filter_tag.find(x2_data[j]) != filter_tag.end()) {
ins_after_filter.push_back(x2_lods[i]);
size_t batch_len = x1_lods[i + 1] - x1_lods[i];
mmap_aux[out_lods.back()] = x1_lods[i];
out_lods.push_back(out_lods.back() + batch_len);
......@@ -88,7 +85,6 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
}
}
}
// set output value
// for those whose ins been dropout, set 0 for whole lines.
// otherwise, copy whole line
......@@ -100,12 +96,12 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
auto* x1_data = x1->data<T>();
// expected auto = T
size_t x1_embed_size = x1->dims()[1];
if (ins_after_filter.size() > 0) {
if (out_lods.size() - 1 > 0) {
out->Resize(framework::make_ddim(
{(int64_t)out_lods.back(), (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({(int64_t)ins_after_filter.size(), 3}));
map->Resize(framework::make_ddim({(int64_t)out_lods.size() - 1, 3}));
loss_weight->Resize(
framework::make_ddim({(int64_t)ins_after_filter.size(), 1}));
framework::make_ddim({(int64_t)out_lods.size() - 1, 1}));
} else {
out->Resize(framework::make_ddim({1, (int64_t)x1_embed_size}));
map->Resize(framework::make_ddim({1, 3}));
......@@ -115,15 +111,15 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
auto* map_data = map->mutable_data<int64_t>(context.GetPlace());
auto* loss_weight_data =
loss_weight->mutable_data<float>(context.GetPlace());
if (ins_after_filter.size() > 0) {
if (out_lods.size() - 1 > 0) {
Vector<size_t> map_lods;
for (size_t i = 0; i < ins_after_filter.size(); i++) {
for (size_t i = 0; i < out_lods.size() - 1; i++) {
map_data[i * 3] = (int64_t)out_lods[i];
map_data[i * 3 + 1] = mmap_aux[map_data[i * 3]];
map_data[i * 3 + 2] = out_lods[i + 1] - out_lods[i];
map_lods.push_back(i);
}
map_lods.push_back(ins_after_filter.size());
map_lods.push_back(out_lods.size() - 1);
std::vector<Vector<size_t>> map_lod_info;
map_lod_info.push_back(map_lods);
......@@ -136,10 +132,11 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
for (size_t i = 0; i < loss_weight->numel(); i++) {
loss_weight_data[i] = 1;
}
for (size_t i = 0; i < ins_after_filter.size(); i++) {
for (size_t i = 0; i < out_lods.size() - 1; i++) {
size_t pos = out_lods[i];
for (size_t k = x1_lods[ins_after_filter[i]];
k < x1_lods[ins_after_filter[i] + 1]; k++) {
for (size_t k = map_data[i * 3 + 1];
k < map_data[i * 3 + 1] + map_data[i * 3 + 2]; k++) {
memcpy(out_data + pos * x1_embed_size, x1_data + k * x1_embed_size,
x1_embed_size * sizeof(T));
++pos;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册