未验证 提交 e3b28d5b 编写于 作者: Y yaoxuefeng 提交者: GitHub

Fix instag (#22632) (#22991)

上级 4bfe5fa9
......@@ -60,6 +60,9 @@ class FilterByInstagOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Ins_tag", "(LoDTensor) ins tag list");
AddInput("Filter_tag", "(1D Tensor) filter tag list");
AddAttr<bool>("is_lod", "is Ins with LoD info or not, default True");
AddAttr<int64_t>("out_val_if_empty",
"if the output after filter is empty, the output value")
.SetDefault(0);
AddOutput("Out", "(LoDTensor) embeded tensor filtered by instag");
AddOutput("LossWeight", "(Tensor) loss weight.");
AddOutput("IndexMap", "(LoDTensor) mapping from Out rows to X1 rows");
......
......@@ -47,6 +47,7 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
// Dim [batch size, embedding size]
auto* x1 = context.Input<LoDTensor>("Ins");
bool is_x1_lod = context.Attr<bool>("is_lod");
int64_t out_val_if_empty = context.Attr<int64_t>("out_val_if_empty");
// X2 is ins tag list
// LoD [[0, Sum(ins1), Sum(ins1, ins2), ... ]]
auto* x2 = context.Input<LoDTensor>("Ins_tag");
......@@ -157,7 +158,15 @@ class FilterByInstagKernel : public framework::OpKernel<T> {
std::vector<Vector<size_t>> out_lod_info;
out_lod_info.push_back(out_lods);
out->set_lod(out_lod_info);
memset(out_data, 0, out->numel() * sizeof(T));
for (int64_t oi = 0; oi < out->numel(); ++oi) {
if (std::is_same<T, int32_t>::value) {
out_data[oi] = (int32_t)out_val_if_empty;
} else if (std::is_same<T, int64_t>::value) {
out_data[oi] = (int64_t)out_val_if_empty;
} else {
out_data[oi] = static_cast<double>(out_val_if_empty);
}
}
loss_weight_data[0] = 0;
}
}
......
......@@ -102,9 +102,9 @@ class AucKernel : public framework::OpKernel<T> {
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
if (label_data[i] > 0) {
origin_stat_pos[binIdx] += 1;
} else {
} else if (label_data[i] == 0) {
origin_stat_neg[binIdx] += 1;
}
}
......@@ -142,9 +142,9 @@ class AucKernel : public framework::OpKernel<T> {
"The predict data must gather or equal 0."));
uint32_t binIdx = static_cast<uint32_t>(predict_data * num_thresholds);
if (label_data[i]) {
if (label_data[i] > 0) {
origin_stat_pos[cur_step_begin + binIdx] += 1;
} else {
} else if (label_data[i] == 0) {
origin_stat_neg[cur_step_begin + binIdx] += 1;
}
}
......
......@@ -9169,7 +9169,7 @@ def stack(x, axis=0):
@templatedoc(op_type="filter_by_instag")
def filter_by_instag(ins, ins_tag, filter_tag, is_lod):
def filter_by_instag(ins, ins_tag, filter_tag, is_lod, out_val_if_empty=0):
"""
**Filter By Instag Layer**
......@@ -9206,6 +9206,8 @@ def filter_by_instag(ins, ins_tag, filter_tag, is_lod):
filter_tag (Variable): Input Variable (1D Tensor/List), usually it is
list that holds the tags.
is_lod (Bool): Boolean value to indicate ins is lod tensor or not.
out_val_if_empty(Int64): If the output after filter is empty, this value
will be set to Output tensor.
Returns:
Variable: filtered ins (LoDTensor) and loss weight (Tensor)
......@@ -9233,7 +9235,8 @@ def filter_by_instag(ins, ins_tag, filter_tag, is_lod):
outputs={'Out': out,
'LossWeight': loss_weight,
'IndexMap': mmap},
attrs={'is_lod': is_lod})
attrs={'is_lod': is_lod,
'out_val_if_empty': out_val_if_empty})
return [out, loss_weight]
......
......@@ -23,6 +23,7 @@ import paddle.fluid.layers as layers
from op_test import OpTest
import random
from decorator_helper import prog_scope
from paddle.fluid.op import Operator
"""This is Test Case 1"""
......@@ -71,7 +72,7 @@ class TestFilterByInstagOp(OpTest):
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True}
self.attrs = {'is_lod': True, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
......@@ -116,7 +117,7 @@ class TestFilterByInstagOp2(OpTest):
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True, }
self.attrs = {'is_lod': True, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
......@@ -158,7 +159,7 @@ class TestFilterByInstagOp3(OpTest):
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': True, }
self.attrs = {'is_lod': True, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
......@@ -199,7 +200,7 @@ class TestFilterByInstagOp4(OpTest):
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': False, }
self.attrs = {'is_lod': False, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
......@@ -209,5 +210,79 @@ class TestFilterByInstagOp4(OpTest):
['Ins'], 'Out', no_grad_set=set(['Ins_tag', 'Filter_tag']))
class TestFilterByInstagOp6(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
x1 = np.random.random((4, 36)).astype('int64')
x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]]
x3 = np.array([3]).astype('int64')
out = np.zeros((1, 36)).astype('double')
out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64')
mmap_lod = [[1]]
loss_weight = np.array([[0]]).astype('double')
self.inputs = {
'Ins': x1,
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': False, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
class TestFilterByInstagOp7(OpTest):
def setUp(self):
self.op_type = 'filter_by_instag'
x1 = np.random.random((4, 36)).astype('int32')
x2 = np.array([[2], [1], [2], [1]]).astype('int64')
x2_lod = [[1, 1, 1, 1]]
x3 = np.array([3]).astype('int64')
out = np.zeros((1, 36)).astype('double')
out_lod = [[1]]
mmap = np.array([[0, 1, 1]]).astype('int64')
mmap_lod = [[1]]
loss_weight = np.array([[0]]).astype('double')
self.inputs = {
'Ins': x1,
'Ins_tag': (x2, x2_lod),
'Filter_tag': x3,
}
self.outputs = {
'Out': (out, out_lod),
'LossWeight': (loss_weight, mmap_lod),
'IndexMap': (mmap, mmap_lod)
}
self.attrs = {'is_lod': False, 'out_val_if_empty': 0}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
pass
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册