test_fake_quantize_op.py 12.7 KB
Newer Older
视言's avatar
视言 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15 16
from __future__ import print_function

视言's avatar
视言 已提交
17 18
import unittest
import numpy as np
19
from op_test import OpTest
20
import paddle.fluid.core as core
视言's avatar
视言 已提交
21 22 23 24


class TestFakeQuantizeOp(OpTest):
    def setUp(self):
25 26 27 28 29 30 31 32 33 34 35 36
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        self.outputs = {
            'Out': np.round(self.inputs['X'] / scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()
Z
Zhen Wang 已提交
37 38


39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72
class TestFakeQuantizeOp1(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.zeros((10, 10)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


class TestFakeQuantizeOp2(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.full((10, 10), 1e-40).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        inv_scale = 1.0 / (scale + 1e-6) if scale < 1e-30 else 1.0 / scale
        self.outputs = {
            'Out': np.round(self.inputs['X'] * inv_scale * (
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
            'OutScale': np.array(scale).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()


Z
Zhen Wang 已提交
73 74
class TestFakeChannelWiseQuantizeOp(OpTest):
    def setUp(self):
75 76 77
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

Z
Zhen Wang 已提交
78
        self.op_type = "fake_channel_wise_quantize_abs_max"
79 80
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

Z
Zhen Wang 已提交
81 82
        scales = []
        outputs = self.inputs['X'].copy()
83 84 85 86 87 88 89 90 91 92 93 94
        bnt = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
                scales.append(scale_v)
                outputs[i] = np.round(outputs[i] / scale_v * bnt)
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
                    "float32")
                scales.append(scale_v)
                outputs[:, i] = np.round(outputs[:, i] / scale_v * bnt)
Z
Zhen Wang 已提交
95 96 97

        self.outputs = {
            'Out': outputs,
98
            'OutScale': np.array(scales).astype("float32"),
Z
Zhen Wang 已提交
99 100
        }

101 102 103 104 105 106
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
            'X': np.random.random((20, 15, 6, 6)).astype("float32"),
        }

Z
Zhen Wang 已提交
107 108
    def test_check_output(self):
        self.check_output()
109 110


111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130
class TestFakeChannelWiseQuantizeOp1(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
        self.inputs = {
            'X': np.random.random((15, 20, 5, 5)).astype("float32"),
        }


class TestFakeChannelWiseQuantizeOp2(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 0
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


class TestFakeChannelWiseQuantizeOp3(TestFakeChannelWiseQuantizeOp):
    def set_quant_axis(self):
        self.quant_axis = 1
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


131
class TestFakeQuantizeRangeAbsMaxOp(OpTest):
132 133
    def setUp(self):
        self.op_type = "fake_quantize_range_abs_max"
视言's avatar
视言 已提交
134
        self.attrs = {
135 136 137
            'bit_length': int(5),
            'window_size': int(1),
            'is_test': False
视言's avatar
视言 已提交
138
        }
139 140
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
        x = x.astype("float32")
视言's avatar
视言 已提交
141
        self.inputs = {
142
            'X': x,
143 144
            'Iter': np.zeros(1).astype("int64"),
            'InScale': np.zeros(1).astype("float32")
视言's avatar
视言 已提交
145
        }
146
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
147

148 149
        out_scales = np.zeros(self.attrs['window_size']).astype("float32")
        out_scales[0] = scale
视言's avatar
视言 已提交
150
        self.outputs = {
151
            'Out': np.round(self.inputs['X'] / scale * (
视言's avatar
视言 已提交
152
                (1 << (self.attrs['bit_length'] - 1)) - 1)),
153 154
            'OutScale': scale,
            'OutScales': out_scales,
视言's avatar
视言 已提交
155 156 157 158 159 160
        }

    def test_check_output(self):
        self.check_output()


Z
Zhen Wang 已提交
161 162 163 164 165 166 167 168
class TestMovingAverageAbsMaxScaleOp(OpTest):
    def setUp(self):
        self.op_type = "moving_average_abs_max_scale"
        self.attrs = {'moving_rate': float(0.9), 'is_test': False}
        accum = np.zeros(1).astype("float32")
        accum[0] = 1
        state = np.zeros(1).astype("float32")
        state[0] = 1
169
        x = np.random.random((8, 16, 7, 7)).astype("float32")
Z
Zhen Wang 已提交
170
        self.inputs = {
171
            'X': x,
Z
Zhen Wang 已提交
172 173 174 175
            'InAccum': accum,
            'InState': state,
        }

176
        out = x
Z
Zhen Wang 已提交
177 178 179 180 181 182 183 184
        out_accum = np.zeros(1).astype("float32")
        out_state = np.zeros(1).astype("float32")
        out_scale = np.zeros(1).astype("float32")
        out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
            np.abs(self.inputs['X'])).astype("float32")
        out_state[0] = self.attrs['moving_rate'] * state[0] + 1
        out_scale = out_accum / out_state
        self.outputs = {
185
            'Out': out,
Z
Zhen Wang 已提交
186 187 188 189 190 191 192 193 194
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

    def test_check_output(self):
        self.check_output()


195 196 197 198 199 200 201 202 203 204
class TestFakeQuantizeRangeAbsMaxOp2(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_range_abs_max"
        self.attrs = {
            'bit_length': int(8),
            'window_size': int(1),
            'is_test': True
        }
        x = (np.random.random((8, 16, 7, 7)) - 0.5) * 10
        x = x.astype("float32")
205
        scale = np.array([np.max(np.abs(x)).astype("float32") - 1.0])
206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224
        out_scales = np.zeros(self.attrs['window_size']).astype("float32")
        out_scales[0] = scale
        self.inputs = {
            'X': x,
            'Iter': np.zeros(1).astype("int64"),
            'InScale': scale.astype("float32")
        }
        xs = np.clip(x, -scale, scale)
        qs = np.round(xs / scale * ((1 << (self.attrs['bit_length'] - 1)) - 1))
        self.outputs = {
            'Out': qs,
            'OutScale': scale.astype("float32"),
            'OutScales': out_scales,
        }

    def test_check_output(self):
        self.check_output(no_check_set=set(['OutScale', 'OutScales']))


225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280
class TestMovingOpBase(OpTest):
    def setUp(self):
        self.init_type()
        self.attrs = {
            'bit_length': int(5),
            'moving_rate': float(0.9),
            'is_test': False
        }
        accum = np.zeros(1).astype("float32")
        accum[0] = 1
        state = np.zeros(1).astype("float32")
        state[0] = 1
        scale = np.zeros(1).astype("float32")
        scale[0] = 0.001
        self.inputs = {
            'X': np.random.random((8, 16, 7, 7)).astype("float32"),
            'InScale': scale,
            'InAccum': accum,
            'InState': state,
        }

        out_accum = np.zeros(1).astype("float32")
        out_state = np.zeros(1).astype("float32")
        out_scale = np.zeros(1).astype("float32")
        out_accum[0] = self.attrs['moving_rate'] * accum[0] + np.max(
            np.abs(self.inputs['X'])).astype("float32")
        out_state[0] = self.attrs['moving_rate'] * state[0] + 1
        out_scale = out_accum / out_state
        out_data = self.calc_output(out_scale)
        self.outputs = {
            'Out': out_data,
            'OutAccum': out_accum,
            'OutState': out_state,
            'OutScale': out_scale,
        }

    def init_type(self):
        self.op_type = "fake_quantize_moving_average_abs_max"

    def calc_output(self, out_scale):
        return np.round(self.inputs['X'] / out_scale * (
            (1 << (self.attrs['bit_length'] - 1)) - 1))

    def test_check_output(self):
        self.check_output()


class TestFakeQuantDequantMovingOp(TestMovingOpBase):
    def init_type(self):
        self.op_type = "fake_quantize_dequantize_moving_average_abs_max"

    def calc_output(self, out_scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / out_scale *
                        range_v) * out_scale / range_v

281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310
    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestFakeQuantDequantAbsOp(OpTest):
    def setUp(self):
        self.op_type = "fake_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8}
        self.inputs = {'X': np.random.random((124, 240)).astype("float32"), }
        scale = np.max(np.abs(self.inputs['X'])).astype("float32")
        out_data = self.calc_output(scale)
        self.outputs = {
            'Out': out_data,
            'OutScale': np.array(scale).astype("float32"),
        }

    def calc_output(self, scale):
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        return np.round(self.inputs['X'] / scale * range_v) * scale / range_v

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)

311

H
huangxu96 已提交
312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376
class TestChannelWiseFakeQuantDequantOp(OpTest):
    def setUp(self):
        self.set_arg()
        assert self.quant_axis in [0, 1], "quant_axis should be 0 or 1."

        self.op_type = "fake_channel_wise_quantize_dequantize_abs_max"
        self.attrs = {'bit_length': 8, 'quant_axis': self.quant_axis}

        scales = []
        outputs = self.inputs['X'].copy()
        range_v = (1 << (self.attrs['bit_length'] - 1)) - 1
        if self.quant_axis == 0:
            for i in range(self.inputs['X'].shape[0]):
                scale_v = np.max(np.abs(self.inputs['X'][i])).astype("float32")
                scales.append(scale_v)
                outputs[i] = np.round(outputs[i] * range_v /
                                      scale_v) * scale_v / range_v
        elif self.quant_axis == 1:
            for i in range(self.inputs['X'].shape[1]):
                scale_v = np.max(np.abs(self.inputs['X'][:, i])).astype(
                    "float32")
                scales.append(scale_v)
                outputs[:, i] = np.round(outputs[:, i] * range_v /
                                         scale_v) * scale_v / range_v

        self.outputs = {
            'Out': outputs,
            'OutScale': np.array(scales).astype("float32"),
        }

    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {
            'X': np.random.random((3, 4, 64, 64)).astype("float32"),
        }

    def test_check_output(self):
        self.check_output()

    def test_check_grad(self):
        x = self.inputs["X"]
        gradient = [np.ones(x.shape) / np.product(x.shape)]
        self.check_grad(["X"], "Out", user_defined_grads=gradient)


class TestChannelWiseFakeQuantDequantOp1(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {
            'X': np.random.random((15, 20, 5, 5)).astype("float32"),
        }


class TestChannelWiseFakeQuantDequantOp2(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 0
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


class TestChannelWiseFakeQuantDequantOp3(TestChannelWiseFakeQuantDequantOp):
    def set_arg(self):
        self.quant_axis = 1
        self.inputs = {'X': np.random.random((30, 15)).astype("float32"), }


视言's avatar
视言 已提交
377 378
if __name__ == "__main__":
    unittest.main()