auto_parallel_quantization.py 13.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
16 17 18 19
from paddle.distributed.auto_parallel.dist_attribute import (
    OperatorDistributedAttribute,
    TensorDistributedAttribute,
)
20 21 22 23 24 25 26 27
from paddle.fluid import core, framework
from paddle.fluid.contrib.slim.quantization import (
    AddQuantDequantPassV2,
    OutScaleForTrainingPass,
    QuantizationTransformPassV2,
    utils,
)
from paddle.fluid.dygraph.parallel import ParallelEnv
28 29 30 31 32 33 34 35 36 37 38 39 40 41

from .pass_base import PassBase, register_pass

TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type


def _node_id(node):
    return (node.node.graph_id(), node.node.id())


@register_pass("auto_parallel_quantization")
class QuantizationPass(PassBase):
    def __init__(self):
42
        super().__init__()
43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
        self.set_attr("dist_context", None)
        self.set_attr("params_grads", None)

    def _check_self(self):
        if self.get_attr("dist_context") is None:
            return False
        if self.get_attr("params_grads") is None:
            return False
        return True

    def _check_conflict(self, other_pass):
        return True

    def _apply_single_impl(self, main_program, startup_program, context):

        dist_context = self.get_attr("dist_context")
        params_grads = self.get_attr("params_grads")

        # TODO: scope and place will be removed,
        # cause params should be initialized by engine module.
        scope = paddle.static.global_scope()
        place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)

        # 1. Program convert to Graph, and this pass is only for train mode
67 68 69
        main_graph = framework.IrGraph(
            core.Graph(main_program.desc), for_test=False
        )
70 71 72 73 74

        # 2. Prepare inputs
        transform_pass_ops = []
        quant_dequant_ops = []
        quantize_op_types = [
75 76 77 78 79
            'conv2d',
            'depthwise_conv2d',
            'mul',
            'matmul',
            'matmul_v2',
80 81 82 83 84 85 86
        ]
        for op_type in quantize_op_types:
            if op_type in TRANSFORM_PASS_OP_TYPES:
                transform_pass_ops.append(op_type)
            elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
                quant_dequant_ops.append(op_type)

87 88 89 90 91
        weight_quantize_type = (
            "channel_wise_abs_max"
            if self.get_attr('channel_wise_abs_max')
            else "abs_max"
        )
92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107

        # 3. Add quant op for ops which have parameters
        transform_pass = QuantizationTransformPassV2(
            scope=scope,
            place=place,
            weight_bits=self.get_attr('weight_bits'),
            activation_bits=self.get_attr('activation_bits'),
            skip_pattern=self.get_attr('not_quant_pattern'),
            activation_quantize_type="moving_average_abs_max",
            quantizable_op_type=transform_pass_ops,
            weight_quantize_type=weight_quantize_type,
            weight_quantize_func=None,
            act_quantize_func=None,
            weight_preprocess_func=None,
            act_preprocess_func=None,
            optimizer_func=None,
108 109
            executor=None,
        )
110 111 112 113 114 115 116 117
        transform_pass.apply(main_graph)

        # 4. Add quant op for ops which don't have parameter
        quant_dequant_pass = AddQuantDequantPassV2(
            scope=scope,
            place=place,
            quant_bits=self.get_attr('activation_bits'),
            skip_pattern=self.get_attr('not_quant_pattern'),
118 119
            quantizable_op_type=quant_dequant_ops,
        )
120 121 122
        quant_dequant_pass.apply(main_graph)

        # 5. Gather quantitative information for the output
123 124 125
        out_scale_training_pass = OutScaleForTrainingPass(
            scope=scope, place=place
        )
126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148
        out_scale_training_pass.apply(main_graph)

        # 6. Convert Graph back to Program
        quant_program = main_graph.to_program()

        # 7. get new prams_grads from quant_program
        new_params_grads = []
        for param, grad in params_grads:
            if param.name not in quant_program.global_block().vars:
                continue

            new_param = quant_program.global_block().vars[param.name]
            new_grad = quant_program.global_block().vars[grad.name]
            new_params_grads.append((new_param, new_grad))

        # 8. complete distributed attribution
        # NOTE: hack implement, upgrading soon
        for ib, block in enumerate(quant_program.blocks):
            # recover origin ops' dist_attr and set quant ops' dist_attr
            qat_offset = 0
            for ip, quant_op in enumerate(block.ops):
                quant_op_dist_attr = OperatorDistributedAttribute()

149 150 151 152
                if (
                    "quantize" in quant_op.type
                    or quant_op.type == "moving_average_abs_max_scale"
                ):
153 154 155

                    input_name = quant_op.desc.input('X')[0]
                    if "quantize" in input_name:
156 157 158
                        input_name = input_name[
                            : input_name.index(".quantized")
                        ]
159 160 161 162

                    if quant_op.type == "moving_average_abs_max_scale":
                        consume_op = main_program.blocks[ib].vars[input_name].op
                    else:
163 164 165
                        consume_op = main_program.blocks[ib].ops[
                            ip - qat_offset
                        ]
166
                    consume_op_dist_attr = dist_context.get_dist_op_for_program(
167 168
                        consume_op
                    ).dist_attr
169 170 171
                    ref_process_mesh = consume_op_dist_attr.process_mesh

                    if input_name in consume_op_dist_attr.outputs_dist_attrs:
172 173 174
                        consume_input_dist_attr = (
                            consume_op_dist_attr.outputs_dist_attrs[input_name]
                        )
175
                    else:
176 177 178
                        consume_input_dist_attr = (
                            consume_op_dist_attr.inputs_dist_attrs[input_name]
                        )
