test_fc_op.py 2.2 KB
Newer Older
1 2
import unittest
import numpy as np
L
Liu Yiqun 已提交
3
from op_test import OpTest
4 5


6
class TestFCOp1(OpTest):
7
    def setUp(self):
8 9
        x0 = np.random.random((16, 32)).astype("float32")
        w0 = np.random.random((32, 10)).astype("float32")
10
        b = np.random.random(10).astype("float32")
11

12 13
        mul_out0 = np.dot(x0, w0)
        sum_out = mul_out0
14 15
        add_out = sum_out + b
        identity_out = add_out
16 17 18

        self.op_type = "fc"
        self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)], "B": b}
19
        self.outputs = {
20 21 22
            "MulOut": [("MulOut0", mul_out0)],
            "SumOut": sum_out,
            "AddOut": add_out,
23 24
            "Y": identity_out
        }
25
        self.attrs = {"xNumColDims": [1], "wNumColDims": [1]}
26 27 28 29 30

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
31
        self.check_grad(["X0", "W0", "B"], "Y", max_relative_error=0.01)
32 33 34 35


class TestFCOp2(OpTest):
    def setUp(self):
36
        x0 = np.random.random((16, 4, 8)).astype("float32")
37
        x1 = np.random.random((16, 32)).astype("float32")
38
        w0 = np.random.random((32, 10)).astype("float32")
39
        w1 = np.random.random((4, 8, 10)).astype("float32")
40
        b = np.random.random(10).astype("float32")
41 42 43 44 45 46 47 48

        mul_out0 = np.dot(x0.reshape(16, 4 * 8), w0)
        mul_out1 = np.dot(x1, w1.reshape(4 * 8, 10))
        sum_out = mul_out0 + mul_out1
        add_out = np.add(sum_out, b)
        sigmoid_out = 1 / (1 + np.exp(-add_out))

        self.op_type = "fc"
49
        self.inputs = {
50 51
            "X": [("X0", x0), ("X1", x1)],
            "W": [("W0", w0), ("W1", w1)],
52 53 54 55 56 57
            "B": b
        }
        self.attrs = {
            "xNumColDims": [1, 1],
            "wNumColDims": [1, 2],
            "activation": "sigmoid"
58
        }
59
        self.outputs = {
60 61 62
            "MulOut": [("MulOut0", mul_out0), ("MulOut1", mul_out1)],
            "SumOut": sum_out,
            "AddOut": add_out,
L
Liu Yiqun 已提交
63
            "Y": sigmoid_out
64
        }
65

L
Liu Yiqun 已提交
66
    def test_check_output(self):
67
        self.check_output()
68

69 70
    def test_check_grad(self):
        self.check_grad(
71
            ["X0", "X1", "W0", "W1", "B"], "Y", max_relative_error=0.01)
72 73 74 75


if __name__ == '__main__':
    unittest.main()