提交 dde19a0f 编写于 作者: W WangZhen

add quantization freeze pass.

上级 f4dec5cd
......@@ -17,6 +17,7 @@
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
......@@ -27,6 +28,10 @@ namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::GraphNum;
using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
......@@ -36,6 +41,12 @@ namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def("topology_sort", TopologySortOperations,
return_value_policy::reference);
m->def("build_adjacency_list", BuildOperationAdjList,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
......
......@@ -64,6 +64,7 @@ if (WITH_TESTING)
add_subdirectory(paddle/dataset/tests)
add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels
......
......@@ -13,6 +13,7 @@
# limitations under the License.
import collections
import numpy as np
from .... import core
from ....framework import IrGraph
from ....framework import Program
......@@ -88,10 +89,6 @@ class QuantizationTransformPass(object):
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self._fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_types = ['fake_dequantize_max_abs']
self._is_test = None
self._global_step = None
......@@ -102,17 +99,17 @@ class QuantizationTransformPass(object):
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()]
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
quant_bits = self._weight_bits if var_node.name() in params \
quant_bits = self._weight_bits if var_node.name() in persistable_vars \
else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \
in params else self._activation_quantize_type
in persistable_vars else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
......@@ -316,3 +313,179 @@ class QuantizationTransformPass(object):
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
class QuantizationFreezePass(object):
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max'):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
self._scope = scope
self._place = place
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict()
def apply(self, graph):
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0]
if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param))
else:
scale_v = self._load_var(op_node.op().output('OutScale')
[0])[0]
self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore
param_v = self._load_var(input_arg_name)
quantized_param_v = self._quant(param_v, scale_v,
self.weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
self._insert_post_dequant_op(graph, op_node)
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_output_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_output_rename_map[name])
graph.update_input_link(old_in, new_in, op_node)
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0]
v = op_node.op().input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
else:
self._op_input_rename_map[k] = self._op_input_rename_map[v]
graph.save_remove_nodes(op_node)
def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
for var_node in op_node.op().inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
if original_var_name in persistable_vars:
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
scale_v = self._var_scale_map[original_var_name]
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name)
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, core.Node)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.op().outputs) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.op().outputs[0]
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(),
shape=output_var_node.var().shape(),
var_dtype=output_var_node.var().dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)},
inputs={'X': output_var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name(
)] = dequant_var_node.name()
return dequant_var_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
def _restore_var(self, name, arr):
t = self._scope.find_var(name).get_tensor()
t.set(arr, self._place)
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name):
"""
Return the original variable name.
"""
if var_name.endswith('.quantized.dequantized'):
return var_name[:-len('.quantized.dequantized')]
if var_name.endswith('.quantized'):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
else:
return var_name
def _dequantized_var_name(self, var_name):
"""
Return dequantized variable name for the input `var_name`.
"""
return "%s.dequantized" % (var_name)
def _is_float(v):
return isinstance(v, float) or isinstance(v, np.float32) \
or isinstance(v, np.float64)
def _quant(x, scale, num_bits):
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
version: 1.0
include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"]
include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"]
pruners:
pruner_1:
class: 'RatioPruner'
......
......@@ -18,7 +18,7 @@ import unittest
class TestFactory(unittest.TestCase):
def test_parse(self):
factory = ConfigFactory('./unitest/configs/config.yaml')
factory = ConfigFactory('./configs/config.yaml')
pruner = factory.instance('pruner_1')
self.assertEquals(pruner.ratios['conv1_1.w'], 0.3)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestGraph(unittest.TestCase):
def test_graph_functions(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops()))
nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops()))
nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
if __name__ == '__main__':
unittest.main()
......@@ -65,6 +65,28 @@ def residual_block(num):
return loss
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
self.quantizable_op_and_inputs = {
......@@ -171,5 +193,103 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.residual_block_quant('range_abs_max')
class TestQuantizeTranspiler(unittest.TestCase):
def freeze_graph(self, use_cuda, seed):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_graph.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
transform_pass = QuantizationTransformPass(
scope=fluid.global_scope(), program_exe=exe)
iters = 5
batch_size = 8
class_num = 10
exe.run(startup)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.program_guard(main):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(program=main,
feed=feeder.feed(data),
fetch_list=[loss])
with fluid.program_guard(test_program):
test_data = next(test_reader())
w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
test_program)
# Testing during training
test_loss1, w_quant = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, w_var])
# Freeze program for inference, but the weight of fc/conv is still float type.
quant_transpiler.freeze_program(test_program, place)
test_loss2, = exe.run(program=test_program,
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
w_freeze = np.array(fluid.global_scope().find_var('conv2d_1.w_0')
.get_tensor())
# fail: -432.0 != -433.0, this is due to the calculation precision
#self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
# Convert parameter to 8-bit.
quant_transpiler.convert_to_int8(test_program, place)
# Save the 8-bit parameter and model file.
fluid.io.save_inference_model('model_8bit', ['image', 'label'],
[loss], exe, test_program)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model('model_8bit',
exe)
# Check the loaded 8-bit weight.
w_8bit = np.array(fluid.global_scope().find_var('conv2d_1.w_0.int8')
.get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
def not_test_freeze_program_cuda(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_program(True, seed=1)
def not_test_freeze_program_cpu(self):
with fluid.unique_name.guard():
self.freeze_program(False, seed=2)
if __name__ == '__main__':
unittest.main()
......@@ -1533,20 +1533,47 @@ class IrGraph(object):
def is_test(self):
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_nodes(self):
return {node for node in self.graph.nodes()}
def all_vars(self):
return {node for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self):
persistable_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
persistable_nodes.add(node)
return persistable_nodes
def all_ops(self):
return {node for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
"""
Get a variable node by name from this graph.
Args:
name(str): the name of the variable node.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
Returns:
Node: the variable node with the giving name.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_vars()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
......@@ -1586,8 +1613,9 @@ class IrGraph(object):
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
assert old_input_node in self.graph.nodes() and new_input_node in \
self.graph.nodes() and op_node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
......@@ -1596,7 +1624,7 @@ class IrGraph(object):
def link_to(self, node_in, node_out):
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
......@@ -1605,6 +1633,18 @@ class IrGraph(object):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def has_circle(self):
return core.has_circle(self.graph)
def graph_num(self):
return core.graph_num(self.graph)
def topology_sort(self):
return core.topology_sort(self.graph)
def build_adjacency_list(self):
return core.build_adjacency_list(self.graph)
def draw(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册