quant_layers.py 36.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
#   Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

15
import paddle
Z
zhiboniu 已提交
16
from paddle.framework import core
17
from paddle.fluid import dygraph_utils
Z
zhiboniu 已提交
18 19
from paddle.utils import unique_name
from paddle.framework import ParamAttr
20
from paddle.fluid.framework import _varbase_creator
Z
zhiboniu 已提交
21
from paddle.nn.initializer import Constant
22
from paddle.fluid.data_feeder import check_variable_and_dtype
H
huangxu96 已提交
23
from paddle.nn import functional as F
24 25
import logging
from paddle.fluid.log_helper import get_logger
26
from paddle import _C_ops, _legacy_C_ops
Z
zhiboniu 已提交
27 28
from paddle import in_dynamic_mode
from paddle.nn import Layer
29 30

__all__ = [
31
    'FakeQuantAbsMax',
32
    'FakeQuantMovingAverageAbsMax',
33 34
    'FakeQuantChannelWiseAbsMax',
    'QuantizedConv2D',
35
    'QuantizedConv2DTranspose',
36 37 38 39
    'QuantizedLinear',
    'MovingAverageAbsMaxScale',
    'MAOutputScaleLayer',
    'FakeQuantMAOutputScaleLayer',
40
    'QuantStub',
41 42
    'QuantizedRowParallelLinear',
    'QuantizedColumnParallelLinear',
43 44
]

45 46 47
_logger = get_logger(__name__,
                     logging.INFO,
                     fmt='%(asctime)s-%(levelname)s: %(message)s')
48

49

Z
zhiboniu 已提交
50
class FakeQuantAbsMax(Layer):
51 52 53 54 55 56 57 58 59 60 61 62 63
    r"""
    FakeQuantAbsMax layer does the abs_max quant and then dequant.
    Its computational formula is described as below:

    :math:`scale = max(abs(X))`
    :math:`range = 2^{bit\_length - 1} - 1`
    :math:`Out = round(X / scale * range) * scale / range`
    """

    def __init__(self,
                 name=None,
                 quant_bits=8,
                 dtype='float32',
64 65
                 quant_on_weight=False,
                 reduce_type=None):
66 67 68
        super(FakeQuantAbsMax, self).__init__()
        self._quant_bits = quant_bits
        self._name = name
69
        self._reduce_type = reduce_type
70 71 72 73
        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
        self._scale_name = unique_name.generate(scale_prefix)
        if quant_on_weight:
74 75 76 77 78 79
            scale_attr = ParamAttr(name=self._scale_name,
                                   initializer=Constant(0.001),
                                   trainable=False)
            self._scale = self.create_parameter(shape=[1],
                                                attr=scale_attr,
                                                dtype=self._dtype)
80 81 82 83 84
            self._scale.stop_gradient = True
        else:
            self._scale = None

    def forward(self, input):
Z
zhiboniu 已提交
85
        if in_dynamic_mode():
86
            attrs = ('bit_length', self._quant_bits)
87 88 89 90 91 92
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
93
            out_scale = self._scale
94 95 96 97
            if self._reduce_type == "max":
                paddle.distributed.all_reduce(
                    out_scale, op=paddle.distributed.ReduceOp.MAX)

98 99 100 101 102 103 104 105
            if not out_scale:
                out_scale = _varbase_creator(
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    name=self._scale_name,
                    shape=[1],
                    dtype=self._dtype,
                    persistable=False)
                out_scale.stop_gradient = True
106
            out, _, = _legacy_C_ops.fake_quantize_dequantize_abs_max(
107
                input, quant_out, out_scale, *attrs)
108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
            return out

        check_variable_and_dtype(input, 'input', ['float32'], "FakeQuantAbsMax")
        attrs = {'bit_length': self._quant_bits}
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        out_scale = self._scale
        if not out_scale:
            out_scale = self._helper.create_variable(
                name=self._scale_name,
                dtype=self._dtype,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=True)
        outputs = {"Out": [quant_out], "OutScale": [out_scale]}

129 130 131 132
        self._helper.append_op(type="fake_quantize_dequantize_abs_max",
                               inputs=inputs,
                               outputs=outputs,
                               attrs=attrs)
133 134 135 136

        return quant_out


