quant_layers.py 28.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Z
zhiboniu 已提交
15
from paddle.framework import core
16
from paddle.fluid import dygraph_utils
Z
zhiboniu 已提交
17 18
from paddle.utils import unique_name
from paddle.framework import ParamAttr
19
from paddle.fluid.framework import _varbase_creator
Z
zhiboniu 已提交
20
from paddle.nn.initializer import Constant
21
from paddle.fluid.data_feeder import check_variable_and_dtype
H
huangxu96 已提交
22
from paddle.nn import functional as F
23 24
import logging
from paddle.fluid.log_helper import get_logger
W
wanghuancoder 已提交
25
from paddle import _C_ops
Z
zhiboniu 已提交
26 27
from paddle import in_dynamic_mode
from paddle.nn import Layer
28 29

__all__ = [
30
    'FakeQuantAbsMax',
31
    'FakeQuantMovingAverageAbsMax',
32 33
    'FakeQuantChannelWiseAbsMax',
    'QuantizedConv2D',
34
    'QuantizedConv2DTranspose',
35 36 37 38
    'QuantizedLinear',
    'MovingAverageAbsMaxScale',
    'MAOutputScaleLayer',
    'FakeQuantMAOutputScaleLayer',
39
    'QuantStub',
40 41
]

42 43 44
_logger = get_logger(__name__,
                     logging.INFO,
                     fmt='%(asctime)s-%(levelname)s: %(message)s')
45

46

Z
zhiboniu 已提交
47
class FakeQuantAbsMax(Layer):
48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68
    r"""
    FakeQuantAbsMax layer does the abs_max quant and then dequant.
    Its computational formula is described as below:

    :math:`scale = max(abs(X))`
    :math:`range = 2^{bit\_length - 1} - 1`
    :math:`Out = round(X / scale * range) * scale / range`
    """

    def __init__(self,
                 name=None,
                 quant_bits=8,
                 dtype='float32',
                 quant_on_weight=False):
        super(FakeQuantAbsMax, self).__init__()
        self._quant_bits = quant_bits
        self._name = name
        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
        self._scale_name = unique_name.generate(scale_prefix)
        if quant_on_weight:
69 70 71 72 73 74
            scale_attr = ParamAttr(name=self._scale_name,
                                   initializer=Constant(0.001),
                                   trainable=False)
            self._scale = self.create_parameter(shape=[1],
                                                attr=scale_attr,
                                                dtype=self._dtype)
75 76 77 78 79
            self._scale.stop_gradient = True
        else:
            self._scale = None

    def forward(self, input):
Z
zhiboniu 已提交
80
        if in_dynamic_mode():
81
            attrs = ('bit_length', self._quant_bits)
82 83 84 85 86 87
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
88 89 90 91 92 93 94 95 96
            out_scale = self._scale
            if not out_scale:
                out_scale = _varbase_creator(
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    name=self._scale_name,
                    shape=[1],
                    dtype=self._dtype,
                    persistable=False)
                out_scale.stop_gradient = True
97 98
            out, _, = _C_ops.fake_quantize_dequantize_abs_max(
                input, quant_out, out_scale, *attrs)
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119
            return out

        check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax")
        attrs = {'bit_length': self._quant_bits}
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        out_scale = self._scale
        if not out_scale:
            out_scale = self._helper.create_variable(
                name=self._scale_name,
                dtype=self._dtype,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=True)
        outputs = {"Out": [quant_out], "OutScale": [out_scale]}

120 121 122 123
        self._helper.append_op(type="fake_quantize_dequantize_abs_max",
                               inputs=inputs,
                               outputs=outputs,
                               attrs=attrs)
124 125 126 127

        return quant_out


Z
zhiboniu 已提交
128
class FakeQuantMovingAverageAbsMax(Layer):
129
    r"""
130
    FakeQuantMovingAverageAbsMax layer does the moving_average_abs_max quant and then dequant.
131 132 133 134 135 136 137 138 139 140 141 142
    Its computational formula is described as below:

    :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
    :math:`range = 2^{bit\_length - 1} - 1`
    :math:`Out = round(X / scale * range) * scale / range`
    """

    def __init__(self,
                 name=None,
                 moving_rate=0.9,
                 quant_bits=8,
                 dtype='float32'):
