test_protobuf_descs.py 7.7 KB
Newer Older
1
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
D
dzhwinter 已提交
2
#
D
dzhwinter 已提交
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
D
dzhwinter 已提交
6
#
D
dzhwinter 已提交
7
#     http://www.apache.org/licenses/LICENSE-2.0
D
dzhwinter 已提交
8
#
D
dzhwinter 已提交
9 10 11 12 13 14
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

Y
Update  
Yu Yang 已提交
17
import unittest
18
import paddle.fluid.core as core
M
minqiyang 已提交
19
import paddle.compat as cpt
20
from paddle.fluid.framework import Program
Y
Update  
Yu Yang 已提交
21 22


Y
Yu Yang 已提交
23 24
class TestOpDesc(unittest.TestCase):
    def test_op_desc(self):
L
Luo Tao 已提交
25 26 27
        program_desc = core.ProgramDesc()
        self.assertIsNotNone(program_desc)
        block = program_desc.block(0)
Y
Yu Yang 已提交
28 29 30 31 32 33 34 35 36 37 38 39 40
        self.assertIsNotNone(block)
        op = block.append_op()
        self.assertIsNotNone(op)
        op.set_type("test")
        self.assertEqual("test", op.type())
        op.set_input("X", ["a", "b", "c"])
        self.assertEqual(["a", "b", "c"], op.input("X"))
        self.assertEqual(["X"], op.input_names())

        op.set_output("Out", ["z"])
        self.assertEqual(['z'], op.output("Out"))
        self.assertEqual(["Out"], op.output_names())

Y
Yu Yang 已提交
41 42 43
        op.set_attr("int_attr", 1)
        self.assertEqual(1, op.attr("int_attr"))
        self.assertTrue(op.has_attr("int_attr"))
Y
Yu Yang 已提交
44
        self.assertEqual(core.AttrType.INT, op.attr_type("int_attr"))
Y
Yu Yang 已提交
45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72

        op.set_attr("float_attr", -1.32)
        self.assertAlmostEqual(-1.32, op.attr("float_attr"), delta=1e-4)
        self.assertTrue(op.has_attr("float_attr"))

        op.set_attr("bool_attr", False)
        self.assertFalse(op.attr("bool_attr"))

        op.set_attr("string_attr", "abc")
        self.assertEqual("abc", op.attr("string_attr"))
        self.assertTrue(op.has_attr("string_attr"))

        op.set_attr("ints_attr", [1, 2, 3])
        self.assertEqual([1, 2, 3], op.attr("ints_attr"))

        expected = [1.2, 2.3, 3.4]
        op.set_attr("floats_attr", expected)
        for e, a in zip(expected, op.attr("floats_attr")):
            self.assertAlmostEqual(e, a, delta=1e-4)

        op.set_attr("strings_attr", ["a", "b", "c"])
        self.assertEqual(["a", "b", "c"], op.attr("strings_attr"))

        op.set_attr("bools_attr", [True, False, True])
        self.assertEqual([True, False, True], op.attr("bools_attr"))

        self.assertEqual(8, len(op.attr_names()))

L
Luo Tao 已提交
73
        op.set_block_attr("block_attr", program_desc.block(0))
G
gongweibao 已提交
74
        self.assertEqual(0, op.block_attr_id("block_attr"))
Y
Yu Yang 已提交
75

F
fengjiayi 已提交
76 77 78 79 80 81
        mul_op = block.append_op()
        mul_op.set_type("mul")
        mul_op.check_attrs()
        self.assertEqual(mul_op.attr("x_num_col_dims"), 1)
        self.assertEqual(mul_op.attr("y_num_col_dims"), 1)

Y
Yu Yang 已提交
82

Y
Update  
Yu Yang 已提交
83 84
class TestProgramDesc(unittest.TestCase):
    def test_instance(self):
85
        program_desc = core.ProgramDesc()
Y
Update  
Yu Yang 已提交
86 87
        self.assertIsNotNone(program_desc)
        del program_desc
88
        program_desc = core.ProgramDesc()
Y
Update  
Yu Yang 已提交
89
        self.assertIsNotNone(program_desc)
Y
Yu Yang 已提交
90
        self.assertIsNotNone(program_desc.block(0))
Y
Update  
Yu Yang 已提交
91 92
        del program_desc

Y
Yu Yang 已提交
93
    def test_append_block(self):
L
Luo Tao 已提交
94 95 96
        program_desc = core.ProgramDesc()
        self.assertIsNotNone(program_desc)
        block_root = program_desc.block(0)
Y
Yu Yang 已提交
97 98
        self.assertIsNotNone(block_root)
        self.assertEqual(block_root.id, 0)
L
Luo Tao 已提交
99 100
        block1 = program_desc.append_block(block_root)
        block2 = program_desc.append_block(block1)
Y
Yu Yang 已提交
101 102 103
        self.assertIsNotNone(block1)
        self.assertEqual(block1.id, block2.parent)
        self.assertEqual(block_root.id, block1.parent)
L
Luo Tao 已提交
104
        block3 = program_desc.append_block(block_root)
Y
Yu Yang 已提交
105
        self.assertEqual(block3.parent, block_root.id)
L
Luo Tao 已提交
106 107
        self.assertEqual(program_desc.block(1).id, 1)
        self.assertEqual(4, program_desc.num_blocks())
F
fengjiayi 已提交
108 109 110 111


class TestVarDesc(unittest.TestCase):
    def test_shape(self):
112
        program_desc = core.ProgramDesc()
Y
Yu Yang 已提交
113
        block = program_desc.block(0)
