test_seq_expand.py 1.7 KB
Newer Older
W
wanghaoshuang 已提交
1 2 3 4 5 6 7
import unittest
import numpy as np
from op_test import OpTest


class TestSeqExpand(OpTest):
    def set_data(self):
W
wanghaoshuang 已提交
8 9 10 11
        x_data = np.random.uniform(0.1, 1, [3, 1]).astype('float32')
        y_data = np.random.uniform(0.1, 1, [8, 1]).astype('float32')
        y_lod = [[0, 1, 4, 8]]
        self.inputs = {'X': x_data, 'Y': (y_data, y_lod)}
W
wanghaoshuang 已提交
12 13

    def compute(self):
W
wanghaoshuang 已提交
14 15
        x = self.inputs['X']
        x_data, x_lod = x if type(x) == tuple else (x, None)
W
wanghaoshuang 已提交
16
        n = 1 + x_data.shape[0] if not x_lod else len(x_lod[0])
W
wanghaoshuang 已提交
17 18 19 20
        y_data, y_lod = self.inputs['Y']
        repeats = [((y_lod[-1][i + 1] - y_lod[-1][i]))
                   for i in range(len(y_lod[-1]) - 1)]
        out = x_data.repeat(repeats, axis=0)
W
wanghaoshuang 已提交
21
        self.outputs = {'Out': out}
W
wanghaoshuang 已提交
22 23

    def setUp(self):
W
wanghaoshuang 已提交
24
        self.op_type = 'seq_expand'
W
wanghaoshuang 已提交
25 26 27 28 29 30
        self.set_data()
        self.compute()

    def test_check_output(self):
        self.check_output()

W
wanghaoshuang 已提交
31 32
    def test_check_grad(self):
        self.check_grad(["X"], "Out")
W
wanghaoshuang 已提交
33 34 35 36 37 38 39


class TestSeqExpandCase1(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [5, 1]).astype('float32')
        x_lod = [[0, 2, 5]]
        y_data = np.random.uniform(0.1, 1, [13, 1]).astype('float32')
W
wanghaoshuang 已提交
40
        y_lod = [[0, 2, 5], [0, 2, 4, 7, 10, 13]]
W
wanghaoshuang 已提交
41
        self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
W
wanghaoshuang 已提交
42 43 44 45 46 47 48 49 50


class TestSeqExpandCase2(TestSeqExpand):
    def set_data(self):
        x_data = np.random.uniform(0.1, 1, [1, 2, 2]).astype('float32')
        x_lod = [[0, 1]]
        y_data = np.random.uniform(0.1, 1, [2, 2, 2]).astype('float32')
        y_lod = [[0, 2]]
        self.inputs = {'X': (x_data, x_lod), 'Y': (y_data, y_lod)}
W
wanghaoshuang 已提交
51 52 53 54


if __name__ == '__main__':
    unittest.main()