test_seq_expand.py 3.8 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5
import unittest
import numpy as np
from op_test import OpTest


W
wanghaoshuang 已提交
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
def repeat(list, starts, times, is_first):
    newlist = [list[0]]
    if is_first:
        for i, time in enumerate(times):
            size = list[i + 1] - list[i]
            newlist.append(newlist[-1] + size * time)
    else:
        for i, time in enumerate(times):
            start = list.index(starts[i])
            end = list.index(starts[i + 1]) + 1
            for t in range(time):
                for index in range(start, end - 1):
                    newlist.append(newlist[-1] + list[index + 1] - list[index])
    return newlist


def repeat_array(array, starts, times):
    newlist = []
    for i, time in enumerate(times):
        for t in range(time):
            newlist.extend(array[starts[i]:starts[i + 1]])
    return newlist


W
wanghaoshuang 已提交
30 31 32 33 34
class TestSeqExpand(OpTest):
    def set_data(self):
        self.op_type = 'seq_expand'
        x = np.random.uniform(0.1, 1, [3, 2, 2]).astype('float32')
        y = np.zeros((6, 2, 2)).astype('float32')
W
wanghaoshuang 已提交
35 36 37
        y_lod = [[0, 2, 3, 6]]
        self.inputs = {'X': (x, None), 'Y': (y, y_lod)}
        self.repeat = 2
W
wanghaoshuang 已提交
38 39

    def compute(self):
W
wanghaoshuang 已提交
40 41 42 43 44 45 46
        x_data, x_lod = self.inputs['X']
        print "x_data: %s" % x_data
        print "x_lod: %s" % x_lod
        if not x_lod:
            x_lod = [[i for i in range(1 + x_data.shape[0])]]
        else:
            x_lod = [x_lod[0]] + x_lod
W
wanghaoshuang 已提交
47
        if self.repeat:
W
wanghaoshuang 已提交
48 49 50 51 52
            self.attrs = {'repeat': self.repeat}
            repeats = (len(x_lod[0]) - 1) * [self.repeat]
            # get out shape
            # out_shape = np.copy(x_data.shape)
            # out_shape[0] = out_shape[0] * self.repeat
W
wanghaoshuang 已提交
53
        else:
W
wanghaoshuang 已提交
54 55 56 57 58 59 60 61 62 63
            y_data, y_lod = self.inputs['Y']
            print "y_lod: %s" % y_lod
            #print "y_lod: %s" % y_lod
            # get repeats
            repeats = [((y_lod[0][i + 1] - y_lod[0][i]) /
                        (x_lod[0][i + 1] - x_lod[0][i]))
                       for i in range(len(y_lod[0]) - 1)]
            # get out shape
            # out_shape = y_data.shape
        # get out lod
W
wanghaoshuang 已提交
64

W
wanghaoshuang 已提交
65 66 67 68 69 70 71
        out_lod = [repeat(x_lod[0], x_lod[0], repeats, True)] + [
            repeat(lod, x_lod[0], repeats, False) for lod in x_lod[1:]
        ]
        # copy data
        out = repeat_array(x_data.tolist(), x_lod[0], repeats)
        self.outputs = {'Out': (out, out_lod)}
        print "outputs: %s" % self.outputs
W
wanghaoshuang 已提交
72 73

    def setUp(self):
W
wanghaoshuang 已提交
74
        self.op_type = 'seq_expand'
W
wanghaoshuang 已提交
75 76 77 78 79 80
        self.set_data()
        self.compute()

    def test_check_output(self):
        self.check_output()

W
wanghaoshuang 已提交
81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117

#    def test_check_grad(self):
#        self.check_grad(["X"], "Out")


class TestSeqExpandCase1(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [7, 1]).astype('float32')
        x_lod = [[0, 5, 7], [0, 2, 5, 7]]
        self.inputs = {'X': (x_data, x_lod)}
        self.repeat = 2


class TestSeqExpandCase2(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [4, 1]).astype('float32')
        self.inputs = {'X': (x_data, None)}
        self.repeat = 2


class TestSeqExpandCase3(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32')
        y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32')
        y_lod = [[0, 1, 4, 8]]
        self.inputs = {'X': (x_data, None), 'Y': (y_data, y_lod)}
        self.repeat = None


class TestSeqExpandCase4(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
        x_lod = [[0, 2, 5]]
        y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
        y_lod = [[0, 4, 13], [0, 2, 4, 7, 10, 13]]
        self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
        self.repeat = None
W
wanghaoshuang 已提交
118 119 120 121


if __name__ == '__main__':
    unittest.main()