diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 04a8824ce90a755c7b5f52c8a66b7c96204e0b94..474630946b20e40bfa03926b62dc47d6180da718 100755 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -184,6 +184,14 @@ message TensorParallelConfig { optional int32 tensor_init_seed = 2 [ default = -1 ]; } +message QatConfig { + optional bool channel_wise_abs_max = 1 [default = true]; + optional int32 weight_bits = 2 [default = 8]; + optional int32 activation_bits = 3 [default = 8]; + repeated string not_quant_pattern = 4; + optional string algo = 5; +} + enum TableType { PS_SPARSE_TABLE = 0; PS_DENSE_TABLE = 1; @@ -327,6 +335,7 @@ message DistributedStrategy { optional bool heter_ccl_mode = 38 [ default = false ]; optional bool is_fl_ps_mode = 39 [ default = false ]; optional bool with_coordinator = 40 [ default = false ]; + optional bool qat = 41 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -344,6 +353,7 @@ message DistributedStrategy { optional TrainerDescConfig trainer_desc_configs = 114; repeated TableParameter downpour_table_param = 115; optional FsClientParameter fs_client_param = 116; + optional QatConfig qat_configs = 117; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; diff --git a/python/paddle/distributed/auto_parallel/completion.py b/python/paddle/distributed/auto_parallel/completion.py index cbdbd7b80b4d2a0d54fade0050c229824c40d205..341f4baf572071865f71a4d6ea4435aa11b629a3 100644 --- a/python/paddle/distributed/auto_parallel/completion.py +++ b/python/paddle/distributed/auto_parallel/completion.py @@ -940,6 +940,12 @@ class Completer: core.op_proto_and_checker_maker.OpRole.Forward): appended_grad_times += 1 + if int(op.attr('op_role')) == int( + int(core.op_proto_and_checker_maker.OpRole.Backward) + | int(core.op_proto_and_checker_maker.OpRole.Loss)): + assert op.type == "fill_constant" + break + # complete the annotation of grad op (xxx_grad op or sum op) # xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id grad_op = ops[idx] diff --git a/python/paddle/distributed/auto_parallel/operators/common.py b/python/paddle/distributed/auto_parallel/operators/common.py index 3305660c1aa65a64a2c284872c065396aa0c8374..b34749b09dfdd3be17d5039ec193d7e226b9da25 100644 --- a/python/paddle/distributed/auto_parallel/operators/common.py +++ b/python/paddle/distributed/auto_parallel/operators/common.py @@ -245,6 +245,8 @@ def is_parameter_related(varname, block): varname = varname[:varname.index(".subprog_")] if ".cast_fp" in varname: varname = varname[:varname.index(".cast_fp")] + if ".quantized" in varname: + varname = varname[:varname.index(".quantized")] assert block.has_var(varname) var = block.var(varname) return var.is_parameter diff --git a/python/paddle/distributed/auto_parallel/parallelizer_v2.py b/python/paddle/distributed/auto_parallel/parallelizer_v2.py index 96dcce6105921581e69b9f011c01f96c395e02ce..93c684eecc164c23e9b2538e9a8a5555032bc90a 100644 --- a/python/paddle/distributed/auto_parallel/parallelizer_v2.py +++ b/python/paddle/distributed/auto_parallel/parallelizer_v2.py @@ -66,9 +66,9 @@ class Parallelizer: serial_loss) # Apply pre optimization passes time0 = time.time() - self._apply_pre_optimization(serial_main_program, - serial_startup_program, serial_loss, - serial_optimizer, params_grads) + serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization( + serial_main_program, serial_startup_program, serial_loss, + serial_optimizer, params_grads) self._logger.info( "within parallel apply_pre_optimization time: {}, mode {}". format(time.time() - time0, self._mode)) @@ -162,6 +162,22 @@ class Parallelizer: optimizer, params_grads): if self._strategy is None: return + + # apply quantization pass + # The pass can be applied when mode must be 'train' + if self._mode == 'train' and self._strategy.qat: + config = copy.deepcopy(self._strategy.qat_configs) + config["dist_context"] = self._dist_context + config["params_grads"] = params_grads + auto_parallel_quantization_pass = new_pass( + "auto_parallel_quantization", config) + auto_parallel_quantization_pass.apply([main_program], + [startup_program], + self._pass_context) + main_program = self._pass_context.get_attr("main_program") + startup_program = self._pass_context.get_attr("startup_program") + params_grads = self._pass_context.get_attr("params_grads") + # apply amp pass # FIXME we disenable amp for eval since it has a little bug with # eval program and which will be fixed in future @@ -195,6 +211,8 @@ class Parallelizer: [startup_program], self._pass_context) + return main_program, startup_program, params_grads + def _apply_post_optimization(self, main_program, startup_program, rank, params_grads): if self._strategy is None: diff --git a/python/paddle/distributed/auto_parallel/reshard.py b/python/paddle/distributed/auto_parallel/reshard.py index 577d61f7670ade9fe8d8662b2ee30bca3131b3f3..4f1f02f815bce7cff50409f37c5ba92b1b165a2e 100644 --- a/python/paddle/distributed/auto_parallel/reshard.py +++ b/python/paddle/distributed/auto_parallel/reshard.py @@ -685,7 +685,8 @@ class Remover: block._remove_op(idx) @staticmethod - def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads): + def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads, + feed_var_names): """Remove no need vars in the main program""" for block_idx, block in enumerate(auto_parallel_main_prog.blocks): remove_vars = set() @@ -731,7 +732,7 @@ class Remover: idx += 1 for var in remove_vars: - if block.vars[var].is_data: + if var in feed_var_names: continue block._remove_var(var) @@ -743,7 +744,12 @@ class Remover: rank_id) Resharder.change_while_op_input_and_output(auto_parallel_main_prog, dist_context) - Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads) + # 'feed_var_names' cannot be removed from auto_parallel_main_prog + feed_var_names = [] + for var in sum(list(dist_context.serial_feed_vars.values()), []): + feed_var_names.append(var.name) + Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads, + feed_var_names) @staticmethod def remove_no_need_in_startup(auto_parallel_main_prog, diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index d58770dd714ff3b35660d38008cfff3624dfed5b..765fca275d7b0eedb600b498c82069f5766e47b8 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -1991,6 +1991,60 @@ class DistributedStrategy(object): else: print("WARNING: auto-search should have value of bool type") + @property + def qat(self): + """ + Indicating whether we are using quantization training + Default Value: False + """ + return self.strategy.qat + + @qat.setter + def qat(self, flag): + if isinstance(flag, bool): + self.strategy.qat = flag + else: + print("WARNING: qat should have value of bool type") + + @property + def qat_configs(self): + """ + Set quantization training configurations. In general, qat has serveral configurable + settings that can be configured through a dict. + + **Notes**: + channel_wise_abs_max(bool): Whether to use `per_channel` quantization training. Default is True. + + weight_bits(int): quantization bit number for weight. Default is 8. + + activation_bits(int): quantization bit number for activation. Default is 8. + + not_quant_pattern(list[str]): When the skip pattern is detected in an op's name scope, + the corresponding op will not be quantized. + + algo(str): Other quantization training algorithm. + + Exampless: + + .. code-block:: python + + import paddle.distributed.fleet as fleet + strategy = fleet.DistributedStrategy() + strategy.qat = True + strategy.qat_configs = { + "channel_wise_abs_max": True, + "weight_bits": 8, + "activation_bits: 8, + "not_quant_pattern": ['skip_quant']} + + """ + return get_msg_dict(self.strategy.qat_configs) + + @qat_configs.setter + def qat_configs(self, configs): + check_configs_key(self.strategy.qat_configs, configs, "qat_configs") + assign_configs_value(self.strategy.qat_configs, configs) + @property def heter_ccl_mode(self): """ diff --git a/python/paddle/distributed/passes/__init__.py b/python/paddle/distributed/passes/__init__.py index 670e7f003d7a4828d95caac5cd3e7230c799b4d6..03dd31fb9b2ae4a752a518f98735fc1d32594017 100644 --- a/python/paddle/distributed/passes/__init__.py +++ b/python/paddle/distributed/passes/__init__.py @@ -19,6 +19,7 @@ from .auto_parallel_sharding import * from .auto_parallel_amp import * from .auto_parallel_fp16 import * from .auto_parallel_recompute import * +from .auto_parallel_quantization import * from .auto_parallel_data_parallel_optimization import * from .cpp_pass import * import os diff --git a/python/paddle/distributed/passes/auto_parallel_quantization.py b/python/paddle/distributed/passes/auto_parallel_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..c0ac93d83939dedd99f1a6720771367e175d0947 --- /dev/null +++ b/python/paddle/distributed/passes/auto_parallel_quantization.py @@ -0,0 +1,258 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import paddle + +from paddle.fluid import core, framework +from paddle.fluid.dygraph.parallel import ParallelEnv +from paddle.fluid.contrib.slim.quantization import utils +from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2 +from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2 +from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass +from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute + +from .pass_base import PassBase, register_pass + +TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type +QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type + + +def _node_id(node): + return (node.node.graph_id(), node.node.id()) + + +@register_pass("auto_parallel_quantization") +class QuantizationPass(PassBase): + + def __init__(self): + super(QuantizationPass, self).__init__() + self.set_attr("dist_context", None) + self.set_attr("params_grads", None) + + def _check_self(self): + if self.get_attr("dist_context") is None: + return False + if self.get_attr("params_grads") is None: + return False + return True + + def _check_conflict(self, other_pass): + return True + + def _apply_single_impl(self, main_program, startup_program, context): + + dist_context = self.get_attr("dist_context") + params_grads = self.get_attr("params_grads") + + # TODO: scope and place will be removed, + # cause params should be initialized by engine module. + scope = paddle.static.global_scope() + place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) + + # 1. Program convert to Graph, and this pass is only for train mode + main_graph = framework.IrGraph(core.Graph(main_program.desc), + for_test=False) + + # 2. Prepare inputs + transform_pass_ops = [] + quant_dequant_ops = [] + quantize_op_types = [ + 'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2' + ] + for op_type in quantize_op_types: + if op_type in TRANSFORM_PASS_OP_TYPES: + transform_pass_ops.append(op_type) + elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: + quant_dequant_ops.append(op_type) + + weight_quantize_type = "channel_wise_abs_max" if self.get_attr( + 'channel_wise_abs_max') else "abs_max" + + # 3. Add quant op for ops which have parameters + transform_pass = QuantizationTransformPassV2( + scope=scope, + place=place, + weight_bits=self.get_attr('weight_bits'), + activation_bits=self.get_attr('activation_bits'), + skip_pattern=self.get_attr('not_quant_pattern'), + activation_quantize_type="moving_average_abs_max", + quantizable_op_type=transform_pass_ops, + weight_quantize_type=weight_quantize_type, + weight_quantize_func=None, + act_quantize_func=None, + weight_preprocess_func=None, + act_preprocess_func=None, + optimizer_func=None, + executor=None) + transform_pass.apply(main_graph) + + # 4. Add quant op for ops which don't have parameter + quant_dequant_pass = AddQuantDequantPassV2( + scope=scope, + place=place, + quant_bits=self.get_attr('activation_bits'), + skip_pattern=self.get_attr('not_quant_pattern'), + quantizable_op_type=quant_dequant_ops) + quant_dequant_pass.apply(main_graph) + + # 5. Gather quantitative information for the output + out_scale_training_pass = OutScaleForTrainingPass(scope=scope, + place=place) + out_scale_training_pass.apply(main_graph) + + # 6. Convert Graph back to Program + quant_program = main_graph.to_program() + + # 7. get new prams_grads from quant_program + new_params_grads = [] + for param, grad in params_grads: + if param.name not in quant_program.global_block().vars: + continue + + new_param = quant_program.global_block().vars[param.name] + new_grad = quant_program.global_block().vars[grad.name] + new_params_grads.append((new_param, new_grad)) + + # 8. complete distributed attribution + # NOTE: hack implement, upgrading soon + for ib, block in enumerate(quant_program.blocks): + # recover origin ops' dist_attr and set quant ops' dist_attr + qat_offset = 0 + for ip, quant_op in enumerate(block.ops): + quant_op_dist_attr = OperatorDistributedAttribute() + + if "quantize" in quant_op.type or \ + quant_op.type == "moving_average_abs_max_scale": + + input_name = quant_op.desc.input('X')[0] + if "quantize" in input_name: + input_name = input_name[:input_name.index(".quantized")] + + if quant_op.type == "moving_average_abs_max_scale": + consume_op = main_program.blocks[ib].vars[input_name].op + else: + consume_op = main_program.blocks[ib].ops[ip - + qat_offset] + consume_op_dist_attr = dist_context.get_dist_op_for_program( + consume_op).dist_attr + ref_process_mesh = consume_op_dist_attr.process_mesh + + if input_name in consume_op_dist_attr.outputs_dist_attrs: + consume_input_dist_attr = consume_op_dist_attr.outputs_dist_attrs[ + input_name] + else: + consume_input_dist_attr = consume_op_dist_attr.inputs_dist_attrs[ + input_name] + + quant_op_dist_attr.impl_idx = 0 + quant_op_dist_attr.impl_type = "default" + quant_op_dist_attr.process_mesh = ref_process_mesh + quant_op_dist_attr.set_input_dist_attr( + quant_op.desc.input('X')[0], consume_input_dist_attr) + + for slot_name in quant_op.desc.input_names(): + if slot_name == "X": + continue + for in_name in quant_op.desc.input(slot_name): + input_var = block.vars[in_name] + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = ref_process_mesh + tensor_dist_attr.dims_mapping = [-1] + dist_context.set_tensor_dist_attr_for_program( + input_var, tensor_dist_attr) + quant_op_dist_attr.set_input_dist_attr( + in_name, tensor_dist_attr) + + for slot_name in quant_op.desc.output_names(): + output_name = quant_op.desc.output(slot_name)[0] + output_var = block.vars[output_name] + if slot_name == "Y": + dist_context.set_tensor_dist_attr_for_program( + output_var, consume_input_dist_attr) + quant_op_dist_attr.set_output_dist_attr( + output_name, consume_input_dist_attr) + else: + tensor_dist_attr = TensorDistributedAttribute() + tensor_dist_attr.process_mesh = ref_process_mesh + tensor_dist_attr.dims_mapping = [-1] + dist_context.set_tensor_dist_attr_for_program( + output_var, tensor_dist_attr) + quant_op_dist_attr.set_output_dist_attr( + output_name, tensor_dist_attr) + + quant_op._set_attr("op_device", "") + qat_offset += 1 + + else: + + origin_op = main_program.blocks[ib].ops[ip - qat_offset] + quant_op.desc.set_original_id(origin_op.desc.original_id()) + dist_origin_op = dist_context.get_dist_op_for_program( + origin_op) + assert dist_origin_op is not None, "origin op must have dist attr." + + origin_op_dist_attr = dist_origin_op.dist_attr + quant_op_dist_attr.impl_idx = origin_op_dist_attr.impl_idx + quant_op_dist_attr.impl_type = origin_op_dist_attr.impl_type + quant_op_dist_attr.process_mesh = origin_op_dist_attr.process_mesh + for idx, input_name in enumerate(quant_op.input_arg_names): + origin_input_name = origin_op.input_arg_names[idx] + origin_input_dist_attr = origin_op_dist_attr.inputs_dist_attrs[ + origin_input_name] + quant_op_dist_attr.set_input_dist_attr( + input_name, origin_input_dist_attr) + + if input_name not in main_program.blocks[ib].vars: + origin_input_var = main_program.blocks[ib].vars[ + origin_input_name] + origin_in_tensor_dist_attr = dist_context.get_dist_tensor_for_program( + origin_input_var).dist_attr + quant_input_var = block.vars[input_name] + dist_context.set_tensor_dist_attr_for_program( + quant_input_var, origin_in_tensor_dist_attr) + + for idx, output_name in enumerate( + quant_op.output_arg_names): + origin_output_name = origin_op.output_arg_names[idx] + origin_output_dist_attr = origin_op_dist_attr.outputs_dist_attrs[ + origin_output_name] + quant_op_dist_attr.set_output_dist_attr( + output_name, origin_output_dist_attr) + + if output_name not in main_program.blocks[ib].vars: + origin_output_var = main_program.blocks[ib].vars[ + origin_output_name] + origin_out_tensor_dist_attr = dist_context.get_dist_tensor_for_program( + origin_output_var).dist_attr + quant_output_var = block.vars[output_name] + dist_context.set_tensor_dist_attr_for_program( + quant_output_var, origin_out_tensor_dist_attr) + + dist_context.set_op_dist_attr_for_program( + quant_op, quant_op_dist_attr) + + # recover vars' dist_attr + for name, dst_var in block.vars.items(): + if name in main_program.blocks[ib].vars: + src_var = main_program.blocks[ib].vars[name] + dist_tensor = dist_context.get_dist_tensor_for_program( + src_var) + if not dist_tensor: + continue + dist_context.set_tensor_dist_attr_for_program( + dst_var, dist_tensor.dist_attr) + + context.set_attr("main_program", quant_program) + context.set_attr("startup_program", startup_program) + context.set_attr("params_grads", new_params_grads) diff --git a/python/paddle/distributed/passes/auto_parallel_recompute.py b/python/paddle/distributed/passes/auto_parallel_recompute.py index fcd7d24377117e9f0b868e8091fe59e416bdbdcd..80edec82fd7de231b50a16395fe1d16737ab43ee 100644 --- a/python/paddle/distributed/passes/auto_parallel_recompute.py +++ b/python/paddle/distributed/passes/auto_parallel_recompute.py @@ -236,14 +236,14 @@ class RecomputePass(PassBase): def _check_conflict(self, other_pass): return True - def _apply_single_impl(self, main_programs, startup_programs, context): + def _apply_single_impl(self, main_program, startup_program, context): checkpoints = self.get_attr("checkpoints") loss = self.get_attr("loss") no_grad_set = self.get_attr("no_grad_set") self._dist_context = self.get_attr("dist_context") - main_block = main_programs.global_block() - no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set) + main_block = main_program.global_block() + no_grad_set_name = _get_stop_gradients(main_program, no_grad_set) # get op_path which is related to loss op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name) @@ -373,7 +373,7 @@ class RecomputePass(PassBase): ckpt_ops_dict[fwd_op_id][0] = False main_block._sync_with_cpp() - main_programs._sync_with_cpp() + main_program._sync_with_cpp() def reset_op_dist_attr(self, op, var_name_dict): op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op) diff --git a/python/paddle/fluid/contrib/slim/quantization/__init__.py b/python/paddle/fluid/contrib/slim/quantization/__init__.py index 4860871d8619524c91976f1fb1b5cdfc2899a0a9..2c5be249b0eb73eec7cc6e22d3c5111fac5fb8e6 100644 --- a/python/paddle/fluid/contrib/slim/quantization/__init__.py +++ b/python/paddle/fluid/contrib/slim/quantization/__init__.py @@ -25,7 +25,8 @@ from .post_training_quantization import * from . import imperative from .imperative import * -__all__ = quantization_pass.__all__ +__all__ = [] +__all__ += quantization_pass.__all__ __all__ += quant_int8_mkldnn_pass.__all__ __all__ += quant2_int8_mkldnn_pass.__all__ __all__ += post_training_quantization.__all__ diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py index 27b634e2ddbdd3bcaaa2a01fb0dc312c8fee18e0..84359f711532c0c1a62363573b9cad432e6b65b4 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/qat.py @@ -42,6 +42,17 @@ _logger = get_logger(__name__, fmt='%(asctime)s-%(levelname)s: %(message)s') +def lazy_import_fleet(layer_name_map, fake_quant_input_layers): + from paddle.distributed import fleet + layer_name_map[ + 'ColumnParallelLinear'] = fleet.meta_parallel.parallel_layers.mp_layers.ColumnParallelLinear + layer_name_map[ + 'RowParallelLinear'] = fleet.meta_parallel.parallel_layers.mp_layers.RowParallelLinear + fake_quant_input_layers.append(fleet.meta_parallel.RowParallelLinear) + fake_quant_input_layers.append(fleet.meta_parallel.ColumnParallelLinear) + return layer_name_map, fake_quant_input_layers + + class ImperativeQuantAware(object): """ Applying quantization aware training (QAT) to the dgraph model. @@ -300,13 +311,15 @@ class ImperativeQuantizeInputs(object): Please refer to the args of ImperativeQuantAware. """ super(ImperativeQuantizeInputs, self).__init__() + self.layer_name_map, self.fake_quant_input_layers = lazy_import_fleet( + utils.layer_name_map, utils.fake_quant_input_layers) self._quantizable_layer_type = tuple( - utils.layer_name_map[layer] if layer in - utils.layer_name_map else layer for layer in quantizable_layer_type) + self.layer_name_map[layer] if layer in + self.layer_name_map else layer for layer in quantizable_layer_type) for layer in self._quantizable_layer_type: assert not isinstance(layer, str) \ - and layer in utils.fake_quant_input_layers, \ + and layer in self.fake_quant_input_layers, \ "%s is unspported to be quantized." % layer quantize_type = { @@ -383,7 +396,7 @@ class ImperativeQuantizeInputs(object): def _get_input_quantized_layer(self, layer): quant_layer_name = None - for key, value in utils.layer_name_map.items(): + for key, value in self.layer_name_map.items(): if isinstance(layer, value): quant_layer_name = 'Quantized' + key break diff --git a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py index fafd8d70c800f8968a27e7c7e08411e8fb1129e5..a30d775165e186212961df9ae3d1e85b3728f016 100644 --- a/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py +++ b/python/paddle/fluid/contrib/slim/quantization/imperative/utils.py @@ -16,63 +16,38 @@ import math import numpy as np import paddle -from paddle.distributed import fleet import paddle.nn.quant.quant_layers as quant_layers from ..utils import _get_op_input_var_names, _get_op_output_var_names, _get_output_name_index, _get_input_name_index layer_name_map = { - 'Conv2DTranspose': - paddle.nn.Conv2DTranspose, - 'Conv2D': - paddle.nn.Conv2D, - 'Linear': - paddle.nn.Linear, - 'AdaptiveAvgPool2D': - paddle.nn.AdaptiveAvgPool2D, - 'AdaptiveMaxPool2D': - paddle.nn.AdaptiveMaxPool2D, - 'AvgPool2D': - paddle.nn.AvgPool2D, - 'MaxPool2D': - paddle.nn.MaxPool2D, - 'Hardswish': - paddle.nn.Hardswish, - 'LeakyReLU': - paddle.nn.LeakyReLU, - 'PReLU': - paddle.nn.PReLU, - 'ReLU': - paddle.nn.ReLU, - 'ReLU6': - paddle.nn.ReLU6, - 'Sigmoid': - paddle.nn.Sigmoid, - 'Softmax': - paddle.nn.Softmax, - 'Swish': - paddle.nn.Swish, - 'Tanh': - paddle.nn.Tanh, - 'Hardswish': - paddle.nn.Hardswish, - 'BatchNorm': - paddle.nn.BatchNorm, - 'GroupNorm': - paddle.nn.GroupNorm, - 'LayerNorm': - paddle.nn.LayerNorm, - 'ColumnParallelLinear': - fleet.meta_parallel.parallel_layers.mp_layers.ColumnParallelLinear, - 'RowParallelLinear': - fleet.meta_parallel.parallel_layers.mp_layers.RowParallelLinear + 'Conv2DTranspose': paddle.nn.Conv2DTranspose, + 'Conv2D': paddle.nn.Conv2D, + 'Linear': paddle.nn.Linear, + 'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D, + 'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D, + 'AvgPool2D': paddle.nn.AvgPool2D, + 'MaxPool2D': paddle.nn.MaxPool2D, + 'Hardswish': paddle.nn.Hardswish, + 'LeakyReLU': paddle.nn.LeakyReLU, + 'PReLU': paddle.nn.PReLU, + 'ReLU': paddle.nn.ReLU, + 'ReLU6': paddle.nn.ReLU6, + 'Sigmoid': paddle.nn.Sigmoid, + 'Softmax': paddle.nn.Softmax, + 'Swish': paddle.nn.Swish, + 'Tanh': paddle.nn.Tanh, + 'Hardswish': paddle.nn.Hardswish, + 'BatchNorm': paddle.nn.BatchNorm, + 'GroupNorm': paddle.nn.GroupNorm, + 'LayerNorm': paddle.nn.LayerNorm, } # Apply fake quant for the inputs of these layers fake_quant_input_layers = [ - paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose, - fleet.meta_parallel.RowParallelLinear, - fleet.meta_parallel.ColumnParallelLinear + paddle.nn.Conv2D, + paddle.nn.Linear, + paddle.nn.Conv2DTranspose, ] # Apply fake quant for the output of these layers diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt index 9c543870ac668e90bb15616b00acb8a36a8089ab..8566186d76c5ba74c9a3faf5798dc9da8701e51b 100644 --- a/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/auto_parallel/CMakeLists.txt @@ -65,4 +65,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU) py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2) py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2) py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip) + py_test_modules(test_quantization MODULES test_quantization) endif() diff --git a/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py b/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..f84ee03e0c9401e6c5bb369b9b0e72749e7325d8 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/auto_parallel/test_quantization.py @@ -0,0 +1,180 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import sys +import numpy as np +import paddle + +import paddle.distributed.fleet as fleet +import paddle.distributed.auto_parallel as auto + +from paddle.distributed.auto_parallel.engine import Engine +from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr + +sys.path.append("..") +import auto_parallel_gpt_model as modeling +from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion + +paddle.enable_static() + + +class FakeDataset: + + def __init__(self, num_samples, sequence_len, vocab_size): + self.num_samples = num_samples + self.sequence_len = sequence_len + self.vocab_size = vocab_size + + def __getitem__(self, idx): + tokens = np.random.randint(self.vocab_size, size=self.sequence_len) + position_ids = np.arange(self.sequence_len) + attention_mask = np.tril(np.ones(self.sequence_len)).reshape( + (1, self.sequence_len, self.sequence_len)).astype(np.float32) + labels = np.random.randint(self.vocab_size, size=self.sequence_len) + loss_mask = np.ones(self.sequence_len).astype(np.float32) + return tokens, position_ids, attention_mask, labels, loss_mask + + def __len__(self): + return self.num_samples + + +def apply_pass(): + dist_strategy = fleet.DistributedStrategy() + dist_strategy.semi_auto = True + dist_strategy.qat = True + dist_strategy.qat_configs = { + 'channel_wise_abs_max': True, + 'weight_bits': 8, + 'activation_bits': 8, + 'not_quant_pattern': ['skip_quant'], + } + return dist_strategy + + +def create_data_holder(batch_size, sequence_len): + tokens = paddle.static.InputSpec(name="tokens", + shape=[batch_size, sequence_len], + dtype='int64') + position_ids = paddle.static.InputSpec(name="position_ids", + shape=[batch_size, sequence_len], + dtype='int64') + attention_mask = paddle.static.InputSpec( + name="attention_mask", + shape=[batch_size, 1, sequence_len, sequence_len], + dtype='float32') + labels = paddle.static.InputSpec(name="labels", + shape=[batch_size, sequence_len], + dtype='int64') + loss_mask = paddle.static.InputSpec(name="loss_mask", + shape=[batch_size, sequence_len], + dtype='float32') + return [tokens, position_ids, attention_mask], [labels, loss_mask] + + +def get_gpt_model(): + modeling.init_global() + modeling._global_parallel_strategy = "serial" + modeling._global_process_mesh = auto.ProcessMesh(mesh=[0]) + + gpt = GPTModel(vocab_size=1000, + hidden_size=64, + num_hidden_layers=2, + num_attention_heads=8, + intermediate_size=256, + hidden_act="gelu", + hidden_dropout_prob=0.0, + attention_probs_dropout_prob=0.0, + max_position_embeddings=1024, + type_vocab_size=1, + initializer_range=0.02, + pad_token_id=0, + eos_token_id=7, + bos_token_id=0, + eol_token_id=3) + model = GPTForPretraining(gpt, + vocab_size=1000, + hidden_size=64, + initializer_range=0.02) + criterion = GPTPretrainingCriterion() + return model, criterion + + +class TestQuantizationPass(unittest.TestCase): + + def test_qat_pass(self): + + batch_size = 8 + batch_num = 10 + sequence_len = 512 + vocab_size = 1000 + + strategy = apply_pass() + model, loss = get_gpt_model() + opt = paddle.optimizer.AdamW(learning_rate=0.00001) + inputs_spec, labels_spec = create_data_holder(batch_size=batch_size, + sequence_len=sequence_len) + + engine = Engine(model, inputs_spec, labels_spec, strategy=strategy) + engine.prepare(optimizer=opt, loss=loss) + + dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size) + engine.fit(train_data=dataset, batch_size=batch_size) + + self.check_program(engine.main_program) + + def check_program(self, program): + + quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']} + quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']} + + quantized_ops = set() + for block in program.blocks: + for op in block.ops: + is_quntized = False + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check forward + if op.type in quantizable_op_and_inputs: + for arg_name in op.input_arg_names: + assert arg_name.endswith('.quantized.dequantized') + quantized_ops.add(arg_name) + + for op in block.ops: + is_quntized = False + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + if ".quantized" in arg_name: + is_quntized = True + + if not is_quntized: + continue + + # check backward + if op.type in quantizable_grad_op_inputs: + for pname in quantizable_grad_op_inputs[op.type]: + arg_name = op.input(pname)[0] + assert arg_name.endswith('.quantized.dequantized') + assert arg_name in quantized_ops + + +if __name__ == "__main__": + unittest.main()