test_operator.py 6.1 KB
Newer Older
Y
Yu Yang 已提交
1
import unittest
Y
Yu Yang 已提交
2
import paddle.v2.framework.op as op
3 4 5 6
import paddle.v2.framework.core as core
import paddle.v2.framework.proto.op_proto_pb2 as op_proto_pb2
import paddle.v2.framework.proto.op_desc_pb2 as op_desc_pb2
import paddle.v2.framework.proto.attr_type_pb2 as attr_type_pb2
Y
Yu Yang 已提交
7 8


9 10
class TestGetAllProtos(unittest.TestCase):
    def test_all(self):
Y
Yu Yang 已提交
11
        all_protos = op.get_all_op_protos()
Y
Yu Yang 已提交
12 13 14 15 16 17
        self.assertNotEqual(0, len(all_protos))

        for each in all_protos:
            self.assertTrue(each.IsInitialized())


18 19
class TestOpDescCreationMethod(unittest.TestCase):
    def test_plain_input_output(self):
Y
Yu Yang 已提交
20 21 22
        op_proto = op_proto_pb2.OpProto()
        op_proto.type = "test"
        ipt = op_proto.inputs.add()
23 24 25
        ipt.name = "X"
        ipt.comment = "not matter"

Y
Yu Yang 已提交
26
        ipt = op_proto.inputs.add()
27 28 29
        ipt.name = "Y"
        ipt.comment = "not matter"

Y
Yu Yang 已提交
30
        opt = op_proto.outputs.add()
31 32 33
        opt.name = "Z"
        opt.comment = "not matter"

Y
Yu Yang 已提交
34
        op_proto.comment = "not matter"
35

Y
Yu Yang 已提交
36
        self.assertTrue(op_proto.IsInitialized())
37

Y
Yu Yang 已提交
38
        method = op.OpDescCreationMethod(op_proto)
39 40 41 42 43 44 45 46 47
        output = method(X="a", Y="b", Z="c")

        expected = op_desc_pb2.OpDesc()
        expected.type = "test"
        expected.inputs.extend(["a", "b"])
        expected.outputs.append("c")
        self.assertEqual(expected, output)

    def test_multiple_input_plain_output(self):
Y
Yu Yang 已提交
48 49 50
        op_proto = op_proto_pb2.OpProto()
        op_proto.type = "fc"
        ipt = op_proto.inputs.add()
51 52 53 54
        ipt.name = "X"
        ipt.comment = ""
        ipt.multiple = True

Y
Yu Yang 已提交
55
        ipt = op_proto.inputs.add()
56 57 58 59
        ipt.name = "W"
        ipt.comment = ""
        ipt.multiple = True

Y
Yu Yang 已提交
60
        ipt = op_proto.inputs.add()
61 62 63
        ipt.name = "b"
        ipt.comment = ""

Y
Yu Yang 已提交
64
        out = op_proto.outputs.add()
65 66 67
        out.name = "Y"
        out.comment = ""

Y
Yu Yang 已提交
68 69 70
        op_proto.comment = ""
        self.assertTrue(op_proto.IsInitialized())
        method = op.OpDescCreationMethod(op_proto)
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

        generated1 = method(X="x", W="w", b="b", Y="y")
        expected1 = op_desc_pb2.OpDesc()
        expected1.inputs.extend(['x', 'w', 'b'])
        expected1.outputs.extend(['y'])
        expected1.type = 'fc'
        attr = expected1.attrs.add()
        attr.name = 'input_format'
        attr.type = attr_type_pb2.INTS
        attr.ints.extend([0, 1, 2, 3])
        self.assertEqual(expected1, generated1)

        generated2 = method(
            X=['x1', 'x2', 'x3'], b='b', W=['w1', 'w2', 'w3'], Y='y')
        expected2 = op_desc_pb2.OpDesc()
        expected2.inputs.extend(['x1', 'x2', 'x3', 'w1', 'w2', 'w3', 'b'])
        expected2.outputs.extend(['y'])
        expected2.type = 'fc'
        attr = expected2.attrs.add()
        attr.name = 'input_format'
        attr.type = attr_type_pb2.INTS
        attr.ints.extend([0, 3, 6, 7])
        self.assertEqual(expected2, generated2)

    def test_attrs(self):
