Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleClas
提交
c5cf3c15
P
PaddleClas
项目概览
PaddlePaddle
/
PaddleClas
大约 1 年 前同步成功
通知
115
Star
4999
Fork
1114
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
19
列表
看板
标记
里程碑
合并请求
6
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleClas
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
19
Issue
19
列表
看板
标记
里程碑
合并请求
6
合并请求
6
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
c5cf3c15
编写于
9月 22, 2020
作者:
G
gaotingquan
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix export_model to support dygraph
上级
6589b2a8
变更
1
隐藏空白更改
内联
并排
Showing
1 changed file
with
25 addition
and
38 deletion
+25
-38
tools/export_model.py
tools/export_model.py
+25
-38
未找到文件。
tools/export_model.py
浏览文件 @
c5cf3c15
...
...
@@ -15,7 +15,10 @@
import
argparse
from
ppcls.modeling
import
architectures
import
paddle.fluid
as
fluid
from
ppcls.utils.save_load
import
load_dygraph_pretrain
import
paddle
import
paddle.nn.functional
as
F
from
paddle.jit
import
to_static
def
parse_args
():
...
...
@@ -24,54 +27,38 @@ def parse_args():
parser
.
add_argument
(
"-p"
,
"--pretrained_model"
,
type
=
str
)
parser
.
add_argument
(
"-o"
,
"--output_path"
,
type
=
str
)
parser
.
add_argument
(
"--class_dim"
,
type
=
int
,
default
=
1000
)
parser
.
add_argument
(
"--img_size"
,
type
=
int
,
default
=
224
)
#
parser.add_argument("--img_size", type=int, default=224)
return
parser
.
parse_args
()
def
create_input
(
img_size
=
224
):
image
=
fluid
.
data
(
name
=
'image'
,
shape
=
[
None
,
3
,
img_size
,
img_size
],
dtype
=
'float32'
)
return
image
class
Net
(
paddle
.
nn
.
Layer
):
def
__init__
(
self
,
net
,
to_static
,
class_dim
):
super
(
Net
,
self
).
__init__
()
self
.
pre_net
=
net
(
class_dim
=
class_dim
)
self
.
to_static
=
to_static
def
create_model
(
args
,
model
,
input
,
class_dim
=
1000
):
if
args
.
model
==
"GoogLeNet"
:
out
,
_
,
_
=
model
.
net
(
input
=
input
,
class_dim
=
class_dim
)
else
:
out
=
model
.
net
(
input
=
input
,
class_dim
=
class_dim
)
out
=
fluid
.
layers
.
softmax
(
out
)
return
out
# 请根据实际需求修改shape
@
to_static
(
input_spec
=
[
paddle
.
static
.
InputSpec
(
shape
=
[
None
,
3
,
224
,
224
],
dtype
=
'float32'
)
])
def
forward
(
self
,
inputs
):
x
=
self
.
pre_net
(
inputs
)
x
=
F
.
softmax
(
x
)
return
x
def
main
():
args
=
parse_args
()
model
=
architectures
.
__dict__
[
args
.
model
]()
place
=
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
startup_prog
=
fluid
.
Program
()
infer_prog
=
fluid
.
Program
()
with
fluid
.
program_guard
(
infer_prog
,
startup_prog
):
with
fluid
.
unique_name
.
guard
():
image
=
create_input
(
args
.
img_size
)
out
=
create_model
(
args
,
model
,
image
,
class_dim
=
args
.
class_dim
)
infer_prog
=
infer_prog
.
clone
(
for_test
=
True
)
fluid
.
load
(
program
=
infer_prog
,
model_path
=
args
.
pretrained_model
,
executor
=
exe
)
paddle
.
disable_static
()
net
=
architectures
.
__dict__
[
args
.
model
]
fluid
.
io
.
save_inference_model
(
dirname
=
args
.
output_path
,
feeded_var_names
=
[
image
.
name
],
main_program
=
infer_prog
,
target_vars
=
out
,
executor
=
exe
,
model_filename
=
'model'
,
params_filename
=
'params'
)
model
=
Net
(
net
,
to_static
,
args
.
class_dim
)
para_state_dict
=
paddle
.
io
.
load_program_state
(
args
.
pretrained_model
)
load_dygraph_pretrain
(
model
,
args
.
pretrained_model
,
True
)
paddle
.
jit
.
save
(
model
,
args
.
output_path
)
if
__name__
==
"__main__"
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录