Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !27331

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Add new paddle.save/load APIs !27331

  • Report abuse
!27331 已合并 9月 15, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007f0ef9464420>
  • 概览 0
  • 提交 11
  • 变更 15

Created by: chenwhql

PR types

New features

PR changes

APIs

Describe

该PR添加新的paddle.save/load APIs.

一、背景

在paddle 2.0中,一级接口paddle.save/load复用了原先动态图接口fluid.dygraph.save_dygraph/load_dygraph的实现,虽然这两个旧接口原先的实现能够满足基础的动态图存储模型及优化器参数的需求,但是如果作为paddle 2.0的一级接口,在概念范围上需要更加普适,更加便于扩展。

paddle.save/load是框架的基础存储载入接口,它们的语义范围应该是save/load一切框架可序列化的对象,而目前复用旧实现的paddle.save/load接口有以下几点问题:

1. 接口语义范围窄,不便于扩展

  • 目前的paddle.save仅支持存储模型或优化器的state_dict,目前的paddle.load仅支持载入模型和优化器的state_dict
  • 我们2.0虽然主推动态图,但静态图仍然是需要支持的,而这两个接口还是仅能在动态图下使用,这是不太合理的,既然是一级接口,静态图下也应该能使用
  • 一级接口的save很可能需要扩展功能,现在仅支持save state_dict,但将来如果我们要扩展支持save tensor, Layer甚至program呢?目前的设计将无法支持,这会给将来留下不兼容隐患或者导致我们需要添加额外的save/load接口

2. save会主动帮用户添加文件后缀

  • save会主动为用户给的文件名添加.pdparams或者.pdopt后缀,这个行为对用户来说比较隐晦,其实不太有必要,也限制了load接口的实现,load实现的时候也不能让用户关注这两个后缀

3. load默认返回两个结果

  • load默认返回模型和优化器的state_dict,如果其中一个找不到,就返回None,这个设计限制了扩展性
  • 一个载入接口,好的设计就是输入一个目标,返回一个载入结果
  • 有一些场景可能用户都只需要载入一个state_dict,比如优化器没有参数需要存储或者仅想从预测模型中载入参数fine-tune

二、主要更改

  1. 添加新实现接口paddle.save/load,而不是以alias的方式使用旧接口fluid.dygraph.save_dygraph/load_dygraph,旧接口仍然保留以原方式使用
    • 新接口概要:
      • paddle.save(obj, path)
      • paddle.load(path, config=None)
      • 仅和原接口在使用行为上有所差别,存储格式完全继承
  2. 移动了以下接口的位置
    • paddle.io.save -> paddle.static.save
    • paddle.io.load -> paddle.static.load
    • paddle.io.save_inference_model -> paddle.static.save_inference_model
    • paddle.io.load_inference_model -> paddle.static.load_inference_model
    • paddle.io.load_program_state -> paddle.static.load_program_state
    • paddle.io.set_program_state -> paddle.static.set_program_state
  3. 移除了以下规划中接口
    • paddle.tensor.save
    • paddle.tensor.load

第1点具体改动及示例:

相比原先的fluid.save_dygraph/load_dygraph,新的paddle.save/load仅有两点变化,以下示例同时也是编码时需要做的改动:

  1. save接口不再为存储对象自动添加后缀,而直接使用用户传入的文件名

    fluid.save_dygraph行为:自动添加后缀

    emb = fluid.dygraph.Embedding([10, 10])
    state_dict = emb.state_dict()
    fluid.save_dygraph( state_dict, "paddle_dy")
    # 存储结果:paddle_dy.pdparams
    
    adam = fluid.optimizer.Adam( learning_rate = fluid.layers.noam_decay( 100, 10000),
                       parameter_list = emb.parameters() )
    state_dict = adam.state_dict()
    fluid.save_dygraph( state_dict, "paddle_dy")
    # 存储结果:paddle_dy.pdopt

    新paddle.save行为:不再添加后缀,用户根据自己需求添加

    emb = paddle.nn.Embedding([10, 10])
    layer_state_dict = emb.state_dict()
    paddle.save(layer_state_dict, "emb.pdparams")
    # 存储结果:emb.params
    
    adam = paddle.optimizer.Adam(
        learning_rate=0.001,
        parameters=emb.parameters())
    opt_state_dict = adam.state_dict()
    paddle.save(opt_state_dict, "adam.pdopt")
    # 存储结果:adam.opt
  2. load接口每次调用仅返回一个object,而不是tuple

    fluid.load_dygraph行为:默认返回tuple

    para_state_dict, opti_state_dict = fluid.load_dygraph("paddle_dy")

    新paddle.save行为:默认一次load仅返回一个结果

    load_layer_state_dict = paddle.load("emb.pdparams")
    load_opt_state_dict = paddle.load("adam.pdopt")

除以上两点外,其他行为均和原接口一致,原先的存储格式保持不变

三、文档

相关中文文档PR:https://github.com/PaddlePaddle/FluidDoc/pull/2669

image

image

image

指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!27331
Source branch: github/fork/chenwhql/saveload/add_new_save_load
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7