diff --git a/tensorflow2fluid/README.md b/tensorflow2fluid/README.md index 51883eb3d38ea0dd3030dc7531d8407f68f7a7de..7b93fdd9eae460c56b684b8a4d41c7c5bea83cfe 100644 --- a/tensorflow2fluid/README.md +++ b/tensorflow2fluid/README.md @@ -75,3 +75,49 @@ save_var.list|模型载入过程中的变量list > 3. 模型需要加载的参数列表为save_var.list 仍然以上面转换后的vgg_16为例,下面通过示例展示如何加载模型,并进行预测 + +``` +#coding:utf-8 +# paddle_vgg为转换后模型存储路径 +from paddle_vgg.mymodel import KitModel +import paddle.fluid as fluid +import numpy + +def model_initialize(): + # 构建模型结构,并初始化参数 + result = KitModel() + exe = fluid.Executor(fluid.CPUPlace()) + exe.run(fluid.default_startup_program()) + + # 根据save_var.list列表,加载模型参数 + var_list = list() + global_block = fluid.default_main_program().global_block() + with open('paddle_vgg/save_var.list') as f: + for line in f: + try: + # 过滤部分不需要加载的参数(OP配置参数) + var = global_block.var(line.strip()) + var_list.append(var) + except: + pass + fluid.io.load_vars(exe, 'paddle_vgg', vars=var_list) + + prog = fluid.default_main_program() + return exe, prog, result + +def test_case(exe, prog, result): + # 测试随机数据输入 + numpy.random.seed(13) + img_data = numpy.random.rand(1, 224, 224, 3) + # tf中输入为NHWC,PaddlePaddle则为NCHW,需transpose + img_data = numpy.transpose(img_data, (0, 3, 1, 2)) + + # input_0为输入数据的张量名,张量名和数据类型须与my_model.py中定义一致 + r, = exe.run(fluid.default_main_program(), + feed={'input_0':numpy.array(img_data, dtype='float32')}, + fetch_list=[result]) + +if __name__ == "__main__": + exe, prog, result = model_initialize() + test_case(exe, prog, result) +```