114
        var = block.var(cpt.to_bytes('my_var'))
Y
Yu Yang 已提交
115
        var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
F
fengjiayi 已提交
116 117 118 119
        src_shape = [3, 2, 10, 8]
        var.set_shape(src_shape)
        res_shape = var.shape()
        self.assertEqual(src_shape, res_shape)
Y
Yu Yang 已提交
120
        self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
Y
Yu Yang 已提交
121

F
fengjiayi 已提交
122 123 124
    def test_multiple_shape(self):
        program_desc = core.ProgramDesc()
        block = program_desc.block(0)
125
        var = block.var(cpt.to_bytes('my_reader'))
F
fengjiayi 已提交
126 127 128 129 130 131 132
        var.set_type(core.VarDesc.VarType.READER)
        src_shapes = [[2, 3, 3], [4, 5], [6, 7, 8, 9]]
        var.set_shapes(src_shapes)
        res_shapes = var.shapes()
        self.assertEqual(src_shapes, res_shapes)
        self.assertEqual(core.VarDesc.VarType.READER, var.type())

F
fengjiayi 已提交
133
    def test_dtype(self):
134
        program_desc = core.ProgramDesc()
F
fengjiayi 已提交
135
        block = program_desc.block(0)
136
        var = block.var(cpt.to_bytes('my_var'))
Y
Yu Yang 已提交
137
        var.set_type(core.VarDesc.VarType.LOD_TENSOR)
138 139
        var.set_dtype(core.VarDesc.VarType.INT32)
        self.assertEqual(core.VarDesc.VarType.INT32, var.dtype())
Y
Yu Yang 已提交
140
        self.assertEqual(core.VarDesc.VarType.LOD_TENSOR, var.type())
F
fengjiayi 已提交
141

F
fengjiayi 已提交
142 143 144
    def test_multiple_dtype(self):
        program_desc = core.ProgramDesc()
        block = program_desc.block(0)
145
        var = block.var(cpt.to_bytes('my_reader'))
F
fengjiayi 已提交
146 147
        var.set_type(core.VarDesc.VarType.READER)
        src_types = [
148 149
            core.VarDesc.VarType.INT32, core.VarDesc.VarType.FP64,
            core.VarDesc.VarType.FP32
F
fengjiayi 已提交
150 151 152 153 154 155 156 157
        ]
        var.set_dtypes(src_types)
        self.assertEqual(src_types, var.dtypes())
        self.assertEqual(core.VarDesc.VarType.READER, var.type())

    def test_multiple_lod_level(self):
        program_desc = core.ProgramDesc()
        block = program_desc.block(0)
158
        var = block.var(cpt.to_bytes('my_reader'))
F
fengjiayi 已提交
159 160 161 162 163 164
        var.set_type(core.VarDesc.VarType.READER)
        src_types = [3, 1, 2]
        var.set_lod_levels(src_types)
        self.assertEqual(src_types, var.lod_levels())
        self.assertEqual(core.VarDesc.VarType.READER, var.type())

F
fengjiayi 已提交
165 166 167

class TestBlockDesc(unittest.TestCase):
    def test_add_var(self):
L
Luo Tao 已提交
168 169 170
        program_desc = core.ProgramDesc()
        self.assertIsNotNone(program_desc)
        block = program_desc.block(0)
F
fengjiayi 已提交
171
        self.assertIsNotNone(block)
172 173 174
        var1 = block.var(cpt.to_bytes("var1"))
        var2 = block.var(cpt.to_bytes("var2"))
        var3 = block.var(cpt.to_bytes("var3"))
F
fengjiayi 已提交
175
        all_vars = block.all_vars()
176
        self.assertEqual(set(all_vars), {var1, var2, var3})
177
        var2_re = block.find_var(cpt.to_bytes("var2"))
F
fengjiayi 已提交
178 179 180
        self.assertEqual(var2_re, var2)

    def test_add_op(self):
L
Luo Tao 已提交
181 182 183
        program_desc = core.ProgramDesc()
        self.assertIsNotNone(program_desc)
        block = program_desc.block(0)
F
fengjiayi 已提交
184 185 186
        self.assertIsNotNone(block)
        op1 = block.append_op()
        op2 = block.append_op()
W
Wu Yi 已提交
187
        op0 = block._prepend_op()
188
        all_ops = []
189
        for idx in range(0, block.op_size()):
190
            all_ops.append(block.op(idx))
F
fengjiayi 已提交
191 192
        self.assertEqual(all_ops, [op0, op1, op2])

W
Wu Yi 已提交
193
    def test__remove_op(self):
194
        program = Program()
L
Luo Tao 已提交
195 196 197
        program_desc = program.desc
        self.assertIsNotNone(program_desc)
        block = program_desc.block(0)
L
Luo Tao 已提交
198
        self.assertIsNotNone(block)
199 200

        op0 = block.append_op()
L
Luo Tao 已提交
201 202
        op1 = block.append_op()
        op2 = block.append_op()
203 204 205 206
        op0.set_type("test")
        op1.set_type("test")
        op2.set_type("test")

W
Wu Yi 已提交
207 208
        block._remove_op(1, 2)
        program._sync_with_cpp()
L
Luo Tao 已提交
209 210

        all_ops = []
211
        for idx in range(0, block.op_size()):
L
Luo Tao 已提交
212
            all_ops.append(block.op(idx))
213
        self.assertEqual(all_ops, [op0, op2])
L
Luo Tao 已提交
214

Y
Update  
Yu Yang 已提交
215 216 217

if __name__ == '__main__':
    unittest.main()