提交 e2ff300b 编写于 作者: W WangZhen

add UT for quantization.

上级 451896fc
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/details/sequential_execution_pass.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_to_program_pass.h"
#include "paddle/fluid/framework/ir/graph_viz_pass.h"
namespace paddle {
......
......@@ -15,7 +15,9 @@
#include "paddle/fluid/pybind/ir.h"
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/var_desc.h"
......@@ -24,6 +26,7 @@
namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
......@@ -32,6 +35,7 @@ using pybind11::return_value_policy;
namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
......@@ -43,6 +47,7 @@ void BindGraph(py::module *m) {
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
.def("set",
......@@ -63,6 +68,12 @@ void BindGraph(py::module *m) {
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
return self.Set(attr_name,
new std::unordered_set<const Node *>(attr));
})
.def("erase", &Graph::Erase)
.def("nodes", &Graph::Nodes, return_value_policy::reference)
.def("create_var_node",
......@@ -91,12 +102,52 @@ void BindNode(py::module *m) {
py::class_<Node> node(*m, "Node");
node.def("name", &Node::Name)
.def("node_type", &Node::NodeType)
.def("var", &Node::Var)
.def("op", &Node::Op)
.def("var", &Node::Var, return_value_policy::reference)
.def("op", &Node::Op, return_value_policy::reference)
.def("id", &Node::id)
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("inputs_remove",
[](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if ((*it)->id() == node_id) {
self.inputs.erase(it);
}
}
})
.def("inputs_remove",
[](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if (*it == &node) {
self.inputs.erase(it);
}
}
})
.def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("outputs_remove",
[](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if ((*it)->id() == node_id) {
self.outputs.erase(it);
}
}
})
.def("outputs_remove",
[](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if (*it == &node) {
self.outputs.erase(it);
}
}
})
.def("outputs_append",
[](Node &self, Node &node) { self.outputs.push_back(&node); })
.def_readwrite("inputs", &Node::inputs)
.def_readwrite("outputs", &Node::outputs);
......
......@@ -228,13 +228,7 @@ void BindBlockDesc(pybind11::module *m) {
void BindVarDsec(pybind11::module *m) {
pybind11::class_<pd::VarDesc> var_desc(*m, "VarDesc", "");
var_desc
.def("__init__",
[](pd::VarDesc &self, const pybind11::bytes &binary_str) {
std::string str(binary_str);
new (&self) pd::VarDesc(str);
},
pybind11::return_value_policy::reference)
var_desc.def(pybind11::init<const std::string &>())
.def("name", &pd::VarDesc::Name, pybind11::return_value_policy::reference)
.def("set_name", &pd::VarDesc::SetName)
.def("set_shape", &pd::VarDesc::SetShape)
......
......@@ -11,12 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import os
import subprocess
from ....framework import Program
from ....framework import Block
from .... import core
__all__ = ['Graph', 'ImitationGraph', 'PyGraph']
__all__ = ['Graph', 'ImitationGraph', 'IRGraph', 'PyGraph']
class PyGraph(object):
......@@ -30,17 +32,18 @@ class PyGraph(object):
self.graph = graph
def all_parameters(self):
params = []
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var().persistable():
params.append(node)
return params
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_vars(self):
return [node for node in self.graph.nodes() if node.is_var()]
return {node for node in self.graph.nodes() if node.is_var()}
def all_ops(self):
return [node for node in self.graph.nodes() if node.is_op()]
return {node for node in self.graph.nodes() if node.is_op()}
def create_param_node(self, name, var_type, shape, var_dtype):
var_desc = core.VarDesc(name)
......@@ -65,10 +68,16 @@ class PyGraph(object):
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
self._update_desc_attr(op_desc, attr, value)
for input_name, var_node in inputs.iteritems():
op_desc.set_input(input_name, [var_node.name()])
for output_name, var_node in outputs.iteritems():
op_desc.set_output(output_name, [var_node.name()])
for input_name, var_nodes in inputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
[var_node.name() for var_node in var_nodes])
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
......@@ -89,6 +98,49 @@ class PyGraph(object):
else:
desc._set_attr(name, val)
def safe_remove_nodes(self, remove_nodes):
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw_graph(self, save_path, name, marked_nodes=None):
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
+ ' -o ' + pdf_save_path, shell=True)
if exited_code != 0:
print('The dot command is needed for creating pdf files.')
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
marked_nodes = marked_nodes - remove_ctr_vars
self.graph.set('__graphviz__marked_node__', marked_nodes)
viz_dot_path = os.path.join(save_path, name) + '.dot'
viz_pass = core.get_pass('graph_viz_pass')
viz_pass.set_str('graph_viz_path', viz_dot_path)
viz_pass.apply(self.graph)
_convert_to_pdf(viz_dot_path)
def to_program(self):
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set_program('program', Program().desc)
convert_pass.apply(self.graph)
program = Program()
program.desc = convert_pass.get_program('program')
return program
class Graph(object):
"""
......@@ -112,3 +164,7 @@ class ImitationGraph(Graph):
def all_parameters(self):
return self.program.global_block().all_parameters()
class IRGraph(Graph):
pass
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
from . import quantization_performer
from .quantization_performer import *
__all__ = quantization_performer.__all__
......@@ -19,6 +19,8 @@ from ....initializer import Constant
from .... import unique_name
from ..graph import PyGraph
__all__ = ['QuantizationPerformer']
class QuantizationPerformer(object):
def __init__(self,
......@@ -108,19 +110,62 @@ class QuantizationPerformer(object):
graph, quant_var_node, scale_var_node, quant_bits)
dequantized_vars[var_node.name()] = dequant_var_node
self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(), dequant_var_node.name())
def _transform_backward(graph, op):
no_dequanted_input_vars = True
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
self._update_input(var_node, dequant_var_node, op)
op.op()._rename_input(var_node.name(),
dequant_var_node.name())
no_dequanted_input_vars = False
if no_dequanted_input_vars:
raise ValueError("There is no dequanted inputs for op %s." %
(op.name()))
if not self.is_test:
self._create_global_step(graph)
ops = graph.all_ops()
# The process of _transform_forward and _transform_backward is needed in two for loops.
# The loop for transforming the forward graph:
for op in ops:
# transform the forward graph
if op.name() in self.quantizable_ops:
_transform_forward(graph, op)
# rename the inputs of backward op
# The loop for renaming the inputs of backward op.
for op in ops:
if op.name() in self.quantizable_grad_ops:
_transform_backward(graph, op)
return self.need_inited_outer
def _create_global_step(self, graph):
if self.weight_quantize_type == 'range_abs_max' or \
self.activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@'
for node in graph.all_vars():
if node.name() == counter_name:
self.global_step = node
if self.global_step is None:
global_step_in = graph.create_param_node(
name=counter_name,
var_type=core.VarDesc.VarType.LOD_TENSOR,
shape=[1],
var_dtype=core.VarDesc.VarType.INT64)
self.need_inited_outer[global_step_in.var()] = \
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
increment_op = graph.create_op_node(
op_type='increment',
attrs={'step': 1.0},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
self._link_to(global_step_in, increment_op)
self._link_to(increment_op, global_step_out)
self.global_step = global_step_out
def _insert_quant_op(self, graph, var_node, quant_bits, quant_type):
"""
Insert fake_quantize_op in the graph.
......@@ -128,7 +173,7 @@ class QuantizationPerformer(object):
if quant_type == 'abs_max':
return self._insert_quant_abs_max_op(graph, var_node, quant_bits)
elif quant_type == 'range_abs_max':
return self._inser_quant_range_abs_max_op(graph, var_node,
return self._insert_quant_range_abs_max_op(graph, var_node,
quant_bits)
def _insert_quant_abs_max_op(self, graph, var_node, quant_bits):
......@@ -237,14 +282,14 @@ class QuantizationPerformer(object):
return dequant_var_node
def _update_input(self, old_input_node, new_input_node, op_node):
old_input_node.outputs.remove(op_node)
op_node.inputs.remove(old_input_node)
new_input_node.outputs.append(op_node)
op_node.inputs.append(new_input_node)
def _link_to(node_in, node_out):
node_in.outputs.append(node_out)
node_out.inputs.append(node_in)
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
op_node.inputs_append(new_input_node)
def _link_to(self, node_in, node_out):
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def _quantized_var_name(self, var_name):
"""
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
import unittest
import random
import numpy as np
import paddle
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
from paddle.fluid.contrib.slim.quantization import QuantizationPerformer
from paddle.fluid.contrib.slim.graph import PyGraph
from paddle.fluid import core
def linear_fc(num):
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
hidden = fluid.layers.fc(hidden, size=128, act='relu')
loss = fluid.layers.cross_entropy(input=hidden, label=label)
loss = fluid.layers.mean(loss)
return loss
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestQuantizationPerformer(unittest.TestCase):
def setUp(self):
# since quant_op and dequant_op is not ready, use cos and sin for test
self.weight_quant_op_type = 'fake_quantize_abs_max'
self.dequant_op_type = 'fake_dequantize_max_abs'
self.quantizable_op_and_inputs = {
'conv2d': ['Input', 'Filter'],
'depthwise_conv2d': ['Input', 'Filter'],
'mul': ['X', 'Y']
}
self.quantizable_op_grad_and_inputs = {
'conv2d_grad': ['Input', 'Filter'],
'depthwise_conv2d_grad': ['Input', 'Filter'],
'mul_grad': ['X', 'Y']
}
def linear_fc_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = linear_fc(3)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_fc_' + quant_type, marked_nodes)
def test_linear_fc_quant_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.linear_fc_quant('abs_max')
def test_linear_fc_quant_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.linear_fc_quant('range_abs_max')
def residual_block_quant(self, quant_type):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = PyGraph(core.Graph(main.desc))
performer = QuantizationPerformer(activation_quantize_type=quant_type)
performer.quantize_transform(graph, False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
graph.draw_graph('.', 'quantize_residual_' + quant_type, marked_nodes)
def test_residual_block_abs_max(self):
self.act_quant_op_type = 'fake_quantize_abs_max'
self.residual_block_quant('abs_max')
def test_residual_block_range_abs_max(self):
self.act_quant_op_type = 'fake_quantize_range_abs_max'
self.residual_block_quant('range_abs_max')
if __name__ == '__main__':
unittest.main()
......@@ -113,6 +113,7 @@ packages=['paddle',
'paddle.fluid.contrib.slim.core',
'paddle.fluid.contrib.slim.graph',
'paddle.fluid.contrib.slim.prune',
'paddle.fluid.contrib.slim.quantization',
'paddle.fluid.contrib.utils',
'paddle.fluid.transpiler',
'paddle.fluid.transpiler.details']
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册