未验证 提交 61bc016c 编写于 作者: Z zhaoyingli 提交者: GitHub

[AutoParallel] Add Quant Pass (#44877)

* add quant pass
上级 9ccdb5fa
......@@ -184,6 +184,14 @@ message TensorParallelConfig {
optional int32 tensor_init_seed = 2 [ default = -1 ];
}
message QatConfig {
optional bool channel_wise_abs_max = 1 [default = true];
optional int32 weight_bits = 2 [default = 8];
optional int32 activation_bits = 3 [default = 8];
repeated string not_quant_pattern = 4;
optional string algo = 5;
}
enum TableType {
PS_SPARSE_TABLE = 0;
PS_DENSE_TABLE = 1;
......@@ -327,6 +335,7 @@ message DistributedStrategy {
optional bool heter_ccl_mode = 38 [ default = false ];
optional bool is_fl_ps_mode = 39 [ default = false ];
optional bool with_coordinator = 40 [ default = false ];
optional bool qat = 41 [ default = false ];
optional RecomputeConfig recompute_configs = 101;
optional AMPConfig amp_configs = 102;
......@@ -344,6 +353,7 @@ message DistributedStrategy {
optional TrainerDescConfig trainer_desc_configs = 114;
repeated TableParameter downpour_table_param = 115;
optional FsClientParameter fs_client_param = 116;
optional QatConfig qat_configs = 117;
optional BuildStrategy build_strategy = 201;
optional ExecutionStrategy execution_strategy = 202;
......
......@@ -940,6 +940,12 @@ class Completer:
core.op_proto_and_checker_maker.OpRole.Forward):
appended_grad_times += 1
if int(op.attr('op_role')) == int(
int(core.op_proto_and_checker_maker.OpRole.Backward)
| int(core.op_proto_and_checker_maker.OpRole.Loss)):
assert op.type == "fill_constant"
break
# complete the annotation of grad op (xxx_grad op or sum op)
# xxx_grad op will have a corresponding forward op in grad_op_id_to_op_id
grad_op = ops[idx]
......
......@@ -245,6 +245,8 @@ def is_parameter_related(varname, block):
varname = varname[:varname.index(".subprog_")]
if ".cast_fp" in varname:
varname = varname[:varname.index(".cast_fp")]
if ".quantized" in varname:
varname = varname[:varname.index(".quantized")]
assert block.has_var(varname)
var = block.var(varname)
return var.is_parameter
......
......@@ -66,9 +66,9 @@ class Parallelizer:
serial_loss)
# Apply pre optimization passes
time0 = time.time()
self._apply_pre_optimization(serial_main_program,
serial_startup_program, serial_loss,
serial_optimizer, params_grads)
serial_main_program, serial_startup_program, params_grads = self._apply_pre_optimization(
serial_main_program, serial_startup_program, serial_loss,
serial_optimizer, params_grads)
self._logger.info(
"within parallel apply_pre_optimization time: {}, mode {}".
format(time.time() - time0, self._mode))
......@@ -162,6 +162,22 @@ class Parallelizer:
optimizer, params_grads):
if self._strategy is None:
return
# apply quantization pass
# The pass can be applied when mode must be 'train'
if self._mode == 'train' and self._strategy.qat:
config = copy.deepcopy(self._strategy.qat_configs)
config["dist_context"] = self._dist_context
config["params_grads"] = params_grads
auto_parallel_quantization_pass = new_pass(
"auto_parallel_quantization", config)
auto_parallel_quantization_pass.apply([main_program],
[startup_program],
self._pass_context)
main_program = self._pass_context.get_attr("main_program")
startup_program = self._pass_context.get_attr("startup_program")
params_grads = self._pass_context.get_attr("params_grads")
# apply amp pass
# FIXME we disenable amp for eval since it has a little bug with
# eval program and which will be fixed in future
......@@ -195,6 +211,8 @@ class Parallelizer:
[startup_program],
self._pass_context)
return main_program, startup_program, params_grads
def _apply_post_optimization(self, main_program, startup_program, rank,
params_grads):
if self._strategy is None:
......
......@@ -685,7 +685,8 @@ class Remover:
block._remove_op(idx)
@staticmethod
def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads):
def remove_no_need_vars(auto_parallel_main_prog, dist_params_grads,
feed_var_names):
"""Remove no need vars in the main program"""
for block_idx, block in enumerate(auto_parallel_main_prog.blocks):
remove_vars = set()
......@@ -731,7 +732,7 @@ class Remover:
idx += 1
for var in remove_vars:
if block.vars[var].is_data:
if var in feed_var_names:
continue
block._remove_var(var)
......@@ -743,7 +744,12 @@ class Remover:
rank_id)
Resharder.change_while_op_input_and_output(auto_parallel_main_prog,
dist_context)
Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads)
# 'feed_var_names' cannot be removed from auto_parallel_main_prog
feed_var_names = []
for var in sum(list(dist_context.serial_feed_vars.values()), []):
feed_var_names.append(var.name)
Remover.remove_no_need_vars(auto_parallel_main_prog, dist_params_grads,
feed_var_names)
@staticmethod
def remove_no_need_in_startup(auto_parallel_main_prog,
......
......@@ -1991,6 +1991,60 @@ class DistributedStrategy(object):
else:
print("WARNING: auto-search should have value of bool type")
@property
def qat(self):
"""
Indicating whether we are using quantization training
Default Value: False
"""
return self.strategy.qat
@qat.setter
def qat(self, flag):
if isinstance(flag, bool):
self.strategy.qat = flag
else:
print("WARNING: qat should have value of bool type")
@property
def qat_configs(self):
"""
Set quantization training configurations. In general, qat has serveral configurable
settings that can be configured through a dict.
**Notes**:
channel_wise_abs_max(bool): Whether to use `per_channel` quantization training. Default is True.
weight_bits(int): quantization bit number for weight. Default is 8.
activation_bits(int): quantization bit number for activation. Default is 8.
not_quant_pattern(list[str]): When the skip pattern is detected in an op's name scope,
the corresponding op will not be quantized.
algo(str): Other quantization training algorithm.
Exampless:
.. code-block:: python
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.qat = True
strategy.qat_configs = {
"channel_wise_abs_max": True,
"weight_bits": 8,
"activation_bits: 8,
"not_quant_pattern": ['skip_quant']}
"""
return get_msg_dict(self.strategy.qat_configs)
@qat_configs.setter
def qat_configs(self, configs):
check_configs_key(self.strategy.qat_configs, configs, "qat_configs")
assign_configs_value(self.strategy.qat_configs, configs)
@property
def heter_ccl_mode(self):
"""
......
......@@ -19,6 +19,7 @@ from .auto_parallel_sharding import *
from .auto_parallel_amp import *
from .auto_parallel_fp16 import *
from .auto_parallel_recompute import *
from .auto_parallel_quantization import *
from .auto_parallel_data_parallel_optimization import *
from .cpp_pass import *
import os
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
from paddle.fluid import core, framework
from paddle.fluid.dygraph.parallel import ParallelEnv
from paddle.fluid.contrib.slim.quantization import utils
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPassV2
from paddle.fluid.contrib.slim.quantization import AddQuantDequantPassV2
from paddle.fluid.contrib.slim.quantization import OutScaleForTrainingPass
from paddle.distributed.auto_parallel.dist_attribute import OperatorDistributedAttribute, TensorDistributedAttribute
from .pass_base import PassBase, register_pass
TRANSFORM_PASS_OP_TYPES = utils._weight_supported_quantizable_op_type
QUANT_DEQUANT_PASS_OP_TYPES = utils._act_supported_quantizable_op_type
def _node_id(node):
return (node.node.graph_id(), node.node.id())
@register_pass("auto_parallel_quantization")
class QuantizationPass(PassBase):
def __init__(self):
super(QuantizationPass, self).__init__()
self.set_attr("dist_context", None)
self.set_attr("params_grads", None)
def _check_self(self):
if self.get_attr("dist_context") is None:
return False
if self.get_attr("params_grads") is None:
return False
return True
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_program, startup_program, context):
dist_context = self.get_attr("dist_context")
params_grads = self.get_attr("params_grads")
# TODO: scope and place will be removed,
# cause params should be initialized by engine module.
scope = paddle.static.global_scope()
place = paddle.fluid.CUDAPlace(ParallelEnv().dev_id)
# 1. Program convert to Graph, and this pass is only for train mode
main_graph = framework.IrGraph(core.Graph(main_program.desc),
for_test=False)
# 2. Prepare inputs
transform_pass_ops = []
quant_dequant_ops = []
quantize_op_types = [
'conv2d', 'depthwise_conv2d', 'mul', 'matmul', 'matmul_v2'
]
for op_type in quantize_op_types:
if op_type in TRANSFORM_PASS_OP_TYPES:
transform_pass_ops.append(op_type)
elif op_type in QUANT_DEQUANT_PASS_OP_TYPES:
quant_dequant_ops.append(op_type)
weight_quantize_type = "channel_wise_abs_max" if self.get_attr(
'channel_wise_abs_max') else "abs_max"
# 3. Add quant op for ops which have parameters
transform_pass = QuantizationTransformPassV2(
scope=scope,
place=place,
weight_bits=self.get_attr('weight_bits'),
activation_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
activation_quantize_type="moving_average_abs_max",
quantizable_op_type=transform_pass_ops,
weight_quantize_type=weight_quantize_type,
weight_quantize_func=None,
act_quantize_func=None,
weight_preprocess_func=None,
act_preprocess_func=None,
optimizer_func=None,
executor=None)
transform_pass.apply(main_graph)
# 4. Add quant op for ops which don't have parameter
quant_dequant_pass = AddQuantDequantPassV2(
scope=scope,
place=place,
quant_bits=self.get_attr('activation_bits'),
skip_pattern=self.get_attr('not_quant_pattern'),
quantizable_op_type=quant_dequant_ops)
quant_dequant_pass.apply(main_graph)
# 5. Gather quantitative information for the output
out_scale_training_pass = OutScaleForTrainingPass(scope=scope,
place=place)
out_scale_training_pass.apply(main_graph)
# 6. Convert Graph back to Program
quant_program = main_graph.to_program()
# 7. get new prams_grads from quant_program
new_params_grads = []
for param, grad in params_grads:
if param.name not in quant_program.global_block().vars:
continue
new_param = quant_program.global_block().vars[param.name]
new_grad = quant_program.global_block().vars[grad.name]
new_params_grads.append((new_param, new_grad))
# 8. complete distributed attribution
# NOTE: hack implement, upgrading soon
for ib, block in enumerate(quant_program.blocks):
# recover origin ops' dist_attr and set quant ops' dist_attr
qat_offset = 0
for ip, quant_op in enumerate(block.ops):
quant_op_dist_attr = OperatorDistributedAttribute()
if "quantize" in quant_op.type or \
quant_op.type == "moving_average_abs_max_scale":
input_name = quant_op.desc.input('X')[0]
if "quantize" in input_name:
input_name = input_name[:input_name.index(".quantized")]
if quant_op.type == "moving_average_abs_max_scale":
consume_op = main_program.blocks[ib].vars[input_name].op
else:
consume_op = main_program.blocks[ib].ops[ip -
qat_offset]
consume_op_dist_attr = dist_context.get_dist_op_for_program(
consume_op).dist_attr
ref_process_mesh = consume_op_dist_attr.process_mesh
if input_name in consume_op_dist_attr.outputs_dist_attrs:
consume_input_dist_attr = consume_op_dist_attr.outputs_dist_attrs[
input_name]
else:
consume_input_dist_attr = consume_op_dist_attr.inputs_dist_attrs[
input_name]
quant_op_dist_attr.impl_idx = 0
quant_op_dist_attr.impl_type = "default"
quant_op_dist_attr.process_mesh = ref_process_mesh
quant_op_dist_attr.set_input_dist_attr(
quant_op.desc.input('X')[0], consume_input_dist_attr)
for slot_name in quant_op.desc.input_names():
if slot_name == "X":
continue
for in_name in quant_op.desc.input(slot_name):
input_var = block.vars[in_name]
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
input_var, tensor_dist_attr)
quant_op_dist_attr.set_input_dist_attr(
in_name, tensor_dist_attr)
for slot_name in quant_op.desc.output_names():
output_name = quant_op.desc.output(slot_name)[0]
output_var = block.vars[output_name]
if slot_name == "Y":
dist_context.set_tensor_dist_attr_for_program(
output_var, consume_input_dist_attr)
quant_op_dist_attr.set_output_dist_attr(
output_name, consume_input_dist_attr)
else:
tensor_dist_attr = TensorDistributedAttribute()
tensor_dist_attr.process_mesh = ref_process_mesh
tensor_dist_attr.dims_mapping = [-1]
dist_context.set_tensor_dist_attr_for_program(
output_var, tensor_dist_attr)
quant_op_dist_attr.set_output_dist_attr(
output_name, tensor_dist_attr)
quant_op._set_attr("op_device", "")
qat_offset += 1
else:
origin_op = main_program.blocks[ib].ops[ip - qat_offset]
quant_op.desc.set_original_id(origin_op.desc.original_id())
dist_origin_op = dist_context.get_dist_op_for_program(
origin_op)
assert dist_origin_op is not None, "origin op must have dist attr."
origin_op_dist_attr = dist_origin_op.dist_attr
quant_op_dist_attr.impl_idx = origin_op_dist_attr.impl_idx
quant_op_dist_attr.impl_type = origin_op_dist_attr.impl_type
quant_op_dist_attr.process_mesh = origin_op_dist_attr.process_mesh
for idx, input_name in enumerate(quant_op.input_arg_names):
origin_input_name = origin_op.input_arg_names[idx]
origin_input_dist_attr = origin_op_dist_attr.inputs_dist_attrs[
origin_input_name]
quant_op_dist_attr.set_input_dist_attr(
input_name, origin_input_dist_attr)
if input_name not in main_program.blocks[ib].vars:
origin_input_var = main_program.blocks[ib].vars[
origin_input_name]
origin_in_tensor_dist_attr = dist_context.get_dist_tensor_for_program(
origin_input_var).dist_attr
quant_input_var = block.vars[input_name]
dist_context.set_tensor_dist_attr_for_program(
quant_input_var, origin_in_tensor_dist_attr)
for idx, output_name in enumerate(
quant_op.output_arg_names):
origin_output_name = origin_op.output_arg_names[idx]
origin_output_dist_attr = origin_op_dist_attr.outputs_dist_attrs[
origin_output_name]
quant_op_dist_attr.set_output_dist_attr(
output_name, origin_output_dist_attr)
if output_name not in main_program.blocks[ib].vars:
origin_output_var = main_program.blocks[ib].vars[
origin_output_name]
origin_out_tensor_dist_attr = dist_context.get_dist_tensor_for_program(
origin_output_var).dist_attr
quant_output_var = block.vars[output_name]
dist_context.set_tensor_dist_attr_for_program(
quant_output_var, origin_out_tensor_dist_attr)
dist_context.set_op_dist_attr_for_program(
quant_op, quant_op_dist_attr)
# recover vars' dist_attr
for name, dst_var in block.vars.items():
if name in main_program.blocks[ib].vars:
src_var = main_program.blocks[ib].vars[name]
dist_tensor = dist_context.get_dist_tensor_for_program(
src_var)
if not dist_tensor:
continue
dist_context.set_tensor_dist_attr_for_program(
dst_var, dist_tensor.dist_attr)
context.set_attr("main_program", quant_program)
context.set_attr("startup_program", startup_program)
context.set_attr("params_grads", new_params_grads)
......@@ -236,14 +236,14 @@ class RecomputePass(PassBase):
def _check_conflict(self, other_pass):
return True
def _apply_single_impl(self, main_programs, startup_programs, context):
def _apply_single_impl(self, main_program, startup_program, context):
checkpoints = self.get_attr("checkpoints")
loss = self.get_attr("loss")
no_grad_set = self.get_attr("no_grad_set")
self._dist_context = self.get_attr("dist_context")
main_block = main_programs.global_block()
no_grad_set_name = _get_stop_gradients(main_programs, no_grad_set)
main_block = main_program.global_block()
no_grad_set_name = _get_stop_gradients(main_program, no_grad_set)
# get op_path which is related to loss
op_path = _find_op_path_(main_block, [loss], [], no_grad_set_name)
......@@ -373,7 +373,7 @@ class RecomputePass(PassBase):
ckpt_ops_dict[fwd_op_id][0] = False
main_block._sync_with_cpp()
main_programs._sync_with_cpp()
main_program._sync_with_cpp()
def reset_op_dist_attr(self, op, var_name_dict):
op_dist_attr = self._dist_context.get_op_dist_attr_for_program(op)
......
......@@ -25,7 +25,8 @@ from .post_training_quantization import *
from . import imperative
from .imperative import *
__all__ = quantization_pass.__all__
__all__ = []
__all__ += quantization_pass.__all__
__all__ += quant_int8_mkldnn_pass.__all__
__all__ += quant2_int8_mkldnn_pass.__all__
__all__ += post_training_quantization.__all__
......
......@@ -42,6 +42,17 @@ _logger = get_logger(__name__,
fmt='%(asctime)s-%(levelname)s: %(message)s')
def lazy_import_fleet(layer_name_map, fake_quant_input_layers):
from paddle.distributed import fleet
layer_name_map[
'ColumnParallelLinear'] = fleet.meta_parallel.parallel_layers.mp_layers.ColumnParallelLinear
layer_name_map[
'RowParallelLinear'] = fleet.meta_parallel.parallel_layers.mp_layers.RowParallelLinear
fake_quant_input_layers.append(fleet.meta_parallel.RowParallelLinear)
fake_quant_input_layers.append(fleet.meta_parallel.ColumnParallelLinear)
return layer_name_map, fake_quant_input_layers
class ImperativeQuantAware(object):
"""
Applying quantization aware training (QAT) to the dgraph model.
......@@ -300,13 +311,15 @@ class ImperativeQuantizeInputs(object):
Please refer to the args of ImperativeQuantAware.
"""
super(ImperativeQuantizeInputs, self).__init__()
self.layer_name_map, self.fake_quant_input_layers = lazy_import_fleet(
utils.layer_name_map, utils.fake_quant_input_layers)
self._quantizable_layer_type = tuple(
utils.layer_name_map[layer] if layer in
utils.layer_name_map else layer for layer in quantizable_layer_type)
self.layer_name_map[layer] if layer in
self.layer_name_map else layer for layer in quantizable_layer_type)
for layer in self._quantizable_layer_type:
assert not isinstance(layer, str) \
and layer in utils.fake_quant_input_layers, \
and layer in self.fake_quant_input_layers, \
"%s is unspported to be quantized." % layer
quantize_type = {
......@@ -383,7 +396,7 @@ class ImperativeQuantizeInputs(object):
def _get_input_quantized_layer(self, layer):
quant_layer_name = None
for key, value in utils.layer_name_map.items():
for key, value in self.layer_name_map.items():
if isinstance(layer, value):
quant_layer_name = 'Quantized' + key
break
......
......@@ -16,63 +16,38 @@ import math
import numpy as np
import paddle
from paddle.distributed import fleet
import paddle.nn.quant.quant_layers as quant_layers
from ..utils import _get_op_input_var_names, _get_op_output_var_names, _get_output_name_index, _get_input_name_index
layer_name_map = {
'Conv2DTranspose':
paddle.nn.Conv2DTranspose,
'Conv2D':
paddle.nn.Conv2D,
'Linear':
paddle.nn.Linear,
'AdaptiveAvgPool2D':
paddle.nn.AdaptiveAvgPool2D,
'AdaptiveMaxPool2D':
paddle.nn.AdaptiveMaxPool2D,
'AvgPool2D':
paddle.nn.AvgPool2D,
'MaxPool2D':
paddle.nn.MaxPool2D,
'Hardswish':
paddle.nn.Hardswish,
'LeakyReLU':
paddle.nn.LeakyReLU,
'PReLU':
paddle.nn.PReLU,
'ReLU':
paddle.nn.ReLU,
'ReLU6':
paddle.nn.ReLU6,
'Sigmoid':
paddle.nn.Sigmoid,
'Softmax':
paddle.nn.Softmax,
'Swish':
paddle.nn.Swish,
'Tanh':
paddle.nn.Tanh,
'Hardswish':
paddle.nn.Hardswish,
'BatchNorm':
paddle.nn.BatchNorm,
'GroupNorm':
paddle.nn.GroupNorm,
'LayerNorm':
paddle.nn.LayerNorm,
'ColumnParallelLinear':
fleet.meta_parallel.parallel_layers.mp_layers.ColumnParallelLinear,
'RowParallelLinear':
fleet.meta_parallel.parallel_layers.mp_layers.RowParallelLinear
'Conv2DTranspose': paddle.nn.Conv2DTranspose,
'Conv2D': paddle.nn.Conv2D,
'Linear': paddle.nn.Linear,
'AdaptiveAvgPool2D': paddle.nn.AdaptiveAvgPool2D,
'AdaptiveMaxPool2D': paddle.nn.AdaptiveMaxPool2D,
'AvgPool2D': paddle.nn.AvgPool2D,
'MaxPool2D': paddle.nn.MaxPool2D,
'Hardswish': paddle.nn.Hardswish,
'LeakyReLU': paddle.nn.LeakyReLU,
'PReLU': paddle.nn.PReLU,
'ReLU': paddle.nn.ReLU,
'ReLU6': paddle.nn.ReLU6,
'Sigmoid': paddle.nn.Sigmoid,
'Softmax': paddle.nn.Softmax,
'Swish': paddle.nn.Swish,
'Tanh': paddle.nn.Tanh,
'Hardswish': paddle.nn.Hardswish,
'BatchNorm': paddle.nn.BatchNorm,
'GroupNorm': paddle.nn.GroupNorm,
'LayerNorm': paddle.nn.LayerNorm,
}
# Apply fake quant for the inputs of these layers
fake_quant_input_layers = [
paddle.nn.Conv2D, paddle.nn.Linear, paddle.nn.Conv2DTranspose,
fleet.meta_parallel.RowParallelLinear,
fleet.meta_parallel.ColumnParallelLinear
paddle.nn.Conv2D,
paddle.nn.Linear,
paddle.nn.Conv2DTranspose,
]
# Apply fake quant for the output of these layers
......
......@@ -65,4 +65,5 @@ if(WITH_DISTRIBUTE AND WITH_GPU)
py_test_modules(test_process_mesh_v2 MODULES test_process_mesh_v2)
py_test_modules(test_dist_attr_v2 MODULES test_dist_attr_v2)
py_test_modules(test_lr_grad_clip MODULES test_lr_grad_clip)
py_test_modules(test_quantization MODULES test_quantization)
endif()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import sys
import numpy as np
import paddle
import paddle.distributed.fleet as fleet
import paddle.distributed.auto_parallel as auto
from paddle.distributed.auto_parallel.engine import Engine
from paddle.distributed.auto_parallel.utils import print_program_with_dist_attr
sys.path.append("..")
import auto_parallel_gpt_model as modeling
from auto_parallel_gpt_model import GPTModel, GPTForPretraining, GPTPretrainingCriterion
paddle.enable_static()
class FakeDataset:
def __init__(self, num_samples, sequence_len, vocab_size):
self.num_samples = num_samples
self.sequence_len = sequence_len
self.vocab_size = vocab_size
def __getitem__(self, idx):
tokens = np.random.randint(self.vocab_size, size=self.sequence_len)
position_ids = np.arange(self.sequence_len)
attention_mask = np.tril(np.ones(self.sequence_len)).reshape(
(1, self.sequence_len, self.sequence_len)).astype(np.float32)
labels = np.random.randint(self.vocab_size, size=self.sequence_len)
loss_mask = np.ones(self.sequence_len).astype(np.float32)
return tokens, position_ids, attention_mask, labels, loss_mask
def __len__(self):
return self.num_samples
def apply_pass():
dist_strategy = fleet.DistributedStrategy()
dist_strategy.semi_auto = True
dist_strategy.qat = True
dist_strategy.qat_configs = {
'channel_wise_abs_max': True,
'weight_bits': 8,
'activation_bits': 8,
'not_quant_pattern': ['skip_quant'],
}
return dist_strategy
def create_data_holder(batch_size, sequence_len):
tokens = paddle.static.InputSpec(name="tokens",
shape=[batch_size, sequence_len],
dtype='int64')
position_ids = paddle.static.InputSpec(name="position_ids",
shape=[batch_size, sequence_len],
dtype='int64')
attention_mask = paddle.static.InputSpec(
name="attention_mask",
shape=[batch_size, 1, sequence_len, sequence_len],
dtype='float32')
labels = paddle.static.InputSpec(name="labels",
shape=[batch_size, sequence_len],
dtype='int64')
loss_mask = paddle.static.InputSpec(name="loss_mask",
shape=[batch_size, sequence_len],
dtype='float32')
return [tokens, position_ids, attention_mask], [labels, loss_mask]
def get_gpt_model():
modeling.init_global()
modeling._global_parallel_strategy = "serial"
modeling._global_process_mesh = auto.ProcessMesh(mesh=[0])
gpt = GPTModel(vocab_size=1000,
hidden_size=64,
num_hidden_layers=2,
num_attention_heads=8,
intermediate_size=256,
hidden_act="gelu",
hidden_dropout_prob=0.0,
attention_probs_dropout_prob=0.0,
max_position_embeddings=1024,
type_vocab_size=1,
initializer_range=0.02,
pad_token_id=0,
eos_token_id=7,
bos_token_id=0,
eol_token_id=3)
model = GPTForPretraining(gpt,
vocab_size=1000,
hidden_size=64,
initializer_range=0.02)
criterion = GPTPretrainingCriterion()
return model, criterion
class TestQuantizationPass(unittest.TestCase):
def test_qat_pass(self):
batch_size = 8
batch_num = 10
sequence_len = 512
vocab_size = 1000
strategy = apply_pass()
model, loss = get_gpt_model()
opt = paddle.optimizer.AdamW(learning_rate=0.00001)
inputs_spec, labels_spec = create_data_holder(batch_size=batch_size,
sequence_len=sequence_len)
engine = Engine(model, inputs_spec, labels_spec, strategy=strategy)
engine.prepare(optimizer=opt, loss=loss)
dataset = FakeDataset(batch_size * batch_num, sequence_len, vocab_size)
engine.fit(train_data=dataset, batch_size=batch_size)
self.check_program(engine.main_program)
def check_program(self, program):
quantizable_op_and_inputs = {'matmul_v2': ['X', 'Y']}
quantizable_grad_op_inputs = {'matmul_v2_grad': ['X', 'Y']}
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
is_quntized = False
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check forward
if op.type in quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
assert arg_name.endswith('.quantized.dequantized')
quantized_ops.add(arg_name)
for op in block.ops:
is_quntized = False
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
if ".quantized" in arg_name:
is_quntized = True
if not is_quntized:
continue
# check backward
if op.type in quantizable_grad_op_inputs:
for pname in quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
assert arg_name.endswith('.quantized.dequantized')
assert arg_name in quantized_ops
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册