143
        super(FakeQuantMovingAverageAbsMax, self).__init__()
144 145 146 147 148
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits

        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
149 150 151 152 153 154
        scale_attr = ParamAttr(name=unique_name.generate(scale_prefix),
                               initializer=Constant(0.001),
                               trainable=False)
        self._scale = self.create_parameter(shape=[1],
                                            attr=scale_attr,
                                            dtype=dtype)
155 156 157 158
        self._scale.stop_gradient = True

        state_prefix = "{}.state".format(
            name) if name else 'quant_dequant.state'
159 160 161 162 163 164
        state_attr = ParamAttr(name=unique_name.generate(state_prefix),
                               initializer=Constant(1),
                               trainable=False)
        self._state = self.create_parameter(shape=[1],
                                            attr=state_attr,
                                            dtype=dtype)
165 166 167 168
        self._state.stop_gradient = True

        accum_prefix = "{}.accum".format(
            name) if name else 'quant_dequant.accum'
169 170 171 172 173 174
        accum_attr = ParamAttr(name=unique_name.generate(accum_prefix),
                               initializer=Constant(1),
                               trainable=False)
        self._accum = self.create_parameter(shape=[1],
                                            attr=accum_attr,
                                            dtype=dtype)
175 176 177
        self._accum.stop_gradient = True

    def forward(self, input):
Z
zhiboniu 已提交
178
        if in_dynamic_mode():
179 180
            attrs = ('moving_rate', self._moving_rate, 'bit_length',
                     self._quant_bits, 'is_test', not self.training)
181 182 183 184 185 186
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
187 188 189
            state = self._state if self.training else None
            accum = self._accum if self.training else None

W
wanghuancoder 已提交
190
            out, _, _, _ = _C_ops.fake_quantize_dequantize_moving_average_abs_max(
191 192 193 194 195
                input, self._scale, accum, state, quant_out, self._scale, state,
                accum, *attrs)
            return out

        check_variable_and_dtype(input, 'input', ['float32'],
196
                                 "FakeQuantMovingAverageAbsMax")
