test_ir_subgraph_python_interface.py 3.8 KB
Newer Older
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3 4 5
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9 10 11 12 13 14 15
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import unittest
16

17
import paddle
18
from paddle import fluid
19 20
from paddle.fluid import core
from paddle.fluid.framework import IrGraph, Program, program_guard
W
wanghuancoder 已提交
21
from paddle.fluid.tests.unittests.eager_op_test import OpTestTool
22
from paddle.static.quantization import QuantizationTransformPass
23 24 25 26 27 28 29

paddle.enable_static()


class TestQuantizationSubGraph(unittest.TestCase):
    def build_graph_with_sub_graph(self):
        def linear_fc(num):
G
GGBond8488 已提交
30 31 32 33 34
            data = paddle.static.data(
                name='image', shape=[-1, 1, 32, 32], dtype='float32'
            )
            label = paddle.static.data(
                name='label', shape=[-1, 1], dtype='int64'
35
            )
36
            hidden = data
37
            for _ in range(num):
C
Charles-hit 已提交
38 39 40
                hidden = paddle.static.nn.fc(
                    hidden, size=128, activation='relu'
                )
41 42 43
            loss = paddle.nn.functional.cross_entropy(
                input=hidden, label=label, reduction='none', use_softmax=False
            )
44
            loss = paddle.mean(loss)
45 46 47 48 49 50 51 52 53 54 55 56
            return loss

        main_program = Program()
        startup_program = Program()

        def true_func():
            return linear_fc(3)

        def false_func():
            return linear_fc(5)

        with program_guard(main_program, startup_program):
57 58 59 60 61 62
            x = paddle.tensor.fill_constant(
                shape=[1], dtype='float32', value=0.1
            )
            y = paddle.tensor.fill_constant(
                shape=[1], dtype='float32', value=0.23
            )
L
LiYuRio 已提交
63
            pred = paddle.less_than(y, x)
64
            out = paddle.static.nn.cond(pred, true_func, false_func)
65 66

        core_graph = core.Graph(main_program.desc)
67
        # We should create graph for test, otherwise it will throw a
68 69 70 71
        # error that it cannot find the node of "STEP_COUNTER"
        graph = IrGraph(core_graph, for_test=True)
        sub_graph = graph.get_sub_graph(0)
        all_sub_graphs = graph.all_sub_graphs(
72 73
            for_test=True
        )  # same reason for subgraph
74 75 76 77 78 79 80 81 82 83 84
        # Should return graph and sub_graphs at the same time. If only return sub_graph, the graph will
        # be destructed and the sub_graphs will be empty.
        return graph, all_sub_graphs

    def test_quant_sub_graphs(self, use_cuda=False):
        graph, sub_graphs = self.build_graph_with_sub_graph()
        place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
        transform_pass = QuantizationTransformPass(
            scope=fluid.global_scope(),
            place=place,
            activation_quantize_type='abs_max',
85 86
            weight_quantize_type='range_abs_max',
        )
87 88 89 90 91 92 93 94 95 96 97
        Find_inserted_quant_op = False
        for sub_graph in sub_graphs:
            transform_pass.apply(sub_graph)
            for op in sub_graph.all_op_nodes():
                if 'quantize' in op.name():
                    Find_inserted_quant_op = True
        self.assertTrue(Find_inserted_quant_op)

    def test_quant_sub_graphs_cpu(self):
        self.test_quant_sub_graphs(use_cuda=False)

98 99 100
    @OpTestTool.skip_if(
        not paddle.is_compiled_with_cuda(), "Not GPU version paddle"
    )
101 102 103 104 105 106
    def test_quant_sub_graphs_gpu(self):
        self.test_quant_sub_graphs(use_cuda=True)


if __name__ == '__main__':
    unittest.main()