Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !25960

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

[Dy2stat] Support InputSpec and Return callable class instance in @declarative !25960

  • Report abuse
!25960 已合并 8月 05, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007f0e4c198190>
  • 概览 56
  • 提交 42
  • 变更 15

Created by: Aurelius84

PR types

New features

PR changes

APIs

Describe

  • 动转静@declarative升级为返回callable的类对象
  • 新增InputSpec概念,支持设置input_spec进行动转静

What's New?

1. declarative返回callable类对象

1.1 背景介绍

此前版本的@declarative装饰函数后,仅返回callable的另一个函数。其接收传递过来的*args,**kwargs参数,执行内部cache的program,返回结果。基本逻辑如下:

def declarative(func):
    def __impl__(*args, **kwargs):
        out = program_translator.get_output(*args, **kwargs)
        return out
   
    return __impl__
之前的方案

此方案可以满足大部分场景,但具有一定的局限性:

  • @declarative无法支持其他参数,接口功能扩展性较差
  • 返回__impl__函数而非类对象,不能友好地暴露属性访问的相关接口
  • 需要搭配ProgramTranslator进行使用,增加了熟悉框架新概念的成本

因此,此PR升级了declarative接口,以类对象作为返回。

升级后的方案

1.2 新功能使用

  • 保持使用方式不变
# 方式一:
@declarative
def foo(x, y):
    return x + y

# 方式二:
foo = declarative(foo)

z = foo(x_var, y_var)
  • 直接通过foo访问属性
# (接上面)

# 1. function相关信息
foo.dygraph_function    # 返回 被装饰函数
foo.to_code()     # 返回 转静态图后的代码
foo._function_spec  # 返回 被装饰函数相关的FunctionSpec对象
foo._function_spec.code  # 返回 被装饰函数的代码
foo._function_spec.args_name    # 返回 被装饰函数的参数列表

# 2. Program信息
foo.program_cache  # 返回 管理被装饰函数转写结果的ProgramCache对象
foo.program_cache.concrete_programs()  # 返回 被装饰函数对应的所有cache的program列表
foo.get_trace_count()  # 返回 已转写的Program的数量
foo.concrete_program  # 返回 最新转写的ConcreteProgram对象
  • 与ProgramTranslator解耦 如上示例,获取被装饰Function的code、program信息,均不依赖ProgramTranslator。 接口使用方式更加简洁、易用。且每个函数单独持有一份ProgramCache,避免交叉。
out = foo(x_var, y_var)  # 返回 执行结果

2. 新增支持InputSpec

2.1 背景介绍

此前@declarative装饰器不支持任何额外的的参数。若要获取转换后的Program,则需要显式的执行一次前向:

# 方式一:
out = foo(fake_x, fake_y)
program_translator.get_program_cache().last()
# 方式二:
program_translator.get_program(foo, fake_x, fake_y)

此方案具有如下局限,或不易用之处:

  • 强依赖显式地fake数据,接口不易用
  • 无法指定input tensor的某些维度为None,无法指定feed layer的name
  • 不支持被装饰函数编译式转换

此PR新增了InputSpec类,并重构了@declarative逻辑,支持input_spec参数来指定feed layer的shape、name信息。

2.2 InputSpec类

InputSpec类似于C++端的VarDesc概念,用于表示一个Tensor的元信息:shape、dtype、name。用户可以通过指定被装饰函数输入参数对应的InputSpec信息,来进行后续Program的推导。

  • 支持从Tensor和numpy中推导
np_var = np.zeros([2, 3]).astype('float32')
var_spec  = InputSpec.from_numpy(np_var, name='x')

# 或者
var = to_variable(np_var)
var_spec = InputSpec.from_variable(var)
  • 支持batch和unbatch操作
batch_var_spec = var_spec.batch(64)  # shape= [64, 2, 3]

unbatch_var_spec = var_spec.unbatch()  # shape= [3]

2.3 @declarative支持input_spec

重构@declarative逻辑,支持input_spec参数,同时也支持后续的横向功能扩展

  • @declarative移除了无法指定额外参数的限制,可根据功能扩展参数。
  • 支持通过input_spec指定Tensor shape为None
# 方式一:
@declarative(input_spec=[InputSpec([None, 10], dtype='float32'),  InputSpec([10], name='y')])
def foo(x, y):
    return x + y

# 方式二:
foo = declarative(foo, input_spec=[InputSpec([None, 10], dtype='float32'),  InputSpec([10], name='y')])
  • 支持编译式静态转换,不依赖fake_input
# x.shape = [None, 10], y.shape=[10]
concrete_program_1 = foo.get_concrete_program(InputSpec([None, 10]), InputSpec([10]))

# build  new program with x.shape = [10], y.shape = [10]
concrete_program_2 = foo.get_concrete_program(InputSpec([10]), InputSpec([10]))
  • 对于非Tensor类型参数的友好支持
# 如果被装饰函数参数中,包含非Tensor类型参数,也支持hashKey计算,得到不同的program
# 由于每个非Tensor值的改变都会触发新Program的构建,框架内部会trace记录已缓存的program
# 超过MAX_TRACE_COUNT=5,则会warning提示用户。
concrete_program_3 = foo.get_concrete_program(InputSpec([None, 10]), InputSpec([10]), c=2)
  • 友好支持嵌套类型的input
# case 1: 输入`l`为一个list
@declarative(input_spec=[[InputSpec([None, 10]), InputSpec([None, 10])]])
    def func_with_list(self, l):
        x, y, int_val = l
        z = x + y
        z = z + int_val
        return z

# case2: 输入`d` 为一个dict
    @declarative(input_spec=[{
        'x': InputSpec([None, 10]),
        'y': InputSpec([None, 10])
    }])
    def func_with_dict(self, d):
        x = d['x']
        y = d['y']
        int_val = d['int_val']

        z = x + y
        z = z + int_val

        return z

 # case 3: 输入为list嵌套dict
    @declarative(input_spec=[[
        InputSpec([None]), {
            'x': InputSpec([None, 10]),
            'y': InputSpec([None, 10])
        }
    ]])
    def func_with_list_dict(self, dl):
        bias = dl[0]
        x = dl[1]['x']
        y = dl[1]['y']

        z = x + y
        z = z + bias

        return z
  • 支持对同一个class的不同函数进行单独装饰,分开存储
class SimpleNet(Layer):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.linear = fluid.dygraph.Linear(10, 3)

    @declarative(input_spec=[InputSpec(shape=[None, 10], dtype='float32')])
    def forward(self, x, a=1, b=2):
        y = self.inner_function(x)
        return y

    # `declarative` is not essential, add it to test for robustness.
    @declarative
    def inner_function(self, x):
        y = self.linear(x)
        return y

    def add_func(self, x, y):
        z = x + y
        return z

# net.forward已在类定义中装饰
# 这里可以单独对net.add_func单独装饰,单独保存
net.add_func = declarative(net.add_func)

TODO

  • 优化InputSpec与 jit.save 接口的搭配使用,提升易用性

文档: image image image image

指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!25960
Source branch: github/fork/Aurelius84/input_spec
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7