# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import paddle from paddle.distributed.auto_parallel.dist_attribute import ( OperatorDistributedAttribute, TensorDistributedAttribute, ) from paddle.fluid import core, framework from paddle.fluid.contrib.slim.quantization import ( AddQuantDequantPassV2, OutScaleForTrainingPass, QuantizationTransformPassV2, utils, ) from paddle.fluid.dygraph.parallel import ParallelEnv from .pass_base import PassBase, register_pass TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type def _node_id(node): return (node.node.graph_id(), node.node.id()) @register_pass("auto_parallel_quantization") class QuantizationPass(PassBase): def __init__(self): super().__init__() self.set_attr("dist_context", None) self.set_attr("params_grads", None) def _check_self(self): if self.get_attr("dist_context") is None: return False if self.get_attr("params_grads") is None: return False return True def _check_conflict(self, other_pass): return True def _apply_single_impl(self, main_program, startup_program, context): dist_context = self.get_attr("dist_context") params_grads = self.get_attr("params_grads") # TODO: scope and place will be removed, # cause params should be initialized by engine module. scope = paddle.static.global_scope() place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id) # 1. Program convert to Graph, and this pass is only for train mode main_graph = framework.IrGraph( core.Graph(main_program.desc), for_test=False ) # 2. Prepare inputs transform_pass_ops = [] quant_dequant_ops = [] quantize_op_types = [ 'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2', ] for op_type in quantize_op_types: if op_type in TRANSFORM_PASS_OP_TYPES: transform_pass_ops.append(op_type) elif op_type in QUANT_DEQUANT_PASS_OP_TYPES: quant_dequant_ops.append(op_type) weight_quantize_type = ( "channel_wise_abs_max" if self.get_attr('channel_wise_abs_max') else "abs_max" ) # 3. Add quant op for ops which have parameters transform_pass = QuantizationTransformPassV2( scope=scope, place=place, weight_bits=self.get_attr('weight_bits'), activation_bits=self.get_attr('activation_bits'), skip_pattern=self.get_attr('not_quant_pattern'), activation_quantize_type="moving_average_abs_max", quantizable_op_type=transform_pass_ops, weight_quantize_type=weight_quantize_type, weight_quantize_func=None, act_quantize_func=None, weight_preprocess_func=None, act_preprocess_func=None, optimizer_func=None, executor=None, ) transform_pass.apply(main_graph) # 4. Add quant op for ops which don't have parameter quant_dequant_pass = AddQuantDequantPassV2( scope=scope, place=place, quant_bits=self.get_attr('activation_bits'), skip_pattern=self.get_attr('not_quant_pattern'), quantizable_op_type=quant_dequant_ops, ) quant_dequant_pass.apply(main_graph) # 5. Gather quantitative information for the output out_scale_training_pass = OutScaleForTrainingPass( scope=scope, place=place ) out_scale_training_pass.apply(main_graph) # 6. Convert Graph back to Program quant_program = main_graph.to_program() # 7. get new prams_grads from quant_program new_params_grads = [] for param, grad in params_grads: if param.name not in quant_program.global_block().vars: continue new_param = quant_program.global_block().vars[param.name] new_grad = quant_program.global_block().vars[grad.name] new_params_grads.append((new_param, new_grad)) # 8. complete distributed attribution # NOTE: hack implement, upgrading soon for ib, block in enumerate(quant_program.blocks): # recover origin ops' dist_attr and set quant ops' dist_attr qat_offset = 0 for ip, quant_op in enumerate(block.ops): quant_op_dist_attr = OperatorDistributedAttribute() if ( "quantize" in quant_op.type or quant_op.type == "moving_average_abs_max_scale" ): input_name = quant_op.desc.input('X')[0] if "quantize" in input_name: input_name = input_name[ : input_name.index(".quantized") ] if quant_op.type == "moving_average_abs_max_scale": consume_op = main_program.blocks[ib].vars[input_name].op else: consume_op = main_program.blocks[ib].ops[ ip - qat_offset ] consume_op_dist_attr = dist_context.get_dist_op_for_program( consume_op ).dist_attr ref_process_mesh = consume_op_dist_attr.process_mesh if input_name in consume_op_dist_attr.outputs_dist_attrs: consume_input_dist_attr = ( consume_op_dist_attr.outputs_dist_attrs[input_name] ) else: consume_input_dist_attr = ( consume_op_dist_attr.inputs_dist_attrs[input_name] ) quant_op_dist_attr.impl_idx = 0 quant_op_dist_attr.impl_type = "default" quant_op_dist_attr.process_mesh = ref_process_mesh quant_op_dist_attr.set_input_dist_attr( quant_op.desc.input('X')[0], consume_input_dist_attr ) for slot_name in quant_op.desc.input_names(): if slot_name == "X": continue for in_name in quant_op.desc.input(slot_name): input_var = block.vars[in_name] tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.dims_mapping = [-1] dist_context.set_tensor_dist_attr_for_program( input_var, tensor_dist_attr ) quant_op_dist_attr.set_input_dist_attr( in_name, tensor_dist_attr ) for slot_name in quant_op.desc.output_names(): output_name = quant_op.desc.output(slot_name)[0] output_var = block.vars[output_name] if slot_name == "Y": dist_context.set_tensor_dist_attr_for_program( output_var, consume_input_dist_attr ) quant_op_dist_attr.set_output_dist_attr( output_name, consume_input_dist_attr ) else: tensor_dist_attr = TensorDistributedAttribute() tensor_dist_attr.process_mesh = ref_process_mesh tensor_dist_attr.dims_mapping = [-1] dist_context.set_tensor_dist_attr_for_program( output_var, tensor_dist_attr ) quant_op_dist_attr.set_output_dist_attr( output_name, tensor_dist_attr ) quant_op._set_attr("op_device", "") qat_offset += 1 else: origin_op = main_program.blocks[ib].ops[ip - qat_offset] quant_op.desc.set_original_id(origin_op.desc.original_id()) dist_origin_op = dist_context.get_dist_op_for_program( origin_op ) assert ( dist_origin_op is not None ), "origin op must have dist attr." origin_op_dist_attr = dist_origin_op.dist_attr quant_op_dist_attr.impl_idx = origin_op_dist_attr.impl_idx quant_op_dist_attr.impl_type = origin_op_dist_attr.impl_type quant_op_dist_attr.process_mesh = ( origin_op_dist_attr.process_mesh ) for idx, input_name in enumerate(quant_op.input_arg_names): origin_input_name = origin_op.input_arg_names[idx] origin_input_dist_attr = ( origin_op_dist_attr.inputs_dist_attrs[ origin_input_name ] ) quant_op_dist_attr.set_input_dist_attr( input_name, origin_input_dist_attr ) if input_name not in main_program.blocks[ib].vars: origin_input_var = main_program.blocks[ib].vars[ origin_input_name ] origin_in_tensor_dist_attr = ( dist_context.get_dist_tensor_for_program( origin_input_var ).dist_attr ) quant_input_var = block.vars[input_name] dist_context.set_tensor_dist_attr_for_program( quant_input_var, origin_in_tensor_dist_attr ) for idx, output_name in enumerate( quant_op.output_arg_names ): origin_output_name = origin_op.output_arg_names[idx] origin_output_dist_attr = ( origin_op_dist_attr.outputs_dist_attrs[ origin_output_name ] ) quant_op_dist_attr.set_output_dist_attr( output_name, origin_output_dist_attr ) if output_name not in main_program.blocks[ib].vars: origin_output_var = main_program.blocks[ib].vars[ origin_output_name ] origin_out_tensor_dist_attr = ( dist_context.get_dist_tensor_for_program( origin_output_var ).dist_attr ) quant_output_var = block.vars[output_name] dist_context.set_tensor_dist_attr_for_program( quant_output_var, origin_out_tensor_dist_attr ) dist_context.set_op_dist_attr_for_program( quant_op, quant_op_dist_attr ) # recover vars' dist_attr for name, dst_var in block.vars.items(): if name in main_program.blocks[ib].vars: src_var = main_program.blocks[ib].vars[name] dist_tensor = dist_context.get_dist_tensor_for_program( src_var ) if not dist_tensor: continue dist_context.set_tensor_dist_attr_for_program( dst_var, dist_tensor.dist_attr ) context.set_attr("main_program", quant_program) context.set_attr("startup_program", startup_program) context.set_attr("params_grads", new_params_grads)