197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225
        attrs = {
            'moving_rate': self._moving_rate,
            'bit_length': self._quant_bits,
            'is_test': not self.training
        }
        inputs = {"X": [input], "InScale": [self._scale]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        outputs = {"Out": [quant_out], "OutScale": [self._scale]}

        if self.training:
            inputs['InState'] = [self._state]
            inputs['InAccum'] = [self._accum]
            outputs['OutState'] = [self._state]
            outputs['OutAccum'] = [self._accum]

        self._helper.append_op(
            type="fake_quantize_dequantize_moving_average_abs_max",
            inputs=inputs,
            outputs=outputs,
            attrs=attrs)

        return quant_out


Z
zhiboniu 已提交
226
class FakeQuantChannelWiseAbsMax(Layer):
227

228 229
    def __init__(self,
                 name=None,
230
                 channel_num=None,
231
                 quant_bits=8,
232
                 quant_axis=0,
233 234
                 dtype='float32',
                 quant_on_weight=False):
235 236
        assert quant_on_weight == True, "Channel_wise only can be used on weight quantization."
        super(FakeQuantChannelWiseAbsMax, self).__init__()
237
        self._quant_bits = quant_bits
238 239
        self._quant_axis = quant_axis
        self._dtype = dtype
240
        self._name = name
241
        self._channel_num = channel_num
242 243 244 245
        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
        self._scale_name = unique_name.generate(scale_prefix)
        if quant_on_weight:
246 247 248 249 250 251
            scale_attr = ParamAttr(name=self._scale_name,
                                   initializer=Constant(0.0),
                                   trainable=False)
            self._scale = self.create_parameter(shape=[self._channel_num],
                                                attr=scale_attr,
                                                dtype=self._dtype)
252 253 254 255 256
            self._scale.stop_gradient = True
        else:
            self._scale = None

    def forward(self, input):
Z
zhiboniu 已提交
257
        if in_dynamic_mode():
258 259
            attrs = ('bit_length', self._quant_bits, 'quant_axis',
                     self._quant_axis)
260 261 262 263 264 265
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
266

267
            out_scale = self._scale
268
            if out_scale is None:
269 270 271
                out_scale = _varbase_creator(
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    name=self._scale_name,
272
                    shape=[self._channel_num],
273 274 275
                    dtype=self._dtype,
                    persistable=False)
                out_scale.stop_gradient = True
276

W
wanghuancoder 已提交
277
            out, _, = _C_ops.fake_channel_wise_quantize_dequantize_abs_max(
278 279 280
                input, quant_out, out_scale, *attrs)
            return out

281 282 283
        check_variable_and_dtype(input, 'input', ['float32'],
                                 "FakeQuantChannelWiseAbsMax")
        attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        out_scale = self._scale
        if not out_scale:
            out_scale = self._helper.create_variable(
                name=self._scale_name,
                dtype=self._dtype,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=True)
        outputs = {"Out": [quant_out], "OutScale": [out_scale]}

        self._helper.append_op(
302
            type="fake_channel_wise_quantize_dequantize_abs_max",
303 304 305 306 307 308 309
            inputs=inputs,
            outputs=outputs,
            attrs=attrs)

        return quant_out


Z
zhiboniu 已提交
310
class MovingAverageAbsMaxScale(Layer):
311

312 313 314 315 316 317 318 319 320 321 322 323 324
    def __init__(self, name=None, moving_rate=0.9, dtype='float32'):
        r"""
        MovingAverageMaxScale layer is used to calculating the output quantization
        scale of Layer. Its computational formula is described as below:

        :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
        :math:`Out = X`
        """
        super(MovingAverageAbsMaxScale, self).__init__()
        self._moving_rate = moving_rate

        scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
        scale_name = unique_name.generate(scale_prefix)
325 326 327 328 329 330
        scale_attr = ParamAttr(name=scale_name,
                               initializer=Constant(0),
                               trainable=False)
        self._scale = self.create_parameter(shape=[1],
                                            attr=scale_attr,
                                            dtype=dtype)
331 332 333
        self._scale.stop_gradient = True

        state_prefix = "{}.state".format(name) if name else 'outscale.state'
334 335 336 337 338 339
        state_attr = ParamAttr(name=unique_name.generate(state_prefix),
                               initializer=Constant(0),
                               trainable=False)
        self._state = self.create_parameter(shape=[1],
                                            attr=state_attr,
                                            dtype=dtype)
340 341 342
        self._state.stop_gradient = True

        accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
343 344 345 346 347 348
        accum_attr = ParamAttr(name=unique_name.generate(accum_prefix),
                               initializer=Constant(0),
                               trainable=False)
        self._accum = self.create_parameter(shape=[1],
                                            attr=accum_attr,
                                            dtype=dtype)
349
        self._accum.stop_gradient = True
H
huangxu96 已提交
350 351

    def forward(self, input):
Z
zhiboniu 已提交
352
        if in_dynamic_mode():
353 354 355 356
            attrs = ('moving_rate', self._moving_rate, 'is_test',
                     not self.training)
            state = self._state if self.training else None
            accum = self._accum if self.training else None
357 358 359 360 361
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.tmp".format(input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
H
huangxu96 已提交
362

W
wanghuancoder 已提交
363
            out, _, _, _ = _C_ops.moving_average_abs_max_scale(
364 365
                input, accum, state, quant_out, self._scale, state, accum,
                *attrs)
H
huangxu96 已提交
366 367
            return out

368 369 370 371
        check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                                 'MovingAverageAbsMaxScale')

        attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
H
huangxu96 已提交
372 373
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
374
            name="{}.tmp".format(input.name),
H
huangxu96 已提交
375 376 377 378
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
379 380 381 382 383 384 385
        outputs = {"Out": [quant_out], "OutScale": [self._scale]}

        if self.training:
            inputs['InState'] = [self._state]
            inputs['InAccum'] = [self._accum]
            outputs['OutState'] = [self._state]
            outputs['OutAccum'] = [self._accum]