Z
zhiboniu 已提交
137
class FakeQuantMovingAverageAbsMax(Layer):
138
    r"""
139
    FakeQuantMovingAverageAbsMax layer does the moving_average_abs_max quant and then dequant.
140 141 142 143 144 145 146 147 148 149 150
    Its computational formula is described as below:

    :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
    :math:`range = 2^{bit\_length - 1} - 1`
    :math:`Out = round(X / scale * range) * scale / range`
    """

    def __init__(self,
                 name=None,
                 moving_rate=0.9,
                 quant_bits=8,
151 152
                 dtype='float32',
                 reduce_type=None):
153
        super(FakeQuantMovingAverageAbsMax, self).__init__()
154 155
        self._moving_rate = moving_rate
        self._quant_bits = quant_bits
156
        self._reduce_type = reduce_type
157 158
        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
159 160 161 162 163 164
        scale_attr = ParamAttr(name=unique_name.generate(scale_prefix),
                               initializer=Constant(0.001),
                               trainable=False)
        self._scale = self.create_parameter(shape=[1],
                                            attr=scale_attr,
                                            dtype=dtype)
165 166 167 168
        self._scale.stop_gradient = True

        state_prefix = "{}.state".format(
            name) if name else 'quant_dequant.state'
169 170 171 172 173 174
        state_attr = ParamAttr(name=unique_name.generate(state_prefix),
                               initializer=Constant(1),
                               trainable=False)
        self._state = self.create_parameter(shape=[1],
                                            attr=state_attr,
                                            dtype=dtype)
175 176 177 178
        self._state.stop_gradient = True

        accum_prefix = "{}.accum".format(
            name) if name else 'quant_dequant.accum'
179 180 181 182 183 184
        accum_attr = ParamAttr(name=unique_name.generate(accum_prefix),
                               initializer=Constant(1),
                               trainable=False)
        self._accum = self.create_parameter(shape=[1],
                                            attr=accum_attr,
                                            dtype=dtype)
185 186 187
        self._accum.stop_gradient = True

    def forward(self, input):
Z
zhiboniu 已提交
188
        if in_dynamic_mode():
189 190
            attrs = ('moving_rate', self._moving_rate, 'bit_length',
                     self._quant_bits, 'is_test', not self.training)
191 192 193 194 195 196
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
197 198 199 200
            if self._reduce_type == "max":
                paddle.distributed.all_reduce(
                    self._scale, op=paddle.distributed.ReduceOp.MAX)

201 202 203
            state = self._state if self.training else None
            accum = self._accum if self.training else None

204
            out, _, _, _ = _legacy_C_ops.fake_quantize_dequantize_moving_average_abs_max(
205 206
                input, self._scale, accum, state, quant_out, self._scale, state,
                accum, *attrs)
207

208 209 210
            return out

        check_variable_and_dtype(input, 'input', ['float32'],
211
                                 "FakeQuantMovingAverageAbsMax")
