diff --git a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py index 1d9f989782962f40f7fc7978a1e0484be137ebc1..b5a3e1a257ef6aa2c2d4a243f75ff962216b000e 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/inference_pass_test.py @@ -28,10 +28,6 @@ from paddle.fluid.core import PaddleDType from paddle.fluid.core import AnalysisConfig from paddle.fluid.core import create_paddle_predictor -from paddle.fluid.framework import IrGraph -from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass - class InferencePassTest(unittest.TestCase): def __init__(self, methodName='runTest'): @@ -58,25 +54,27 @@ class InferencePassTest(unittest.TestCase): def _get_place(self): return set([False, core.is_compiled_with_cuda()]) - def _save_models(self, executor, program, scope): + def _save_models(self, dirname, feeded_var_names, target_vars, executor, + program, scope): with fluid.scope_guard(scope): - outs = executor.run(program=program, - feed=self.feeds, - fetch_list=self.fetch_list, - return_numpy=False) # save models as combined to ensure that # there won't be too many useless files # after finishing a couple of tests. - fluid.io.save_inference_model( - dirname=self.path, - feeded_var_names=list(self.feeds.keys()), - target_vars=self.fetch_list, - executor=executor, - main_program=program) + fluid.io.save_inference_model(dirname, feeded_var_names, + target_vars, executor, program) + def _get_paddle_outs(self, executor, program, scope): + ''' + Return PaddlePaddle outputs. + ''' + with fluid.scope_guard(scope): + outs = executor.run(program=program, + feed=self.feeds, + fetch_list=self.fetch_list, + return_numpy=False) return outs - def _get_analysis_outputs(self, config): + def _get_inference_outs(self, config): ''' Return AnalysisPredictor outputs. ''' @@ -170,113 +168,75 @@ class InferencePassTest(unittest.TestCase): device = "GPU" if use_gpu else "CPU" with fluid.scope_guard(scope): executor.run(self.startup_program) - - if quant: - main_graph = IrGraph( - core.Graph(self.main_program.desc), for_test=True) - - transform_pass = QuantizationTransformPass( - scope=scope, - place=place, - activation_quantize_type=self.activation_quant_type, - weight_quantize_type=self.weight_quant_type, - quantizable_op_type=[ - 'conv2d', 'mul', 'depthwise_conv2d', 'conv2d_transpose' - ]) - transform_pass.apply(main_graph) - weight_scale_map = { - "conv2d": "conv2d_0.w_0.scale", - "mul": "fc_0.w_0.scale" - } - - weight_scale_tensor = scope.var(weight_scale_map[ - self.quantized_op_type]).get_tensor() - weight_scale = np.ones(self.channels).astype("float32") - weight_scale_tensor.set(weight_scale, place) - - op_nodes = main_graph.all_op_nodes() - for op_node in op_nodes: - if op_node.name() in [self.quantized_op_type, "relu"]: - op_node.op()._set_attr("out_threshold", 0.5) - - with fluid.scope_guard(scope): - executor.run(program=self.main_program, - feed=self.feeds, - fetch_list=self.fetch_list) - - freeze_pass = QuantizationFreezePass( - scope=scope, - place=place, - weight_quantize_type=self.weight_quant_type) - freeze_pass.apply(main_graph) - self.main_program = main_graph.to_program() - - outs = self._save_models(executor, self.main_program, scope) - - analysis_outputs = self._get_analysis_outputs( + self._save_models(self.path, + list(self.feeds.keys()), self.fetch_list, executor, + self.main_program, scope) + paddle_outs = self._get_paddle_outs(executor, self.main_program, scope) + inference_outs = self._get_inference_outs( self._get_analysis_config(use_gpu=use_gpu)) # Check whether the results calculated on CPU and on GPU are the same. self.assertTrue( - len(outs) == len(analysis_outputs), + len(paddle_outs) == len(inference_outs), "The number of outputs is different between inference and training forward at {}". format(device)) - for out, analysis_output in zip(outs, analysis_outputs): - out = np.array(out) + for out, inference_out in zip(paddle_outs, inference_outs): + paddle_out = np.array(out) if flatten: - out = out.flatten() - analysis_output = analysis_output.flatten() + paddle_out = paddle_out.flatten() + inference_out = inference_out.flatten() self.assertTrue( np.allclose( - out, analysis_output, atol=atol), + paddle_out, inference_out, atol=atol), "Output has diff between inference and training forward at {} ". format(device)) # Check whether the trt results and the GPU results are the same. if use_gpu and self.enable_trt: - tensorrt_outputs = self._get_analysis_outputs( + tensorrt_outputs = self._get_inference_outs( self._get_analysis_config( use_gpu=use_gpu, use_trt=self.enable_trt)) if self.trt_parameters.use_static: #deserialize - tensorrt_outputs = self._get_analysis_outputs( + tensorrt_outputs = self._get_inference_outs( self._get_analysis_config( use_gpu=use_gpu, use_trt=self.enable_trt)) self.assertTrue( - len(tensorrt_outputs) == len(outs), + len(tensorrt_outputs) == len(paddle_outs), "The number of outputs is different between GPU and TensorRT. ") - for out, tensorrt_output in zip(outs, tensorrt_outputs): - out = np.array(out) + for paddle_out, tensorrt_output in zip(paddle_outs, + tensorrt_outputs): + paddle_out = np.array(paddle_out) if flatten: - out = out.flatten() + paddle_out = paddle_out.flatten() tensorrt_output = tensorrt_output.flatten() self.assertTrue( np.allclose( - out, tensorrt_output, rtol=rtol, atol=atol), + paddle_out, tensorrt_output, rtol=rtol, atol=atol), "Output has diff between GPU and TensorRT. ") # Check whether the mkldnn results and the CPU results are the same. if (not use_gpu) and self.enable_mkldnn: - mkldnn_outputs = self._get_analysis_outputs( + mkldnn_outputs = self._get_inference_outs( self._get_analysis_config( use_gpu=use_gpu, use_mkldnn=self.enable_mkldnn)) self.assertTrue( - len(outs) == len(mkldnn_outputs), + len(paddle_outs) == len(mkldnn_outputs), "The number of outputs is different between CPU and MKLDNN. ") if self.enable_mkldnn_bfloat16: atol = 0.01 - for out, mkldnn_output in zip(outs, mkldnn_outputs): + for paddle_out, mkldnn_output in zip(paddle_outs, mkldnn_outputs): self.assertTrue( np.allclose( - np.array(out), mkldnn_output, atol=atol), + np.array(paddle_out), mkldnn_output, atol=atol), "Output has diff between CPU and MKLDNN. ") class TensorRTParam: diff --git a/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py b/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py new file mode 100644 index 0000000000000000000000000000000000000000..a75911232c50a1f8057e4277c8b7cf7ef816f0fe --- /dev/null +++ b/python/paddle/fluid/tests/unittests/ir/inference/quant_dequant_test.py @@ -0,0 +1,371 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest +import random +import numpy as np +import six +import paddle.fluid as fluid +import paddle +import warnings +from paddle.fluid.framework import IrGraph +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass +from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass +from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass +from paddle.fluid.contrib.slim.quantization import OutScaleForInferencePass +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPass +from paddle.fluid import (core, Program, Variable, program_guard, layers) +from paddle.fluid.io import prepend_feed_ops, append_fetch_ops +from inference_pass_test import InferencePassTest +from paddle.fluid.core import create_paddle_predictor +from paddle.fluid.core import AnalysisConfig + + +class QuantDequantTest(unittest.TestCase): + def __init__(self, methodName='runTest'): + super(QuantDequantTest, self).__init__(methodName) + paddle.enable_static() + self.main_program = fluid.Program() + self.startup_program = fluid.Program() + self.test_main_program = fluid.Program() + self.test_startup_program = fluid.Program() + self.feeds = None + self.fetch_list = None + self.enable_mkldnn = False + self.enable_mkldnn_bfloat16 = False + self.enable_trt = False + self.enable_tensorrt_oss = True + self.trt_parameters = None + self.dynamic_shape_params = None + self.enable_lite = False + self.lite_parameters = None + self.path = "./inference_pass/" + self.__class__.__name__ + "/" + self.data = None + self.label = None + self.result = None + np.random.seed(1) + random.seed(1) + + # from Paddle release2.1 + def _normalize_program(self, program, feed_vars, fetch_vars): + if not isinstance(program, Program): + raise TypeError( + "program type must be `fluid.Program`, but received `%s`" % + type(program)) + if not isinstance(feed_vars, list): + feed_vars = [feed_vars] + if not all(isinstance(v, Variable) for v in feed_vars): + raise TypeError( + "feed_vars type must be a Variable or a list of Variable.") + if not isinstance(fetch_vars, list): + fetch_vars = [fetch_vars] + if not all(isinstance(v, Variable) for v in fetch_vars): + raise TypeError( + "fetch_vars type must be a Variable or a list of Variable.") + + # remind users to set auc_states to 0 if auc op were found. + for op in program.global_block().ops: + # clear device of Op + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName( + ) + op._set_attr(device_attr_name, "") + if op.type == 'auc': + warnings.warn("Be sure that you have set auc states to 0 " + "before saving inference model.") + break + + # serialize program + copy_program = program.clone() + global_block = copy_program.global_block() + remove_op_idx = [] + for i, op in enumerate(global_block.ops): + op.desc.set_is_target(False) + if op.type == "feed" or op.type == "fetch": + remove_op_idx.append(i) + for idx in remove_op_idx[::-1]: + global_block._remove_op(idx) + copy_program.desc.flush() + + feed_var_names = [var.name for var in feed_vars] + copy_program = copy_program._prune_with_input( + feeded_var_names=feed_var_names, targets=fetch_vars) + copy_program = copy_program._inference_optimize(prune_read_op=True) + fetch_var_names = [var.name for var in fetch_vars] + prepend_feed_ops(copy_program, feed_var_names) + append_fetch_ops(copy_program, fetch_var_names) + copy_program.desc._set_version() + return copy_program + + def _save_models(self, dirname, feeded_var_names, target_vars, executor, + program, scope): + with fluid.scope_guard(scope): + fluid.io.save_inference_model(dirname, feeded_var_names, + target_vars, executor, program) + + def _get_paddle_outs(self, feed, fetch_list, executor, program, scope): + ''' + Return PaddlePaddle outputs. + ''' + with fluid.scope_guard(scope): + outs = executor.run(program=program, + feed=feed, + fetch_list=fetch_list, + return_numpy=True) + return outs + + def _get_inference_outs(self, config): + ''' + Return AnalysisPredictor outputs. + ''' + predictor = create_paddle_predictor(config) + tensor_shapes = predictor.get_input_tensor_shape() + names = predictor.get_input_names() + for i, name in enumerate(names): + shape = tensor_shapes[name] + shape[0] = 1 + tensor = predictor.get_input_tensor(name) + feed_data = list(self.feeds.values())[i] + tensor.copy_from_cpu(np.array(feed_data)) + if type(feed_data) == fluid.LoDTensor: + tensor.set_lod(feed_data.lod()) + + predictor.zero_copy_run() + + output_names = predictor.get_output_names() + outs = [ + predictor.get_output_tensor(out_name).copy_to_cpu() + for out_name in output_names + ] + return outs + + def _get_analysis_config(self, + use_gpu=False, + use_trt=False, + use_mkldnn=False): + ''' + Return a new object of AnalysisConfig. + ''' + config = AnalysisConfig(self.path) + config.disable_gpu() + config.switch_specify_input_names(True) + config.switch_ir_optim(True) + config.switch_use_feed_fetch_ops(False) + if use_gpu: + config.enable_use_gpu(100, 0) + if use_trt: + config.enable_tensorrt_engine( + self.trt_parameters.workspace_size, + self.trt_parameters.max_batch_size, + self.trt_parameters.min_subgraph_size, + self.trt_parameters.precision, + self.trt_parameters.use_static, + self.trt_parameters.use_calib_mode) + + if self.dynamic_shape_params: + config.set_trt_dynamic_shape_info( + self.dynamic_shape_params.min_input_shape, + self.dynamic_shape_params.max_input_shape, + self.dynamic_shape_params.optim_input_shape, + self.dynamic_shape_params.disable_trt_plugin_fp16) + if self.enable_tensorrt_oss: + config.enable_tensorrt_oss() + + elif use_mkldnn: + config.enable_mkldnn() + if self.enable_mkldnn_bfloat16: + config.enable_mkldnn_bfloat16() + print('config summary:', config.summary()) + return config + + def check_output_with_option(self, + use_gpu, + atol=1e-5, + flatten=False, + quant=False, + rtol=1e-5): + ''' + Check whether calculating on CPU and GPU, enable TensorRT + or disable TensorRT, enable MKLDNN or disable MKLDNN + are all the same. + ''' + place = fluid.CUDAPlace(0) if use_gpu else fluid.CPUPlace() + executor = fluid.Executor(place) + scope = fluid.Scope() + device = "GPU" if use_gpu else "CPU" + + with fluid.scope_guard(scope): + executor.run(self.startup_program) + executor.run(self.test_startup_program) + main_graph = IrGraph(core.Graph(self.main_program.desc), for_test=False) + test_graph = IrGraph( + core.Graph(self.test_main_program.desc), for_test=True) + + transform_pass = QuantizationTransformPass( + scope=scope, + place=place, + activation_quantize_type=self.activation_quantize_type, + weight_quantize_type=self.weight_quantize_type) + transform_pass.apply(main_graph) + transform_pass.apply(test_graph) + + add_quant_dequant_pass = AddQuantDequantPass(scope=scope, place=place) + add_quant_dequant_pass.apply(main_graph) + add_quant_dequant_pass.apply(test_graph) + + scale_training_pass = OutScaleForTrainingPass(scope=scope, place=place) + scale_training_pass.apply(main_graph) + + build_strategy = fluid.BuildStrategy() + build_strategy.memory_optimize = False + build_strategy.enable_inplace = False + build_strategy.fuse_all_reduce_ops = False + binary = fluid.CompiledProgram(main_graph.graph) + + iters = 10 + batch_size = 1 + train_reader = paddle.batch( + paddle.reader.shuffle( + paddle.dataset.mnist.train(), buf_size=500), + batch_size=batch_size) + feeder = fluid.DataFeeder( + feed_list=[self.data, self.label], place=place) + with fluid.scope_guard(scope): + for _ in range(iters): + data = next(train_reader()) + loss_v = executor.run(binary, + feed=feeder.feed(data), + fetch_list=[self.loss]) + + scale_inference_pass = OutScaleForInferencePass(scope=scope) + scale_inference_pass.apply(test_graph) + + # Freeze graph for inference, but the weight of fc/conv is still float type. + freeze_pass = QuantizationFreezePass( + scope=scope, + place=place, + weight_quantize_type=self.weight_quantize_type) + freeze_pass.apply(test_graph) + + self.main_program = test_graph.to_program() + + with fluid.scope_guard(scope): + self.main_program = self._normalize_program( + self.main_program, self.data, self.fetch_list) + + self._save_models(self.path, + list(self.feeds.keys()), self.fetch_list, executor, + self.main_program, scope) + + paddle_outs = self._get_paddle_outs(self.feeds, self.fetch_list, + executor, self.main_program, scope) + inference_outs = self._get_inference_outs( + self._get_analysis_config(use_gpu=use_gpu)) + + # Check whether the results calculated on CPU and on GPU are the same. + self.assertTrue( + len(paddle_outs) == len(inference_outs), + "The number of outputs is different between inference and training forward at {}". + format(device)) + + for out, inference_out in zip(paddle_outs, inference_outs): + paddle_out = np.array(out) + + if flatten: + paddle_out = paddle_out.flatten() + inference_out = inference_out.flatten() + + self.assertTrue( + np.allclose( + paddle_out, inference_out, atol=atol), + "Output has diff between inference and training forward at {} ". + format(device)) + + # Check whether the trt results and the GPU results are the same. + if use_gpu and self.enable_trt: + tensorrt_outputs = self._get_inference_outs( + self._get_analysis_config( + use_gpu=use_gpu, use_trt=self.enable_trt)) + + if self.trt_parameters.use_static: + #deserialize + tensorrt_outputs = self._get_inference_outs( + self._get_analysis_config( + use_gpu=use_gpu, use_trt=self.enable_trt)) + + self.assertTrue( + len(tensorrt_outputs) == len(paddle_outs), + "The number of outputs is different between GPU and TensorRT. ") + + for paddle_out, tensorrt_output in zip(paddle_outs, + tensorrt_outputs): + paddle_out = np.array(paddle_out) + + if flatten: + paddle_out = paddle_out.flatten() + tensorrt_output = tensorrt_output.flatten() + + self.assertTrue( + np.allclose( + paddle_out, tensorrt_output, rtol=rtol, atol=atol), + "Output has diff between GPU and TensorRT. ") + + # Check whether the mkldnn results and the CPU results are the same. + if (not use_gpu) and self.enable_mkldnn: + mkldnn_outputs = self._get_inference_outs( + self._get_analysis_config( + use_gpu=use_gpu, use_mkldnn=self.enable_mkldnn)) + + self.assertTrue( + len(paddle_outs) == len(mkldnn_outputs), + "The number of outputs is different between CPU and MKLDNN. ") + + if self.enable_mkldnn_bfloat16: + atol = 0.01 + for paddle_out, mkldnn_output in zip(paddle_outs, mkldnn_outputs): + self.assertTrue( + np.allclose( + np.array(paddle_out), mkldnn_output, atol=atol), + "Output has diff between CPU and MKLDNN. ") + + class TensorRTParam: + ''' + Prepare TensorRT subgraph engine parameters. + ''' + + def __init__(self, workspace_size, max_batch_size, min_subgraph_size, + precision, use_static, use_calib_mode): + self.workspace_size = workspace_size + self.max_batch_size = max_batch_size + self.min_subgraph_size = min_subgraph_size + self.precision = precision + self.use_static = use_static + self.use_calib_mode = use_calib_mode + + class DynamicShapeParam: + ''' + Prepare TensorRT subgraph engine dynamic shape parameters. + ''' + + def __init__(self, min_input_shape, max_input_shape, optim_input_shape, + disable_trt_plugin_fp16): + self.min_input_shape = min_input_shape + self.max_input_shape = max_input_shape + self.optim_input_shape = optim_input_shape + self.disable_trt_plugin_fp16 = disable_trt_plugin_fp16 + + def quant_dequant(self): + place = fluid.CPUPlace() + exe = fluid.Executor(place) + scope = fluid.Scope() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py index cde2fa412d7050c19cdf4e185b8e5307c40021e3..7adfb7574825d09d5b5636d965d81fda6fd85acb 100644 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py +++ b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_fc_fuse_pass.py @@ -17,9 +17,11 @@ from __future__ import print_function import unittest import numpy as np from inference_pass_test import InferencePassTest +from quant_dequant_test import QuantDequantTest import paddle.fluid as fluid import paddle.fluid.core as core from paddle.fluid.core import AnalysisConfig +from paddle.fluid.core import PassVersionChecker class FCFusePassTRTTest(InferencePassTest): @@ -283,5 +285,55 @@ class FCFusePassTRTDynamicDims4Cols3Test(InferencePassTest): self.check_output_with_option(use_gpu[i]) +class FcQuantDequantFusePassTRTTest(QuantDequantTest): + def setUp(self): + def network(): + self.data = fluid.data( + name='data', shape=[1, 28, 28], dtype='float32') + self.label = fluid.data(name='label', shape=[1, 1], dtype='int64') + fc_out = fluid.layers.fc(input=self.data, + size=10, + num_flatten_dims=1, + bias_attr=False, + act=None) + result = fluid.layers.relu(fc_out) + loss = fluid.layers.cross_entropy(input=result, label=self.label) + avg_loss = fluid.layers.mean(loss) + return avg_loss, result + + self.main_program.random_seed = 2 + self.startup_program.random_seed = 2 + self.test_main_program.random_seed = 2 + #self.test_startup_program.random_seed = 2 + with fluid.unique_name.guard(): + with fluid.program_guard(self.main_program, self.startup_program): + self.loss, result = network() + opt = fluid.optimizer.Adam(learning_rate=0.0001) + opt.minimize(self.loss) + with fluid.unique_name.guard(): + with fluid.program_guard(self.test_main_program, + self.startup_program): + network() + self.feeds = {"data": np.random.random((1, 28, 28)).astype("float32")} + self.fetch_list = [result] + + self.enable_trt = True + + self.trt_parameters = FcQuantDequantFusePassTRTTest.TensorRTParam( + 1 << 30, 32, 0, AnalysisConfig.Precision.Int8, False, False) + self.activation_quantize_type = 'moving_average_abs_max' + self.weight_quantize_type = 'channel_wise_abs_max' + + def test_check_output(self): + #self.quant_dequant() + if core.is_compiled_with_cuda(): + use_gpu = True + self.check_output_with_option( + use_gpu, atol=1e-2, flatten=False, rtol=1e-2) + self.assertTrue( + PassVersionChecker.IsCompatible( + 'quant_conv2d_dequant_fuse_pass')) + + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_quant_conv2d_dequant_fuse_pass.py b/python/paddle/fluid/tests/unittests/ir/inference/test_trt_quant_conv2d_dequant_fuse_pass.py deleted file mode 100644 index 8d6f9a23af3fa5f713d6225c0f1ea9b559dc5e71..0000000000000000000000000000000000000000 --- a/python/paddle/fluid/tests/unittests/ir/inference/test_trt_quant_conv2d_dequant_fuse_pass.py +++ /dev/null @@ -1,88 +0,0 @@ -# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import unittest -import numpy as np -from inference_pass_test import InferencePassTest -import paddle.fluid as fluid -import paddle.fluid.core as core -from paddle.fluid.framework import IrGraph -from paddle.fluid.core import PassVersionChecker -from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass -from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass -from paddle.fluid.core import AnalysisConfig - - -class QuantDequantTest(InferencePassTest): - def setUp(self): - with fluid.program_guard(self.main_program, self.startup_program): - data = fluid.data( - name="data", shape=[-1, 3, 32, 32], dtype="float32") - param_attr = fluid.ParamAttr( - initializer=fluid.initializer.Constant(value=0.0), - trainable=False) - quantized_op_out = self.append_quantized_op(data, param_attr) - relu_out = fluid.layers.relu(quantized_op_out) - self.set_quant_pattern() - - self.feeds = { - "data": np.random.random([1, 3, 32, 32]).astype("float32"), - } - self.enable_trt = True - self.trt_parameters = QuantDequantTest.TensorRTParam( - 1 << 30, 32, 0, AnalysisConfig.Precision.Int8, False, False) - self.fetch_list = [relu_out] - - def append_quantized_op(self, x, param_attr): - return fluid.layers.conv2d( - input=x, - num_filters=3, - filter_size=3, - param_attr=param_attr, - bias_attr=False, - act=None) - - def set_quant_pattern(self): - self.activation_quant_type = 'moving_average_abs_max' - self.weight_quant_type = 'channel_wise_abs_max' - self.quantized_op_type = 'conv2d' - self.channels = 3 - - def test_check_output(self): - if core.is_compiled_with_cuda(): - use_gpu = True - self.check_output_with_option(use_gpu, flatten=True, quant=True) - self.assertTrue( - PassVersionChecker.IsCompatible( - 'quant_conv2d_dequant_fuse_pass')) - - -class QuantFcDequantTest(QuantDequantTest): - def append_quantized_op(self, x, param_attr): - return fluid.layers.fc(x, - size=100, - num_flatten_dims=1, - param_attr=param_attr, - bias_attr=False, - act=None) - - def set_quant_pattern(self): - self.activation_quant_type = 'moving_average_abs_max' - self.weight_quant_type = 'abs_max' - self.quantized_op_type = 'mul' - self.channels = 1 - - -if __name__ == "__main__": - unittest.main()