adaround.py 12.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
#   Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import numpy as np
import time
import sys
import logging

20
import paddle
21
import paddle.fluid as fluid
22
import paddle
23 24

from ....log_helper import get_logger
25 26 27 28 29 30 31 32 33 34 35 36 37 38
from .utils import (
    load_variable_data,
    set_variable_data,
    stable_sigmoid,
    quant_tensor,
    dequant_tensor,
    _channelwise_quant_axis1_ops,
    calculate_quant_cos_error,
    bias_correction_w,
)

_logger = get_logger(
    __name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s'
)
39 40 41 42 43 44

GAMMA = -0.1
ZETA = 1.1


def compute_soft_rounding(alpha_v):
45
    return fluid.layers.clip(
46 47 48
        paddle.nn.functional.sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA,
        min=0,
        max=1,
49
    )
50 51 52


def compute_soft_rounding_np(alpha_v):
53 54 55
    return np.clip(
        stable_sigmoid(alpha_v) * (ZETA - GAMMA) + GAMMA, a_min=0, a_max=1
    )
56 57


58
class AdaRoundLoss:
59 60 61 62 63
    def __init__(self, reg_param=0.01, default_beta_range=(20, 2)):
        self.default_reg_param = reg_param
        self.default_beta_range = default_beta_range

    def compute_recon_loss(self, ada_quantized_output, orig_output):
64 65 66
        square_cost = fluid.layers.square_error_cost(
            ada_quantized_output, orig_output
        )
67
        recon_loss = fluid.layers.reduce_mean(
68 69
            fluid.layers.reduce_sum(square_cost, dim=-1)
        )
70 71 72 73 74 75 76 77 78
        return recon_loss

    def compute_round_loss(self, alpha_v, warm_start, beta):
        def round_loss_fn():
            # compute rectified sigmoid of parameter 'alpha' which maps it between zero and one
            h_v = compute_soft_rounding(alpha_v)

            # calculate regularization term - which ensures parameter to converge to exactly zeros and ones
            # at the end of optimization
79
            reg_term = fluid.layers.reduce_sum(
80
                -paddle.pow(paddle.abs(2 * h_v - 1), beta) + 1
81
            )
82 83 84 85 86 87

            # calculate the rounding loss
            round_loss = self.default_reg_param * reg_term

            return round_loss

88
        round_loss = fluid.layers.cond(
89 90 91 92 93 94
            warm_start,
            lambda: fluid.layers.fill_constant(
                shape=[1], dtype='float32', value=0.0
            ),
            round_loss_fn,
        )
95 96 97 98 99 100 101 102 103 104 105 106

        return round_loss

    def compute_beta(self, max_iter, cur_iter, warm_start):

        #  Start and stop beta for annealing of rounding loss (start_beta, end_beta)
        start_beta, end_beta = self.default_beta_range

        # iteration at end of warm start period, which is 20% of max iterations
        warm_start_end_iter = warm_start * max_iter

        # compute relative iteration of current iteration
107 108 109 110 111 112
        rel_iter = (cur_iter - warm_start_end_iter) / (
            max_iter - warm_start_end_iter
        )
        beta = end_beta + 0.5 * (start_beta - end_beta) * (
            1 + np.cos(rel_iter * np.pi)
        )
113 114 115 116

        return beta


117
class AdaRound:
118 119 120 121 122 123 124 125 126 127
    def __init__(
        self,
        scale,
        weight_tensor,
        scope=None,
        weight_var_name=None,
        weight_op_type=None,
        is_train=True,
        num_iterations=1000,
    ):
128 129 130 131
        self.is_train = is_train
        self.num_iterations = num_iterations
        self.warm_start = 0.1
        self.weight_bits = 8
132
        self.offset = 0.0  # zero-point offset