179 180 181 182 183

                    quant_op_dist_attr.impl_idx = 0
                    quant_op_dist_attr.impl_type = "default"
                    quant_op_dist_attr.process_mesh = ref_process_mesh
                    quant_op_dist_attr.set_input_dist_attr(
184 185
                        quant_op.desc.input('X')[0], consume_input_dist_attr
                    )
186 187 188 189 190 191 192 193 194 195

                    for slot_name in quant_op.desc.input_names():
                        if slot_name == "X":
                            continue
                        for in_name in quant_op.desc.input(slot_name):
                            input_var = block.vars[in_name]
                            tensor_dist_attr = TensorDistributedAttribute()
                            tensor_dist_attr.process_mesh = ref_process_mesh
                            tensor_dist_attr.dims_mapping = [-1]
                            dist_context.set_tensor_dist_attr_for_program(
196 197
                                input_var, tensor_dist_attr
                            )
198
                            quant_op_dist_attr.set_input_dist_attr(
199 200
                                in_name, tensor_dist_attr
                            )
201 202 203 204 205 206

                    for slot_name in quant_op.desc.output_names():
                        output_name = quant_op.desc.output(slot_name)[0]
                        output_var = block.vars[output_name]
                        if slot_name == "Y":
                            dist_context.set_tensor_dist_attr_for_program(
207 208
                                output_var, consume_input_dist_attr
                            )
209
                            quant_op_dist_attr.set_output_dist_attr(
210 211
                                output_name, consume_input_dist_attr
                            )
212 213 214 215 216
                        else:
                            tensor_dist_attr = TensorDistributedAttribute()
                            tensor_dist_attr.process_mesh = ref_process_mesh
                            tensor_dist_attr.dims_mapping = [-1]
                            dist_context.set_tensor_dist_attr_for_program(
217 218
                                output_var, tensor_dist_attr
                            )
219
                            quant_op_dist_attr.set_output_dist_attr(
220 221
                                output_name, tensor_dist_attr
                            )
222 223 224 225 226 227 228 229 230

                    quant_op._set_attr("op_device", "")
                    qat_offset += 1

                else:

                    origin_op = main_program.blocks[ib].ops[ip - qat_offset]
                    quant_op.desc.set_original_id(origin_op.desc.original_id())
                    dist_origin_op = dist_context.get_dist_op_for_program(
231 232 233 234 235
                        origin_op
                    )
                    assert (
                        dist_origin_op is not None
                    ), "origin op must have dist attr."
236 237 238 239

                    origin_op_dist_attr = dist_origin_op.dist_attr
                    quant_op_dist_attr.impl_idx = origin_op_dist_attr.impl_idx
                    quant_op_dist_attr.impl_type = origin_op_dist_attr.impl_type
240 241 242
                    quant_op_dist_attr.process_mesh = (
                        origin_op_dist_attr.process_mesh
                    )
243 244
                    for idx, input_name in enumerate(quant_op.input_arg_names):
                        origin_input_name = origin_op.input_arg_names[idx]
245 246 247 248 249
                        origin_input_dist_attr = (
                            origin_op_dist_attr.inputs_dist_attrs[
                                origin_input_name
                            ]
                        )
250
                        quant_op_dist_attr.set_input_dist_attr(
251 252
                            input_name, origin_input_dist_attr
                        )
253 254 255

                        if input_name not in main_program.blocks[ib].vars:
                            origin_input_var = main_program.blocks[ib].vars[
256 257 258 259 260 261 262
                                origin_input_name
                            ]
                            origin_in_tensor_dist_attr = (
                                dist_context.get_dist_tensor_for_program(
                                    origin_input_var
                                ).dist_attr
                            )
263 264
                            quant_input_var = block.vars[input_name]
                            dist_context.set_tensor_dist_attr_for_program(
265 266
                                quant_input_var, origin_in_tensor_dist_attr
                            )
267 268

                    for idx, output_name in enumerate(
269 270
                        quant_op.output_arg_names
                    ):
271
                        origin_output_name = origin_op.output_arg_names[idx]
272 273 274 275 276
                        origin_output_dist_attr = (
                            origin_op_dist_attr.outputs_dist_attrs[
                                origin_output_name
                            ]
                        )
277
                        quant_op_dist_attr.set_output_dist_attr(
278 279
                            output_name, origin_output_dist_attr
                        )
280 281 282

                        if output_name not in main_program.blocks[ib].vars:
                            origin_output_var = main_program.blocks[ib].vars[
283 284 285 286 287 288 289
                                origin_output_name
                            ]
                            origin_out_tensor_dist_attr = (
                                dist_context.get_dist_tensor_for_program(
                                    origin_output_var
                                ).dist_attr
                            )
290 291
                            quant_output_var = block.vars[output_name]
                            dist_context.set_tensor_dist_attr_for_program(
292 293
                                quant_output_var, origin_out_tensor_dist_attr
                            )
294 295

                dist_context.set_op_dist_attr_for_program(
296 297
                    quant_op, quant_op_dist_attr
                )
298 299 300 301 302 303

            # recover vars' dist_attr
            for name, dst_var in block.vars.items():
                if name in main_program.blocks[ib].vars:
                    src_var = main_program.blocks[ib].vars[name]
                    dist_tensor = dist_context.get_dist_tensor_for_program(
304 305
                        src_var
                    )
306 307 308
                    if not dist_tensor:
                        continue
                    dist_context.set_tensor_dist_attr_for_program(
309 310
                        dst_var, dist_tensor.dist_attr
                    )
311 312 313 314

        context.set_attr("main_program", quant_program)
        context.set_attr("startup_program", startup_program)
        context.set_attr("params_grads", new_params_grads)