test_fc_op.py 2.0 KB
Newer Older
1 2
import unittest
import numpy as np
L
Liu Yiqun 已提交
3
from op_test import OpTest
4 5


6
class TestFCOp1(OpTest):
7
    def setUp(self):
L
Liu Yiqun 已提交
8
        self.op_type = "fc"
9 10
        x0 = np.random.random((16, 32)).astype("float32")
        w0 = np.random.random((32, 10)).astype("float32")
11
        b = np.random.random(10).astype("float32")
12 13 14
        self.inputs = {"X": [("X0", x0)], "W": [("W0", w0)], "b": b}
        mul_out0 = np.dot(x0, w0)
        sum_out = mul_out0
15 16 17
        add_out = sum_out + b
        identity_out = add_out
        self.outputs = {
18
            "mul_out": [("mul_out0", mul_out0)],
19 20 21 22 23 24 25 26 27
            "sum_out": sum_out,
            "add_out": add_out,
            "Y": identity_out
        }

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
28
        self.check_grad(["X0", "W0", "b"], "Y", max_relative_error=0.01)
29 30 31 32 33


class TestFCOp2(OpTest):
    def setUp(self):
        self.op_type = "fc"
34
        x0 = np.random.random((16, 32)).astype("float32")
35
        x1 = np.random.random((16, 32)).astype("float32")
36
        w0 = np.random.random((32, 10)).astype("float32")
37 38
        w1 = np.random.random((32, 10)).astype("float32")
        b = np.random.random(10).astype("float32")
39
        self.inputs = {
40 41
            "X": [("X0", x0), ("X1", x1)],
            "W": [("W0", w0), ("W1", w1)],
L
Liu Yiqun 已提交
42
            "b": b
43
        }
44 45
        self.attrs = {"activation": "sigmoid"}
        mul_out0 = np.dot(x0, w0)
46
        mul_out1 = np.dot(x1, w1)
47
        sum_out = mul_out0 + mul_out1
48
        add_out = np.add(sum_out, b)
49
        sigmoid_out = 1 / (1 + np.exp(-add_out))
50
        self.outputs = {
51
            "mul_out": [("mul_out0", mul_out0), ("mul_out1", mul_out1)],
52
            "sum_out": sum_out,
53
            "add_out": add_out,
L
Liu Yiqun 已提交
54
            "Y": sigmoid_out
55
        }
56

L
Liu Yiqun 已提交
57
    def test_check_output(self):
58
        self.check_output()
59

60 61
    def test_check_grad(self):
        self.check_grad(
62
            ["X0", "X1", "W0", "W1", "b"], "Y", max_relative_error=0.01)
63 64 65 66


if __name__ == '__main__':
    unittest.main()