H
huangxu96 已提交
386

387 388 389 390
        self._helper.append_op(type="moving_average_abs_max_scale",
                               inputs=inputs,
                               outputs=outputs,
                               attrs=attrs)
H
huangxu96 已提交
391 392 393 394

        return quant_out


395
QuantStub = MovingAverageAbsMaxScale
396 397


Z
zhiboniu 已提交
398
class QuantizedConv2D(Layer):
399 400 401 402 403 404 405 406 407 408 409
    """
    The computational logic of QuantizedConv2D is the same with Conv2D.
    The only difference is that its inputs are all fake quantized.
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
410 411 412 413 414
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
415 416 417 418 419
        super(QuantizedConv2D, self).__init__()
        # For Conv2D
        self._groups = getattr(layer, '_groups')
        self._stride = getattr(layer, '_stride')
        self._padding = getattr(layer, '_padding')
H
huangxu96 已提交
420 421 422 423
        self._padding_mode = getattr(layer, '_padding_mode')
        if self._padding_mode != 'zeros':
            self._reversed_padding_repeated_twice = getattr(
                layer, '_reversed_padding_repeated_twice')
424
        self._dilation = getattr(layer, '_dilation')
H
huangxu96 已提交
425
        self._data_format = getattr(layer, '_data_format')
426 427
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
H
huangxu96 已提交
428

429
        # For FakeQuant
H
huangxu96 已提交
430
        self._conv2d_quant_axis = 0
431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457
        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[self._conv2d_quant_axis],
                quant_axis=self._conv2d_quant_axis)
        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None
458 459

    def forward(self, input):
460 461
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
462
        quant_input = self._fake_quant_input(input)
463 464 465 466 467

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)
468

H
huangxu96 已提交
469 470 471 472 473 474
        if self._padding_mode != 'zeros':
            quant_input = F.pad(quant_input,
                                self._reversed_padding_repeated_twice,
                                mode=self._padding_mode,
                                data_format=self._data_format)
            self._padding = 0
475

476 477 478 479 480 481 482 483
        return F.conv2d(quant_input,
                        quant_weight,
                        bias=self.bias,
                        padding=self._padding,
                        stride=self._stride,
                        dilation=self._dilation,
                        groups=self._groups,
                        data_format=self._data_format)
484 485


Z
zhiboniu 已提交
486
class QuantizedConv2DTranspose(Layer):
487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578
    """
    The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
    The only difference is that its inputs are all fake quantized.
    
    Examples:
       .. code-block:: python
          import paddle
          import paddle.nn as nn
          from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
          x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
          conv = nn.Conv2DTranspose(4, 6, (3, 3))
          conv_quantized = QuantizedConv2DTranspose(conv)
          y_quantized = conv_quantized(x_var)
          y_var = conv(x_var)
          y_quantized_np = y_quantized.numpy()
          y_np = y_var.numpy()
          print(y_np.shape, y_quantized_np.shape)
          # (2, 6, 10, 10), (2, 6, 10, 10)
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
        r"""
        Constructor.

        The arguments are the same as ImperativeQuantAware.
        """
        super(QuantizedConv2DTranspose, self).__init__()
        # For Conv2DTranspose
        self._groups = getattr(layer, '_groups')
        self._stride = getattr(layer, '_stride')
        self._padding = getattr(layer, '_padding')
        self._output_padding = getattr(layer, 'output_padding')
        self._dilation = getattr(layer, '_dilation')
        self._data_format = getattr(layer, '_data_format')
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
        # For FakeQuant
        self._conv2d_transpose_quant_axis = 1
        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[
                    self._conv2d_transpose_quant_axis],
                quant_axis=self._conv2d_transpose_quant_axis)
        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None

    def forward(self, input, output_size=None):
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
        quant_input = self._fake_quant_input(input)

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

        if output_size is None:
            output_padding = self._output_padding
        else:
            output_padding = 0

579 580 581 582 583 584 585 586 587 588
        return F.conv2d_transpose(quant_input,
                                  quant_weight,
                                  bias=self.bias,
                                  padding=self._padding,
                                  output_padding=output_padding,
                                  stride=self._stride,
                                  dilation=self._dilation,
                                  groups=self._groups,
                                  output_size=output_size,
                                  data_format=self._data_format)
