提交 e2ff300b 编写于 作者: W WangZhen

add UT for quantization.

上级 451896fc
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/sequential_execution_pass.h" #include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h" #include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle { namespace paddle {
......
...@@ -15,7 +15,9 @@ ...@@ -15,7 +15,9 @@
#include "paddle/fluid/pybind/ir.h" #include "paddle/fluid/pybind/ir.h"
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h" #include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h" #include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
...@@ -24,6 +26,7 @@ ...@@ -24,6 +26,7 @@
namespace py = pybind11; namespace py = pybind11;
using paddle::framework::ir::Graph; using paddle::framework::ir::Graph;
using paddle::framework::ir::Node; using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::OpDesc; using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc; using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc; using paddle::framework::VarDesc;
...@@ -32,6 +35,7 @@ using pybind11::return_value_policy; ...@@ -32,6 +35,7 @@ using pybind11::return_value_policy;
namespace paddle { namespace paddle {
namespace pybind { namespace pybind {
void BindGraph(py::module *m) { void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
py::class_<Graph, std::shared_ptr<Graph>>( py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph", *m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see " "The graph is a Directed Acyclic Single Static Assignment Graph, see "
...@@ -43,6 +47,7 @@ void BindGraph(py::module *m) { ...@@ -43,6 +47,7 @@ void BindGraph(py::module *m) {
.def("get_double", &Graph::Get<double>) .def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>) .def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>) .def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name, .def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); }) int attr) { return self.Set(attr_name, new int(attr)); })
.def("set", .def("set",
...@@ -63,6 +68,12 @@ void BindGraph(py::module *m) { ...@@ -63,6 +68,12 @@ void BindGraph(py::module *m) {
const ProgramDesc &attr) { const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr)); return self.Set(attr_name, new ProgramDesc(attr));
}) })
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("erase", &Graph::Erase) .def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference) .def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node", .def("create_var_node",
...@@ -91,12 +102,52 @@ void BindNode(py::module *m) { ...@@ -91,12 +102,52 @@ void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node"); py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name) node.def("name", &Node::Name)
.def("node_type", &Node::NodeType) .def("node_type", &Node::NodeType)
.def("var", &Node::Var) .def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op) .def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id) .def("id", &Node::id)
.def("is_op", &Node::IsOp) .def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar) .def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar) .def("is_ctrl_var", &Node::IsCtrlVar)
.def("inputs_remove",
[](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if ((*it)->id() == node_id) {
self.inputs.erase(it);
}
}
})
.def("inputs_remove",
[](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if (*it == &node) {
self.inputs.erase(it);
}
}
})
.def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("outputs_remove",
[](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if ((*it)->id() == node_id) {
self.outputs.erase(it);
}
}
})
.def("outputs_remove",
[](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if (*it == &node) {
self.outputs.erase(it);
}
}
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs) .def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs); .def_readwrite("outputs", &Node::outputs);
......
...@@ -228,13 +228,7 @@ void BindBlockDesc(pybind11::module *m) { ...@@ -228,13 +228,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) { void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", ""); pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc var_desc.def(pybind11::init<const std::string &>())
.def("__init__",
[](pd::VarDesc &self, const pybind11::bytes &binary_str) {
std::string str(binary_str);
new (&self) pd::VarDesc(str);
},
pybind11::return_value_policy::reference)
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference) .def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName) .def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape) .def("set_shape", &pd::VarDesc::SetShape)
......
...@@ -11,12 +11,14 @@ ...@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from __future__ import print_function
import os
import subprocess
from ....framework import Program from ....framework import Program
from ....framework import Block from ....framework import Block
from .... import core from .... import core
__all__ = ['Graph', 'ImitationGraph', 'PyGraph'] __all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph']
class PyGraph(object): class PyGraph(object):
...@@ -30,17 +32,18 @@ class PyGraph(object): ...@@ -30,17 +32,18 @@ class PyGraph(object):
self.graph = graph self.graph = graph
def all_parameters(self): def all_parameters(self):
params = [] param_nodes = set()
for node in self.graph.nodes(): for node in self.graph.nodes():
if node.is_var() and node.var().persistable(): if node.is_var() and node.var() is not None and node.var(
params.append(node) ).persistable():
return params param_nodes.add(node)
return param_nodes
def all_vars(self): def all_vars(self):
return [node for node in self.graph.nodes() if node.is_var()] return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self): def all_ops(self):
return [node for node in self.graph.nodes() if node.is_op()] return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype): def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name) var_desc = core.VarDesc(name)
...@@ -65,10 +68,16 @@ class PyGraph(object): ...@@ -65,10 +68,16 @@ class PyGraph(object):
op_desc.set_type(op_type) op_desc.set_type(op_type)
for attr, value in attrs.iteritems(): for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value) self._update_desc_attr(op_desc, attr, value)
for input_name, var_node in inputs.iteritems(): for input_name, var_nodes in inputs.iteritems():
op_desc.set_input(input_name, [var_node.name()]) if not isinstance(var_nodes, list):
for output_name, var_node in outputs.iteritems(): var_nodes = [var_nodes]
op_desc.set_output(output_name, [var_node.name()]) op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc) return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc): def create_op_node_from_desc(self, op_desc):
...@@ -89,6 +98,49 @@ class PyGraph(object): ...@@ -89,6 +98,49 @@ class PyGraph(object):
else: else:
desc._set_attr(name, val) desc._set_attr(name, val)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw_graph(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set_str('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph)
program = Program()
program.desc = convert_pass.get_program('program')
return program
class Graph(object): class Graph(object):
""" """
...@@ -112,3 +164,7 @@ class ImitationGraph(Graph): ...@@ -112,3 +164,7 @@ class ImitationGraph(Graph):
def all_parameters(self): def all_parameters(self):
return self.program.global_block().all_parameters() return self.program.global_block().all_parameters()
class IRGraph(Graph):
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import quantization_performer
from .quantization_performer import *
__all__ = quantization_performer.__all__
...@@ -19,6 +19,8 @@ from ....initializer import Constant ...@@ -19,6 +19,8 @@ from ....initializer import Constant
from .... import unique_name from .... import unique_name
from ..graph import PyGraph from ..graph import PyGraph
__all__ = ['QuantizationPerformer']
class QuantizationPerformer(object): class QuantizationPerformer(object):
def __init__(self, def __init__(self,
...@@ -108,19 +110,62 @@ class QuantizationPerformer(object): ...@@ -108,19 +110,62 @@ class QuantizationPerformer(object):
graph, quant_var_node, scale_var_node, quant_bits) graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node dequantized_vars[var_node.name()] = dequant_var_node
self._update_input(var_node, dequant_var_node, op) self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(), dequant_var_node.name())
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(),
dequant_var_node.name())
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self.is_test: if not self.is_test:
self._create_global_step(graph) self._create_global_step(graph)
ops = graph.all_ops() ops = graph.all_ops()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops: for op in ops:
# transform the forward graph
if op.name() in self.quantizable_ops: if op.name() in self.quantizable_ops:
_transform_forward(graph, op) _transform_forward(graph, op)
# rename the inputs of backward op # The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self.quantizable_grad_ops: if op.name() in self.quantizable_grad_ops:
_transform_backward(graph, op) _transform_backward(graph, op)
return self.need_inited_outer return self.need_inited_outer
def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \
self.activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@'
for node in graph.all_vars():
if node.name() == counter_name:
self.global_step = node
if self.global_step is None:
global_step_in = graph.create_param_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self.need_inited_outer[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
increment_op = graph.create_op_node(
op_type='increment',
attrs={'step': 1.0},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
self._link_to(global_step_in, increment_op)
self._link_to(increment_op, global_step_out)
self.global_step = global_step_out
def _insert_quant_op(self, graph, var_node, quant_bits, quant_type): def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
""" """
Insert fake_quantize_op in the graph. Insert fake_quantize_op in the graph.
...@@ -128,7 +173,7 @@ class QuantizationPerformer(object): ...@@ -128,7 +173,7 @@ class QuantizationPerformer(object):
if quant_type == 'abs_max': if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(graph, var_node, quant_bits) return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
elif quant_type == 'range_abs_max': elif quant_type == 'range_abs_max':
return self._inser_quant_range_abs_max_op(graph, var_node, return self._insert_quant_range_abs_max_op(graph, var_node,
quant_bits) quant_bits)
def _insert_quant_abs_max_op(self, graph, var_node, quant_bits): def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
...@@ -237,14 +282,14 @@ class QuantizationPerformer(object): ...@@ -237,14 +282,14 @@ class QuantizationPerformer(object):
return dequant_var_node return dequant_var_node
def _update_input(self, old_input_node, new_input_node, op_node): def _update_input(self, old_input_node, new_input_node, op_node):
old_input_node.outputs.remove(op_node) old_input_node.outputs_remove(op_node)
op_node.inputs.remove(old_input_node) op_node.inputs_remove(old_input_node)
new_input_node.outputs.append(op_node) new_input_node.outputs_append(op_node)
op_node.inputs.append(new_input_node) op_node.inputs_append(new_input_node)
def _link_to(node_in, node_out): def _link_to(self, node_in, node_out):
node_in.outputs.append(node_out) node_in.outputs_append(node_out)
node_out.inputs.append(node_in) node_out.inputs_append(node_in)
def _quantized_var_name(self, var_name): def _quantized_var_name(self, var_name):
""" """
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import random
import numpy as np
import paddle
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.contrib.slim.quantization import QuantizationPerformer
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestQuantizationPerformer(unittest.TestCase):
def setUp(self):
# since quant_op and dequant_op is not ready, use cos and sin for test
self.weight_quant_op_type = 'fake_quantize_abs_max'
self.dequant_op_type = 'fake_dequantize_max_abs'
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_op_grad_and_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_fc_' + quant_type, marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_residual_' + quant_type, marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
if __name__ == '__main__':
unittest.main()
...@@ -113,6 +113,7 @@ packages=['paddle', ...@@ -113,6 +113,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.core', 'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph', 'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune', 'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.utils', 'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler', 'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details'] 'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册