Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !3566

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
接近 2 年 前同步成功

通知 2323
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Python wrapper follows tf and pytorch's concept !3566

  • Report abuse
!3566 已关闭 8月 18, 2017 由 saxon_zh@saxon_zh 创建
#<User:0x00007f0ef90d63a8>
  • 概览 11
  • 提交 2
  • 变更 11

Created by: Superjomn

一个简单的wrapper实现,但只实现了核心的部分,整理下一些想法:

基于的大前提

  • python wrapper 不需要以 caffe2 为参考对象,而应该以更流行的 tf 或 pytorch 为参考的基础
    • caffe2 的python语法并不流行;c++ 部分参考之,因为实现简单加速开发;python wrapper 工作量不大,参考 caffe2 没有必要的理由
    • 我们最终面向的是用户,当下 TF 和 pytorch 占主流,不可否认这两个平台的受众占大多数(更流行)
      • 一些基本的思想已经潜移默化成了行业的潮流,比如输入输出必然是 tensor(mxnet也是,流行度前三的平台的选择)
      • 拓扑通过 input/output 的argument 自动创建,而非 caffe2 中的 model.add_op(xxx) 或 net.add_op(xxx)
    • v2 反而比较类似 tf, 每个 layer 类似一个 function,通过 cost 自动推导拓扑等,类似 caffe2 显式有个net,感觉并不一定需要

具体实现细节:

  • 以 op 的inputs 和 outputs 来自动推断拓扑,而非基于 op, 比如 net.add(op) 或者 model.add(op)
  • inputs 和 outputs 必须为 Var , 而非字符串, typeof(input 应该是 Var 而非 str,实现可以是 str 但提供给用户的是概念,Var 里面有更多一些逻辑
  • 根据 target 自动推断涉及的子图,DFS抽取出对应的最小子图, 动态创建 NetOp 来run,除了支持原始的 paddle.trainer.train 之外,提供 paddle.trainer.run(targets, ...) 接口来执行类似 tf.Session.run 的逻辑,子图的执行更自然
  • 所有的 layer 和 op 等 sub-module 全部折叠到 paddle 下(类似tf),比如 paddle.layer.fc 也可以用
import  xxx as pd
fc_out = pd.layer.fc(xxx)
fc_out = pd.fc(xxx)

模块

  • var.py variable 的封装,提供 Var, Var 会统一所有 op, layer 的 inputs 和 outputs 格式
  • op.py 包含所有 op 的实现
    • pybind 的所有 op 都会有 python 的封装以支持更 user-friendly 的语法
  • layer.py layer实现
  • topology.py DFS 子图推断相关
  • session.py 提供Session 存储所有创建的 Var 和 Op, 提供类似 tf.run 的接口
    • 类似 tf.Session, Session 全局只需要一份,可以用 g_session 隐藏掉

只增加了 Var 新概念

由于所有的 op 和 layer 的 inputs 和 outputs 都统一成了 Var

MINIST 使用示例

兼容 v2 的方式

# all the layer, op namespace are imported into pd for short
# the old absolute path should also work
# for example, pd.data -> paddle.layer.data
import paddle.v2 as pd

images = pd.data(name='pixe', type=pd.dense_vector(784))

label = pd.data(name='label', type=pd.integer_value(10))

prediction = pd.fc(input=images, size=10, act=...)
cost = pd.cross_entropy(prediction, label)

optimizer = pd.Momentum(...)

# following are different styles should supported

# v2 style, seems that hard to support multiple sub-model
parameters = pd.parameters.create(cost)
trainer = pd.SGD(cost=[cost], parameters=parameters, update_equation=optimizer)
trainer.train(reader=..., num_passes=5)

v2 的方式貌似很难支持多个sub-model 的执行(不支持子图),必须往 paddle.trainer 里 添加新的接口,这里是此demo核心概念应该支持的另外一种写法

# all the layer, op namespace are imported into pd for short
# the old absolute path should also work
# for example, pd.data -> paddle.layer.data
import paddle.v2 as pd

images = pd.data(name='pixe', type=pd.dense_vector(784))

label = pd.data(name='label', type=pd.integer_value(10))

prediction = pd.fc(input=images, size=10, act=...)
cost = pd.cross_entropy(prediction, label)

optimizer = pd.Momentum(...)

# style with new features borrowed from tf and pytorch
trainer = pd.SGD()
# same as global_variables_initializer
trainer.init_parameters()
# train a sub-model if there has more than one sub-models has different costs like gan
trainer.train(targets=[cost], update_equation=optimizer, reader=...)

infer

# just forward run is supported
# borrowed from tf, forward run a sub-model whose end point is targets
trainer.run(targets=[cost], reader=...)

最后的想法

  • python wrapper 最好能多参考写model同学的建议,特别是比较熟悉/用过 tf, pytorch, mxnet, caffe2 等
  • 写 python wrapper 的 developer, 和用 python wrapper 的user,想法可能差异很大
  • 我们最终面向的用户是同一群人,客观现实,他们在使用的习惯,概念上已经被非常主流的框架同化了
  • 如果提供非主流的用法,可能没有很多的小白用户接受再教育,小白是从其他平台或者了解其他平台用法的过来的,不免会带入其他平台的一些使用习惯
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!3566
Source branch: github/fork/Superjomn/python-wrapper
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7