Y
Yu Yang 已提交
96 97 98
        op_proto = op_proto_pb2.OpProto()
        op_proto.type = "test"
        ipt = op_proto.inputs.add()
99 100 101 102
        ipt.name = 'X'
        ipt.comment = ""

        def __add_attr__(name, type):
Y
Yu Yang 已提交
103
            attr = op_proto.attrs.add()
104 105 106 107 108 109 110 111 112 113 114
            attr.name = name
            attr.comment = ""
            attr.type = type

        __add_attr__("int_attr", attr_type_pb2.INT)
        __add_attr__("float_attr", attr_type_pb2.FLOAT)
        __add_attr__("string_attr", attr_type_pb2.STRING)
        __add_attr__("ints_attr", attr_type_pb2.INTS)
        __add_attr__("floats_attr", attr_type_pb2.FLOATS)
        __add_attr__("strings_attr", attr_type_pb2.STRINGS)

Y
Yu Yang 已提交
115 116
        op_proto.comment = ""
        self.assertTrue(op_proto.IsInitialized())
117

Y
Yu Yang 已提交
118
        method = op.OpDescCreationMethod(op_proto)
119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164

        generated = method(
            X="a",
            int_attr=10,
            float_attr=3.2,
            string_attr="test_str",
            ints_attr=[0, 1, 2, 3, 4],
            floats_attr=[0.2, 3.2, 4.5],
            strings_attr=["a", "b", "c"])

        expected = op_desc_pb2.OpDesc()
        expected.type = "test"
        expected.inputs.extend(['a'])
        attr = expected.attrs.add()
        attr.name = "int_attr"
        attr.type = attr_type_pb2.INT
        attr.i = 10

        attr = expected.attrs.add()
        attr.name = "float_attr"
        attr.type = attr_type_pb2.FLOAT
        attr.f = 3.2

        attr = expected.attrs.add()
        attr.name = "string_attr"
        attr.type = attr_type_pb2.STRING
        attr.s = "test_str"

        attr = expected.attrs.add()
        attr.name = "ints_attr"
        attr.type = attr_type_pb2.INTS
        attr.ints.extend([0, 1, 2, 3, 4])

        attr = expected.attrs.add()
        attr.name = 'floats_attr'
        attr.type = attr_type_pb2.FLOATS
        attr.floats.extend([0.2, 3.2, 4.5])

        attr = expected.attrs.add()
        attr.name = 'strings_attr'
        attr.type = attr_type_pb2.STRINGS
        attr.strings.extend(['a', 'b', 'c'])

        self.assertEqual(expected, generated)

    def test_input_temporary_output(self):
Y
Yu Yang 已提交
165 166 167
        op_proto = op_proto_pb2.OpProto()
        op_proto.type = "test"
        out = op_proto.outputs.add()
168 169 170
        out.name = "OUT"
        out.comment = ""

Y
Yu Yang 已提交
171
        out = op_proto.outputs.add()
172 173 174 175
        out.name = "TMP"
        out.comment = ""
        out.temporary = True

Y
Yu Yang 已提交
176
        out = op_proto.outputs.add()
177 178
        out.name = "OUT2"
        out.comment = ""
Y
Yu Yang 已提交
179
        op_proto.comment = ""
180

Y
Yu Yang 已提交
181
        method = op.OpDescCreationMethod(op_proto)
182 183 184 185 186 187 188 189 190 191 192 193 194
        generated = method(OUT="a", OUT2="b")
        desc = op_desc_pb2.OpDesc()
        desc.outputs.extend(["a", core.var_names.temp(), "b"])
        desc.type = "test"
        attr = desc.attrs.add()
        attr.name = "temporary_index"
        attr.type = attr_type_pb2.INTS
        attr.ints.append(2)
        self.assertEqual(generated, desc)


class TestOpCreations(unittest.TestCase):
    def test_all(self):
Y
Yu Yang 已提交
195
        add_op = op.Operator("add_two", X="a", Y="b", Out="z")
196 197 198 199 200 201
        self.assertIsNotNone(add_op)
        # Invoke C++ DebugString()
        self.assertEqual('Op(add_two), inputs:(a, b), outputs:(z).',
                         str(add_op))


Y
Yu Yang 已提交
202 203
if __name__ == "__main__":
    unittest.main()