未验证 提交 548f2be4 编写于 作者: Z zmxdream 提交者: GitHub

[GPUPS]instag cuda kernel (#40377)

* update. test=develop

* fix. test=develop

* fix. test=develop
上级 5ae85131
此差异已折叠。
...@@ -61,7 +61,20 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -61,7 +61,20 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
// expected auto = const int64_t // expected auto = const int64_t
auto* x2_data = x2->data<int64_t>(); auto* x2_data = x2->data<int64_t>();
// e.g get [0, 1, 2, 3, ...] // e.g get [0, 1, 2, 3, ...]
size_t x2_lods_size = x2->dims()[0]; // size_t x2_lods_size = x2->dims()[0];
// size_t instag_num_per_ins = x2->dims()[1];
Vector<size_t> x2_lods(1, 0);
if (x2->lod().size() != 0) { // lod_level = 1
x2_lods = x2->lod()[0];
} else { // lod_level = 0
const size_t x2_lods_size = x2->dims()[0];
const size_t instag_num_per_ins = x2->dims()[1];
for (size_t i = 0; i < x2_lods_size; i++) {
x2_lods.push_back(x2_lods.back() + instag_num_per_ins);
}
}
Vector<size_t> x1_lods(1, 0); Vector<size_t> x1_lods(1, 0);
if (!is_x1_lod) { if (!is_x1_lod) {
for (int i = 0; i < x1->dims()[0]; i++) { for (int i = 0; i < x1->dims()[0]; i++) {
...@@ -79,8 +92,8 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -79,8 +92,8 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
} }
std::unordered_map<int64_t, int64_t> mmap_aux; std::unordered_map<int64_t, int64_t> mmap_aux;
Vector<size_t> out_lods(1, 0); Vector<size_t> out_lods(1, 0);
for (size_t i = 0; i < x2_lods_size; i++) { for (size_t i = 0; i < x2_lods.size() - 1; i++) {
for (size_t j = i; j < i + 1; j++) { for (size_t j = x2_lods[i]; j < x2_lods[i + 1]; j++) {
if (filter_tag.find(x2_data[j]) != filter_tag.end()) { if (filter_tag.find(x2_data[j]) != filter_tag.end()) {
size_t batch_len = x1_lods[i + 1] - x1_lods[i]; size_t batch_len = x1_lods[i + 1] - x1_lods[i];
mmap_aux[out_lods.back()] = x1_lods[i]; mmap_aux[out_lods.back()] = x1_lods[i];
...@@ -165,8 +178,10 @@ class FilterByInstagKernel : public framework::OpKernel<T> { ...@@ -165,8 +178,10 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
out_data[oi] = (int32_t)out_val_if_empty; out_data[oi] = (int32_t)out_val_if_empty;
} else if (std::is_same<T, int64_t>::value) { } else if (std::is_same<T, int64_t>::value) {
out_data[oi] = (int64_t)out_val_if_empty; out_data[oi] = (int64_t)out_val_if_empty;
} else { } else if (std::is_same<T, double>::value) {
out_data[oi] = static_cast<double>(out_val_if_empty); out_data[oi] = static_cast<double>(out_val_if_empty);
} else {
out_data[oi] = static_cast<float>(out_val_if_empty);
} }
} }
loss_weight_data[0] = 0; loss_weight_data[0] = 0;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册