212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240
        attrs = {
            'moving_rate': self._moving_rate,
            'bit_length': self._quant_bits,
            'is_test': not self.training
        }
        inputs = {"X": [input], "InScale": [self._scale]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        outputs = {"Out": [quant_out], "OutScale": [self._scale]}

        if self.training:
            inputs['InState'] = [self._state]
            inputs['InAccum'] = [self._accum]
            outputs['OutState'] = [self._state]
            outputs['OutAccum'] = [self._accum]

        self._helper.append_op(
            type="fake_quantize_dequantize_moving_average_abs_max",
            inputs=inputs,
            outputs=outputs,
            attrs=attrs)

        return quant_out


Z
zhiboniu 已提交
241
class FakeQuantChannelWiseAbsMax(Layer):
242

243 244
    def __init__(self,
                 name=None,
245
                 channel_num=None,
246
                 quant_bits=8,
247
                 quant_axis=0,
248
                 dtype='float32',
249 250
                 quant_on_weight=False,
                 reduce_type=None):
251 252
        assert quant_on_weight == True, "Channel_wise only can be used on weight quantization."
        super(FakeQuantChannelWiseAbsMax, self).__init__()
253
        self._quant_bits = quant_bits
254 255
        self._quant_axis = quant_axis
        self._dtype = dtype
256
        self._name = name
257
        self._channel_num = channel_num
258
        self._reduce_type = reduce_type
259 260 261 262
        scale_prefix = "{}.scale".format(
            name) if name else 'quant_dequant.scale'
        self._scale_name = unique_name.generate(scale_prefix)
        if quant_on_weight:
263 264 265 266 267 268
            scale_attr = ParamAttr(name=self._scale_name,
                                   initializer=Constant(0.0),
                                   trainable=False)
            self._scale = self.create_parameter(shape=[self._channel_num],
                                                attr=scale_attr,
                                                dtype=self._dtype)
269 270 271 272 273
            self._scale.stop_gradient = True
        else:
            self._scale = None

    def forward(self, input):
Z
zhiboniu 已提交
274
        if in_dynamic_mode():
275 276
            attrs = ('bit_length', self._quant_bits, 'quant_axis',
                     self._quant_axis)
277 278 279 280 281 282
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.quantized.dequantized".format(
                                             input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
283

284
            out_scale = self._scale
285 286 287
            if self._reduce_type == "max":
                paddle.distributed.all_reduce(
                    out_scale, op=paddle.distributed.ReduceOp.MAX)
288
            if out_scale is None:
289 290 291
                out_scale = _varbase_creator(
                    type=core.VarDesc.VarType.LOD_TENSOR,
                    name=self._scale_name,
292
                    shape=[self._channel_num],
293 294 295
                    dtype=self._dtype,
                    persistable=False)
                out_scale.stop_gradient = True
296

297
            out, _, = _legacy_C_ops.fake_channel_wise_quantize_dequantize_abs_max(
298 299 300
                input, quant_out, out_scale, *attrs)
            return out

301 302 303
        check_variable_and_dtype(input, 'input', ['float32'],
                                 "FakeQuantChannelWiseAbsMax")
        attrs = {'bit_length': self._quant_bits, 'quant_axis': self._quant_axis}
304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
            name="{}.quantized.dequantized".format(input.name),
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
        out_scale = self._scale
        if not out_scale:
            out_scale = self._helper.create_variable(
                name=self._scale_name,
                dtype=self._dtype,
                type=core.VarDesc.VarType.LOD_TENSOR,
                persistable=False,
                stop_gradient=True)
        outputs = {"Out": [quant_out], "OutScale": [out_scale]}

        self._helper.append_op(
322
            type="fake_channel_wise_quantize_dequantize_abs_max",
323 324 325 326 327 328 329
            inputs=inputs,
            outputs=outputs,
            attrs=attrs)

        return quant_out


Z
zhiboniu 已提交
330
class MovingAverageAbsMaxScale(Layer):
331

332 333 334 335 336
    def __init__(self,
                 name=None,
                 moving_rate=0.9,
                 dtype='float32',
                 reduce_type=None):
337 338 339 340 341 342 343 344 345
        r"""
        MovingAverageMaxScale layer is used to calculating the output quantization
        scale of Layer. Its computational formula is described as below:

        :math:`scale = (moving\_rate*accum+max(abs(x)))/(moving\_rate*state+1)`
        :math:`Out = X`
        """
        super(MovingAverageAbsMaxScale, self).__init__()
        self._moving_rate = moving_rate
346
        self._reduce_type = reduce_type
347 348
        scale_prefix = '{}.scale'.format(name) if name else 'outscale.scale'
        scale_name = unique_name.generate(scale_prefix)
349 350 351 352 353 354
        scale_attr = ParamAttr(name=scale_name,
                               initializer=Constant(0),
                               trainable=False)
        self._scale = self.create_parameter(shape=[1],
                                            attr=scale_attr,
                                            dtype=dtype)
355 356 357
        self._scale.stop_gradient = True

        state_prefix = "{}.state".format(name) if name else 'outscale.state'
358 359 360 361 362 363
        state_attr = ParamAttr(name=unique_name.generate(state_prefix),
                               initializer=Constant(0),
                               trainable=False)
        self._state = self.create_parameter(shape=[1],
                                            attr=state_attr,
                                            dtype=dtype)
364 365 366
        self._state.stop_gradient = True

        accum_prefix = "{}.accum".format(name) if name else 'outscale.accum'
367 368 369 370 371 372
        accum_attr = ParamAttr(name=unique_name.generate(accum_prefix),
                               initializer=Constant(0),
                               trainable=False)
        self._accum = self.create_parameter(shape=[1],
                                            attr=accum_attr,
                                            dtype=dtype)
373
        self._accum.stop_gradient = True
H
huangxu96 已提交
374 375

    def forward(self, input):
Z
zhiboniu 已提交
376
        if in_dynamic_mode():
377 378
            attrs = ('moving_rate', self._moving_rate, 'is_test',
                     not self.training)
379

380 381 382 383 384
            quant_out = _varbase_creator(type=input.type,
                                         name="{}.tmp".format(input.name),
                                         shape=input.shape,
                                         dtype=input.dtype,
                                         persistable=False)
385 386 387 388 389 390
            if self._reduce_type == "max":
                paddle.distributed.all_reduce(
                    self._scale, op=paddle.distributed.ReduceOp.MAX)

            state = self._state if self.training else None
            accum = self._accum if self.training else None
H
huangxu96 已提交
391

392
            out, _, _, _ = _legacy_C_ops.moving_average_abs_max_scale(
393 394
                input, accum, state, quant_out, self._scale, state, accum,
                *attrs)
H
huangxu96 已提交
395 396
            return out

397 398 399 400
        check_variable_and_dtype(input, 'input', ['float32', 'float64'],
                                 'MovingAverageAbsMaxScale')

        attrs = {'moving_rate': self._moving_rate, 'is_test': not self.training}
H
huangxu96 已提交
401 402
        inputs = {"X": [input]}
        quant_out = self._helper.create_variable(
403
            name="{}.tmp".format(input.name),
H
huangxu96 已提交
404 405 406 407
            dtype=input.dtype,
            type=core.VarDesc.VarType.LOD_TENSOR,
            persistable=False,
            stop_gradient=False)
408 409 410 411 412 413 414
        outputs = {"Out": [quant_out], "OutScale": [self._scale]}

        if self.training:
            inputs['InState'] = [self._state]
            inputs['InAccum'] = [self._accum]
            outputs['OutState'] = [self._state]
            outputs['OutAccum'] = [self._accum]
H
huangxu96 已提交
415

416 417 418 419
        self._helper.append_op(type="moving_average_abs_max_scale",
                               inputs=inputs,
                               outputs=outputs,
                               attrs=attrs)
H
huangxu96 已提交
420 421 422 423

        return quant_out


424
QuantStub = MovingAverageAbsMaxScale
425 426


Z
zhiboniu 已提交
427
class QuantizedConv2D(Layer):
428 429 430 431 432 433 434 435 436 437 438
    """
    The computational logic of QuantizedConv2D is the same with Conv2D.
    The only difference is that its inputs are all fake quantized.
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
439 440 441 442 443
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
444 445 446 447 448
        super(QuantizedConv2D, self).__init__()
        # For Conv2D
        self._groups = getattr(layer, '_groups')
        self._stride = getattr(layer, '_stride')
        self._padding = getattr(layer, '_padding')
H
huangxu96 已提交
449 450 451 452
        self._padding_mode = getattr(layer, '_padding_mode')
        if self._padding_mode != 'zeros':
            self._reversed_padding_repeated_twice = getattr(
                layer, '_reversed_padding_repeated_twice')
453
        self._dilation = getattr(layer, '_dilation')
H
huangxu96 已提交
454
        self._data_format = getattr(layer, '_data_format')
455 456
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
H
huangxu96 已提交
457

458
        # For FakeQuant
H
huangxu96 已提交
459
        self._conv2d_quant_axis = 0
460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486
        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[self._conv2d_quant_axis],
                quant_axis=self._conv2d_quant_axis)
        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None
487 488

    def forward(self, input):
489 490
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
491
        quant_input = self._fake_quant_input(input)
492 493 494 495 496

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)
497

H
huangxu96 已提交
498 499 500 501 502 503
        if self._padding_mode != 'zeros':
            quant_input = F.pad(quant_input,
                                self._reversed_padding_repeated_twice,
                                mode=self._padding_mode,
                                data_format=self._data_format)
            self._padding = 0
504

505 506 507 508 509 510 511 512
        return F.conv2d(quant_input,
                        quant_weight,
                        bias=self.bias,
                        padding=self._padding,
                        stride=self._stride,
                        dilation=self._dilation,
                        groups=self._groups,
                        data_format=self._data_format)
513 514


Z
zhiboniu 已提交
515
class QuantizedConv2DTranspose(Layer):
516
    """
517

518 519 520 521 522
    The computational logic of QuantizedConv2DTranspose is the same with Conv2DTranspose.
    The only difference is that its inputs are all fake quantized.
    
    Examples:
       .. code-block:: python
523

524 525 526
          import paddle
          import paddle.nn as nn
          from paddle.nn.quant.quant_layers import QuantizedConv2DTranspose
527

528 529 530 531 532 533 534 535 536
          x_var = paddle.uniform((2, 4, 8, 8), dtype='float32', min=-1., max=1.)
          conv = nn.Conv2DTranspose(4, 6, (3, 3))
          conv_quantized = QuantizedConv2DTranspose(conv)
          y_quantized = conv_quantized(x_var)
          y_var = conv(x_var)
          y_quantized_np = y_quantized.numpy()
          y_np = y_var.numpy()
          print(y_np.shape, y_quantized_np.shape)
          # (2, 6, 10, 10), (2, 6, 10, 10)
537

538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
        r"""
        Constructor.

        The arguments are the same as ImperativeQuantAware.
        """
        super(QuantizedConv2DTranspose, self).__init__()
        # For Conv2DTranspose
        self._groups = getattr(layer, '_groups')
        self._stride = getattr(layer, '_stride')
        self._padding = getattr(layer, '_padding')
        self._output_padding = getattr(layer, 'output_padding')
        self._dilation = getattr(layer, '_dilation')
        self._data_format = getattr(layer, '_data_format')
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
        # For FakeQuant
        self._conv2d_transpose_quant_axis = 1
        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[
                    self._conv2d_transpose_quant_axis],
                quant_axis=self._conv2d_transpose_quant_axis)
        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None

    def forward(self, input, output_size=None):
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
        quant_input = self._fake_quant_input(input)

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

        if output_size is None:
            output_padding = self._output_padding
        else:
            output_padding = 0

