Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
s920243400
PaddleDetection
提交
a7efab7e
P
PaddleDetection
项目概览
s920243400
/
PaddleDetection
与 Fork 源项目一致
Fork自
PaddlePaddle / PaddleDetection
通知
2
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
a7efab7e
编写于
1月 30, 2019
作者:
W
WangZhen
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add comments for public API. test=develop
上级
0db41a9c
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
242 addition
and
23 deletion
+242
-23
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
...ddle/fluid/contrib/slim/quantization/quantization_pass.py
+66
-0
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
...paddle/fluid/contrib/slim/tests/test_quantization_pass.py
+13
-13
python/paddle/fluid/framework.py
python/paddle/fluid/framework.py
+163
-10
未找到文件。
python/paddle/fluid/contrib/slim/quantization/quantization_pass.py
浏览文件 @
a7efab7e
...
...
@@ -39,7 +39,13 @@ class QuantizationTransformPass(object):
"""
Convert and rewrite the IrGraph according to weight and
activation quantization type.
Args:
scope(fluid.Scope): When activation use 'range_abs_max' as the quantize
type, this pass will create some new parameters. The scope is used to
initialize these new parameters.
program_exe(fluid.Executor): program_exe is used to initialize new
parameters described above.
weight_bits (int): quantization bit number for weights,
the bias is not quantized.
activation_bits (int): quantization bit number for activation.
...
...
@@ -53,6 +59,7 @@ class QuantizationTransformPass(object):
support 'abs_max'. The 'range_abs_max' usually is not used for
weight, since weights are fixed once the model is well trained.
window_size (int): the window size for 'range_abs_max' quantization.
Examples:
.. code-block:: python
# The original graph will be rewrite.
...
...
@@ -96,6 +103,14 @@ class QuantizationTransformPass(object):
self
.
_global_step
=
None
def
apply
(
self
,
graph
):
"""
Quantize the graph for training process. According to weight and
activation quantization type, the graph will be added some fake
quantize operators and fake dequantize operators.
Args:
graph(IrGraph): the applied graph.
"""
assert
isinstance
(
graph
,
IrGraph
),
'graph must be the instance of IrGraph.'
self
.
_need_initialized
.
clear
()
...
...
@@ -336,6 +351,23 @@ class QuantizationTransformPass(object):
class
QuantizationFreezePass
(
object
):
"""
The freeze pass is used to adjust the quantize operator order, for example:
1) `activation -> quant -> dequant -> conv2d` will be freezed into
`activation -> quant -> conv2d -> dequant`
2) `weight -> quant -> dequant -> conv2d` will be freezed into `weight -> conv2d`,
and weight will be sacled offline.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the weight tensors.
weight_bits (int): quantization bit number for weights.
activation_bits (int): quantization bit number for activation.
weight_quantize_type (str): quantization type for weights, support 'abs_max'.
The 'range_abs_max' usually is not used for weight, since weights are fixed once the
model is well trained.
"""
def
__init__
(
self
,
scope
,
place
,
...
...
@@ -361,6 +393,12 @@ class QuantizationFreezePass(object):
self
.
_var_scale_map
=
collections
.
OrderedDict
()
def
apply
(
self
,
graph
):
"""
Adjust quantize/dequantize operators order for the inference process.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_vars
()]
ops
=
graph
.
all_ops
()
for
op_node
in
ops
:
...
...
@@ -518,6 +556,15 @@ class QuantizationFreezePass(object):
class
ConvertToInt8Pass
(
object
):
"""
Convert the weights into int8_t type.
Args:
scope(fluid.Scope): scope is used to get the weight tensor values.
place(fluid.CPUPlace|fluid.CUDAPlace): place is used to restore the
8bits weight tensors.
"""
def
__init__
(
self
,
scope
,
place
):
assert
scope
is
not
None
,
\
'The scope cannot be set None.'
...
...
@@ -528,6 +575,13 @@ class ConvertToInt8Pass(object):
self
.
_quantizable_ops
=
[
'conv2d'
,
'depthwise_conv2d'
,
'mul'
]
def
apply
(
self
,
graph
):
"""
Convert weights' tpye of the graph. After that, the data type of the
graph weigths is int8_t.
Args:
graph(IrGraph): the applied graph.
"""
persistable_vars
=
[
p
.
name
()
for
p
in
graph
.
all_persistable_vars
()]
ops
=
graph
.
all_ops
()
input_map
=
{}
...
...
@@ -581,6 +635,10 @@ class ConvertToInt8Pass(object):
class
TransformForMobilePass
(
object
):
"""
This pass is used to convert the freezed graph for paddle-mobile execution.
"""
def
__init__
(
self
):
self
.
_fake_quant_op_names
=
[
'fake_quantize_abs_max'
,
'fake_quantize_range_abs_max'
...
...
@@ -588,6 +646,14 @@ class TransformForMobilePass(object):
self
.
_fake_dequant_op_names
=
[
'fake_dequantize_max_abs'
]
def
apply
(
self
,
graph
):
"""
Because paddle-mobile use `quantize` an `dequantize` as the names of
quantize operator and dequantize operator, the `apply` function just
realize this logic.
Args:
graph(IrGraph): the graph will be transformed.
"""
ops
=
graph
.
all_ops
()
for
op_node
in
ops
:
name
=
op_node
.
name
()
...
...
python/paddle/fluid/contrib/slim/tests/test_quantization_pass.py
浏览文件 @
a7efab7e
...
...
@@ -248,8 +248,8 @@ class TestQuantizationFreezePass(unittest.TestCase):
quantized_main_program
=
main_graph
.
to_program
()
quantized_test_program
=
test_graph
.
to_program
()
iters
=
10
batch_size
=
1
28
iters
=
5
batch_size
=
1
6
train_exe
=
fluid
.
ParallelExecutor
(
main_program
=
quantized_main_program
,
...
...
@@ -271,7 +271,7 @@ class TestQuantizationFreezePass(unittest.TestCase):
# fetch_list=[loss])
loss_v
=
train_exe
.
run
(
feed
=
feeder
.
feed
(
data
),
fetch_list
=
[
loss
.
name
])
print
(
'{}: {}'
.
format
(
'loss'
+
dev_name
+
quant_type
,
loss_v
))
#
print('{}: {}'.format('loss' + dev_name + quant_type, loss_v))
test_data
=
next
(
test_reader
())
with
fluid
.
program_guard
(
quantized_test_program
):
...
...
@@ -299,15 +299,15 @@ class TestQuantizationFreezePass(unittest.TestCase):
feed
=
feeder
.
feed
(
test_data
),
fetch_list
=
[
loss
])
self
.
assertAlmostEqual
(
test_loss1
,
test_loss2
,
delta
=
5e-3
)
print
(
'{}: {}'
.
format
(
'test_loss1'
+
dev_name
+
quant_type
,
test_loss1
))
print
(
'{}: {}'
.
format
(
'test_loss2'
+
dev_name
+
quant_type
,
test_loss2
))
#
print('{}: {}'.format('test_loss1' + dev_name + quant_type, test_loss1))
#
print('{}: {}'.format('test_loss2' + dev_name + quant_type, test_loss2))
w_freeze
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0'
).
get_tensor
())
# Maybe failed, this is due to the calculation precision
self
.
assertAlmostEqual
(
np
.
sum
(
w_freeze
),
np
.
sum
(
w_quant
))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
print
(
'{}: {}'
.
format
(
'w_quant'
+
dev_name
+
quant_type
,
np
.
sum
(
w_quant
)))
#
self.assertAlmostEqual(np.sum(w_freeze), np.sum(w_quant))
#
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
#
np.sum(w_freeze)))
#
print('{}: {}'.format('w_quant' + dev_name + quant_type,
#
np.sum(w_quant)))
# Convert parameter to 8-bit.
convert_int8_pass
=
ConvertToInt8Pass
(
scope
=
scope
,
place
=
place
)
...
...
@@ -330,9 +330,9 @@ class TestQuantizationFreezePass(unittest.TestCase):
w_8bit
=
np
.
array
(
scope
.
find_var
(
'conv2d_1.w_0.int8'
).
get_tensor
())
self
.
assertEqual
(
w_8bit
.
dtype
,
np
.
int8
)
self
.
assertEqual
(
np
.
sum
(
w_8bit
),
np
.
sum
(
w_freeze
))
print
(
'{}: {}'
.
format
(
'w_8bit'
+
dev_name
+
quant_type
,
np
.
sum
(
w_8bit
)))
print
(
'{}: {}'
.
format
(
'w_freeze'
+
dev_name
+
quant_type
,
np
.
sum
(
w_freeze
)))
#
print('{}: {}'.format('w_8bit' + dev_name + quant_type, np.sum(w_8bit)))
#
print('{}: {}'.format('w_freeze' + dev_name + quant_type,
#
np.sum(w_freeze)))
mobile_pass
=
TransformForMobilePass
()
mobile_pass
.
apply
(
test_graph
)
...
...
python/paddle/fluid/framework.py
浏览文件 @
a7efab7e
...
...
@@ -1516,12 +1516,16 @@ class Block(object):
class
IrGraph
(
object
):
"""
IrGraph uses core.Graph as the delegation to accomplish the manipulation.
Python IrGraph. Beneath it is a core.Graph, which is used for
create a c++ Ir Pass Graph. An IrGraph is just a graph view of
a Program. In an IrGraph, both Variables and Operators are graph
nodes.
"""
def
__init__
(
self
,
graph
,
for_test
=
False
):
"""
Construct the IrGraph using core.Graph.
Construct an IrGraph using core.Graph.
Args:
graph(core.Graph): C++ Graph.
for_test(bool): True for the test graph and false for the train graph.
...
...
@@ -1532,15 +1536,27 @@ class IrGraph(object):
self
.
_for_test
=
for_test
def
is_test
(
self
):
"""
If the graph is used for testing, the function returns true. Otherwise, returns false.
"""
return
self
.
_for_test
def
all_nodes
(
self
):
"""
Return all nodes included in the graph as a set.
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()}
def
all_vars
(
self
):
"""
Return all variable nodes included in the graph as a set.
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_var
()}
def
all_persistable_vars
(
self
):
"""
Return all persistable variable nodes included in the graph as a set.
"""
persistable_nodes
=
set
()
for
node
in
self
.
graph
.
nodes
():
if
node
.
is_var
()
and
node
.
var
()
is
not
None
and
node
.
var
(
...
...
@@ -1549,18 +1565,24 @@ class IrGraph(object):
return
persistable_nodes
def
all_ops
(
self
):
"""
Return all operator nodes included in the graph as a set.
"""
return
{
node
for
node
in
self
.
graph
.
nodes
()
if
node
.
is_op
()}
def
var_node
(
self
,
name
):
"""
Get a variable node by name from this graph.
Get a variable node by name from the graph.
Args:
name(str): the name of the variable node.
Raises:
ValueError: The If input's type is not str, or this graph
doesn't have a variable with the giving name.
Returns:
Node: the variable node with the giving name.
core.
Node: the variable node with the giving name.
"""
if
not
isinstance
(
name
,
six
.
string_types
):
raise
TypeError
(
...
...
@@ -1576,6 +1598,19 @@ class IrGraph(object):
return
target_var_node
def
create_param_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
"""
Create a persistable variable node in the graph. In IrGraph,
it can not distinguish between persistable variables and parameters.
Args:
name(str): the name of the persistable variable node.
vart_type(core.VarDesc.VarType): the type of the persistable variable node.
shape(list): the shape of the persistable variable node.
var_dtype(core.VarDesc.VarType): the data type of the persistable variable node.
Returns:
core.Node: the created persistable variable node.
"""
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
...
...
@@ -1584,6 +1619,20 @@ class IrGraph(object):
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_var_node
(
self
,
name
,
var_type
,
shape
,
var_dtype
):
"""
Create a variable node in the graph. The created variable node is
not persistable.
Args:
name(str): the name of the variable node.
vart_type(core.VarDesc.VarType): the type of the variable node.
shape(list): the shape of the variable node.
var_dtype(core.VarDesc.VarType): the data type of the variable node.
Returns:
core.Node: the created variable node.
"""
var_desc
=
core
.
VarDesc
(
name
)
var_desc
.
set_type
(
var_type
)
var_desc
.
set_shape
(
shape
)
...
...
@@ -1591,9 +1640,31 @@ class IrGraph(object):
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_var_node_from_desc
(
self
,
var_desc
):
"""
Create a variable node by using an existing VarDesc in the graph.
Depend on the giving VarDesc, the created variable node may be persistable.
Args:
var_desc(core.VarDesc): the giving variable description.
Returns:
core.Node: the created variable node.
"""
return
self
.
graph
.
create_var_node
(
var_desc
)
def
create_op_node
(
self
,
op_type
,
attrs
,
inputs
,
outputs
):
"""
Create a operator node in the graph.
Args:
op_type(str): the type of the operator node.
attrs(dict): the attributes of the operator node.
inputs(dict): the inputs of the operator node.
outputs(dict): the outpus of the operator node.
Returns:
core.Node: the created operator node.
"""
op_desc
=
core
.
OpDesc
()
op_desc
.
set_type
(
op_type
)
for
attr
,
value
in
attrs
.
iteritems
():
...
...
@@ -1611,9 +1682,26 @@ class IrGraph(object):
return
self
.
graph
.
create_op_node
(
op_desc
)
def
create_op_node_from_desc
(
self
,
op_desc
):
"""
Create a operator node by using an existing OpDesc in the graph.
Args:
op_desc(core.VarDesc): the giving operator description.
Returns:
core.Node: the created operator node.
"""
return
self
.
graph
.
create_op_node
(
op_desc
)
def
update_input_link
(
self
,
old_input_node
,
new_input_node
,
op_node
):
"""
Update the input's link of a operator node.
Args:
old_input_node(core.Node): the old input node of the giving op_node.
new_input_node(core.Node): the new input node of the giving op_node.
op_node(core.Node): the operator node that is needed to update input's link.
"""
assert
old_input_node
in
self
.
graph
.
nodes
()
and
new_input_node
in
\
self
.
graph
.
nodes
()
and
op_node
in
self
.
graph
.
nodes
(),
\
'The three arguments(old_input_node&new_input_node&op_node) must be in the graph nodes.'
...
...
@@ -1624,12 +1712,26 @@ class IrGraph(object):
op_node
.
op
().
_rename_input
(
old_input_node
.
name
(),
new_input_node
.
name
())
def
link_to
(
self
,
node_in
,
node_out
):
"""
Connect two nodes.
Args:
node_in(core.Node): the input node.
node_out(core.Node): the output node.
"""
assert
node_in
in
self
.
graph
.
nodes
()
and
node_out
in
self
.
graph
.
nodes
(),
\
'The two arguments(node_in&node_out) must be in the graph nodes.'
node_in
.
outputs_append
(
node_out
)
node_out
.
inputs_append
(
node_in
)
def
safe_remove_nodes
(
self
,
remove_nodes
):
"""
Remove nodes safely since links connected to these removed nodes are
also removed.
Args:
remove_nodes(set): the nodes prepared to be removed.
"""
if
not
isinstance
(
remove_nodes
,
set
):
if
isinstance
(
remove_nodes
,
Iterable
):
remove_nodes
=
set
(
remove_nodes
)
...
...
@@ -1638,18 +1740,57 @@ class IrGraph(object):
core
.
graph_safe_remove_nodes
(
self
.
graph
,
remove_nodes
)
def
has_circle
(
self
):
"""
Check if the graph has a circle.
Returns:
bool: True if the graph has a circle else False.
"""
return
core
.
has_circle
(
self
.
graph
)
def
graph_num
(
self
):
"""
Count the number of unconnected graphs in this graph.
Returns:
int: the number of unconnected graphs.
"""
return
core
.
graph_num
(
self
.
graph
)
def
topology_sort
(
self
):
"""
Perform the topology sort operation on the graph.
Notes: the `graph` cannot contain a circle.
Returns:
set(core.Node): nodes in topology order.
"""
return
core
.
topology_sort
(
self
.
graph
)
def
build_adjacency_list
(
self
):
"""
Build an adjacency list of operations for the `graph`.
Returns:
dict{core.Node: set(core.Node)}: the adjacency list.
"""
return
core
.
build_adjacency_list
(
self
.
graph
)
def
draw
(
self
,
save_path
,
name
,
marked_nodes
=
None
):
def
draw
(
self
,
save_path
,
name
,
marked_nodes
=
None
,
remove_ctr_var
=
True
):
"""
Draw the graph. If `dot` command is installed, the drawn graph
will be saved as pdf file type, otherwise dot file type is used.
Args:
save_path(str): the save path of drawn graph.
name(str): the name of drawn graph.
marked_nodes(set(core.Node)): nodes that are needed to be marked.
Default value is None.
remove_ctr_var(bool): If it is set True, all control variable nodes
in the graph will be removed. Default value is True.
"""
def
_convert_to_pdf
(
dot_file_path
):
pdf_save_path
=
os
.
path
.
splitext
(
dot_file_path
)[
0
]
+
'.pdf'
exited_code
=
subprocess
.
call
(
'dot -Tpdf '
+
dot_file_path
\
...
...
@@ -1659,15 +1800,17 @@ class IrGraph(object):
print
(
'The {} is saved as the dot filetype.'
.
format
(
dot_file_path
))
remove_ctr_vars
=
set
()
if
remove_ctr_var
:
remove_ctr_vars
=
set
()
for
node
in
self
.
graph
.
nodes
():
if
node
.
is_ctrl_var
():
remove_ctr_vars
.
add
(
node
)
self
.
safe_remove_nodes
(
remove_ctr_vars
)
ops_num
=
0
for
node
in
self
.
graph
.
nodes
():
if
node
.
is_ctrl_var
():
remove_ctr_vars
.
add
(
node
)
elif
node
.
is_op
():
if
node
.
is_op
():
ops_num
+=
1
print
(
'Total ops num = {}.'
.
format
(
ops_num
))
self
.
safe_remove_nodes
(
remove_ctr_vars
)
if
marked_nodes
is
not
None
:
if
not
isinstance
(
marked_nodes
,
set
):
marked_nodes
=
set
(
marked_nodes
)
...
...
@@ -1682,6 +1825,16 @@ class IrGraph(object):
_convert_to_pdf
(
viz_dot_path
)
def
to_program
(
self
):
"""
Convert the graph into a Program.
Notes: When the graph includes backward operator nodes, the
conversion process may be failed. Usually, this function is
only used to convert a test graph.
Returns:
Program: a program converted from the graph.
"""
convert_pass
=
core
.
get_pass
(
'graph_to_program_pass'
)
desc
=
core
.
ProgramDesc
()
convert_pass
.
set_not_owned
(
'program'
,
desc
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录