Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleSlim
提交
bf123166
P
PaddleSlim
项目概览
PaddlePaddle
/
PaddleSlim
大约 2 年 前同步成功
通知
51
Star
1434
Fork
344
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
53
列表
看板
标记
里程碑
合并请求
16
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleSlim
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
53
Issue
53
列表
看板
标记
里程碑
合并请求
16
合并请求
16
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bf123166
编写于
12月 27, 2021
作者:
G
Guanghua Yu
提交者:
GitHub
12月 27, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add post_quant CE (#961)
上级
06175aea
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
43 addition
and
7 deletion
+43
-7
demo/dygraph/post_quant/ptq.py
demo/dygraph/post_quant/ptq.py
+20
-6
demo/quant/quant_embedding/train.py
demo/quant/quant_embedding/train.py
+13
-0
demo/quant/quant_post/quant_post.py
demo/quant/quant_post/quant_post.py
+10
-1
未找到文件。
demo/dygraph/post_quant/ptq.py
浏览文件 @
bf123166
...
@@ -20,6 +20,7 @@ import contextlib
...
@@ -20,6 +20,7 @@ import contextlib
import
os
import
os
import
time
import
time
import
math
import
math
import
random
import
numpy
as
np
import
numpy
as
np
from
PIL
import
Image
from
PIL
import
Image
...
@@ -28,7 +29,6 @@ from paddle.io import Dataset
...
@@ -28,7 +29,6 @@ from paddle.io import Dataset
from
paddle.vision.transforms
import
transforms
from
paddle.vision.transforms
import
transforms
import
paddle.vision.models
as
models
import
paddle.vision.models
as
models
import
paddle.nn
as
nn
import
paddle.nn
as
nn
from
paddleslim
import
PTQ
from
paddleslim
import
PTQ
import
sys
import
sys
...
@@ -41,7 +41,6 @@ from models.dygraph.mobilenet_v3 import MobileNetV3_large_x1_0
...
@@ -41,7 +41,6 @@ from models.dygraph.mobilenet_v3 import MobileNetV3_large_x1_0
class
ImageNetValDataset
(
Dataset
):
class
ImageNetValDataset
(
Dataset
):
def
__init__
(
self
,
data_dir
,
image_size
=
224
,
resize_short_size
=
256
):
def
__init__
(
self
,
data_dir
,
image_size
=
224
,
resize_short_size
=
256
):
super
(
ImageNetValDataset
,
self
).
__init__
()
super
(
ImageNetValDataset
,
self
).
__init__
()
train_file_list
=
os
.
path
.
join
(
data_dir
,
'train_list.txt'
)
val_file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
val_file_list
=
os
.
path
.
join
(
data_dir
,
'val_list.txt'
)
test_file_list
=
os
.
path
.
join
(
data_dir
,
'test_list.txt'
)
test_file_list
=
os
.
path
.
join
(
data_dir
,
'test_list.txt'
)
self
.
data_dir
=
data_dir
self
.
data_dir
=
data_dir
...
@@ -68,9 +67,9 @@ class ImageNetValDataset(Dataset):
...
@@ -68,9 +67,9 @@ class ImageNetValDataset(Dataset):
return
len
(
self
.
data
)
return
len
(
self
.
data
)
def
calibrate
(
model
,
dataset
,
batch_num
,
batch_size
):
def
calibrate
(
model
,
dataset
,
batch_num
,
batch_size
,
num_workers
=
5
):
data_loader
=
paddle
.
io
.
DataLoader
(
data_loader
=
paddle
.
io
.
DataLoader
(
dataset
,
batch_size
=
batch_size
,
num_workers
=
5
)
dataset
,
batch_size
=
batch_size
,
num_workers
=
num_workers
)
for
idx
,
data
in
enumerate
(
data_loader
()):
for
idx
,
data
in
enumerate
(
data_loader
()):
img
=
data
[
0
]
img
=
data
[
0
]
...
@@ -85,6 +84,15 @@ def calibrate(model, dataset, batch_num, batch_size):
...
@@ -85,6 +84,15 @@ def calibrate(model, dataset, batch_num, batch_size):
def
main
():
def
main
():
num_workers
=
5
if
FLAGS
.
ce_test
:
# set seed
seed
=
111
paddle
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
num_workers
=
0
# 1 load model
# 1 load model
model_list
=
[
x
for
x
in
models
.
__dict__
[
"__all__"
]]
model_list
=
[
x
for
x
in
models
.
__dict__
[
"__all__"
]]
model_list
.
append
(
'mobilenet_v3'
)
model_list
.
append
(
'mobilenet_v3'
)
...
@@ -126,8 +134,12 @@ def main():
...
@@ -126,8 +134,12 @@ def main():
quant_model
=
ptq
.
quantize
(
fp32_model
,
fuse
=
FLAGS
.
fuse
,
fuse_list
=
fuse_list
)
quant_model
=
ptq
.
quantize
(
fp32_model
,
fuse
=
FLAGS
.
fuse
,
fuse_list
=
fuse_list
)
print
(
"Start calibrate..."
)
print
(
"Start calibrate..."
)
calibrate
(
quant_model
,
val_dataset
,
FLAGS
.
quant_batch_num
,
calibrate
(
FLAGS
.
quant_batch_size
)
quant_model
,
val_dataset
,
FLAGS
.
quant_batch_num
,
FLAGS
.
quant_batch_size
,
num_workers
=
num_workers
)
# 3 save
# 3 save
quant_output_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
FLAGS
.
model
,
"int8_infer"
,
quant_output_dir
=
os
.
path
.
join
(
FLAGS
.
output_dir
,
FLAGS
.
model
,
"int8_infer"
,
...
@@ -172,6 +184,8 @@ if __name__ == '__main__':
...
@@ -172,6 +184,8 @@ if __name__ == '__main__':
"--quant_batch_num"
,
default
=
10
,
type
=
int
,
help
=
"batch num for quant"
)
"--quant_batch_num"
,
default
=
10
,
type
=
int
,
help
=
"batch num for quant"
)
parser
.
add_argument
(
parser
.
add_argument
(
"--quant_batch_size"
,
default
=
10
,
type
=
int
,
help
=
"batch size for quant"
)
"--quant_batch_size"
,
default
=
10
,
type
=
int
,
help
=
"batch size for quant"
)
parser
.
add_argument
(
'--ce_test'
,
default
=
False
,
type
=
bool
,
help
=
"Whether to CE test."
)
FLAGS
=
parser
.
parse_args
()
FLAGS
=
parser
.
parse_args
()
assert
FLAGS
.
data
,
"error: must provide data path"
assert
FLAGS
.
data
,
"error: must provide data path"
...
...
demo/quant/quant_embedding/train.py
浏览文件 @
bf123166
...
@@ -10,6 +10,7 @@ import paddle
...
@@ -10,6 +10,7 @@ import paddle
import
six
import
six
import
reader
import
reader
from
net
import
skip_gram_word2vec
from
net
import
skip_gram_word2vec
import
paddle
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logging
.
basicConfig
(
format
=
'%(asctime)s - %(levelname)s - %(message)s'
)
logger
=
logging
.
getLogger
(
"paddle"
)
logger
=
logging
.
getLogger
(
"paddle"
)
...
@@ -77,6 +78,12 @@ def parse_args():
...
@@ -77,6 +78,12 @@ def parse_args():
required
=
False
,
required
=
False
,
default
=
False
,
default
=
False
,
help
=
'print speed or not , (default: False)'
)
help
=
'print speed or not , (default: False)'
)
parser
.
add_argument
(
'--ce_test'
,
required
=
False
,
default
=
False
,
help
=
'Whether to CE test, (default: False)'
)
return
parser
.
parse_args
()
return
parser
.
parse_args
()
...
@@ -185,6 +192,12 @@ def GetFileList(data_path):
...
@@ -185,6 +192,12 @@ def GetFileList(data_path):
def
train
(
args
):
def
train
(
args
):
if
args
.
ce_test
:
# set seed
seed
=
111
paddle
.
seed
(
seed
)
np
.
random
.
seed
(
seed
)
random
.
seed
(
seed
)
if
not
os
.
path
.
isdir
(
args
.
model_output_dir
):
if
not
os
.
path
.
isdir
(
args
.
model_output_dir
):
os
.
mkdir
(
args
.
model_output_dir
)
os
.
mkdir
(
args
.
model_output_dir
)
...
...
demo/quant/quant_post/quant_post.py
浏览文件 @
bf123166
...
@@ -6,7 +6,9 @@ import argparse
...
@@ -6,7 +6,9 @@ import argparse
import
functools
import
functools
import
math
import
math
import
time
import
time
import
random
import
numpy
as
np
import
numpy
as
np
import
paddle
sys
.
path
[
0
]
=
os
.
path
.
join
(
sys
.
path
[
0
]
=
os
.
path
.
join
(
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
os
.
path
.
dirname
(
"__file__"
),
os
.
path
.
pardir
,
os
.
path
.
pardir
)
...
@@ -29,12 +31,13 @@ add_arg('params_filename', str, None, "params file name")
...
@@ -29,12 +31,13 @@ add_arg('params_filename', str, None, "params file name")
add_arg
(
'algo'
,
str
,
'hist'
,
"calibration algorithm"
)
add_arg
(
'algo'
,
str
,
'hist'
,
"calibration algorithm"
)
add_arg
(
'hist_percent'
,
float
,
0.9999
,
"The percentile of algo:hist"
)
add_arg
(
'hist_percent'
,
float
,
0.9999
,
"The percentile of algo:hist"
)
add_arg
(
'bias_correction'
,
bool
,
False
,
"Whether to use bias correction"
)
add_arg
(
'bias_correction'
,
bool
,
False
,
"Whether to use bias correction"
)
add_arg
(
'ce_test'
,
bool
,
False
,
"Whether to CE test."
)
# yapf: enable
# yapf: enable
def
quantize
(
args
):
def
quantize
(
args
):
val_reader
=
reader
.
train
()
val_reader
=
reader
.
val
()
place
=
paddle
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
paddle
.
CPUPlace
()
place
=
paddle
.
CUDAPlace
(
0
)
if
args
.
use_gpu
else
paddle
.
CPUPlace
()
...
@@ -59,6 +62,12 @@ def quantize(args):
...
@@ -59,6 +62,12 @@ def quantize(args):
def
main
():
def
main
():
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
print_arguments
(
args
)
print_arguments
(
args
)
if
args
.
ce_test
:
# set seed
seed
=
111
np
.
random
.
seed
(
seed
)
paddle
.
seed
(
seed
)
random
.
seed
(
seed
)
quantize
(
args
)
quantize
(
args
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录