提交 59e5cc51 编写于 作者: W WangZhen

Add quantization transform pass and UT.

上级 e2ff300b
......@@ -148,8 +148,8 @@ void BindNode(py::module *m) {
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
.def_readonly("inputs", &Node::inputs)
.def_readonly("outputs", &Node::outputs);
py::enum_<Node::Type>(node, "Type")
.value("Operation", Node::Type::kOperation)
......
......@@ -26,10 +26,20 @@ class PyGraph(object):
PyGraph uses core.Graph as the delegation to accomplish the manipulation.
"""
def __init__(self, graph):
def __init__(self, graph, for_test=False):
"""
Construct the PyGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
"""
assert isinstance(
graph, core.Graph), 'graph must be the instance of core.Graph.'
self.graph = graph
self.for_test = for_test
def is_test(self):
return self.for_test
def all_parameters(self):
param_nodes = set()
......@@ -103,7 +113,7 @@ class PyGraph(object):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw_graph(self, save_path, name, marked_nodes=None):
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
......@@ -126,6 +136,8 @@ class PyGraph(object):
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
if self.graph.has('__graphviz__marked_node__'):
self.graph.erase('__graphviz__marked_node__')
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
......@@ -137,8 +149,8 @@ class PyGraph(object):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph)
program = Program()
program.desc = convert_pass.get_program('program')
desc = convert_pass.get_program('program')
program = Program.construct_from_desc(desc)
return program
......
......@@ -14,7 +14,7 @@
from __future__ import print_function
from . import quantization_performer
from .quantization_performer import *
from . import quantization_pass
from .quantization_pass import *
__all__ = quantization_performer.__all__
__all__ = quantization_pass.__all__
......@@ -15,22 +15,26 @@
import collections
import numpy as np
from .... import core
from ....framework import Program
from ....framework import Variable
from ....initializer import Constant
from .... import unique_name
from ..graph import PyGraph
__all__ = ['QuantizationPerformer']
__all__ = ['QuantizationTransformPass']
class QuantizationPerformer(object):
class QuantizationTransformPass(object):
def __init__(self,
scope=None,
program_exe=None,
weight_bits=8,
activation_bits=8,
activation_quantize_type='abs_max',
weight_quantize_type='abs_max',
window_size=10000):
"""
Convert and rewrite the IRGraph according to weight and
Convert and rewrite the PyGraph according to weight and
activation quantization type.
Args:
weight_bits (int): quantization bit number for weights,
......@@ -48,15 +52,21 @@ class QuantizationPerformer(object):
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# the original graph will be rewrite, if you don't want to
# change it, please clone at first.
# graph = graph.clone()
from paddle.fluid.contrib.slim import *
from paddle.fluid.contrib.quantize import *
graph = IRGraph(program)
performer = QuantizationPerformer()
performer.quantize_transform(graph)
# The original graph will be rewrite.
import paddle.fluid as fluid
from paddle.fluid.contrib.slim.quantization \
import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core
graph = PyGraph(core.Graph(program.desc), for_test=False)
exe = fluid.Executor(fluid.CPUPlace())
transform_pass = QuantizationTransformPass(fluid.global_scope(),
exe)
transform_pass.apply(graph)
"""
self.scope = scope
self.program_exe = program_exe
self.weight_bits = weight_bits
self.activation_bits = activation_bits
......@@ -74,7 +84,7 @@ class QuantizationPerformer(object):
self.weight_quantize_type = weight_quantize_type
self.window_size = window_size
self.need_inited_outer = collections.OrderedDict()
self.need_initialized = collections.OrderedDict()
self.quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self.quantizable_grad_ops = [
'%s_grad' % (op) for op in self.quantizable_ops
......@@ -86,11 +96,11 @@ class QuantizationPerformer(object):
self.is_test = None
self.global_step = None
def quantize_transform(self, graph, is_test):
self.need_inited_outer.clear()
self.is_test = is_test
def apply(self, graph):
assert isinstance(graph,
PyGraph), 'graph must be the instance of PyGraph.'
self.need_initialized.clear()
self.is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()]
......@@ -138,7 +148,19 @@ class QuantizationPerformer(object):
if op.name() in self.quantizable_grad_ops:
_transform_backward(graph, op)
return self.need_inited_outer
if len(self.need_initialized) > 0:
assert self.scope is not None, \
'The scope cannot be set None when activation_quantize_type equals to range_abs_max.'
assert self.program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in self.need_initialized.iteritems():
var = Variable.construct_from_desc(init_program.global_block(),
var_desc)
initializer(var, init_program.global_block())
self.program_exe.run(program=init_program, scope=self.scope)
return graph
def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \
......@@ -153,7 +175,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self.need_inited_outer[global_step_in.var()] = \
self.need_initialized[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
......@@ -220,7 +242,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=var_node.var().dtype())
self.need_inited_outer[scale_in_node.var()] = Constant(value=0.001)
self.need_initialized[scale_in_node.var()] = Constant(value=0.001)
scale_out_node = graph.create_var_node_from_desc(scale_in_node.var())
inputs = {'X': var_node, 'InScale': scale_in_node}
......@@ -233,7 +255,7 @@ class QuantizationPerformer(object):
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[self.window_size],
var_dtype=var_node.var().dtype())
self.need_inited_outer[scales_node.var()] = Constant(value=0)
self.need_initialized[scales_node.var()] = Constant(value=0)
inputs['Iter'] = self.global_step
outputs['OutScales'] = scales_node
attrs = {
......
......@@ -15,11 +15,10 @@
import unittest
import random
import numpy as np
import paddle
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.contrib.slim.quantization import QuantizationPerformer
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core
......@@ -66,22 +65,39 @@ def residual_block(num):
return loss
class TestQuantizationPerformer(unittest.TestCase):
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
# since quant_op and dequant_op is not ready, use cos and sin for test
self.weight_quant_op_type = 'fake_quantize_abs_max'
self.dequant_op_type = 'fake_dequantize_max_abs'
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_op_grad_and_inputs = {
self.quantizable_grad_op_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def check_program(self, transform_pass, program):
quantized_ops = set()
for block in program.blocks:
for op in block.ops:
# check forward
if op.type in self.quantizable_op_and_inputs:
for arg_name in op.input_arg_names:
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
quantized_ops.add(arg_name)
for op in block.ops:
# check backward
if op.type in self.quantizable_grad_op_inputs:
for pname in self.quantizable_grad_op_inputs[op.type]:
arg_name = op.input(pname)[0]
self.assertTrue(
arg_name.endswith('.quantized.dequantized'))
self.assertTrue(arg_name in quantized_ops)
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
......@@ -89,14 +105,26 @@ class TestQuantizationPerformer(unittest.TestCase):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_fc_' + quant_type, marked_nodes)
graph.draw('.', 'quantize_fc_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_fc_' + quant_type, val_marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
......@@ -113,14 +141,26 @@ class TestQuantizationPerformer(unittest.TestCase):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
exe = fluid.Executor(fluid.CPUPlace())
graph = PyGraph(core.Graph(main.desc), for_test=False)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(),
program_exe=exe,
activation_quantize_type=quant_type)
transform_pass.apply(graph)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_residual_' + quant_type, marked_nodes)
graph.draw('.', 'quantize_residual_' + quant_type, marked_nodes)
program = graph.to_program()
self.check_program(transform_pass, program)
val_graph = PyGraph(core.Graph(program.desc), for_test=False)
val_marked_nodes = set()
for op in val_graph.all_ops():
if op.name().find('quantize') > -1:
val_marked_nodes.add(op)
val_graph.draw('.', 'val_residual_' + quant_type, val_marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
......
......@@ -378,6 +378,27 @@ class Variable(object):
self._ivar.desc = self.desc
self._ivar.stop_gradient = stop_gradient
@staticmethod
def construct_from_desc(block, desc):
"""
Construct a Variable from variable desc.
Args:
desc(core.VarDesc): The variable desc for constructing.
Returns:
Variable: A variable.
"""
v = Variable(
block=block,
type=desc.type(),
name=desc.name(),
shape=desc.shape(),
dtype=desc.dtype(),
lod_level=desc.lod_level(),
persistable=desc.persistable())
v.desc = desc
return v
def _numpy(self):
tensor = self._ivar.value().get_tensor()
return np.array(tensor)
......@@ -1925,6 +1946,25 @@ class Program(object):
p._sync_with_cpp()
return p
@staticmethod
def construct_from_desc(desc):
"""
Construct a program from program desc.
Notes: All information about parameters will be lost.
Args:
desc(core.ProgramDesc): The program desc for constructing.
Returns:
Program: A program.
"""
p = Program()
p.desc = desc
p.blocks = [Block(p, i) for i in six.moves.range(p.desc.num_blocks())]
p._sync_with_cpp()
return p
@property
def random_seed(self):
"""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册