133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155
        self.adaround_loss = AdaRoundLoss()
        self.ori_weight_tensor = weight_tensor
        self.scale = scale
        self.scope = scope
        self.quant_axis = 0
        if weight_op_type in _channelwise_quant_axis1_ops:
            self.quant_axis = 1
        self.weight_var_name = weight_var_name
        self.alpha_name = weight_var_name + ".alpha"
        self.initialize_alpha(weight_tensor.copy(), scale, weight_var_name)

    def initialize_alpha(self, tensor, scale, var_name):
        """
        Initializes alpha parameter, same shape as the weight tensor
        """
        tensor_scale = quant_tensor(tensor, scale, quant_axis=self.quant_axis)
        tensor_floor = np.floor(tensor_scale)
        tensor = tensor_scale - tensor_floor
        alpha = -np.log((ZETA - GAMMA) / (tensor - GAMMA) - 1)
        self.alpha_v = fluid.layers.create_parameter(
            shape=alpha.shape,
            dtype="float32",
            name=var_name + ".alpha",
156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
            default_initializer=fluid.initializer.NumpyArrayInitializer(alpha),
        )

    def _calculate_output_with_adarounded_weights(
        self, program, place, exe, data, fp32_fetch_list, weight_tensor_dequant
    ):
        set_variable_data(
            self.scope, place, self.weight_var_name, weight_tensor_dequant
        )

        adaround_out_tensor = exe.run(
            program=program,
            feed=data,
            fetch_list=[fp32_fetch_list],
            return_numpy=True,
            scope=self.scope,
        )
173 174 175 176 177 178 179
        return adaround_out_tensor

    def _calculate_quant_weight(self):
        np_alpha = load_variable_data(self.scope, self.alpha_name)
        h_alpha = compute_soft_rounding_np(np_alpha)

        # Scale the tensor
180 181 182 183 184
        tensor_scale = quant_tensor(
            self.ori_weight_tensor.copy(),
            self.scale,
            quant_axis=self.quant_axis,
        )
185 186 187 188 189 190 191 192 193 194 195

        weight_tensor = np.floor(tensor_scale)

        # Adaround the tensor
        weight_tensor_quant = np.add(weight_tensor, h_alpha)
        return weight_tensor_quant

    def _calculate_adarounded_weights(self):
        weight_tensor_quant = self._calculate_quant_weight()

        # Dequantize the tensor
196 197 198 199 200
        weight_tensor_dequant = dequant_tensor(
            weight_tensor_quant + self.offset,
            self.scale,
            quant_axis=self.quant_axis,
        )
201 202 203 204 205 206 207
        return weight_tensor_dequant

    def update_final_weights(self):
        weight_tensor_quant = self._calculate_quant_weight()
        return weight_tensor_quant

    def get_loss(self, beta, warm_start, adaround_out_tensor, orig_out_tensor):
208
        round_loss = self.adaround_loss.compute_round_loss(
209 210
            self.alpha_v, warm_start, beta
        )
211
        recon_loss = self.adaround_loss.compute_recon_loss(
212 213
            adaround_out_tensor, orig_out_tensor
        )
214 215 216 217
        loss = round_loss + recon_loss
        losses = {
            'loss': loss,
            'round_loss': round_loss,
218
            'recon_loss': recon_loss,
219 220 221 222 223
        }
        return losses

    def update_beta_warm(self, cur_iteration):
        warm_start = cur_iteration < self.num_iterations * self.warm_start
224 225 226
        beta = self.adaround_loss.compute_beta(
            self.num_iterations, cur_iteration, self.warm_start
        )
227 228 229
        return beta, warm_start


230 231 232 233 234 235 236 237 238 239 240 241 242 243 244
def run_adaround(
    data_loader,
    fp32_program,
    fetch_list,
    exe,
    scope,
    place,
    quantized_op_pairs,
    weight_op_pairs,
    scale_dict,
    num_iterations=1000,
    lr=0.001,
    bias_correction=False,
    fast_mode=True,
):
245 246 247 248 249 250 251 252 253 254 255 256 257
    fetch_op_name = fetch_list[0].name
    final_weight_tensor_quant_dict = {}
    for weight_var_name, quant_op_out_name in quantized_op_pairs.items():
        _logger.info('Start adaround op: {}'.format(weight_var_name))
        weight_op_type = weight_op_pairs[weight_var_name]
        # get scale and weight tensor
        weight_var_tensor = load_variable_data(scope, weight_var_name)
        scale = scale_dict[weight_var_name]
        fp32_fetch_list = None
        for _op in fp32_program.global_block().ops:
            if _op.type == "fetch":
                _op._rename_input(fetch_op_name, quant_op_out_name)
                fp32_fetch_list = fp32_program.global_block().var(
258 259
                    quant_op_out_name
                )