612 613 614 615 616 617 618 619 620 621
        return F.conv2d_transpose(quant_input,
                                  quant_weight,
                                  bias=self.bias,
                                  padding=self._padding,
                                  output_padding=output_padding,
                                  stride=self._stride,
                                  dilation=self._dilation,
                                  groups=self._groups,
                                  output_size=output_size,
                                  data_format=self._data_format)
622 623


Z
zhiboniu 已提交
624
class QuantizedLinear(Layer):
625 626 627 628 629 630 631 632 633 634 635
    """
    The computational logic of QuantizedLinear is the same with Linear.
    The only difference is that its inputs are all fake quantized.
    """

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
636 637 638 639 640
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
641 642 643 644
        super(QuantizedLinear, self).__init__()
        # For Linear
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
H
huangxu96 已提交
645
        self.name = getattr(layer, 'name')
646
        # For FakeQuant
H
huangxu96 已提交
647
        self._linear_quant_axis = 1
648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676

        if weight_quant_layer is not None:
            self._fake_quant_weight = weight_quant_layer()
        else:
            self._fake_quant_weight = _get_fake_quant_type(
                weight_quantize_type,
                name=self.weight.name,
                moving_rate=moving_rate,
                quant_bits=weight_bits,
                dtype=self._dtype,
                quant_on_weight=True,
                channel_num=self.weight.shape[self._linear_quant_axis],
                quant_axis=self._linear_quant_axis)

        if act_quant_layer is not None:
            self._fake_quant_input = act_quant_layer()
        else:
            self._fake_quant_input = _get_fake_quant_type(
                activation_quantize_type,
                name=layer.full_name(),
                moving_rate=moving_rate,
                quant_bits=activation_bits,
                dtype=self._dtype,
                quant_on_weight=False)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None
