Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
6992fd6c
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
6992fd6c
编写于
7月 31, 2020
作者:
B
baiyfbupt
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add automatic calculation of pact clip threshold
上级
271fbf44
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
164 addition
and
117 deletion
+164
-117
demo/quant/pact_quant_aware/pact.py
demo/quant/pact_quant_aware/pact.py
+0
-30
demo/quant/pact_quant_aware/train.py
demo/quant/pact_quant_aware/train.py
+60
-24
paddleslim/common/__init__.py
paddleslim/common/__init__.py
+2
-2
paddleslim/common/analyze_helper.py
paddleslim/common/analyze_helper.py
+74
-61
paddleslim/quant/__init__.py
paddleslim/quant/__init__.py
+1
-0
paddleslim/quant/utility.py
paddleslim/quant/utility.py
+27
-0
未找到文件。
demo/quant/pact_quant_aware/pact.py
已删除
100644 → 0
浏览文件 @
271fbf44
import
sys
import
paddle
import
paddle.fluid
as
fluid
from
paddleslim.quant
import
quant_aware
,
convert
import
numpy
as
np
from
paddle.fluid.layer_helper
import
LayerHelper
def
pact
(
x
,
name
=
None
):
helper
=
LayerHelper
(
"pact"
,
**
locals
())
dtype
=
'float32'
init_thres
=
20
u_param_attr
=
fluid
.
ParamAttr
(
name
=
x
.
name
+
'_pact'
,
initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
init_thres
),
regularizer
=
fluid
.
regularizer
.
L2Decay
(
0.0001
),
learning_rate
=
1
)
u_param
=
helper
.
create_parameter
(
attr
=
u_param_attr
,
shape
=
[
1
],
dtype
=
dtype
)
x
=
fluid
.
layers
.
elementwise_sub
(
x
,
fluid
.
layers
.
relu
(
fluid
.
layers
.
elementwise_sub
(
x
,
u_param
)))
x
=
fluid
.
layers
.
elementwise_add
(
x
,
fluid
.
layers
.
relu
(
fluid
.
layers
.
elementwise_sub
(
-
u_param
,
x
)))
return
x
def
get_optimizer
():
return
fluid
.
optimizer
.
MomentumOptimizer
(
0.0001
,
0.9
)
demo/quant/pact_quant_aware/train.py
浏览文件 @
6992fd6c
...
@@ -10,13 +10,14 @@ import numpy as np
...
@@ -10,13 +10,14 @@ import numpy as np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
sys
.
path
[
0
]
=
os
.
path
.
join
(
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
from
paddleslim.common
import
get_logger
from
paddleslim.common
import
get_logger
,
get_distribution
,
pdf
from
paddleslim.analysis
import
flops
from
paddleslim.analysis
import
flops
from
paddleslim.quant
import
quant_aware
,
quant_post
,
convert
from
paddleslim.quant
import
quant_aware
,
quant_post
,
convert
from
paddleslim.quant
import
pact_thres
import
models
import
models
from
utility
import
add_arguments
,
print_arguments
from
utility
import
add_arguments
,
print_arguments
sys
.
path
.
append
(
'./'
)
from
pa
ct
import
*
from
pa
ddle.fluid.layer_helper
import
LayerHelper
quantization_model_save_dir
=
'./quantization_models/'
quantization_model_save_dir
=
'./quantization_models/'
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
...
@@ -158,11 +159,63 @@ def compress(args):
...
@@ -158,11 +159,63 @@ def compress(args):
exe
=
fluid
.
Executor
(
place
)
exe
=
fluid
.
Executor
(
place
)
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
val_reader
=
paddle
.
fluid
.
io
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
train_reader
=
paddle
.
fluid
.
io
.
batch
(
train_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
capacity
=
512
,
use_double_buffer
=
True
,
iterable
=
True
)
valid_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
capacity
=
512
,
use_double_buffer
=
True
,
iterable
=
True
)
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
train_loader
.
set_sample_list_generator
(
train_reader
,
place
)
valid_loader
.
set_sample_list_generator
(
val_reader
,
place
)
# get all activations distribution
act_names
=
[
var
.
name
for
var
in
list
(
train_prog
.
list_vars
())
if
not
var
.
persistable
and
'generated_var'
not
in
var
.
name
and
'@GRAD'
not
in
var
.
name
]
var_dist
=
get_distribution
(
train_prog
,
act_names
,
exe
,
train_loader
)
train_loader
.
set_sample_list_generator
(
train_reader
,
places
)
# draw histogram
pdf
(
var_dist
,
pdf_save_dir
=
'var_dist_pdf'
)
# calculate appropriate pact clip threshold
pact_alphas
=
pact_thres
(
var_dist
)
# 2. quantization transform programs (training aware)
# 2. quantization transform programs (training aware)
# Make some quantization transforms in the graph before training and testing.
# Make some quantization transforms in the graph before training and testing.
# According to the weight and activation quantization type, the graph will be added
# According to the weight and activation quantization type, the graph will be added
# some fake quantize operators and fake dequantize operators.
# some fake quantize operators and fake dequantize operators.
def
pact
(
x
):
helper
=
LayerHelper
(
"pact"
,
**
locals
())
dtype
=
'float32'
init_thres
=
pact_alphas
[
x
.
name
.
split
(
'_tmp_input'
)[
0
]]
u_param_attr
=
fluid
.
ParamAttr
(
name
=
x
.
name
+
'_pact'
,
initializer
=
fluid
.
initializer
.
ConstantInitializer
(
value
=
init_thres
),
regularizer
=
fluid
.
regularizer
.
L2Decay
(
0.0001
),
learning_rate
=
1
)
u_param
=
helper
.
create_parameter
(
attr
=
u_param_attr
,
shape
=
[
1
],
dtype
=
dtype
)
x
=
fluid
.
layers
.
elementwise_sub
(
x
,
fluid
.
layers
.
relu
(
fluid
.
layers
.
elementwise_sub
(
x
,
u_param
)))
x
=
fluid
.
layers
.
elementwise_add
(
x
,
fluid
.
layers
.
relu
(
fluid
.
layers
.
elementwise_sub
(
-
u_param
,
x
)))
return
x
def
get_optimizer
():
return
fluid
.
optimizer
.
MomentumOptimizer
(
0.0001
,
0.9
)
if
args
.
use_pact
:
if
args
.
use_pact
:
act_preprocess_func
=
pact
act_preprocess_func
=
pact
optimizer_func
=
get_optimizer
optimizer_func
=
get_optimizer
...
@@ -201,25 +254,6 @@ def compress(args):
...
@@ -201,25 +254,6 @@ def compress(args):
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
fluid
.
io
.
load_vars
(
exe
,
args
.
pretrained_model
,
predicate
=
if_exist
)
val_reader
=
paddle
.
fluid
.
io
.
batch
(
val_reader
,
batch_size
=
args
.
batch_size
)
train_reader
=
paddle
.
fluid
.
io
.
batch
(
train_reader
,
batch_size
=
args
.
batch_size
,
drop_last
=
True
)
train_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
capacity
=
512
,
use_double_buffer
=
True
,
iterable
=
True
)
valid_loader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
[
image
,
label
],
capacity
=
512
,
use_double_buffer
=
True
,
iterable
=
True
)
places
=
fluid
.
cuda_places
()
if
args
.
use_gpu
else
fluid
.
cpu_places
()
train_loader
.
set_sample_list_generator
(
train_reader
,
places
)
valid_loader
.
set_sample_list_generator
(
val_reader
,
place
)
def
test
(
epoch
,
program
):
def
test
(
epoch
,
program
):
batch_id
=
0
batch_id
=
0
acc_top1_ns
=
[]
acc_top1_ns
=
[]
...
@@ -270,8 +304,7 @@ def compress(args):
...
@@ -270,8 +304,7 @@ def compress(args):
array
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
)
array
=
np
.
array
(
fluid
.
global_scope
().
find_var
(
var
.
name
)
.
get_tensor
())
.
get_tensor
())
threshold
[
var
.
name
]
=
array
[
0
]
threshold
[
var
.
name
]
=
array
[
0
]
print
(
threshold
)
_logger
.
info
(
threshold
)
batch_id
+=
1
batch_id
+=
1
build_strategy
=
fluid
.
BuildStrategy
()
build_strategy
=
fluid
.
BuildStrategy
()
...
@@ -307,6 +340,7 @@ def compress(args):
...
@@ -307,6 +340,7 @@ def compress(args):
exe
,
exe
,
dirname
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'best_model'
),
dirname
=
os
.
path
.
join
(
args
.
checkpoint_dir
,
'best_model'
),
main_program
=
val_program
)
main_program
=
val_program
)
# 3. Freeze the graph after training by adjusting the quantize
# 3. Freeze the graph after training by adjusting the quantize
# operators' order for the inference.
# operators' order for the inference.
# The dtype of float_program's weights is float32, but in int8 range.
# The dtype of float_program's weights is float32, but in int8 range.
...
@@ -315,6 +349,8 @@ def compress(args):
...
@@ -315,6 +349,8 @@ def compress(args):
save_int8
=
True
)
save_int8
=
True
)
print
(
"eval best_model after convert"
)
print
(
"eval best_model after convert"
)
final_acc1
=
test
(
best_epoch
,
float_program
)
final_acc1
=
test
(
best_epoch
,
float_program
)
_logger
.
info
(
"final acc:{}"
.
format
(
final_acc1
))
# 4. Save inference model
# 4. Save inference model
model_path
=
os
.
path
.
join
(
quantization_model_save_dir
,
args
.
model
,
model_path
=
os
.
path
.
join
(
quantization_model_save_dir
,
args
.
model
,
'act_'
+
quant_config
[
'activation_quantize_type'
]
'act_'
+
quant_config
[
'activation_quantize_type'
]
...
...
paddleslim/common/__init__.py
浏览文件 @
6992fd6c
...
@@ -21,10 +21,10 @@ from .cached_reader import cached_reader
...
@@ -21,10 +21,10 @@ from .cached_reader import cached_reader
from
.server
import
Server
from
.server
import
Server
from
.client
import
Client
from
.client
import
Client
from
.meter
import
AvgrageMeter
from
.meter
import
AvgrageMeter
from
.analyze_helper
import
pdf
from
.analyze_helper
import
pdf
,
get_distribution
__all__
=
[
__all__
=
[
'EvolutionaryController'
,
'SAController'
,
'get_logger'
,
'ControllerServer'
,
'EvolutionaryController'
,
'SAController'
,
'get_logger'
,
'ControllerServer'
,
'ControllerClient'
,
'lock'
,
'unlock'
,
'cached_reader'
,
'AvgrageMeter'
,
'ControllerClient'
,
'lock'
,
'unlock'
,
'cached_reader'
,
'AvgrageMeter'
,
'Server'
,
'Client'
,
'RLBaseController'
,
'pdf'
'Server'
,
'Client'
,
'RLBaseController'
,
'pdf'
,
'get_distribution'
]
]
paddleslim/common/analyze_helper.py
浏览文件 @
6992fd6c
...
@@ -12,55 +12,51 @@
...
@@ -12,55 +12,51 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
logging
import
numpy
as
np
from
matplotlib.backends.backend_pdf
import
PdfPages
import
matplotlib.pyplot
as
plt
import
os
import
os
import
types
import
paddle
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
import
numpy
as
np
import
matplotlib
matplotlib
.
use
(
'Agg'
)
import
matplotlib.pyplot
as
plt
from
matplotlib.backends.backend_pdf
import
PdfPages
import
logging
from
..common
import
get_logger
from
..common
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
def
pdf
(
program
,
def
get_distribution
(
program
,
var_names
,
var_names
,
executor
=
None
,
executor
,
batch_generator
=
None
,
reader
=
None
,
data_loader
=
None
,
feed_vars
=
None
,
feed_vars
=
None
,
scope
=
None
):
fetch_list
=
None
,
scope
=
None
,
pdf_save_dir
=
'tmp_pdf'
):
"""
"""
Draw hist for distributtion of variables in that name is in var_names
Get the variables distribution in the var_names list
Args:
Args:
program(fluid.Program): program to analyze.
program(fluid.Program): program to analyze.
var_names(list): name of variables to analyze. When there is activation name in var_names,
var_names(list): name of variables to analyze. When there is activation name in var_names,
you should set executor
, one of batch_generator and data_loader, feed_list
.
you should set executor.
executor(fluid.Executor, optional): The executor to run program. Default is None.
executor(fluid.Executor, optional): The executor to run program. Default is None.
batch_generator(Python Generator, optional): The batch generator provides calibrate data for DataLoader
,
reader(Python Generator, fluid.io.DataLoader, optional): If you only want to get the distribution of weight parameters
,
and it returns a batch every time. For data_loader and batch_generator,
you do not need to provide a reader. Otherwise, a reader must be provided. The reader provides calibrate data,
only one can be set. Default is None
.
and it returns a batch every time. It must be either a python generator or a iterable fluid dataloader
.
data_loader(fluid.io.DataLoader, optional): The data_loader provides calibrate data to run program.
When you use a python generator, please ensure that its behavior is consistent with `batch_generator`。
Default is None.
You can get more detail about batch_generator at https://www.paddlepaddle.org.cn/documentation/docs/zh/api_cn/io_cn/DataLoader_cn.html#id1
feed_vars(list): feed variables for program. When you use
batch_generator to provide data,
feed_vars(list): feed variables for program. When you use
python generator reader to provide data,
you should set feed_vars. Default is None.
you should set feed_vars. Default is None.
fetch_list(list): fetch list for program. Default is None.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
scope(fluid.Scope, optional): The scope to run program, use it to load variables.
If scope is None, will use fluid.global_scope().
If scope is None, will use fluid.global_scope().
pdf_save_dir(str): dirname to save pdf. Default is 'tmp_pdf'
Returns:
Returns:
dict: numpy array of variables that name in var_names
dict: numpy array of variables
distribution
that name in var_names
"""
"""
scope
=
fluid
.
global_scope
()
if
scope
is
None
else
scope
scope
=
fluid
.
global_scope
()
if
scope
is
None
else
scope
assert
isinstance
(
var_names
,
list
),
'var_names is a list of variable name'
assert
isinstance
(
var_names
,
list
),
'var_names is a list of variable name'
var_changed
=
[]
real_names
=
[]
real_names
=
[]
weight_only
=
True
weight_only
=
True
for
var
in
program
.
list_vars
():
for
var
in
program
.
list_vars
():
...
@@ -68,52 +64,70 @@ def pdf(program,
...
@@ -68,52 +64,70 @@ def pdf(program,
if
var
.
persistable
==
False
:
if
var
.
persistable
==
False
:
weight_only
=
False
weight_only
=
False
var
.
persistable
=
True
var
.
persistable
=
True
var_changed
.
append
(
var
)
real_names
.
append
(
var
.
name
)
real_names
.
append
(
var
.
name
)
if
weight_only
==
False
:
def
update_var_dist
(
var_dist
):
if
batch_generator
is
not
None
:
for
name
in
real_names
:
var
=
scope
.
find_var
(
name
)
if
var
is
not
None
:
var_array
=
np
.
array
(
var
.
get_tensor
())
var_dist
[
name
]
=
var_array
else
:
_logger
.
info
(
"can't find var {} in scope."
.
format
(
name
))
return
var_dist
var_dist
=
{}
if
weight_only
:
var_dist
=
update_var_dist
(
var_dist
)
else
:
assert
isinstance
(
reader
,
types
.
GeneratorType
)
or
isinstance
(
reader
,
fluid
.
reader
.
DataLoaderBase
),
"when var_names include activations'name, reader must be either a python generator or a fluid dataloader."
assert
executor
is
not
None
,
"when var_names include activations'name, executor must be set"
if
isinstance
(
reader
,
types
.
GeneratorType
):
assert
feed_vars
is
not
None
,
"When using batch_generator, feed_vars must be set"
assert
feed_vars
is
not
None
,
"When using batch_generator, feed_vars must be set"
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
dataloader
=
fluid
.
io
.
DataLoader
.
from_generator
(
feed_list
=
feed_vars
,
capacity
=
512
,
iterable
=
True
)
feed_list
=
feed_vars
,
capacity
=
128
,
iterable
=
True
)
dataloader
.
set_batch_generator
(
batch_generato
r
,
executor
.
place
)
dataloader
.
set_batch_generator
(
reade
r
,
executor
.
place
)
elif
data_loader
is
not
None
:
elif
isinstance
(
reader
,
fluid
.
reader
.
DataLoaderBase
)
:
dataloader
=
data_lo
ader
dataloader
=
re
ader
else
:
else
:
_logger
.
info
(
_logger
.
info
(
"When both batch_generator and data_loader is None, var_names can only include weight names"
"When both batch_generator and data_loader is None, var_names can only include weight names"
)
)
return
return
assert
executor
is
not
None
,
"when var_names include activations'name, executor must be set"
assert
fetch_list
is
not
None
,
"when var_names include activations'name,, executor must be set"
for
data
in
dataloader
:
for
data
in
dataloader
:
executor
.
run
(
program
=
program
,
executor
.
run
(
program
=
program
,
feed
=
data
)
feed
=
data
,
var_dist
=
update_var_dist
(
var_dist
)
fetch_list
=
fetch_list
,
return_numpy
=
False
)
break
break
res_np
=
{}
for
var
in
var_changed
:
for
name
in
real_names
:
var
.
persistable
=
False
var
=
fluid
.
global_scope
().
find_var
(
name
)
if
var
is
not
None
:
return
var_dist
res_np
[
name
]
=
np
.
array
(
var
.
get_tensor
())
else
:
_logger
.
info
(
def
pdf
(
var_dist
,
pdf_save_dir
=
'var_dist_pdf'
):
"can't find var {}. Maybe you should set one of batch_generator and data_loader"
.
"""
format
(
name
))
Draw hist for distributtion of variables in that in var_dist.
numbers
=
len
(
real_names
)
Args:
var_dist(dict): numpy array of variables distribution.
pdf_save_dir(str): dirname to save pdf. Default is 'var_dist_pdf'
"""
numbers
=
len
(
var_dist
)
if
pdf_save_dir
is
not
None
:
if
pdf_save_dir
is
not
None
:
if
not
os
.
path
.
exists
(
pdf_save_dir
):
if
not
os
.
path
.
exists
(
pdf_save_dir
):
os
.
mkdir
(
pdf_save_dir
)
os
.
mkdir
(
pdf_save_dir
)
pdf_path
=
os
.
path
.
join
(
pdf_save_dir
,
'result.pdf'
)
pdf_path
=
os
.
path
.
join
(
pdf_save_dir
,
'result.pdf'
)
with
PdfPages
(
pdf_path
)
as
pdf
:
with
PdfPages
(
pdf_path
)
as
pdf
:
idx
=
1
for
i
,
name
in
enumerate
(
var_dist
.
keys
()):
for
name
in
res_np
.
keys
():
if
i
%
10
==
0
:
if
idx
%
10
==
0
:
_logger
.
info
(
"plt {}/{}"
.
format
(
i
,
numbers
))
_logger
.
info
(
"plt {}/{}"
.
format
(
idx
,
numbers
))
arr
=
var_dist
[
name
]
arr
=
res_np
[
name
]
arr
=
arr
.
flatten
()
arr
=
arr
.
flatten
()
weights
=
np
.
ones_like
(
arr
)
/
len
(
arr
)
weights
=
np
.
ones_like
(
arr
)
/
len
(
arr
)
plt
.
hist
(
arr
,
bins
=
1000
,
weights
=
weights
)
plt
.
hist
(
arr
,
bins
=
1000
,
weights
=
weights
)
...
@@ -123,5 +137,4 @@ def pdf(program,
...
@@ -123,5 +137,4 @@ def pdf(program,
plt
.
show
()
plt
.
show
()
pdf
.
savefig
()
pdf
.
savefig
()
plt
.
close
()
plt
.
close
()
idx
+=
1
_logger
.
info
(
"variables histogram have been saved as {}"
.
format
(
pdf_path
))
return
res_np
paddleslim/quant/__init__.py
浏览文件 @
6992fd6c
...
@@ -29,3 +29,4 @@ except Exception as e:
...
@@ -29,3 +29,4 @@ except Exception as e:
"please use Paddle >= 2.0.0 or develop version"
)
"please use Paddle >= 2.0.0 or develop version"
)
from
.quant_embedding
import
quant_embedding
from
.quant_embedding
import
quant_embedding
from
.utility
import
pact_thres
\ No newline at end of file
paddleslim/quant/utility.py
0 → 100755
浏览文件 @
6992fd6c
import
logging
import
numpy
as
np
from
..common
import
get_logger
_logger
=
get_logger
(
__name__
,
level
=
logging
.
INFO
)
def
pact_thres
(
var_dist
,
q
=
100
):
"""
Compute the qth percentile threshold of the data in var_dist.
Args:
var_dist(dict): numpy array of variables distribution.
q(float): Percentile to compute which must be between 0 and 100 inclusive. Default is 100.
Returns:
dict: the qth percentile of the array element in var_dist.
"""
var_percentile
=
{}
for
var_name
in
var_dist
.
keys
():
var
=
var_dist
[
var_name
]
var
=
var
.
flatten
()
var
=
np
.
abs
(
var
)
try
:
var_percentile
[
var_name
]
=
np
.
percentile
(
var
,
q
)
except
:
_logger
.
info
(
'{} is empty in this program'
.
format
(
var_name
))
return
var_percentile
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录