260 261 262 263 264 265 266 267 268 269
                fetch_op_name = quant_op_out_name

        # build adaround program
        exec_strategy = fluid.ExecutionStrategy()
        exec_strategy.num_iteration_per_drop_scope = 1
        startup_program = fluid.Program()
        train_program = fluid.Program()
        with fluid.program_guard(train_program, startup_program):
            with fluid.unique_name.guard():
                # initialize adaround
270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300
                adaround = AdaRound(
                    scale,
                    weight_var_tensor,
                    scope=scope,
                    weight_var_name=weight_var_name,
                    weight_op_type=weight_op_type,
                    num_iterations=num_iterations,
                )
                orig_out_tensor = fluid.data(
                    name='orig_out_tensor',
                    shape=fp32_fetch_list.shape,
                    dtype='float32',
                )
                adaround_out_tensor = fluid.data(
                    name='adaround_out_tensor',
                    shape=fp32_fetch_list.shape,
                    dtype='float32',
                )
                beta_tensor = fluid.data(
                    name='beta', shape=[1], dtype='float32'
                )
                warm_start_tensor = fluid.data(
                    name='warm_start', shape=[1], dtype='bool'
                )

                train_fetches_loss = adaround.get_loss(
                    beta_tensor,
                    warm_start_tensor,
                    adaround_out_tensor,
                    orig_out_tensor,
                )
301 302 303 304 305 306 307 308 309 310 311
                optimizer = fluid.optimizer.Adam(learning_rate=lr)
                loss = train_fetches_loss['loss']
                optimizer.minimize(loss)
        exe.run(startup_program)

        start_time = time.time()
        prev_start_time = start_time
        for i, data in enumerate(data_loader()):
            prev_start_time = start_time
            start_time = time.time()
            # run fp32 model
312 313 314 315 316 317 318
            np_orig_out_tensor = exe.run(
                program=fp32_program,
                feed=data,
                fetch_list=[fp32_fetch_list],
                return_numpy=True,
                scope=scope,
            )
319

320 321 322 323 324 325 326 327 328 329 330 331
            adaround_weight_tensor_dequant = (
                adaround._calculate_adarounded_weights()
            )
            np_adaround_out_tensor = (
                adaround._calculate_output_with_adarounded_weights(
                    fp32_program,
                    place,
                    exe,
                    data,
                    fp32_fetch_list,
                    adaround_weight_tensor_dequant,
                )
332 333 334
            )

            # If the cosine distance of the two tensor is small, skip training
335 336 337
            cos_error = calculate_quant_cos_error(
                np_orig_out_tensor[0], np_adaround_out_tensor[0]
            )
338 339 340 341 342 343 344 345
            if fast_mode and cos_error > 0.99:
                _logger.info("The cosine error is small, skip training.")
                break
            beta, warm_start = adaround.update_beta_warm(i)
            feed_dict = {
                'orig_out_tensor': np_orig_out_tensor[0],
                'adaround_out_tensor': np_adaround_out_tensor[0],
                'beta': beta,
346
                'warm_start': warm_start,
347 348 349 350 351
            }
            out = exe.run(
                train_program,
                feed=feed_dict,
                fetch_list=[v.name for v in train_fetches_loss.values()],
352 353
                return_numpy=True,
            )
354
            _logger.info(
355 356 357 358 359 360 361 362 363
                "Iter {:d}, lr {:.5f}, loss {:.5f}, loss_round {:.5f}, loss_recon {:.5f}, time {:.5f}s".format(
                    i,
                    lr,
                    np.mean(out[0]),
                    np.mean(out[1]),
                    np.mean(out[2]),
                    start_time - prev_start_time,
                )
            )
364 365 366 367
            sys.stdout.flush()
            if i == num_iterations:
                break
        final_weight_tensor_quant_dict[
368 369
            weight_var_name
        ] = adaround.update_final_weights()
370 371 372 373 374 375 376

        if bias_correction:
            final_weight_tensor_quant_dict[weight_var_name] = bias_correction_w(
                weight_var_tensor,
                final_weight_tensor_quant_dict[weight_var_name],
                scale,
                adaround.quant_axis,
377 378
                weight_bits=adaround.weight_bits,
            )
379

380 381 382 383
        del adaround

    # update adarounded calibrated weights
    for weight_var_name in quantized_op_pairs.keys():
384 385 386 387 388 389
        set_variable_data(
            scope,
            place,
            weight_var_name,
            final_weight_tensor_quant_dict[weight_var_name],
        )