677 678

    def forward(self, input):
679 680
        if self._act_preprocess is not None:
            input = self._act_preprocess(input)
681
        quant_input = self._fake_quant_input(input)
682 683 684 685 686 687

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

688 689 690 691
        out = F.linear(x=quant_input,
                       weight=quant_weight,
                       bias=self.bias,
                       name=self.name)
H
huangxu96 已提交
692
        return out
693 694


695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866
class QuantizedColumnParallelLinear(Layer):

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
        super(QuantizedColumnParallelLinear, self).__init__()
        '''
        
        '''
        assert weight_quant_layer is None, "When quantizing ColumnParallelLinear, weight_quant_layer should be None."
        assert act_quant_layer is None, "When quantizing ColumnParallelLinear, act_quant_layer should be None."

        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
        self.name = getattr(layer, '_name')
        # For FakeQuant
        self._linear_quant_axis = 1

        self.is_mp = getattr(layer, 'is_mp')
        self.model_parallel_group = getattr(layer, 'model_parallel_group')
        self.gather_output = getattr(layer, 'gather_output')

        self._fake_quant_weight = _get_fake_quant_type(
            weight_quantize_type,
            name=self.weight.name,
            moving_rate=moving_rate,
            quant_bits=weight_bits,
            dtype=self._dtype,
            quant_on_weight=True,
            channel_num=self.weight.shape[self._linear_quant_axis],
            quant_axis=self._linear_quant_axis,
            reduce_type='max'
            if paddle.distributed.get_world_size() > 1 else None)

        self._fake_quant_input = _get_fake_quant_type(
            activation_quantize_type,
            name=layer.full_name(),
            moving_rate=moving_rate,
            quant_bits=activation_bits,
            dtype=self._dtype,
            quant_on_weight=False,
            reduce_type=None)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None

    def forward(self, input):
        if self.is_mp:
            input_parallel = paddle.distributed.collective._c_identity(
                input, group=self.model_parallel_group)
        else:
            input_parallel = input

        if self._act_preprocess is not None:
            input_parallel = self._act_preprocess(input_parallel)
        quant_input = self._fake_quant_input(input_parallel)

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

        output_parallel = F.linear(x=quant_input,
                                   weight=quant_weight,
                                   bias=self.bias,
                                   name=self.name)

        if self.gather_output and self.is_mp:
            output = paddle.distributed.collective._c_concat(
                output_parallel, group=self.model_parallel_group)
        else:
            output = output_parallel
        return output


