未验证 提交 8a4f85fe 编写于 作者: P Pei Yang 提交者: GitHub

Add unittests and OP version registry for quant_conv2d_dequant_fuse_pass (#27689)

上级 dec53a9c
......@@ -19,6 +19,7 @@
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
#include "paddle/fluid/framework/ir/quant_conv2d_dequant_fuse_pass.h"
#include "paddle/fluid/framework/op_version_registry.h"
namespace paddle {
namespace framework {
......@@ -284,3 +285,15 @@ void QuantDequantFusePass::ApplyImpl(ir::Graph* graph) const {
REGISTER_PASS(quant_conv2d_dequant_fuse_pass,
paddle::framework::ir::QuantDequantFusePass);
REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("conv2d", 0)
.EQ("fc", 0)
.LE("conv2d_transpose", 1)
.EQ("fake_quantize_abs_max", 0)
.EQ("fake_quantize_range_abs_max", 0)
.EQ("fake_quantize_moving_average_abs_max", 0)
.EQ("fake_channel_wise_quantize_abs_max", 0)
.EQ("fake_dequantize_max_abs", 0));
......@@ -27,6 +27,10 @@ from paddle.fluid.core import PaddleDType
from paddle.fluid.core import AnalysisConfig
from paddle.fluid.core import create_paddle_predictor
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
class InferencePassTest(unittest.TestCase):
def __init__(self, methodName='runTest'):
......@@ -48,7 +52,8 @@ class InferencePassTest(unittest.TestCase):
def _get_place(self):
return set([False, core.is_compiled_with_cuda()])
def _save_models(self, executor, program):
def _save_models(self, executor, program, scope):
with fluid.scope_guard(scope):
outs = executor.run(program=program,
feed=self.feeds,
fetch_list=self.fetch_list,
......@@ -133,7 +138,11 @@ class InferencePassTest(unittest.TestCase):
for place_ in use_gpu:
self.check_output_with_option(place_, atol)
def check_output_with_option(self, use_gpu, atol=1e-5, flatten=False):
def check_output_with_option(self,
use_gpu,
atol=1e-5,
flatten=False,
quant=False):
'''
Check whether calculating on CPU and GPU, enable TensorRT
or disable TensorRT, enable MKLDNN or disable MKLDNN
......@@ -141,9 +150,52 @@ class InferencePassTest(unittest.TestCase):
'''
place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace()
executor = fluid.Executor(place)
scope = fluid.Scope()
device = "GPU" if use_gpu else "CPU"
with fluid.scope_guard(scope):
executor.run(self.startup_program)
outs = self._save_models(executor, self.main_program)
if quant:
main_graph = IrGraph(
core.Graph(self.main_program.desc), for_test=True)
transform_pass = QuantizationTransformPass(
scope=scope,
place=place,
activation_quantize_type=self.activation_quant_type,
weight_quantize_type=self.weight_quant_type,
quantizable_op_type=[
'conv2d', 'mul', 'depthwise_conv2d', 'conv2d_transpose'
])
transform_pass.apply(main_graph)
weight_scale_map = {
"conv2d": "conv2d_0.w_0.scale",
"mul": "fc_0.w_0.scale"
}
weight_scale_tensor = scope.var(weight_scale_map[
self.quantized_op_type]).get_tensor()
weight_scale = np.ones(self.channels).astype("float32")
weight_scale_tensor.set(weight_scale, place)
op_nodes = main_graph.all_op_nodes()
for op_node in op_nodes:
if op_node.name() in [self.quantized_op_type, "relu"]:
op_node.op()._set_attr("out_threshold", 0.5)
with fluid.scope_guard(scope):
executor.run(program=self.main_program,
feed=self.feeds,
fetch_list=self.fetch_list)
freeze_pass = QuantizationFreezePass(
scope=scope,
place=place,
weight_quantize_type=self.weight_quant_type)
freeze_pass.apply(main_graph)
self.main_program = main_graph.to_program()
outs = self._save_models(executor, self.main_program, scope)
analysis_outputs = self._get_analysis_outputs(
self._get_analysis_config(use_gpu=use_gpu))
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from inference_pass_test import InferencePassTest
import paddle.fluid as fluid
import paddle.fluid.core as core
from paddle.fluid.framework import IrGraph
from paddle.fluid.core import PassVersionChecker
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.core import AnalysisConfig
class QuantDequantTest(InferencePassTest):
def setUp(self):
with fluid.program_guard(self.main_program, self.startup_program):
data = fluid.data(
name="data", shape=[-1, 3, 32, 32], dtype="float32")
param_attr = fluid.ParamAttr(
initializer=fluid.initializer.Constant(value=0.0),
trainable=False)
quantized_op_out = self.append_quantized_op(data, param_attr)
relu_out = fluid.layers.relu(quantized_op_out)
self.set_quant_pattern()
self.feeds = {
"data": np.random.random([1, 3, 32, 32]).astype("float32"),
}
self.enable_trt = True
self.trt_parameters = QuantDequantTest.TensorRTParam(
1 << 30, 32, 0, AnalysisConfig.Precision.Int8, False, False)
self.fetch_list = [relu_out]
def append_quantized_op(self, x, param_attr):
return fluid.layers.conv2d(
input=x,
num_filters=3,
filter_size=3,
param_attr=param_attr,
bias_attr=False,
act=None)
def set_quant_pattern(self):
self.activation_quant_type = 'moving_average_abs_max'
self.weight_quant_type = 'channel_wise_abs_max'
self.quantized_op_type = 'conv2d'
self.channels = 3
def test_check_output(self):
if core.is_compiled_with_cuda():
use_gpu = True
self.check_output_with_option(use_gpu, flatten=True, quant=True)
self.assertTrue(
PassVersionChecker.IsCompatible(
'quant_conv2d_dequant_fuse_pass'))
class QuantFcDequantTest(QuantDequantTest):
def append_quantized_op(self, x, param_attr):
return fluid.layers.fc(x,
size=100,
num_flatten_dims=1,
param_attr=param_attr,
bias_attr=False,
act=None)
def set_quant_pattern(self):
self.activation_quant_type = 'moving_average_abs_max'
self.weight_quant_type = 'abs_max'
self.quantized_op_type = 'mul'
self.channels = 1
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册