589 590


Z
zhiboniu 已提交
591
class QuantizedLinear(Layer):
592 593 594 595 596 597 598 599 600 601 602
    """
    The computational logic of QuantizedLinear is the same with Linear.
    The only difference is that its inputs are all fake quantized.
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
603 604 605 606 607
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
608 609 610 611
        super(QuantizedLinear, self).__init__()
        # For Linear
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
H
huangxu96 已提交
612
        self.name = getattr(layer, 'name')
613
        # For FakeQuant
H
huangxu96 已提交
614
        self._linear_quant_axis = 1
615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643

        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[self._linear_quant_axis],
                quant_axis=self._linear_quant_axis)

        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None
644 645

    def forward(self, input):
646 647
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
648
        quant_input = self._fake_quant_input(input)
649 650 651 652 653 654

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

655 656 657 658
        out = F.linear(x=quant_input,
                       weight=quant_weight,
                       bias=self.bias,
                       name=self.name)
H
huangxu96 已提交
659
        return out
660 661


Z
zhiboniu 已提交
662
class MAOutputScaleLayer(Layer):
663 664
    """
    Add MovingAverageMaxScale layer to the behind of the input layer.
665
    Calculate the scale (moving average abs max) for the output of the input layer.
666 667 668
    """

    def __init__(self, layer=None, moving_rate=0.9, name=None, dtype='float32'):
669
        r"""
670
        Construct
671
        """
672
        super(MAOutputScaleLayer, self).__init__()
673
        self._layer = layer
674 675 676 677 678 679 680 681
        if name is None:
            name = layer.full_name()
        self._ma_output_scale = \
            MovingAverageAbsMaxScale(name, moving_rate, dtype)

    def forward(self, *inputs, **kwargs):
        out = self._layer(*inputs, **kwargs)
        # TODO (jc): support the ops of several outputs
682 683
        if (isinstance(out, list) or isinstance(out, tuple)
                or isinstance(out, dict)):
684 685 686
            return out
        else:
            return self._ma_output_scale(out)
687

688

Z
zhiboniu 已提交
689
class FakeQuantMAOutputScaleLayer(Layer):
690 691 692 693
    """
    Add FakeQuantMovingAverageAbsMax layer to the behind of the input layer.
    """

694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719
    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 name=None,
                 *args,
                 **kwargs):

        super(FakeQuantMAOutputScaleLayer, self).__init__()
        self._layer = layer
        self._fake_quant_output = _get_fake_quant_type(
            'moving_average_abs_max',
            name=layer.full_name() if name is None else name,
            moving_rate=moving_rate,
            quant_bits=activation_bits,
            dtype=self._dtype,
            quant_on_weight=False)

    def forward(self, *inputs, **kwargs):
        out = self._layer(*inputs, **kwargs)
        # TODO (jc): support the ops of several outputs
        if (isinstance(out, list) or isinstance(out, tuple)) and len(out) > 1:
            return out
        else:
            return self._fake_quant_output(out)
720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746


def _get_fake_quant_type(quant_type, **kwargs):
    call_args = {
        "name": kwargs.get("name", None),
        "quant_bits": kwargs.get("quant_bits", 8),
        "dtype": kwargs.get("dtype", "float32")
    }

    if quant_type == 'abs_max':
        call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
    elif quant_type == 'moving_average_abs_max':
        call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
    elif quant_type == 'channel_wise_abs_max':
        call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
        call_args["channel_num"] = kwargs.get("channel_num", None)
        call_args["quant_axis"] = kwargs.get("quant_axis", 0)
        assert call_args["channel_num"] is not None, (
            "You need to input channel_num"
            "when you use channel_wise_abs_max strategy.")
    fake_quant_map = {
        'abs_max': FakeQuantAbsMax,
        'moving_average_abs_max': FakeQuantMovingAverageAbsMax,
        'channel_wise_abs_max': FakeQuantChannelWiseAbsMax
    }

    return fake_quant_map[quant_type](**call_args)