提交 a0b53376 编写于 作者: A Aurelius84 提交者: Tao Luo

increase data shape in topk_avg_pooling unittest (#22010)

上级 86c40e20
......@@ -32,18 +32,19 @@ class TestSequenceTopkAvgPoolingOp(OpTest):
self.op_type = "sequence_topk_avg_pooling"
def set_data(self):
topks = [2]
topks = [1, 3, 5]
channel_num = 3
dim = 10
row = [2, 4]
col = [3, 2]
row = [30, 45]
col = [25, 36]
self.init_data(topks, channel_num, row, col, dim)
self.init_data(topks, channel_num, row, col, dim)
def init_data(self, topks, channel_num, row, col, dim=10):
self.attrs = {"topks": topks, "channel_num": channel_num}
feature = [row[i] * col[i] for i in range(len(row))]
numel = sum(feature) * channel_num
x_data = np.random.random((numel, )).astype('float32')
x_data = np.arange(numel).astype('float32')
x_lod = [[x * channel_num for x in feature]]
row_data = np.random.random((sum(row), dim)).astype('float32')
col_data = np.random.random((sum(col), dim)).astype('float32')
......@@ -53,6 +54,30 @@ class TestSequenceTopkAvgPoolingOp(OpTest):
'COLUMN': (col_data, [col])
}
def calc_gradient(self, pos_data, topks, channel_num, row, col):
max_k = topks[-1]
pos_data = pos_data.flatten()
in_numel = sum([row[i] * col[i] for i in range(len(row))]) * channel_num
out_numel = sum(row) * len(topks) * channel_num
gradient = np.zeros(shape=(in_numel), dtype="float32")
dout_val = 1. / out_numel
pos_offset, in_offset = 0, 0
for bs_idx in range(len(row)): # batch
row_size = row[bs_idx]
col_size = col[bs_idx]
for ch in range(channel_num): # channel
for row_idx in range(row_size): # row
in_idx = in_offset + row_idx * col_size
pos_idx = pos_offset + row_idx * max_k
for k_idx in range(len(topks)):
for k in range(topks[k_idx]):
if pos_data[pos_idx + k] != -1:
gradient[in_idx + pos_data[
pos_idx + k]] += dout_val / topks[k_idx]
in_offset += row_size * col_size
pos_offset += row_size * max_k
return gradient
def compute(self):
topks = self.attrs['topks']
max_k = topks[-1]
......@@ -70,7 +95,6 @@ class TestSequenceTopkAvgPoolingOp(OpTest):
self.assertTrue(
x_len == channel_num * row_lod[0][idx] * col_lod[0][idx],
"x_len: %s can't mod channel_num: %s" % (x_len, channel_num))
# feature = x_len / channel_num
out_tmp = np.zeros((0, ), dtype=x_data.dtype)
pos_tmp = np.zeros((0, ), dtype='int32')
for ch in range(channel_num):
......@@ -94,6 +118,8 @@ class TestSequenceTopkAvgPoolingOp(OpTest):
pos = np.hstack((pos, pos_tmp.flatten()))
self.outputs = {'Out': (out.astype('float32'), out_lod), 'pos': pos}
self.gradient = self.calc_gradient(pos, topks, channel_num, row_lod[0],
col_lod[0])
def get_topk(self, x, topk):
real_topk = topk if topk < len(x) else len(x)
......@@ -118,16 +144,16 @@ class TestSequenceTopkAvgPoolingOp(OpTest):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
self.check_grad(['X'], 'Out', user_defined_grads=[self.gradient])
class TestSequenceTopkAvgPoolingOpCase1(TestSequenceTopkAvgPoolingOp):
def set_data(self):
topks = [2, 3]
channel_num = 3
channel_num = 5
dim = 10
row = [3]
col = [4]
row = [36]
col = [48]
self.init_data(topks, channel_num, row, col, dim)
def test_api(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册