test_squared_l2_distance_op.py 2.0 KB
Newer Older
1 2
import unittest
import numpy as np
Q
qijun 已提交
3
from op_test import OpTest
4 5


Q
qijun 已提交
6
class TestSquaredL2DistanceOp_f0(OpTest):
7
    def setUp(self):
Q
qijun 已提交
8
        self.op_type = "squared_l2_distance"
9
        self.inputs = {
Q
qijun 已提交
10 11
            'X': np.random.uniform(0.1, 0.6, (2, 3)).astype("float32"),
            'Y': np.random.uniform(0.1, 0.6, (2, 3)).astype("float32")
12
        }
Y
yangyaming 已提交
13 14
        sub_res = self.inputs['X'] - self.inputs['Y']
        output = sub_res * sub_res
15
        self.outputs = {
Y
yangyaming 已提交
16 17 18 19
            'sub_result': sub_res,
            'Out': np.expand_dims(output.sum(1), 1)
        }

Q
qijun 已提交
20 21 22 23 24
    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        self.check_grad(['X', 'Y'], 'Out')
Y
yangyaming 已提交
25 26


Q
qijun 已提交
27
class TestSquaredL2DistanceOp_f1(OpTest):
Y
yangyaming 已提交
28
    def setUp(self):
Q
qijun 已提交
29
        self.op_type = "squared_l2_distance"
Y
yangyaming 已提交
30
        self.inputs = {
Q
qijun 已提交
31 32
            'X': np.random.uniform(0.1, 0.6, (2, 3)).astype("float32"),
            'Y': np.random.uniform(0.1, 0.6, (1, 3)).astype("float32")
Y
yangyaming 已提交
33 34 35 36 37 38 39 40
        }
        sub_res = self.inputs['X'] - self.inputs['Y']
        output = sub_res * sub_res
        self.outputs = {
            'sub_result': sub_res,
            'Out': np.expand_dims(output.sum(1), 1)
        }

Q
qijun 已提交
41 42
    def test_check_output(self):
        self.check_output()
Y
yangyaming 已提交
43

Q
qijun 已提交
44 45
    def test_check_grad(self):
        self.check_grad(['X', 'Y'], 'Out')
Y
yangyaming 已提交
46

Q
qijun 已提交
47 48

class TestSquaredL2DistanceOp_f2(OpTest):
Y
yangyaming 已提交
49
    def setUp(self):
Q
qijun 已提交
50
        self.op_type = "squared_l2_distance"
Y
yangyaming 已提交
51
        self.inputs = {
Q
qijun 已提交
52 53
            'X': np.random.uniform(0.1, 0.6, (2, 3, 4)).astype("float32"),
            'Y': np.random.uniform(0.1, 0.6, (1, 3, 4)).astype("float32")
Y
yangyaming 已提交
54 55
        }
        sub_res = self.inputs['X'] - self.inputs['Y']
Q
qijun 已提交
56
        sub_res = sub_res.reshape((2, 3 * 4))
Y
yangyaming 已提交
57 58 59
        output = sub_res * sub_res
        self.outputs = {
            'sub_result': sub_res,
60 61 62
            'Out': np.expand_dims(output.sum(1), 1)
        }

Q
qijun 已提交
63 64
    def test_check_output(self):
        self.check_output()
65

Q
qijun 已提交
66 67
    def test_check_grad(self):
        self.check_grad(['X', 'Y'], 'Out')
68 69


Q
qijun 已提交
70
if __name__ == "__main__":
71
    unittest.main()