未验证 提交 832bd720 编写于 作者: Z Zhen Wang 提交者: GitHub

Merge pull request #15610 from wzzju/quantization_inference_passes

Quantization inference passes
......@@ -13,10 +13,12 @@
// limitations under the License.
#include "paddle/fluid/pybind/ir.h"
#include <algorithm>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/node.h"
#include "paddle/fluid/framework/op_desc.h"
......@@ -27,6 +29,10 @@ namespace py = pybind11;
using paddle::framework::ir::Graph;
using paddle::framework::ir::Node;
using paddle::framework::ir::GraphSafeRemoveNodes;
using paddle::framework::ir::HasCircle;
using paddle::framework::ir::GraphNum;
using paddle::framework::ir::TopologySortOperations;
using paddle::framework::ir::BuildOperationAdjList;
using paddle::framework::OpDesc;
using paddle::framework::ProgramDesc;
using paddle::framework::VarDesc;
......@@ -36,6 +42,12 @@ namespace paddle {
namespace pybind {
void BindGraph(py::module *m) {
m->def("graph_safe_remove_nodes", GraphSafeRemoveNodes);
m->def("has_circle", HasCircle);
m->def("graph_num", GraphNum);
m->def("topology_sort", TopologySortOperations,
return_value_policy::reference);
m->def("build_adjacency_list", BuildOperationAdjList,
return_value_policy::reference);
py::class_<Graph, std::shared_ptr<Graph>>(
*m, "Graph",
"The graph is a Directed Acyclic Single Static Assignment Graph, see "
......@@ -46,7 +58,6 @@ void BindGraph(py::module *m) {
.def("get_float", &Graph::Get<float>)
.def("get_double", &Graph::Get<double>)
.def("get_string", &Graph::Get<std::string>)
.def("get_program", &Graph::Get<ProgramDesc>)
.def("get_marked_nodes", &Graph::Get<std::unordered_set<const Node *>>)
.def("set", [](Graph &self, const std::string &attr_name,
int attr) { return self.Set(attr_name, new int(attr)); })
......@@ -63,11 +74,6 @@ void BindGraph(py::module *m) {
[](Graph &self, const std::string &attr_name, double attr) {
return self.Set(attr_name, new double(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
})
.def("set",
[](Graph &self, const std::string &attr_name,
const std::unordered_set<const Node *> &attr) {
......@@ -108,42 +114,42 @@ void BindNode(py::module *m) {
.def("is_op", &Node::IsOp)
.def("is_var", &Node::IsVar)
.def("is_ctrl_var", &Node::IsCtrlVar)
.def("clear_inputs", [](Node &self) { self.inputs.clear(); })
.def("inputs_remove",
[](Node &self, int node_id) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if ((*it)->id() == node_id) {
self.inputs.erase(it);
}
auto pos = std::find_if(
self.inputs.begin(), self.inputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("inputs_remove",
[](Node &self, Node &node) {
for (auto it = self.inputs.begin(); it != self.inputs.end();
it++) {
if (*it == &node) {
self.inputs.erase(it);
}
auto pos =
std::find(self.inputs.begin(), self.inputs.end(), &node);
if (pos != self.inputs.end()) {
self.inputs.erase(pos);
}
})
.def("inputs_append",
[](Node &self, Node &node) { self.inputs.push_back(&node); })
.def("clear_outputs", [](Node &self) { self.outputs.clear(); })
.def("outputs_remove",
[](Node &self, int node_id) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if ((*it)->id() == node_id) {
self.outputs.erase(it);
}
auto pos = std::find_if(
self.outputs.begin(), self.outputs.end(),
[&node_id](const Node *n) { return n->id() == node_id; });
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("outputs_remove",
[](Node &self, Node &node) {
for (auto it = self.outputs.begin(); it != self.outputs.end();
it++) {
if (*it == &node) {
self.outputs.erase(it);
}
auto pos =
std::find(self.outputs.begin(), self.outputs.end(), &node);
if (pos != self.outputs.end()) {
self.outputs.erase(pos);
}
})
.def("outputs_append",
......
......@@ -829,8 +829,7 @@ All parameter, weight, gradient are variables in Paddle.
m.def("disable_profiler", platform::DisableProfiler);
m.def("is_profiler_enabled", platform::IsProfileEnabled);
m.def("reset_profiler", platform::ResetProfiler);
m.def("get_pass", [](const py::bytes &binary_str) {
std::string pass_type(binary_str);
m.def("get_pass", [](const std::string &pass_type) {
auto pass = framework::ir::PassRegistry::Instance().Get(pass_type);
return std::shared_ptr<framework::ir::Pass>(std::move(pass));
});
......@@ -838,10 +837,9 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<ir::Pass, std::shared_ptr<ir::Pass>> pass(m, "Pass");
pass.def(py::init())
.def("has", &ir::Pass::Has)
.def("set",
[](ir::Pass &self, const std::string &attr_name,
const ProgramDesc &attr) {
return self.Set(attr_name, new ProgramDesc(attr));
.def("set_not_owned",
[](ir::Pass &self, const std::string &attr_name, ProgramDesc &attr) {
self.SetNotOwned<ProgramDesc>(attr_name, &attr);
})
.def(
"set",
......@@ -850,7 +848,6 @@ All parameter, weight, gradient are variables in Paddle.
})
.def("set", [](ir::Pass &self, const std::string &name,
int val) { self.Set<const int>(name, new int(val)); })
.def("get_program", &ir::Pass::Get<ProgramDesc>)
.def("type", &ir::Pass::Type)
.def("apply", [](ir::Pass &self, std::shared_ptr<ir::Graph> graph) {
std::unique_ptr<ir::Graph> origin_graph(graph.get());
......
......@@ -64,6 +64,7 @@ if (WITH_TESTING)
add_subdirectory(paddle/dataset/tests)
add_subdirectory(paddle/fluid/tests)
add_subdirectory(paddle/fluid/contrib/tests)
add_subdirectory(paddle/fluid/contrib/slim/tests)
endif()
install(DIRECTORY ${PADDLE_PYTHON_PACKAGE_DIR}
DESTINATION opt/paddle/share/wheels
......
......@@ -13,14 +13,19 @@
# limitations under the License.
import collections
import numpy as np
import six
from ..... import compat as cpt
from .... import core
from ....framework import IrGraph
from ....framework import Program
from ....framework import Variable
from ....initializer import Constant
from .... import unique_name
__all__ = ['QuantizationTransformPass']
__all__ = [
'QuantizationTransformPass', 'QuantizationFreezePass', 'ConvertToInt8Pass',
'TransformForMobilePass'
]
class QuantizationTransformPass(object):
......@@ -35,7 +40,13 @@ class QuantizationTransformPass(object):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
......@@ -49,6 +60,7 @@ class QuantizationTransformPass(object):
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# The original graph will be rewrite.
......@@ -88,31 +100,35 @@ class QuantizationTransformPass(object):
self._quantizable_grad_ops = [
'%s_grad' % (op) for op in self._quantizable_ops
]
self._fake_quant_op_types = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_types = ['fake_dequantize_max_abs']
self._is_test = None
self._global_step = None
def apply(self, graph):
"""
Quantize the graph for training process. According to weight and
activation quantization type, the graph will be added some fake
quantize operators and fake dequantize operators.
Args:
graph(IrGraph): the applied graph.
"""
assert isinstance(graph,
IrGraph), 'graph must be the instance of IrGraph.'
self._need_initialized.clear()
self._is_test = graph.is_test()
# marked the variable which has been dequantized.
dequantized_vars = collections.OrderedDict()
params = [p.name() for p in graph.all_parameters()]
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
def _transform_forward(graph, op):
for var_node in op.inputs:
if var_node.name() in dequantized_vars:
dequant_var_node = dequantized_vars[var_node.name()]
else:
quant_bits = self._weight_bits if var_node.name() in params \
quant_bits = self._weight_bits if var_node.name() in persistable_vars \
else self._activation_bits
quant_type = self._weight_quantize_type if var_node.name() \
in params else self._activation_quantize_type
in persistable_vars else self._activation_quantize_type
quant_var_node, scale_var_node = self._insert_quant_op(
graph, var_node, quant_bits, quant_type)
dequant_var_node = self._insert_dequant_op(
......@@ -150,9 +166,14 @@ class QuantizationTransformPass(object):
assert self._program_exe is not None, \
'The program_exe cannot be set None when activation_quantize_type equals to range_abs_max.'
init_program = Program()
for var_desc, initializer in self._need_initialized.iteritems():
var = Variable(init_program.global_block())
var._set_desc(var_desc)
for var_desc, initializer in six.iteritems(self._need_initialized):
var = init_program.global_block().create_var(
name=var_desc.name(),
shape=var_desc.shape(),
dtype=var_desc.dtype(),
type=var_desc.type(),
lod_level=var_desc.lod_level(),
persistable=var_desc.persistable())
initializer(var, init_program.global_block())
self._program_exe.run(program=init_program, scope=self._scope)
......@@ -161,7 +182,7 @@ class QuantizationTransformPass(object):
def _create_global_step(self, graph):
if self._weight_quantize_type == 'range_abs_max' or \
self._activation_quantize_type == 'range_abs_max':
counter_name = '@STEP_COUNTER@'
counter_name = cpt.to_text('@STEP_COUNTER@')
for node in graph.all_vars():
if node.name() == counter_name:
self._global_step = node
......@@ -175,9 +196,14 @@ class QuantizationTransformPass(object):
Constant(value=0, force_cpu=True)
global_step_out = graph.create_var_node_from_desc(
global_step_in.var())
# The attribute of `op_role` is needed by ParallelExecutor.
increment_op = graph.create_op_node(
op_type='increment',
attrs={'step': 1.0},
attrs={
'step': 1.0,
'op_role':
core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': global_step_in},
outputs={'Out': global_step_out})
graph.link_to(global_step_in, increment_op)
......@@ -212,7 +238,10 @@ class QuantizationTransformPass(object):
var_dtype=var_node.var().dtype())
quant_op_node = graph.create_op_node(
op_type='fake_quantize_abs_max',
attrs={'bit_length': quant_bits},
attrs={
'bit_length': quant_bits,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node},
outputs={'Out': quant_var_node,
'OutScale': scale_var_node})
......@@ -257,7 +286,8 @@ class QuantizationTransformPass(object):
attrs = {
'window_size': self._window_size,
'bit_length': quant_bits,
'is_test': self._is_test
'is_test': self._is_test,
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
}
quant_op_node = graph.create_op_node(
op_type='fake_quantize_range_abs_max',
......@@ -290,7 +320,10 @@ class QuantizationTransformPass(object):
max_range = (1 << (quant_bits - 1)) - 1
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={'max_range': float(max_range)},
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
......@@ -316,3 +349,330 @@ class QuantizationTransformPass(object):
Return the scale name of quantized variable for the input `var_name`.
"""
return "%s.scale" % (var_name)
class QuantizationFreezePass(object):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be freezed into
`activation -> quant -> conv2d -> dequant`
2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`,
and weight will be sacled offline.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained.
"""
def __init__(self,
scope,
place,
weight_bits=8,
activation_bits=8,
weight_quantize_type='abs_max'):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
self._scope = scope
self._place = place
self._weight_bits = weight_bits
self._activation_bits = activation_bits
self._weight_quantize_type = weight_quantize_type
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
self._op_input_rename_map = collections.OrderedDict()
self._op_output_rename_map = collections.OrderedDict()
self._var_scale_map = collections.OrderedDict()
def apply(self, graph):
"""
Adjust quantize/dequantize operators order for the inference process.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_quant_op_names:
input_arg_name = op_node.op().input('X')[0]
if input_arg_name in persistable_vars:
if self._weight_quantize_type == 'abs_max':
param = self._load_var(input_arg_name)
scale_v = np.max(np.abs(param))
else:
scale_v = self._load_var(op_node.op().output('OutScale')
[0])[0]
self._var_scale_map[input_arg_name] = scale_v
else:
scale_v = graph.var_node(op_node.op().output('OutScale')[0])
self._var_scale_map[input_arg_name] = scale_v
if input_arg_name in persistable_vars:
self._remove_fake_quant_and_dequant_op(graph, op_node)
# quantize weight and restore
param_v = self._load_var(input_arg_name)
quantized_param_v = self._quant(param_v, scale_v,
self._weight_bits)
self._restore_var(input_arg_name, quantized_param_v)
ops = graph.all_ops()
for op_node in ops:
op_name = op_node.name()
if op_name in self._fake_dequant_op_names:
self._remove_fake_quant_and_dequant_op(graph, op_node)
ops = graph.all_ops()
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
self._insert_post_dequant_op(graph, op_node)
for op_node in ops:
# insert dequant_op after fc/conv, need to rename inputs of the followed ops
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_output_rename_map:
old_in = graph.var_node(name)
new_in = self._op_output_rename_map[name]
graph.update_input_link(old_in, new_in, op_node)
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
return graph
def _remove_fake_quant_and_dequant_op(self, graph, op_node):
k = op_node.op().output('Out')[0]
v = op_node.op().input('X')[0]
if v not in self._op_input_rename_map:
self._op_input_rename_map[k] = v
else:
self._op_input_rename_map[k] = self._op_input_rename_map[v]
graph.safe_remove_nodes(op_node)
def _insert_post_dequant_op(self, graph, op_node):
max_range = None
scale_var_node = None
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
for var_node in op_node.inputs:
name = var_node.name()
if name in self._op_input_rename_map:
old_in = graph.var_node(name)
new_in = graph.var_node(self._op_input_rename_map[name])
new_in.clear_outputs()
graph.update_input_link(old_in, new_in, op_node)
original_var_name = self._original_var_name(name)
scale_v = self._var_scale_map[original_var_name]
if original_var_name in persistable_vars:
param_range = (1 << (self._weight_bits - 1)) - 1
act_range = (1 << (self._activation_bits - 1)) - 1
assert self._is_float(
scale_v), 'The scale of parameter %s is not a float.' % (
original_var_name)
max_range = param_range * act_range / scale_v
else:
assert isinstance(scale_v, core.Node)
scale_var_node = self._var_scale_map[original_var_name]
if len(op_node.outputs) != 1:
raise ValueError("Only support one output, but op %s has"
" more than one output." % (op_node.name()))
output_var_node = op_node.outputs[0]
dequant_var_node = graph.create_var_node(
name=self._dequantized_var_name(output_var_node.name()),
var_type=output_var_node.var().type(),
shape=output_var_node.var().shape(),
var_dtype=output_var_node.var().dtype())
dequant_op_node = graph.create_op_node(
op_type='fake_dequantize_max_abs',
attrs={
'max_range': float(max_range),
'op_role': core.op_proto_and_checker_maker.OpRole.Forward
},
inputs={'X': output_var_node,
'Scale': scale_var_node},
outputs={'Out': dequant_var_node})
graph.link_to(output_var_node, dequant_op_node)
graph.link_to(scale_var_node, dequant_op_node)
graph.link_to(dequant_op_node, dequant_var_node)
self._op_output_rename_map[output_var_node.name()] = dequant_var_node
return dequant_var_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
def _restore_var(self, name, array):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array, self._place)
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
graph.safe_remove_nodes(all_unused_vars)
def _original_var_name(self, var_name):
"""
Return the original variable name.
"""
if var_name.endswith('.quantized.dequantized'):
return var_name[:-len('.quantized.dequantized')]
if var_name.endswith('.quantized'):
return var_name[:-len('.quantized')]
if var_name.endswith('.dequantized'):
return var_name[:-len('.dequantized')]
if var_name.endswith('.scale'):
return var_name[:-len('.scale')]
else:
return var_name
def _dequantized_var_name(self, var_name):
"""
Return dequantized variable name for the input `var_name`.
"""
return "%s.dequantized" % (var_name)
def _is_float(self, v):
return isinstance(v, float) or isinstance(v, np.float32) \
or isinstance(v, np.float64)
def _quant(self, x, scale, num_bits):
return np.round(x / scale * ((1 << (num_bits - 1)) - 1))
class ConvertToInt8Pass(object):
"""
Convert the weights into int8_t type.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
"""
def __init__(self, scope, place):
assert scope is not None, \
'The scope cannot be set None.'
assert place is not None, \
'The place cannot be set None.'
self._scope = scope
self._place = place
self._quantizable_ops = ['conv2d', 'depthwise_conv2d', 'mul']
def apply(self, graph):
"""
Convert weights' tpye of the graph. After that, the data type of the
graph weigths is int8_t.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars = [p.name() for p in graph.all_persistable_vars()]
ops = graph.all_ops()
input_map = {}
for op_node in ops:
op_name = op_node.name()
if op_name in self._quantizable_ops:
for var_node in op_node.inputs:
name = var_node.name()
if name in persistable_vars:
if name not in input_map:
int8_var_node = self._convert_to_int8(graph,
var_node)
input_map[name] = int8_var_node
graph.update_input_link(var_node, input_map[name],
op_node)
# remove the unused var node in the graph
self._remove_unused_var_nodes(graph)
return graph
def _convert_to_int8(self, graph, var_node):
int8_var_node_name = var_node.name() + ".int8"
int8_var_node = graph.create_param_node(
name=cpt.to_text(int8_var_node_name),
var_type=var_node.var().type(),
shape=var_node.var().shape(),
var_dtype=core.VarDesc.VarType.INT8)
array = self._load_var(var_node.name())
self._scope.var(int8_var_node_name)
self._store_var(int8_var_node_name, array, np.int8)
return int8_var_node
def _load_var(self, name):
return np.array(self._scope.find_var(name).get_tensor())
def _store_var(self, name, array, dtype):
tensor = self._scope.find_var(name).get_tensor()
tensor.set(array.astype(dtype), self._place)
def _remove_unused_var_nodes(self, graph):
all_used_vars = set()
ops = graph.all_ops()
for op_node in ops:
for input_node in op_node.inputs:
all_used_vars.add(input_node)
for output_node in op_node.outputs:
all_used_vars.add(output_node)
all_unused_vars = graph.all_vars() - all_used_vars
graph.safe_remove_nodes(all_unused_vars)
class TransformForMobilePass(object):
"""
This pass is used to convert the freezed graph for paddle-mobile execution.
"""
def __init__(self):
self._fake_quant_op_names = [
'fake_quantize_abs_max', 'fake_quantize_range_abs_max'
]
self._fake_dequant_op_names = ['fake_dequantize_max_abs']
def apply(self, graph):
"""
Because paddle-mobile use `quantize` an `dequantize` as the names of
quantize operator and dequantize operator, the `apply` function just
realize this logic.
Args:
graph(IrGraph): the graph will be transformed.
"""
ops = graph.all_ops()
for op_node in ops:
name = op_node.name()
if name in self._fake_quant_op_names:
op_node.op().set_type('quantize')
quant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, quant_node)
for output_node in op_node.outputs:
graph.link_to(quant_node, output_node)
graph.safe_remove_nodes(op_node)
if name in self._fake_dequant_op_names:
op_node.op().set_type('dequantize')
dequant_node = graph.create_op_node_from_desc(op_node.op())
for input_node in op_node.inputs:
graph.link_to(input_node, dequant_node)
for output_node in op_node.outputs:
graph.link_to(dequant_node, output_node)
graph.safe_remove_nodes(op_node)
return graph
file(GLOB TEST_OPS RELATIVE "${CMAKE_CURRENT_SOURCE_DIR}" "test_*.py")
string(REPLACE ".py" "" TEST_OPS "${TEST_OPS}")
foreach(src ${TEST_OPS})
py_test(${src} SRCS ${src}.py)
endforeach()
version: 1.0
include: ["./unitest/configs/pruners.yaml", "./unitest/configs/pruners_0.yaml"]
include: ["./configs/pruners.yaml", "./configs/pruners_0.yaml"]
pruners:
pruner_1:
class: 'RatioPruner'
......
......@@ -18,7 +18,7 @@ import unittest
class TestFactory(unittest.TestCase):
def test_parse(self):
factory = ConfigFactory('./unitest/configs/config.yaml')
factory = ConfigFactory('./configs/config.yaml')
pruner = factory.instance('pruner_1')
self.assertEquals(pruner.ratios['conv1_1.w'], 0.3)
......
# copyright (c) 2018 paddlepaddle authors. all rights reserved.
#
# licensed under the apache license, version 2.0 (the "license");
# you may not use this file except in compliance with the license.
# you may obtain a copy of the license at
#
# http://www.apache.org/licenses/license-2.0
#
# unless required by applicable law or agreed to in writing, software
# distributed under the license is distributed on an "as is" basis,
# without warranties or conditions of any kind, either express or implied.
# see the license for the specific language governing permissions and
# limitations under the license.
from __future__ import print_function
import unittest
import paddle.fluid as fluid
import six
from paddle.fluid.framework import IrGraph
from paddle.fluid import core
def residual_block(num):
def conv_bn_layer(input,
ch_out,
filter_size,
stride,
padding,
act='relu',
bias_attr=False):
tmp = fluid.layers.conv2d(
input=input,
filter_size=filter_size,
num_filters=ch_out,
stride=stride,
padding=padding,
act=None,
bias_attr=bias_attr)
return fluid.layers.batch_norm(input=tmp, act=act)
data = fluid.layers.data(name='image', shape=[1, 32, 32], dtype='float32')
label = fluid.layers.data(name='label', shape=[1], dtype='int64')
hidden = data
for _ in six.moves.xrange(num):
conv = conv_bn_layer(hidden, 16, 3, 1, 1, act=None, bias_attr=True)
short = conv_bn_layer(hidden, 16, 1, 1, 0, act=None)
hidden = fluid.layers.elementwise_add(x=conv, y=short, act='relu')
fc = fluid.layers.fc(input=hidden, size=10)
loss = fluid.layers.cross_entropy(input=fc, label=label)
loss = fluid.layers.mean(loss)
return loss
class TestGraph(unittest.TestCase):
def test_graph_functions(self):
main = fluid.Program()
startup = fluid.Program()
with fluid.program_guard(main, startup):
loss = residual_block(2)
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
graph = IrGraph(core.Graph(main.desc), for_test=False)
marked_nodes = set()
for op in graph.all_ops():
if op.name().find('conv2d') > -1:
marked_nodes.add(op)
graph.draw('.', 'residual', marked_nodes)
self.assertFalse(graph.has_circle())
self.assertEqual(graph.graph_num(), 1)
nodes = graph.topology_sort()
self.assertEqual(len(nodes), len(graph.all_ops()))
nodes_map = graph.build_adjacency_list()
self.assertEqual(len(nodes_map), len(graph.all_ops()))
nodes_num = len(graph.all_nodes())
graph.safe_remove_nodes(marked_nodes)
self.assertEqual(len(graph.all_nodes()), nodes_num - len(marked_nodes))
if __name__ == '__main__':
unittest.main()
......@@ -17,9 +17,12 @@ import random
import numpy as np
import paddle.fluid as fluid
import six
from paddle.fluid.framework import Program
import paddle
from paddle.fluid.framework import IrGraph
from paddle.fluid.contrib.slim.quantization import QuantizationTransformPass
from paddle.fluid.contrib.slim.quantization import QuantizationFreezePass
from paddle.fluid.contrib.slim.quantization import ConvertToInt8Pass
from paddle.fluid.contrib.slim.quantization import TransformForMobilePass
from paddle.fluid import core
......@@ -65,6 +68,28 @@ def residual_block(num):
return loss
def conv_net(img, label):
conv_pool_1 = fluid.nets.simple_img_conv_pool(
input=img,
filter_size=5,
num_filters=20,
pool_size=2,
pool_stride=2,
act="relu")
conv_pool_1 = fluid.layers.batch_norm(conv_pool_1)
conv_pool_2 = fluid.nets.simple_img_conv_pool(
input=conv_pool_1,
filter_size=5,
num_filters=50,
pool_size=2,
pool_stride=2,
act="relu")
prediction = fluid.layers.fc(input=conv_pool_2, size=10, act='softmax')
loss = fluid.layers.cross_entropy(input=prediction, label=label)
avg_loss = fluid.layers.mean(loss)
return avg_loss
class TestQuantizationTransformPass(unittest.TestCase):
def setUp(self):
self.quantizable_op_and_inputs = {
......@@ -171,5 +196,177 @@ class TestQuantizationTransformPass(unittest.TestCase):
self.residual_block_quant('range_abs_max')
class TestQuantizationFreezePass(unittest.TestCase):
def freeze_graph(self, use_cuda, seed, quant_type):
def build_program(main, startup, is_test):
main.random_seed = seed
startup.random_seed = seed
with fluid.unique_name.guard():
with fluid.program_guard(main, startup):
img = fluid.layers.data(
name='image', shape=[1, 28, 28], dtype='float32')
label = fluid.layers.data(
name='label', shape=[1], dtype='int64')
loss = conv_net(img, label)
if not is_test:
opt = fluid.optimizer.Adam(learning_rate=0.001)
opt.minimize(loss)
return [img, label], loss
random.seed(0)
np.random.seed(0)
main = fluid.Program()
startup = fluid.Program()
test_program = fluid.Program()
feeds, loss = build_program(main, startup, False)
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
main_graph = IrGraph(core.Graph(main.desc), for_test=False)
test_graph = IrGraph(core.Graph(test_program.desc), for_test=True)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
scope = fluid.Scope()
with fluid.scope_guard(scope):
exe.run(startup)
transform_pass = QuantizationTransformPass(
scope=scope, program_exe=exe, activation_quantize_type=quant_type)
transform_pass.apply(main_graph)
transform_pass.apply(test_graph)
dev_name = '_gpu_' if use_cuda else '_cpu_'
marked_nodes = set()
for op in main_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
main_graph.draw('.', 'main' + dev_name + quant_type, marked_nodes)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test' + dev_name + quant_type, marked_nodes)
quantized_main_program = main_graph.to_program()
quantized_test_program = test_graph.to_program()
iters = 5
batch_size = 8
#train_exe = fluid.ParallelExecutor(
# main_program=quantized_main_program,
# use_cuda=bool(use_cuda),
# loss_name=loss.name,
# scope=scope)
train_reader = paddle.batch(
paddle.reader.shuffle(
paddle.dataset.mnist.train(), buf_size=500),
batch_size=batch_size)
test_reader = paddle.batch(
paddle.dataset.mnist.test(), batch_size=batch_size)
feeder = fluid.DataFeeder(feed_list=feeds, place=place)
with fluid.scope_guard(scope):
for _ in range(iters):
data = next(train_reader())
loss_v = exe.run(program=quantized_main_program,
feed=feeder.feed(data),
fetch_list=[loss])
#loss_v = train_exe.run(feed=feeder.feed(data),
# fetch_list=[loss.name])
#print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data = next(test_reader())
with fluid.program_guard(quantized_test_program):
w_var = fluid.framework._get_var('conv2d_1.w_0.quantized',
quantized_test_program)
# Testing
with fluid.scope_guard(scope):
test_loss1, w_quant = exe.run(program=quantized_test_program,
feed=feeder.feed(test_data),
fetch_list=[loss, w_var])
# Freeze graph for inference, but the weight of fc/conv is still float type.
freeze_pass = QuantizationFreezePass(scope=scope, place=place)
freeze_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_freeze' + dev_name + quant_type,
marked_nodes)
server_program = test_graph.to_program()
with fluid.scope_guard(scope):
test_loss2, = exe.run(program=server_program,
feed=feeder.feed(test_data),
fetch_list=[loss])
self.assertAlmostEqual(test_loss1, test_loss2, delta=5e-3)
#print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze = np.array(scope.find_var('conv2d_1.w_0').get_tensor())
# Maybe failed, this is due to the calculation precision
# self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
#print('{}: {}'.format('w_quant' + dev_name + quant_type,
# np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass = ConvertToInt8Pass(scope=scope, place=place)
convert_int8_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_int8' + dev_name + quant_type, marked_nodes)
server_program_int8 = test_graph.to_program()
# Save the 8-bit parameter and model file.
with fluid.scope_guard(scope):
fluid.io.save_inference_model('server_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
server_program_int8)
# Test whether the 8-bit parameter and model file can be loaded successfully.
[infer, feed, fetch] = fluid.io.load_inference_model(
'server_int8' + dev_name + quant_type, exe)
# Check the loaded 8-bit weight.
w_8bit = np.array(scope.find_var('conv2d_1.w_0.int8').get_tensor())
self.assertEqual(w_8bit.dtype, np.int8)
self.assertEqual(np.sum(w_8bit), np.sum(w_freeze))
#print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#print('{}: {}'.format('w_freeze' + dev_name + quant_type,
# np.sum(w_freeze)))
mobile_pass = TransformForMobilePass()
mobile_pass.apply(test_graph)
marked_nodes = set()
for op in test_graph.all_ops():
if op.name().find('quantize') > -1:
marked_nodes.add(op)
test_graph.draw('.', 'test_mobile' + dev_name + quant_type,
marked_nodes)
mobile_program = test_graph.to_program()
with fluid.scope_guard(scope):
fluid.io.save_inference_model('mobile_int8' + dev_name + quant_type,
['image', 'label'], [loss], exe,
mobile_program)
def test_freeze_graph_cuda_dynamic(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='abs_max')
def test_freeze_graph_cpu_dynamic(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='abs_max')
def test_freeze_graph_cuda_static(self):
if fluid.core.is_compiled_with_cuda():
with fluid.unique_name.guard():
self.freeze_graph(True, seed=1, quant_type='range_abs_max')
def test_freeze_graph_cpu_static(self):
with fluid.unique_name.guard():
self.freeze_graph(False, seed=2, quant_type='range_abs_max')
if __name__ == '__main__':
unittest.main()
......@@ -204,9 +204,11 @@ class TestQuantizeTranspiler(unittest.TestCase):
build_program(test_program, startup, True)
test_program = test_program.clone(for_test=True)
quant_transpiler = QuantizeTranspiler()
quant_transpiler.training_transpile(main)
quant_transpiler.training_transpile(test_program)
quant_type = 'range_abs_max' # 'range_abs_max' or 'abs_max'
quant_transpiler = QuantizeTranspiler(
activation_quantize_type=quant_type)
quant_transpiler.training_transpile(main, startup)
quant_transpiler.training_transpile(test_program, startup)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
......
......@@ -16,6 +16,8 @@ from __future__ import print_function
import collections
from collections import defaultdict
from collections import Iterable
import contextlib
from .wrapped_decorator import signature_safe_contextmanager
import os
import re
......@@ -1529,12 +1531,16 @@ class Block(object):
class IrGraph(object):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
"""
def __init__(self, graph, for_test=False):
"""
Construct the IrGraph using core.Graph.
Construct an IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
......@@ -1545,23 +1551,81 @@ class IrGraph(object):
self._for_test = for_test
def is_test(self):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
"""
return self._for_test
def all_parameters(self):
param_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
param_nodes.add(node)
return param_nodes
def all_nodes(self):
"""
Return all nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes()}
def all_vars(self):
"""
Return all variable nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_var()}
def all_persistable_vars(self):
"""
Return all persistable variable nodes included in the graph as a set.
"""
persistable_nodes = set()
for node in self.graph.nodes():
if node.is_var() and node.var() is not None and node.var(
).persistable():
persistable_nodes.add(node)
return persistable_nodes
def all_ops(self):
"""
Return all operator nodes included in the graph as a set.
"""
return {node for node in self.graph.nodes() if node.is_op()}
def var_node(self, name):
"""
Get a variable node by name from the graph.
Args:
name(str): the name of the variable node.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
Returns:
core.Node: the variable node with the giving name.
"""
if not isinstance(name, six.string_types):
raise TypeError(
"var require string as parameter, but get %s instead." %
(type(name)))
target_var_node = None
var_nodes = self.all_vars()
for var_node in var_nodes:
if var_node.name() == name:
target_var_node = var_node
if target_var_node is None:
raise ValueError("var_node %s not in this graph" % name)
return target_var_node
def create_param_node(self, name, var_type, shape, var_dtype):
"""
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
Args:
name(str): the name of the persistable variable node.
vart_type(core.VarDesc.VarType): the type of the persistable variable node.
shape(list): the shape of the persistable variable node.
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
core.Node: the created persistable variable node.
"""
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
......@@ -1570,6 +1634,20 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc)
def create_var_node(self, name, var_type, shape, var_dtype):
"""
Create a variable node in the graph. The created variable node is
not persistable.
Args:
name(str): the name of the variable node.
vart_type(core.VarDesc.VarType): the type of the variable node.
shape(list): the shape of the variable node.
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
core.Node: the created variable node.
"""
var_desc = core.VarDesc(name)
var_desc.set_type(var_type)
var_desc.set_shape(shape)
......@@ -1577,19 +1655,41 @@ class IrGraph(object):
return self.graph.create_var_node(var_desc)
def create_var_node_from_desc(self, var_desc):
"""
Create a variable node by using an existing VarDesc in the graph.
Depend on the giving VarDesc, the created variable node may be persistable.
Args:
var_desc(core.VarDesc): the giving variable description.
Returns:
core.Node: the created variable node.
"""
return self.graph.create_var_node(var_desc)
def create_op_node(self, op_type, attrs, inputs, outputs):
"""
Create a operator node in the graph.
Args:
op_type(str): the type of the operator node.
attrs(dict): the attributes of the operator node.
inputs(dict): the inputs of the operator node.
outputs(dict): the outpus of the operator node.
Returns:
core.Node: the created operator node.
"""
op_desc = core.OpDesc()
op_desc.set_type(op_type)
for attr, value in attrs.iteritems():
for attr, value in six.iteritems(attrs):
self._update_desc_attr(op_desc, attr, value)
for input_name, var_nodes in inputs.iteritems():
for input_name, var_nodes in six.iteritems(inputs):
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_input(input_name,
[var_node.name() for var_node in var_nodes])
for output_name, var_nodes in outputs.iteritems():
for output_name, var_nodes in six.iteritems(outputs):
if not isinstance(var_nodes, list):
var_nodes = [var_nodes]
op_desc.set_output(output_name,
......@@ -1597,11 +1697,29 @@ class IrGraph(object):
return self.graph.create_op_node(op_desc)
def create_op_node_from_desc(self, op_desc):
"""
Create a operator node by using an existing OpDesc in the graph.
Args:
op_desc(core.VarDesc): the giving operator description.
Returns:
core.Node: the created operator node.
"""
return self.graph.create_op_node(op_desc)
def update_input_link(self, old_input_node, new_input_node, op_node):
assert old_input_node in self.graph.nodes() and new_input_node in self.graph.nodes() and \
op_node in self.graph.nodes(), 'Th three arguments must be in the graph nodes.'
"""
Update the input's link of a operator node.
Args:
old_input_node(core.Node): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link.
"""
assert old_input_node in self.graph.nodes() and new_input_node in \
self.graph.nodes() and op_node in self.graph.nodes(), \
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
old_input_node.outputs_remove(op_node)
op_node.inputs_remove(old_input_node)
new_input_node.outputs_append(op_node)
......@@ -1609,17 +1727,85 @@ class IrGraph(object):
op_node.op()._rename_input(old_input_node.name(), new_input_node.name())
def link_to(self, node_in, node_out):
"""
Connect two nodes.
Args:
node_in(core.Node): the input node.
node_out(core.Node): the output node.
"""
assert node_in in self.graph.nodes() and node_out in self.graph.nodes(), \
'Th two arguments must be in the graph nodes.'
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in.outputs_append(node_out)
node_out.inputs_append(node_in)
def safe_remove_nodes(self, remove_nodes):
"""
Remove nodes safely since links connected to these removed nodes are
also removed.
Args:
remove_nodes(set): the nodes prepared to be removed.
"""
if not isinstance(remove_nodes, set):
remove_nodes = set(remove_nodes)
if isinstance(remove_nodes, Iterable):
remove_nodes = set(remove_nodes)
else:
remove_nodes = {remove_nodes}
core.graph_safe_remove_nodes(self.graph, remove_nodes)
def draw(self, save_path, name, marked_nodes=None):
def has_circle(self):
"""
Check if the graph has a circle.
Returns:
bool: True if the graph has a circle else False.
"""
return core.has_circle(self.graph)
def graph_num(self):
"""
Count the number of unconnected graphs in this graph.
Returns:
int: the number of unconnected graphs.
"""
return core.graph_num(self.graph)
def topology_sort(self):
"""
Perform the topology sort operation on the graph.
Notes: the `graph` cannot contain a circle.
Returns:
set(core.Node): nodes in topology order.
"""
return core.topology_sort(self.graph)
def build_adjacency_list(self):
"""
Build an adjacency list of operations for the `graph`.
Returns:
dict{core.Node: set(core.Node)}: the adjacency list.
"""
return core.build_adjacency_list(self.graph)
def draw(self, save_path, name, marked_nodes=None, remove_ctr_var=True):
"""
Draw the graph. If `dot` command is installed, the drawn graph
will be saved as pdf file type, otherwise dot file type is used.
Args:
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
"""
def _convert_to_pdf(dot_file_path):
pdf_save_path = os.path.splitext(dot_file_path)[0] + '.pdf'
exited_code = subprocess.call('dot -Tpdf ' + dot_file_path \
......@@ -1629,15 +1815,17 @@ class IrGraph(object):
print('The {} is saved as the dot filetype.'.format(
dot_file_path))
remove_ctr_vars = set()
if remove_ctr_var:
remove_ctr_vars = set()
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
self.safe_remove_nodes(remove_ctr_vars)
ops_num = 0
for node in self.graph.nodes():
if node.is_ctrl_var():
remove_ctr_vars.add(node)
elif node.is_op():
if node.is_op():
ops_num += 1
print('Total ops num = {}.'.format(ops_num))
self.safe_remove_nodes(remove_ctr_vars)
if marked_nodes is not None:
if not isinstance(marked_nodes, set):
marked_nodes = set(marked_nodes)
......@@ -1652,10 +1840,20 @@ class IrGraph(object):
_convert_to_pdf(viz_dot_path)
def to_program(self):
"""
Convert the graph into a Program.
Notes: When the graph includes backward operator nodes, the
conversion process may be failed. Usually, this function is
only used to convert a test graph.
Returns:
Program: a program converted from the graph.
"""
convert_pass = core.get_pass('graph_to_program_pass')
convert_pass.set('program', Program().desc)
desc = core.ProgramDesc()
convert_pass.set_not_owned('program', desc)
convert_pass.apply(self.graph)
desc = convert_pass.get_program('program')
program = Program._construct_from_desc(desc)
return program
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册