auto_parallel_quantization.py 13.2 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle

from paddle.fluid import core, framework
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.contrib.slim.quantization import utils
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
23 24 25 26
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
    TensorDistributedAttribute,
)
27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65

from .pass_base import PassBase, register_pass

TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type


def _node_id(node):
    return (node.node.graph_id(), node.node.id())


@register_pass("auto_parallel_quantization")
class QuantizationPass(PassBase):
    def __init__(self):
        super(QuantizationPass, self).__init__()
        self.set_attr("dist_context", None)
        self.set_attr("params_grads", None)

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
        if self.get_attr("params_grads") is None:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, context):

        dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        # TODO: scope and place will be removed,
        # cause params should be initialized by engine module.
        scope = paddle.static.global_scope()
        place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)

        # 1. Program convert to Graph, and this pass is only for train mode
66 67 68
        main_graph = framework.IrGraph(
            core.Graph(main_program.desc), for_test=False
        )
69 70 71 72 73

        # 2. Prepare inputs
        transform_pass_ops = []
        quant_dequant_ops = []
        quantize_op_types = [
74 75 76 77 78
            'conv2d',
            'depthwise_conv2d',
            'mul',
            'matmul',
            'matmul_v2',
79 80 81 82 83 84 85
        ]
        for op_type in quantize_op_types:
            if op_type in TRANSFORM_PASS_OP_TYPES:
                transform_pass_ops.append(op_type)
            elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
                quant_dequant_ops.append(op_type)

86 87 88 89 90
        weight_quantize_type = (
            "channel_wise_abs_max"
            if self.get_attr('channel_wise_abs_max')
            else "abs_max"
        )
91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106

        # 3. Add quant op for ops which have parameters
        transform_pass = QuantizationTransformPassV2(
            scope=scope,
            place=place,
            weight_bits=self.get_attr('weight_bits'),
            activation_bits=self.get_attr('activation_bits'),
            skip_pattern=self.get_attr('not_quant_pattern'),
            activation_quantize_type="moving_average_abs_max",
            quantizable_op_type=transform_pass_ops,
            weight_quantize_type=weight_quantize_type,
            weight_quantize_func=None,
            act_quantize_func=None,
            weight_preprocess_func=None,
            act_preprocess_func=None,
            optimizer_func=None,
107 108
            executor=None,
        )
109 110 111 112 113 114 115 116
        transform_pass.apply(main_graph)

        # 4. Add quant op for ops which don't have parameter
        quant_dequant_pass = AddQuantDequantPassV2(
            scope=scope,
            place=place,
            quant_bits=self.get_attr('activation_bits'),
            skip_pattern=self.get_attr('not_quant_pattern'),
117 118
            quantizable_op_type=quant_dequant_ops,
        )
119 120 121
        quant_dequant_pass.apply(main_graph)

        # 5. Gather quantitative information for the output
122 123 124
        out_scale_training_pass = OutScaleForTrainingPass(
            scope=scope, place=place
        )
125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147
        out_scale_training_pass.apply(main_graph)

        # 6. Convert Graph back to Program
        quant_program = main_graph.to_program()

        # 7. get new prams_grads from quant_program
        new_params_grads = []
        for param, grad in params_grads:
            if param.name not in quant_program.global_block().vars:
                continue

            new_param = quant_program.global_block().vars[param.name]
            new_grad = quant_program.global_block().vars[grad.name]
            new_params_grads.append((new_param, new_grad))

        # 8. complete distributed attribution
        # NOTE: hack implement, upgrading soon
        for ib, block in enumerate(quant_program.blocks):
            # recover origin ops' dist_attr and set quant ops' dist_attr
            qat_offset = 0
            for ip, quant_op in enumerate(block.ops):
                quant_op_dist_attr = OperatorDistributedAttribute()

148 149 150 151
                if (
                    "quantize" in quant_op.type
                    or quant_op.type == "moving_average_abs_max_scale"
                ):
152 153 154

                    input_name = quant_op.desc.input('X')[0]
                    if "quantize" in input_name:
155 156 157
                        input_name = input_name[
                            : input_name.index(".quantized")
                        ]
158 159 160 161

                    if quant_op.type == "moving_average_abs_max_scale":
                        consume_op = main_program.blocks[ib].vars[input_name].op
                    else:
162 163 164
                        consume_op = main_program.blocks[ib].ops[
                            ip - qat_offset
                        ]
165
                    consume_op_dist_attr = dist_context.get_dist_op_for_program(
166 167
                        consume_op
                    ).dist_attr
168 169 170
                    ref_process_mesh = consume_op_dist_attr.process_mesh

                    if input_name in consume_op_dist_attr.outputs_dist_attrs:
171 172 173
                        consume_input_dist_attr = (
                            consume_op_dist_attr.outputs_dist_attrs[input_name]
                        )
174
                    else:
175 176 177
                        consume_input_dist_attr = (
                            consume_op_dist_attr.inputs_dist_attrs[input_name]
                        )
178 179 180 181 182

                    quant_op_dist_attr.impl_idx = 0
                    quant_op_dist_attr.impl_type = "default"
                    quant_op_dist_attr.process_mesh = ref_process_mesh
                    quant_op_dist_attr.set_input_dist_attr(
183 184
                        quant_op.desc.input('X')[0], consume_input_dist_attr
                    )
