test_seq_concat_op.py 3.4 KB
Newer Older
Y
Yancey1989 已提交
1 2
import unittest
import numpy as np
Y
Yu Yang 已提交
3
import sys
Y
Yancey1989 已提交
4
from op_test import OpTest
D
dzhwinter 已提交
5
exit(0)
Y
Yancey1989 已提交
6 7


8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
def to_abs_lod(lod):
    if len(lod) == 0 or len(lod) == 1:
        return lod
    import copy
    new_lod = copy.deepcopy(lod)
    for idx, val in enumerate(lod[0]):
        new_lod[0][idx] = lod[1][val]
    return new_lod


def seq_concat(inputs, level):
    lod0 = inputs['X'][0][1][1]
    lod1 = inputs['X'][1][1][1]
    x0 = inputs['X'][0][1][0]
    x1 = inputs['X'][1][1][0]
    level_idx = len(lod0) - level - 1
    outs = []
    for i in range(len(lod0[level_idx]) - 1):
        sub_x0 = x0[to_abs_lod(lod0)[level_idx][i]:to_abs_lod(lod0)[level_idx][
            i + 1], :]
        sub_x1 = x1[to_abs_lod(lod1)[level_idx][i]:to_abs_lod(lod1)[level_idx][
            i + 1], :]
        outs.append(np.concatenate((sub_x0, sub_x1), axis=0))
    return np.concatenate(outs, axis=0)


class TestSeqConcatOp(OpTest):
Y
Yancey1989 已提交
35 36
    def set_data(self):
        # two level, batch size is 3
Y
Yancey1989 已提交
37 38 39 40
        x0 = np.random.random((4, 6, 3)).astype('float32')
        lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]]
        x1 = np.random.random((4, 8, 3)).astype('float32')
        lod1 = [[0, 2, 4], [0, 1, 2, 3, 4]]
Y
Yancey1989 已提交
41 42 43 44
        axis = 1
        level = 1
        self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
        self.attrs = {'axis': axis, 'level': level}
45
        self.outputs = {'Out': (np.concatenate([x0, x1], axis=1), lod0)}
Y
Yancey1989 已提交
46 47 48 49 50 51 52 53 54 55 56 57

    def setUp(self):
        self.op_type = "sequence_concat"
        self.set_data()

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['x0'], 'Out')


58
class TestSeqConcatOpLevelZeroNestedSequence(TestSeqConcatOp):
Y
Yancey1989 已提交
59 60
    def set_data(self):
        # two level, batch size is 3
Y
Yancey1989 已提交
61 62
        x0 = np.random.random((4, 6, 3)).astype('float32')
        lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]]
63 64
        x1 = np.random.random((7, 6, 3)).astype('float32')
        lod1 = [[0, 2, 4], [0, 1, 3, 5, 7]]
Y
Yancey1989 已提交
65
        axis = 0
66
        level = 0
Y
Yancey1989 已提交
67 68
        self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
        self.attrs = {'axis': axis, 'level': level}
69 70
        out_lod = [[0, 2, 4], [0, 2, 5, 8, 11]]
        self.outputs = {'Out': (seq_concat(self.inputs, level), out_lod)}
Y
Yancey1989 已提交
71

72 73 74 75 76 77 78 79 80 81 82 83 84 85

class TestSeqConcatOplevelOneNestedSequence(TestSeqConcatOp):
    def set_data(self):
        # two level, batch size is 3
        x0 = np.random.random((4, 6, 3)).astype('float32')
        lod0 = [[0, 2, 4], [0, 1, 2, 3, 4]]
        x1 = np.random.random((7, 6, 3)).astype('float32')
        lod1 = [[0, 3, 4], [0, 1, 3, 5, 7]]
        axis = 0
        level = 1
        self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
        self.attrs = {'axis': axis, 'level': level}
        out_lod = [[0, 5, 8], [0, 1, 2, 3, 5, 7, 8, 9, 11]]
        self.outputs = {'Out': (seq_concat(self.inputs, level), out_lod)}
Y
Yancey1989 已提交
86 87


88
class TestSeqConcatOpLevelZeroSequence(TestSeqConcatOp):
Y
Yancey1989 已提交
89 90 91
    def set_data(self):
        # two level, batch size is 3
        x0 = np.random.random((4, 3, 4)).astype('float32')
92 93 94
        lod0 = [[0, 1, 2, 3, 4]]
        x1 = np.random.random((7, 3, 4)).astype('float32')
        lod1 = [[0, 1, 3, 5, 7]]
Y
Yancey1989 已提交
95 96 97 98
        axis = 0
        level = 0
        self.inputs = {'X': [('x0', (x0, lod0)), ('x1', (x1, lod1))]}
        self.attrs = {'axis': axis, 'level': level}
99 100
        out_lod = [[0, 2, 5, 8, 11]]
        self.outputs = {'Out': (seq_concat(self.inputs, level), out_lod)}
Y
Yancey1989 已提交
101 102 103 104


if __name__ == '__main__':
    unittest.main()