未验证 提交 5f06d5d2 编写于 作者: J Jason 提交者: GitHub

Update README.md

上级 896bd6b0
......@@ -75,3 +75,49 @@ save_var.list|模型载入过程中的变量list
> 3. 模型需要加载的参数列表为save_var.list
仍然以上面转换后的vgg_16为例,下面通过示例展示如何加载模型,并进行预测
```
#coding:utf-8
# paddle_vgg为转换后模型存储路径
from paddle_vgg.mymodel import KitModel
import paddle.fluid as fluid
import numpy
def model_initialize():
# 构建模型结构,并初始化参数
result = KitModel()
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
# 根据save_var.list列表,加载模型参数
var_list = list()
global_block = fluid.default_main_program().global_block()
with open('paddle_vgg/save_var.list') as f:
for line in f:
try:
# 过滤部分不需要加载的参数(OP配置参数)
var = global_block.var(line.strip())
var_list.append(var)
except:
pass
fluid.io.load_vars(exe, 'paddle_vgg', vars=var_list)
prog = fluid.default_main_program()
return exe, prog, result
def test_case(exe, prog, result):
# 测试随机数据输入
numpy.random.seed(13)
img_data = numpy.random.rand(1, 224, 224, 3)
# tf中输入为NHWC,PaddlePaddle则为NCHW,需transpose
img_data = numpy.transpose(img_data, (0, 3, 1, 2))
# input_0为输入数据的张量名,张量名和数据类型须与my_model.py中定义一致
r, = exe.run(fluid.default_main_program(),
feed={'input_0':numpy.array(img_data, dtype='float32')},
fetch_list=[result])
if __name__ == "__main__":
exe, prog, result = model_initialize()
test_case(exe, prog, result)
```
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册