185 186 187 188 189 190 191 192 193 194

                    for slot_name in quant_op.desc.input_names():
                        if slot_name == "X":
                            continue
                        for in_name in quant_op.desc.input(slot_name):
                            input_var = block.vars[in_name]
                            tensor_dist_attr = TensorDistributedAttribute()
                            tensor_dist_attr.process_mesh = ref_process_mesh
                            tensor_dist_attr.dims_mapping = [-1]
                            dist_context.set_tensor_dist_attr_for_program(
195 196
                                input_var, tensor_dist_attr
                            )
197
                            quant_op_dist_attr.set_input_dist_attr(
198 199
                                in_name, tensor_dist_attr
                            )
200 201 202 203 204 205

                    for slot_name in quant_op.desc.output_names():
                        output_name = quant_op.desc.output(slot_name)[0]
                        output_var = block.vars[output_name]
                        if slot_name == "Y":
                            dist_context.set_tensor_dist_attr_for_program(
206 207
                                output_var, consume_input_dist_attr
                            )
208
                            quant_op_dist_attr.set_output_dist_attr(
209 210
                                output_name, consume_input_dist_attr
                            )
211 212 213 214 215
                        else:
                            tensor_dist_attr = TensorDistributedAttribute()
                            tensor_dist_attr.process_mesh = ref_process_mesh
                            tensor_dist_attr.dims_mapping = [-1]
                            dist_context.set_tensor_dist_attr_for_program(
216 217
                                output_var, tensor_dist_attr
                            )
218
                            quant_op_dist_attr.set_output_dist_attr(
219 220
                                output_name, tensor_dist_attr
                            )
221 222 223 224 225 226 227 228 229

                    quant_op._set_attr("op_device", "")
                    qat_offset += 1

                else:

                    origin_op = main_program.blocks[ib].ops[ip - qat_offset]
                    quant_op.desc.set_original_id(origin_op.desc.original_id())
                    dist_origin_op = dist_context.get_dist_op_for_program(
230 231 232 233 234
                        origin_op
                    )
                    assert (
                        dist_origin_op is not None
                    ), "origin op must have dist attr."
235 236 237 238

                    origin_op_dist_attr = dist_origin_op.dist_attr
                    quant_op_dist_attr.impl_idx = origin_op_dist_attr.impl_idx
                    quant_op_dist_attr.impl_type = origin_op_dist_attr.impl_type
239 240 241
                    quant_op_dist_attr.process_mesh = (
                        origin_op_dist_attr.process_mesh
                    )
242 243
                    for idx, input_name in enumerate(quant_op.input_arg_names):
                        origin_input_name = origin_op.input_arg_names[idx]
244 245 246 247 248
                        origin_input_dist_attr = (
                            origin_op_dist_attr.inputs_dist_attrs[
                                origin_input_name
                            ]
                        )
249
                        quant_op_dist_attr.set_input_dist_attr(
250 251
                            input_name, origin_input_dist_attr
                        )
252 253 254

                        if input_name not in main_program.blocks[ib].vars:
                            origin_input_var = main_program.blocks[ib].vars[
255 256 257 258 259 260 261
                                origin_input_name
                            ]
                            origin_in_tensor_dist_attr = (
                                dist_context.get_dist_tensor_for_program(
                                    origin_input_var
                                ).dist_attr
                            )
262 263
                            quant_input_var = block.vars[input_name]
                            dist_context.set_tensor_dist_attr_for_program(
264 265
                                quant_input_var, origin_in_tensor_dist_attr
                            )
266 267

                    for idx, output_name in enumerate(
268 269
                        quant_op.output_arg_names
                    ):
270
                        origin_output_name = origin_op.output_arg_names[idx]
271 272 273 274 275
                        origin_output_dist_attr = (
                            origin_op_dist_attr.outputs_dist_attrs[
                                origin_output_name
                            ]
                        )
276
                        quant_op_dist_attr.set_output_dist_attr(
277 278
                            output_name, origin_output_dist_attr
                        )
279 280 281

                        if output_name not in main_program.blocks[ib].vars:
                            origin_output_var = main_program.blocks[ib].vars[
282 283 284 285 286 287 288
                                origin_output_name
                            ]
                            origin_out_tensor_dist_attr = (
                                dist_context.get_dist_tensor_for_program(
                                    origin_output_var
                                ).dist_attr
                            )
289 290
                            quant_output_var = block.vars[output_name]
                            dist_context.set_tensor_dist_attr_for_program(
291 292
                                quant_output_var, origin_out_tensor_dist_attr
                            )
293 294

                dist_context.set_op_dist_attr_for_program(
295 296
                    quant_op, quant_op_dist_attr
                )
297 298 299 300 301 302

            # recover vars' dist_attr
            for name, dst_var in block.vars.items():
                if name in main_program.blocks[ib].vars:
                    src_var = main_program.blocks[ib].vars[name]
                    dist_tensor = dist_context.get_dist_tensor_for_program(
303 304
                        src_var
                    )
305 306 307
                    if not dist_tensor:
                        continue
                    dist_context.set_tensor_dist_attr_for_program(
308 309
                        dst_var, dist_tensor.dist_attr
                    )
310 311 312 313

        context.set_attr("main_program", quant_program)
        context.set_attr("startup_program", startup_program)
        context.set_attr("params_grads", new_params_grads)