未验证 提交 62f455e0 编写于 作者: C cc 提交者: GitHub

Support quantizing program_desc (#29526)

* Support quantizing program_desc, test=develop
上级 47d10c55
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import collections
import logging
import numpy as np
from .... import core
from ....framework import Program, Operator, Variable, program_guard
from .... import unique_name
from ....layer_helper import LayerHelper
from ....param_attr import ParamAttr
from ....initializer import Constant
from ....log_helper import get_logger
_logger = get_logger(
__name__, logging.INFO, fmt='%(asctime)s-%(levelname)s: %(message)s')
class QuantizeTranspilerV2(object):
def __init__(self,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max',
activation_quantize_type='abs_max',
quantizable_op_type=['conv2d', 'depthwise_conv2d', 'mul'],
skip_pattern=['skip_quant']):
"""
Add quant_dequant op before the quantized op to quantize the fluid Program.
It is a patch for distributed quantization, we will support others module for
distributed quantization.
Args:
weight_bits(int): the bit of quantized weight.
activation_bits(int): the bit of quantized activation.
weight_quantize_type(str): the quantization type for weight.
Only support to be 'abs_max' for now.
activation_quantize_type(str): the quantization type for activation.
Only support to be 'abs_max' for now.
quantizable_op_type(str): set the op type for quantization.
skip_pattern(str|list): The user-defined quantization skip pattern, which
will be presented in the name scope of an op. When the skip pattern is
detected in an op's name scope, the corresponding op will not be quantized.
"""
self._weight_bits = weight_bits
self._activation_bits = activation_bits
assert activation_quantize_type == "abs_max", \
"activation_quantize_type should be abs_max for now."
assert weight_quantize_type == "abs_max", \
"weight_quantize_type should be abs_max for now."
self._activation_quantize_type = activation_quantize_type
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = quantizable_op_type
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self._skip_pattern = skip_pattern
self.helper = LayerHelper(self.__class__.__name__)
def apply(self, program, startup_program):
"""
Apply quantization to fluid Program.
Args:
program(Program): the train or test program to be quantized.
startup_program(Program): the corresponding startup_program.
Returns:
None
"""
assert isinstance(program, Program), \
"program must be the instance of Program"
assert isinstance(startup_program, Program), \
"startup_program must be the instance of Program"
quant_dequant_vars = [
collections.OrderedDict() for _ in range(len(program.blocks))
]
with program_guard(program, startup_program):
for block in program.blocks:
ops = list(block.ops)
for op in ops:
if op.type in self._quantizable_ops and \
(not self._is_skip_quant(op)):
self._transform_forward(block, op, quant_dequant_vars)
for block in program.blocks:
ops = list(block.ops)
for op in ops:
if op.type in self._quantizable_grad_ops and \
(not self._is_skip_quant(op)):
self._transform_backward(block, op, quant_dequant_vars)
def _is_skip_quant(self, op):
"""
Analyse whether the op should skip quantization or not.
"""
user_skipped = False
if isinstance(self._skip_pattern, list):
user_skipped = op.has_attr("op_namescope") and \
any(pattern in op.attr("op_namescope") \
for pattern in self._skip_pattern)
elif isinstance(self._skip_pattern, str):
user_skipped = op.has_attr("op_namescope") and \
op.attr("op_namescope").find(
self._skip_pattern) != -1
return user_skipped
def _transform_forward(self, block, op, quant_dequant_vars):
op._set_attr("quantization_type", "qat_with_weight")
idx = block.ops.index(op)
block_id = block.idx
for in_name in op.input_arg_names:
if in_name in quant_dequant_vars[block_id]:
quant_dequant_var = quant_dequant_vars[block_id][in_name]
else:
in_var = block.var(in_name)
quant_bits = self._weight_bits if in_var.persistable \
else self._activation_bits
quant_type = self._weight_quantize_type if in_var.persistable \
else self._activation_quantize_type
if quant_type == "abs_max":
quant_dequant_var = self._insert_quant_dequant_abs_max_op(
block, idx, in_var, quant_bits)
else:
_logger.error("Quant_type only supported to be abs_max")
quant_dequant_vars[block_id][in_name] = quant_dequant_var
op._rename_input(in_name, quant_dequant_var.name)
def _transform_backward(self, block, op, quant_dequant_vars):
block_id = block.idx
no_dequanted_input_vars = True
for name in op.input_arg_names:
if name in quant_dequant_vars[block_id]:
dequant_var = quant_dequant_vars[block_id][name]
op._rename_input(name, dequant_var.name)
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.type))
def _insert_quant_dequant_abs_max_op(self, block, idx, in_var, quant_bits):
quant_dequant_var = block.create_var(
type=in_var.type,
name="{}.quant_dequant".format(in_var.name),
shape=in_var.shape,
dtype=in_var.dtype)
scale_var = self.helper.create_parameter(
attr=ParamAttr(
name="{}.quant_dequant.scale".format(in_var.name),
initializer=Constant(0.001),
trainable=False),
shape=[1],
dtype=in_var.dtype)
scale_var.stop_gradient = True
inputs = {'X': in_var}
outputs = {'Out': quant_dequant_var, 'OutScale': scale_var}
attrs = {'bit_length': quant_bits}
block._insert_op(
idx,
type='fake_quantize_dequantize_abs_max',
attrs=attrs,
inputs=inputs,
outputs=outputs)
return quant_dequant_var
......@@ -123,8 +123,9 @@ if(WIN32)
list(REMOVE_ITEM TEST_OPS test_light_nas)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mnist)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_post_training_quantization_resnet50)
list(REMOVE_ITEM TEST_OPS test_weight_quantization_mobilenetv1)
list(REMOVE_ITEM TEST_OPS test_quantize_transpiler_v2)
endif()
if(LINUX AND WITH_MKLDNN)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import os
import unittest
import random
import numpy as np
import six
import paddle.fluid as fluid
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization.quantize_transpiler_v2 import QuantizeTranspilerV2
from paddle.fluid import core
paddle.enable_static()
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
os.environ["CPU_NUM"] = "1"
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
pool_type='max',
act="relu")
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
pool_type='avg',
act="relu")
with fluid.name_scope("skip_quant"):
hidden = fluid.layers.fc(input=conv_pool_1, size=100, act='relu')
prediction = fluid.layers.fc(input=hidden, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
class TestQuantizeProgramPass(unittest.TestCase):
def quantize_program(self,
use_cuda,
seed,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=False):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.0001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
train_program = fluid.Program()
startup_program = fluid.Program()
test_program = fluid.Program()
feeds, loss = build_program(train_program, startup_program, False)
build_program(test_program, startup_program, True)
test_program = test_program.clone(for_test=True)
if not for_ci:
train_graph = IrGraph(
core.Graph(train_program.desc), for_test=False)
train_graph.draw('.', 'train_program_1')
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
test_graph.draw('.', 'test_program_1')
qt = QuantizeTranspilerV2(
activation_quantize_type=activation_quant_type,
weight_quantize_type=weight_quant_type,
quantizable_op_type=[
'conv2d', 'depthwise_conv2d', 'mul', 'pool2d'
])
qt.apply(train_program, startup_program)
qt.apply(test_program, startup_program)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup_program)
if not for_ci:
train_graph = IrGraph(
core.Graph(train_program.desc), for_test=False)
train_graph.draw('.', 'train_program_2')
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
test_graph.draw('.', 'test_program_2')
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
build_strategy.fuse_all_reduce_ops = False
binary = fluid.CompiledProgram(train_program).with_data_parallel(
loss_name=loss.name, build_strategy=build_strategy)
iters = 2
batch_size = 8
train_reader = paddle.batch(
paddle.dataset.mnist.train(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(binary,
feed=feeder.feed(data),
fetch_list=[loss])
if not for_ci:
print('{}: {}'.format('loss', loss_v))
if not for_ci:
with fluid.scope_guard(scope):
fluid.io.save_inference_model('./infer_model',
['image', 'label'], [loss], exe,
test_program)
def test_quantize_program_gpu(self):
if fluid.core.is_compiled_with_cuda():
self.quantize_program(
use_cuda=True,
seed=1,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
def test_quantize_program_cpu(self):
self.quantize_program(
use_cuda=False,
seed=2,
activation_quant_type='abs_max',
weight_quant_type='abs_max',
for_ci=True)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册