class QuantizedRowParallelLinear(Layer):

    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 weight_quantize_type='abs_max',
                 activation_quantize_type='abs_max',
                 weight_pre_layer=None,
                 act_pre_layer=None,
                 weight_quant_layer=None,
                 act_quant_layer=None):
        super(QuantizedRowParallelLinear, self).__init__()
        assert weight_quant_layer is None, "When quantizing RowParallelLinear, weight_quant_layer cannot defined by yourself."
        assert act_quant_layer is None, "When quantizing RowParallelLinear, act_quant_layer cannot defined by yourself."

        # For Linear
        self.weight = getattr(layer, 'weight')
        self.bias = getattr(layer, 'bias')
        self.name = getattr(layer, '_name')
        # For FakeQuant
        self._linear_quant_axis = 1

        self.input_is_parallel = getattr(layer, 'input_is_parallel')
        self.is_mp = getattr(layer, 'is_mp')
        self.model_parallel_group = getattr(layer, 'model_parallel_group')

        self._fake_quant_weight = _get_fake_quant_type(
            weight_quantize_type,
            name=self.weight.name,
            moving_rate=moving_rate,
            quant_bits=weight_bits,
            dtype=self._dtype,
            quant_on_weight=True,
            channel_num=self.weight.shape[self._linear_quant_axis],
            quant_axis=self._linear_quant_axis,
            reduce_type='max'
            if paddle.distributed.get_world_size() > 1 else None)

        self._fake_quant_input = _get_fake_quant_type(
            activation_quantize_type,
            name=layer.full_name(),
            moving_rate=moving_rate,
            quant_bits=activation_bits,
            dtype=self._dtype,
            quant_on_weight=False,
            reduce_type='max'
            if paddle.distributed.get_world_size() > 1 else None)

        self._act_preprocess = act_pre_layer(
        ) if act_pre_layer is not None else None
        self._weight_preprocess = weight_pre_layer(
        ) if weight_pre_layer is not None else None

    def forward(self, input):
        if self.input_is_parallel or (not self.is_mp):
            input_parallel = input
        else:
            # split last dim
            input_parallel = paddle.distributed.collective._c_split(
                input, group=self.model_parallel_group)

        if self._act_preprocess is not None:
            input_parallel = self._act_preprocess(input_parallel)
        quant_input = self._fake_quant_input(input_parallel)

        weight = self.weight
        if self._weight_preprocess is not None:
            weight = self._weight_preprocess(self.weight)
        quant_weight = self._fake_quant_weight(weight)

        output_parallel = F.linear(x=quant_input,
                                   weight=quant_weight,
                                   name=self.name)
        if self.is_mp:
            output_ = paddle.distributed.collective._mp_allreduce(
                output_parallel,
                group=self.model_parallel_group,
                use_calc_stream=True,
                use_model_parallel=True)
        else:
            output_ = output_parallel
        output = output_ + self.bias if self.bias is not None else output_
        return output


