Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !23163

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
接近 2 年 前同步成功

通知 2323
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Add support for attr type Op and add fill_constant Op and scale Op !23163

  • Report abuse
!23163 已合并 3月 23, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007f2b7ee39990>
  • 概览 10
  • 提交 4
  • 变更 5

Created by: wangchaochaohu

  1. Add support for attr cal in Fusion group
  2. fill_constant OP support
  3. scale Op support

定义了一种新的取值方式,整体思路对于之前实现的${0}这种数字代表输入的方式加强,添加了表达式中可以对Attr 进行操作,目前我们不支持Attr为Tensor 这种方式, 因为Tensor的输入是Runtime的时候决定的,对于Fusion Group 来说状态不可活的。 我们通过${Attr}这种方式来标识Op的Attr,再code_generator中获取Attr 并传递给code_generator_helper中进行处理,得到最终的计算表达式。

I0325 06:55:08.236503 50258 fusion_group_pass.cc:56] subgraph: {
  Node(Op(fill_constant), inputs:{ShapeTensor[], ShapeTensorList[]}, outputs:{Out[fill_constant_0.tmp_0]}), inputs:{}, outputs:{fill_constant_0.tmp_0}.
    Node(data0{2x2}), inputs:{}, outputs:{elementwise_add}
    Node(data1{2x2}), inputs:{}, outputs:{elementwise_add}
    Node(fill_constant_0.tmp_0{2x2}), inputs:{fill_constant}, outputs:{scale}
  Node(Op(elementwise_add), inputs:{X[data0], Y[data1]}, outputs:{Out[elementwise_add_1]}), inputs:{data0, data1}, outputs:{elementwise_add_1}.
  Node(Op(scale), inputs:{ScaleTensor[], X[fill_constant_0.tmp_0]}, outputs:{Out[scale_0.tmp_0]}), inputs:{fill_constant_0.tmp_0}, outputs:{scale_0.tmp_0}.
    Node(elementwise_add_1{2x2}), inputs:{elementwise_add}, outputs:{elementwise_mul, elementwise_mul_grad}
    Node(scale_0.tmp_0{2x2}), inputs:{scale}, outputs:{elementwise_mul, elementwise_mul_grad}
  Node(Op(elementwise_mul), inputs:{X[scale_0.tmp_0], Y[elementwise_add_1]}, outputs:{Out[elementwise_mul_0]}), inputs:{scale_0.tmp_0, elementwise_add_1}, outputs:{elementwise_mul_0}.
    Node(elementwise_mul_0{2x2}), inputs:{elementwise_mul}, outputs:{mean, mean_grad}
}
extern "C" __global__ void FusedElementwise2(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5) {
  for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
      idx < N;
      idx += gridDim.x * blockDim.x) {
    float tmp0 = arg0[idx];
    float tmp1 = arg1[idx];
    float tmp2 = static_cast<float>(2.0);
    float tmp3 = tmp0 + tmp1;
    float tmp4 = true ? (3 * tmp2 + 1) : (3 * (tmp2 + 1));
    float tmp5 = tmp4 * tmp3;
    arg2[idx] = tmp2;
    arg3[idx] = tmp3;
    arg4[idx] = tmp4;
    arg5[idx] = tmp5;
  }
}

language_model上测试static large model 没有性能提升

指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!23163
Source branch: github/fork/wangchaochaohu/fill_constant_fusion
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7