提交 59e5cc51 编写于 作者: W WangZhen

Add quantization transform pass and UT.

上级 e2ff300b
...@@ -148,8 +148,8 @@ void BindNode(py::module *m) { ...@@ -148,8 +148,8 @@ void BindNode(py::module *m) {
}) })
.def("outputs_append", .def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); }) [](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs) .def_readonly("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs); .def_readonly("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type") py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation) .value("Operation", Node::Type::kOperation)
......
...@@ -26,10 +26,20 @@ class PyGraph(object): ...@@ -26,10 +26,20 @@ class PyGraph(object):
PyGraph uses core.Graph as the delegation to accomplish the manipulation. PyGraph uses core.Graph as the delegation to accomplish the manipulation.
""" """
def __init__(self, graph): def __init__(self, graph, for_test=False):
"""
Construct the PyGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance( assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.' graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph self.graph = graph
self.for_test = for_test
def is_test(self):
return self.for_test
def all_parameters(self): def all_parameters(self):
param_nodes = set() param_nodes = set()
...@@ -103,7 +113,7 @@ class PyGraph(object): ...@@ -103,7 +113,7 @@ class PyGraph(object):
remove_nodes = set(remove_nodes) remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes) core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw_graph(self, save_path, name, marked_nodes=None): def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path): def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf' pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \ exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
...@@ -126,6 +136,8 @@ class PyGraph(object): ...@@ -126,6 +136,8 @@ class PyGraph(object):
if not isinstance(marked_nodes, set): if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes) marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes) self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot' viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass') viz_pass = core.get_pass('graph_viz_pass')
...@@ -137,8 +149,8 @@ class PyGraph(object): ...@@ -137,8 +149,8 @@ class PyGraph(object):
convert_pass = core.get_pass('graph_to_program_pass') convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc) convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph) convert_pass.apply(self.graph)
program = Program() desc = convert_pass.get_program('program')
program.desc = convert_pass.get_program('program') program = Program.construct_from_desc(desc)
return program return program
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
from __future__ import print_function from __future__ import print_function
from . import quantization_performer from . import quantization_pass
from .quantization_performer import * from .quantization_pass import *
__all__ = quantization_performer.__all__ __all__ = quantization_pass.__all__
...@@ -15,22 +15,26 @@ ...@@ -15,22 +15,26 @@
import collections import collections
import numpy as np import numpy as np
from .... import core from .... import core
from ....framework import Program
from ....framework import Variable
from ....initializer import Constant from ....initializer import Constant
from .... import unique_name from .... import unique_name
from ..graph import PyGraph from ..graph import PyGraph
__all__ = ['QuantizationPerformer'] __all__ = ['QuantizationTransformPass']
class QuantizationPerformer(object): class QuantizationTransformPass(object):
def __init__(self, def __init__(self,
scope=None,
program_exe=None,
weight_bits=8, weight_bits=8,
activation_bits=8, activation_bits=8,
activation_quantize_type='abs_max', activation_quantize_type='abs_max',
weight_quantize_type='abs_max', weight_quantize_type='abs_max',
window_size=10000): window_size=10000):
""" """
Convert and rewrite the IRGraph according to weight and Convert and rewrite the PyGraph according to weight and
activation quantization type. activation quantization type.
Args: Args:
weight_bits (int): quantization bit number for weights, weight_bits (int): quantization bit number for weights,
...@@ -48,15 +52,21 @@ class QuantizationPerformer(object): ...@@ -48,15 +52,21 @@ class QuantizationPerformer(object):
window_size (int): the window size for 'range_abs_max' quantization. window_size (int): the window size for 'range_abs_max' quantization.
Examples: Examples:
.. code-block:: python .. code-block:: python
# the original graph will be rewrite, if you don't want to # The original graph will be rewrite.
# change it, please clone at first. import paddle.fluid as fluid
# graph = graph.clone() from paddle.fluid.contrib.slim.quantization \
from paddle.fluid.contrib.slim import * import QuantizationTransformPass
from paddle.fluid.contrib.quantize import * from paddle.fluid.contrib.slim.graph import PyGraph
graph = IRGraph(program) from paddle.fluid import core
performer = QuantizationPerformer()
performer.quantize_transform(graph) graph = PyGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
transform_pass.apply(graph)
""" """
self.scope = scope
self.program_exe = program_exe
self.weight_bits = weight_bits self.weight_bits = weight_bits
self.activation_bits = activation_bits self.activation_bits = activation_bits
...@@ -74,7 +84,7 @@ class QuantizationPerformer(object): ...@@ -74,7 +84,7 @@ class QuantizationPerformer(object):
self.weight_quantize_type = weight_quantize_type self.weight_quantize_type = weight_quantize_type
self.window_size = window_size self.window_size = window_size
self.need_inited_outer = collections.OrderedDict() self.need_initialized = collections.OrderedDict()
self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul'] self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self.quantizable_grad_ops = [ self.quantizable_grad_ops = [
'%s_grad' % (op) for op in self.quantizable_ops '%s_grad' % (op) for op in self.quantizable_ops
...@@ -86,11 +96,11 @@ class QuantizationPerformer(object): ...@@ -86,11 +96,11 @@ class QuantizationPerformer(object):
self.is_test = None self.is_test = None
self.global_step = None self.global_step = None
def quantize_transform(self, graph, is_test): def apply(self, graph):
self.need_inited_outer.clear()
self.is_test = is_test
assert isinstance(graph, assert isinstance(graph,
PyGraph), 'graph must be the instance of PyGraph.' PyGraph), 'graph must be the instance of PyGraph.'
self.need_initialized.clear()
self.is_test = graph.is_test()
# marked the variable which has been dequantized. # marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict() dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()] params = [p.name() for p in graph.all_parameters()]
...@@ -138,7 +148,19 @@ class QuantizationPerformer(object): ...@@ -138,7 +148,19 @@ class QuantizationPerformer(object):
if op.name() in self.quantizable_grad_ops: if op.name() in self.quantizable_grad_ops:
_transform_backward(graph, op) _transform_backward(graph, op)
return self.need_inited_outer if len(self.need_initialized) > 0:
assert self.scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self.program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in self.need_initialized.iteritems():
var = Variable.construct_from_desc(init_program.global_block(),
var_desc)
initializer(var, init_program.global_block())
self.program_exe.run(program=init_program, scope=self.scope)
return graph
def _create_global_step(self, graph): def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \ if self.weight_quantize_type == 'range_abs_max' or \
...@@ -153,7 +175,7 @@ class QuantizationPerformer(object): ...@@ -153,7 +175,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=core.VarDesc.VarType.INT64) var_dtype=core.VarDesc.VarType.INT64)
self.need_inited_outer[global_step_in.var()] = \ self.need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True) Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc( global_step_out = graph.create_var_node_from_desc(
global_step_in.var()) global_step_in.var())
...@@ -220,7 +242,7 @@ class QuantizationPerformer(object): ...@@ -220,7 +242,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1], shape=[1],
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
self.need_inited_outer[scale_in_node.var()] = Constant(value=0.001) self.need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var()) scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node} inputs = {'X': var_node, 'InScale': scale_in_node}
...@@ -233,7 +255,7 @@ class QuantizationPerformer(object): ...@@ -233,7 +255,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR, var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self.window_size], shape=[self.window_size],
var_dtype=var_node.var().dtype()) var_dtype=var_node.var().dtype())
self.need_inited_outer[scales_node.var()] = Constant(value=0) self.need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self.global_step inputs['Iter'] = self.global_step
outputs['OutScales'] = scales_node outputs['OutScales'] = scales_node
attrs = { attrs = {
......
...@@ -15,11 +15,10 @@ ...@@ -15,11 +15,10 @@
import unittest import unittest
import random import random
import numpy as np import numpy as np
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import six import six
from paddle.fluid.framework import Program from paddle.fluid.framework import Program
from paddle.fluid.contrib.slim.quantization import QuantizationPerformer from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core from paddle.fluid import core
...@@ -66,22 +65,39 @@ def residual_block(num): ...@@ -66,22 +65,39 @@ def residual_block(num):
return loss return loss
class TestQuantizationPerformer(unittest.TestCase): class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self): def setUp(self):
# since quant_op and dequant_op is not ready, use cos and sin for test
self.weight_quant_op_type = 'fake_quantize_abs_max'
self.dequant_op_type = 'fake_dequantize_max_abs'
self.quantizable_op_and_inputs = { self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'], 'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'], 'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y'] 'mul': ['X', 'Y']
} }
self.quantizable_op_grad_and_inputs = { self.quantizable_grad_op_inputs = {
'conv2d_grad': ['Input', 'Filter'], 'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'], 'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y'] 'mul_grad': ['X', 'Y']
} }
def check_program(self, transform_pass, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
# check forward
if op.type in self.quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
quantized_ops.add(arg_name)
for op in block.ops:
# check backward
if op.type in self.quantizable_grad_op_inputs:
for pname in self.quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type): def linear_fc_quant(self, quant_type):
main = fluid.Program() main = fluid.Program()
startup = fluid.Program() startup = fluid.Program()
...@@ -89,14 +105,26 @@ class TestQuantizationPerformer(unittest.TestCase): ...@@ -89,14 +105,26 @@ class TestQuantizationPerformer(unittest.TestCase):
loss = linear_fc(3) loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc)) exe = fluid.Executor(fluid.CPUPlace())
performer = QuantizationPerformer(activation_quantize_type=quant_type) graph = PyGraph(core.Graph(main.desc), for_test=False)
performer.quantize_transform(graph, False) transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_ops():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw_graph('.', 'quantize_fc_' + quant_type, marked_nodes) graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self): def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max' self.act_quant_op_type = 'fake_quantize_abs_max'
...@@ -113,14 +141,26 @@ class TestQuantizationPerformer(unittest.TestCase): ...@@ -113,14 +141,26 @@ class TestQuantizationPerformer(unittest.TestCase):
loss = residual_block(2) loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001) opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss) opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc)) exe = fluid.Executor(fluid.CPUPlace())
performer = QuantizationPerformer(activation_quantize_type=quant_type) graph = PyGraph(core.Graph(main.desc), for_test=False)
performer.quantize_transform(graph, False) transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set() marked_nodes = set()
for op in graph.all_ops(): for op in graph.all_ops():
if op.name().find('quantize') > -1: if op.name().find('quantize') > -1:
marked_nodes.add(op) marked_nodes.add(op)
graph.draw_graph('.', 'quantize_residual_' + quant_type, marked_nodes) graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self): def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max' self.act_quant_op_type = 'fake_quantize_abs_max'
......
...@@ -378,6 +378,27 @@ class Variable(object): ...@@ -378,6 +378,27 @@ class Variable(object):
self._ivar.desc = self.desc self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient self._ivar.stop_gradient = stop_gradient
@staticmethod
def construct_from_desc(block, desc):
"""
Construct a Variable from variable desc.
Args:
desc(core.VarDesc): The variable desc for constructing.
Returns:
Variable: A variable.
"""
v = Variable(
block=block,
type=desc.type(),
name=desc.name(),
shape=desc.shape(),
dtype=desc.dtype(),
lod_level=desc.lod_level(),
persistable=desc.persistable())
v.desc = desc
return v
def _numpy(self): def _numpy(self):
tensor = self._ivar.value().get_tensor() tensor = self._ivar.value().get_tensor()
return np.array(tensor) return np.array(tensor)
...@@ -1925,6 +1946,25 @@ class Program(object): ...@@ -1925,6 +1946,25 @@ class Program(object):
p._sync_with_cpp() p._sync_with_cpp()
return p return p
@staticmethod
def construct_from_desc(desc):
"""
Construct a program from program desc.
Notes: All information about parameters will be lost.
Args:
desc(core.ProgramDesc): The program desc for constructing.
Returns:
Program: A program.
"""
p = Program()
p.desc = desc
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp()
return p
@property @property
def random_seed(self): def random_seed(self):
""" """
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册