# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import print_function import unittest import itertools import numpy as np import math from op_test import OpTest # numpy.round has different behavior in comparision to c++ round function # so we use round_c instead of numpy.round to align the output data def round_c_single_element(val): dtype = type(val) if val >= 0: return dtype(np.floor(val + 0.5)) return dtype(np.ceil(val - 0.5)) round_c = np.vectorize(round_c_single_element) def get_compute_type(dtype): assert dtype in [np.float16, np.float32, np.float64] if dtype == np.float16: return np.float32 return dtype class TestFakeQuantizeAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_quantize_abs_max' self.attrs = {'bit_length': 8} def _fake_quantize_abs_max(self, dtype, input_shape, distribution): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) scale = np.max(np.abs(input_data)) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale output_data = round_c(input_data.astype(compute_type) * inv_scale * bnt) self.inputs = {'X': input_data} self.outputs = {'Out': output_data, 'OutScale': scale} self.dtype = dtype self.check_output() def test_fake_quantize_abs_max(self): self._fake_quantize_abs_max(np.float32, (124, 240), np.random.random) def test_fake_quantize_abs_max_float16(self): self._fake_quantize_abs_max(np.float16, (124, 240), np.random.random) def test_fake_quantize_abs_max_underflow(self): self._fake_quantize_abs_max(np.float32, (10, 10), np.zeros) def test_fake_quantize_abs_max_underflow2(self): self._fake_quantize_abs_max(np.float32, (10, 10), lambda shape: np.full(shape, 1e-40)) class TestFakeChannelWiseQuantizeAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_channel_wise_quantize_abs_max' self.attrs = {'bit_length': 8} def _fake_channel_wise_quantize_abs_max(self, dtype, input_shape, quant_axis, distribution): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 compute_axis = tuple( i for i in range(len(input_shape)) if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) output_data = round_c(bnt * input_data.astype(compute_type) / scale_broadcast) if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0] self.inputs = {'X': input_data} self.outputs = {'Out': output_data, 'OutScale': scale} self.dtype = dtype self.attrs['quant_axis'] = quant_axis self.check_output() def test_fake_channel_wise_quantize_abs_max(self): dtype_options = [np.float32, np.float16] input_shape_quant_axis_options = [[(20, 15, 6, 6), 0], [(15, 20, 5, 5), 1], [(30, 15), 0], [(30, 15), 1]] for dtype, input_shape_quant_axis in itertools.product( dtype_options, input_shape_quant_axis_options): input_shape, quant_axis = input_shape_quant_axis with self.subTest( dtype=dtype, input_shape=input_shape, quant_axis=quant_axis): self._fake_channel_wise_quantize_abs_max( dtype, input_shape, quant_axis, np.random.random) class TestFakeQuantizeRangeAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_quantize_range_abs_max' self.attrs = {'bit_length': 5, 'window_size': 1} def _fake_quantize_range_abs_max(self, dtype, input_shape, distribution, is_test=False): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 in_scale = np.zeros(1).astype(dtype) out_scale = np.zeros(self.attrs['window_size']).astype(dtype) out_scale[0] = np.max(np.abs(input_data)) if is_test: out_scale[0] = in_scale[0] = out_scale[0] - 1.0 clip_data = np.clip(input_data, -in_scale, in_scale) else: clip_data = input_data output_data = round_c( clip_data.astype(compute_type) / out_scale[0] * bnt) self.inputs = { 'X': input_data, 'Iter': np.zeros(1).astype(np.int64), 'InScale': in_scale } self.outputs = { 'Out': output_data, 'OutScale': out_scale[0], 'OutScales': out_scale } self.dtype = dtype self.attrs['is_test'] = is_test self.check_output() def test_fake_quantize_range_abs_max(self): dtype_options = [np.float32, np.float16] is_test_options = [False, True] for dtype, is_test in itertools.product(dtype_options, is_test_options): self.attrs['bit_length'] = 8 if is_test else 5 with self.subTest(dtype=dtype, is_test=is_test): self._fake_quantize_range_abs_max( dtype, (8, 16, 7, 7), lambda shape: (np.random.random(shape) - 0.5) * 10, is_test=is_test) class TestMovingAverageAbsMaxScaleOp(OpTest): def setUp(self): self.op_type = 'moving_average_abs_max_scale' self.attrs = {'moving_rate': float(0.9), 'is_test': False} def _moving_average_abs_max_scale(self, dtype, input_shape, distribution): input_data = distribution(input_shape).astype(dtype) in_accum = np.ones(1).astype(dtype) in_state = np.ones(1).astype(dtype) out_accum = self.attrs['moving_rate'] * in_accum[0] + np.max( np.abs(input_data)) out_state = self.attrs['moving_rate'] * in_state[0] + 1.0 out_scale = out_accum / out_state self.inputs = { 'X': input_data, 'InAccum': in_accum, 'InState': in_state } self.outputs = { 'Out': input_data, 'OutAccum': out_accum, 'OutState': out_state, 'OutScale': out_scale } self.dtype = dtype self.check_output() def test_moving_average_abs_max(self): self._moving_average_abs_max_scale(np.float32, (8, 16, 7, 7), np.random.random) class TestFakeQuantizeMovingAverageAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_quantize_moving_average_abs_max' self.attrs = {'bit_length': 5, 'moving_rate': 0.9, 'is_test': False} def _fake_quantize_moving_average_abs_max(self, dtype, input_shape, distribution, dequantize=False, with_gradient=False): input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 in_accum = np.ones(1).astype(dtype) in_state = np.ones(1).astype(dtype) in_scale = np.array([0.001]).astype(dtype) out_accum = np.zeros(1).astype(dtype) out_state = np.zeros(1).astype(dtype) out_scale = np.zeros(1).astype(dtype) out_accum[0] = self.attrs['moving_rate'] * in_accum[0] + np.max( np.abs(input_data)) out_state[0] = self.attrs['moving_rate'] * in_state[0] + 1.0 out_scale = out_accum / out_state round_data = round_c(input_data.astype(compute_type) / out_scale * bnt) if dequantize: output_data = (round_data * out_scale / bnt).astype(dtype) self.op_type = 'fake_quantize_dequantize_moving_average_abs_max' else: output_data = round_data.astype(dtype) self.inputs = { 'X': input_data, 'InScale': in_scale, 'InAccum': in_accum, 'InState': in_state } self.outputs = { 'Out': output_data, 'OutAccum': out_accum, 'OutState': out_state, 'OutScale': out_scale } self.dtype = dtype self.check_output() if with_gradient: gradient = [ np.ones(input_data.shape) / np.product(input_data.shape) ] self.check_grad(['X'], 'Out', user_defined_grads=gradient) def test_fake_quantize_moving_average_abs_max(self): self._fake_quantize_moving_average_abs_max(np.float32, (8, 16, 7, 7), np.random.random) def test_fake_quantize_moving_average_abs_max_float16(self): self._fake_quantize_moving_average_abs_max(np.float16, (8, 16, 7, 7), np.random.random) def test_fake_quantize_dequantize_moving_average_abs_max(self): self._fake_quantize_moving_average_abs_max( np.float32, (8, 16, 7, 7), np.random.random, dequantize=True, with_gradient=True) class TestFakeQuantizeDequantizeAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} def _fake_quantize_dequantize_abs_max(self, dtype, input_shape, distribution): input_data = distribution(input_shape).astype(dtype) scale = np.max(np.abs(input_data)).astype(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 output_data = round_c(input_data / scale * bnt) * scale / bnt self.inputs = {'X': input_data} self.outputs = { 'Out': output_data, 'OutScale': np.array(scale).astype(dtype) } self.dtype = dtype self.check_output() gradient = [np.ones(input_data.shape) / np.product(input_data.shape)] self.check_grad(['X'], 'Out', user_defined_grads=gradient) def test_fake_quantize_dequantize_abs_max(self): self._fake_quantize_dequantize_abs_max(np.float32, (124, 240), np.random.random) class TestChannelWiseFakeQuantizeDequantizeAbsMaxOp(OpTest): def setUp(self): self.op_type = 'fake_channel_wise_quantize_dequantize_abs_max' self.attrs = {'bit_length': 8} def _fake_channel_wise_quantize_dequantize_abs_max( self, dtype, input_shape, quant_axis, distribution): assert quant_axis in [0, 1], 'quant_axis should be 0 or 1.' input_data = distribution(input_shape).astype(dtype) compute_type = get_compute_type(dtype) bnt = (1 << (self.attrs['bit_length'] - 1)) - 1 output_data = input_data.copy().astype(compute_type) compute_axis = tuple( i for i in range(len(input_shape)) if i != quant_axis) scale_broadcast = np.amax(input_data, axis=compute_axis, keepdims=True) output_data = round_c(bnt * output_data / scale_broadcast) * scale_broadcast / bnt if quant_axis == 1: scale_broadcast = np.transpose(scale_broadcast, (1, ) + compute_axis) scale = scale_broadcast.reshape(input_shape[quant_axis], -1)[:, 0] self.inputs = {'X': input_data} self.outputs = {'Out': output_data, 'OutScale': scale} self.dtype = dtype self.attrs['quant_axis'] = quant_axis self.check_output() gradient = [np.ones(input_data.shape) / np.product(input_data.shape)] self.check_grad(['X'], 'Out', user_defined_grads=gradient) def test_channel_wise_fake_quant_dequant_abs_max(self): input_shape_quant_axis_options = [[(3, 4, 64, 64), 0], [(15, 20, 5, 5), 1], [(30, 15), 0], [(30, 15), 1]] for input_shape, quant_axis in input_shape_quant_axis_options: with self.subTest(input_shape=input_shape, quant_axis=quant_axis): self._fake_channel_wise_quantize_dequantize_abs_max( np.float32, input_shape, quant_axis, np.random.random) def quantize_max_abs(x, max_range): scale = np.max(np.abs(x).flatten()) y = np.round(x / scale * max_range) return y, scale def channel_wise_quantize_max_abs(x, quant_bit=8, quant_axis=0): assert quant_axis in [0, 1], "The quant_axis should be 0 or 1." scales = [] y = x.copy() max_range = math.pow(2, quant_bit - 1) - 1 if quant_axis == 0: for i in range(x.shape[0]): scale = np.max(np.abs(x[i])).astype("float32") scales.append(scale) y[i] = np.round(x[i] * max_range / scale) elif quant_axis == 1: for i in range(x.shape[1]): scale = np.max(np.abs(x[:, i])).astype("float32") scales.append(scale) y[:, i] = np.round(x[:, i] * max_range / scale) return y, scales class TestChannelWiseQuantizeOp(OpTest): def set_args(self): self.bit_length = 8 self.data_type = "float32" self.quant_axis = 0 def setUp(self): self.set_args() self.op_type = "quantize_linear" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, self.quant_axis) scale = np.array(scale).astype(self.data_type) zero_point = np.zeros(scale.shape, dtype="int32") self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} self.attrs = { 'bit_length': self.bit_length, 'quant_axis': self.quant_axis } self.outputs = {'Y': yq} def test_check_output(self): self.check_output() class TestChannelWiseQuantizeOp1(TestChannelWiseQuantizeOp): def set_args(self): self.bit_length = 8 self.data_type = "float32" self.quant_axis = 1 class TestChannelWiseQuantizeOpTrain(OpTest): def set_args(self): self.bit_length = 8 self.data_type = "float32" self.quant_axis = 0 self.is_test = False def setUp(self): self.set_args() self.op_type = "quantize_linear" x = np.random.randn(4, 3, 64, 64).astype(self.data_type) yq, scale = channel_wise_quantize_max_abs(x, self.bit_length, self.quant_axis) scale = np.array(scale).astype(self.data_type) zero_point = np.zeros(scale.shape, dtype="int32") self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} self.attrs = { 'bit_length': self.bit_length, 'quant_axis': self.quant_axis, 'is_test': self.is_test } self.outputs = {'Y': yq, 'OutScale': scale} def test_check_output(self): self.check_output() class TestquantizeOp(OpTest): def set_args(self): self.bit_length = 8 self.quant_axis = -1 self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" def setUp(self): self.set_args() self.op_type = "quantize_linear" x = np.random.randn(31, 65).astype(self.data_type) yq, scale = quantize_max_abs(x, self.max_range) scale = np.array(scale).astype(self.data_type) zero_point = np.zeros(scale.shape, dtype="int32") self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} self.attrs = { 'bit_length': self.bit_length, 'quant_axis': self.quant_axis, } self.outputs = {'Y': yq} def test_check_output(self): self.check_output() class TestquantizeOpTrain(TestquantizeOp): def set_args(self): self.bit_length = 8 self.quant_axis = -1 self.max_range = math.pow(2, self.bit_length - 1) - 1 self.data_type = "float32" self.is_test = False def setUp(self): self.set_args() self.op_type = "quantize_linear" x = np.random.randn(31, 65).astype(self.data_type) yq, scale = quantize_max_abs(x, self.max_range) scale = np.array(scale).astype(self.data_type) zero_point = np.zeros(scale.shape, dtype="int32") self.inputs = {'X': x, 'Scale': scale, 'ZeroPoint': zero_point} self.attrs = { 'bit_length': self.bit_length, 'quant_axis': self.quant_axis, 'is_test': self.is_test } self.outputs = {'Y': yq, 'OutScale': scale} def test_check_output(self): self.check_output() if __name__ == '__main__': unittest.main()