Z
zhiboniu 已提交
867
class MAOutputScaleLayer(Layer):
868 869
    """
    Add MovingAverageMaxScale layer to the behind of the input layer.
870
    Calculate the scale (moving average abs max) for the output of the input layer.
871 872
    """

873 874 875 876 877 878
    def __init__(self,
                 layer=None,
                 moving_rate=0.9,
                 name=None,
                 dtype='float32',
                 reduce_type=None):
879
        r"""
880
        Construct
881
        """
882
        super(MAOutputScaleLayer, self).__init__()
883
        self._layer = layer
884 885 886
        if name is None:
            name = layer.full_name()
        self._ma_output_scale = \
887
            MovingAverageAbsMaxScale(name, moving_rate, dtype, reduce_type)
888 889 890 891

    def forward(self, *inputs, **kwargs):
        out = self._layer(*inputs, **kwargs)
        # TODO (jc): support the ops of several outputs
892 893
        if (isinstance(out, list) or isinstance(out, tuple)
                or isinstance(out, dict)):
894 895 896
            return out
        else:
            return self._ma_output_scale(out)
897

898

Z
zhiboniu 已提交
899
class FakeQuantMAOutputScaleLayer(Layer):
900 901 902 903
    """
    Add FakeQuantMovingAverageAbsMax layer to the behind of the input layer.
    """

904 905 906 907 908 909
    def __init__(self,
                 layer,
                 weight_bits=8,
                 activation_bits=8,
                 moving_rate=0.9,
                 name=None,
910
                 reduce_type=None,
911 912 913 914 915 916 917 918 919 920 921
                 *args,
                 **kwargs):

        super(FakeQuantMAOutputScaleLayer, self).__init__()
        self._layer = layer
        self._fake_quant_output = _get_fake_quant_type(
            'moving_average_abs_max',
            name=layer.full_name() if name is None else name,
            moving_rate=moving_rate,
            quant_bits=activation_bits,
            dtype=self._dtype,
922 923
            quant_on_weight=False,
            reduce_type=reduce_type)
924 925 926 927 928 929 930 931

    def forward(self, *inputs, **kwargs):
        out = self._layer(*inputs, **kwargs)
        # TODO (jc): support the ops of several outputs
        if (isinstance(out, list) or isinstance(out, tuple)) and len(out) > 1:
            return out
        else:
            return self._fake_quant_output(out)
932 933 934 935 936 937


def _get_fake_quant_type(quant_type, **kwargs):
    call_args = {
        "name": kwargs.get("name", None),
        "quant_bits": kwargs.get("quant_bits", 8),
938 939
        "dtype": kwargs.get("dtype", "float32"),
        "reduce_type": kwargs.get("reduce_type", None)
940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959
    }

    if quant_type == 'abs_max':
        call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
    elif quant_type == 'moving_average_abs_max':
        call_args["moving_rate"] = kwargs.get("moving_rate", 0.9)
    elif quant_type == 'channel_wise_abs_max':
        call_args["quant_on_weight"] = kwargs.get("quant_on_weight", False)
        call_args["channel_num"] = kwargs.get("channel_num", None)
        call_args["quant_axis"] = kwargs.get("quant_axis", 0)
        assert call_args["channel_num"] is not None, (
            "You need to input channel_num"
            "when you use channel_wise_abs_max strategy.")
    fake_quant_map = {
        'abs_max': FakeQuantAbsMax,
        'moving_average_abs_max': FakeQuantMovingAverageAbsMax,
        'channel_wise_abs_max': FakeQuantChannelWiseAbsMax
    }

    return fake_quant_map[quant_type](**call_args)