Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !22771

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

add sum op support for fusion group !22771

  • Report abuse
!22771 已合并 2月 26, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007fed5cef7440>
  • 概览 32
  • 提交 9
  • 变更 7

Created by: wangchaochaohu

fusion group now support simple elementwise OP. In this PR, we add sum op support for fusion group.

  1. add codegen for sum op in a special way
  2. add support of sum fusion in fusion group detection pattern
  3. add some unittest for sum fusion
W0228 04:12:07.934985 35030 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0228 04:12:07.940497 35030 device_context.cc:245] device: 0, cuDNN Version: 7.5.
I0228 04:12:09.940064 35030 fusion_group_pass.cc:56] origin: Graph: {
    Node(sum_1.tmp_0{32x128}), inputs:{sum}, outputs:{}
    Node(elementwise_add_0{32x128}), inputs:{elementwise_add}, outputs:{sum}
    Node(data0{32x128}), inputs:{}, outputs:{elementwise_add}
    Node(data1{32x128}), inputs:{}, outputs:{elementwise_add}
  Node(Op(elementwise_add), inputs:{X[data0], Y[data1]}, outputs:{Out[elementwise_add_0]}), inputs:{data0, data1}, outputs:{elementwise_add_0}.
    Node(sum_0.tmp_0{32x128}), inputs:{sum}, outputs:{sum}
    Node(data3{32x128}), inputs:{}, outputs:{sum}
    Node(data2{32x128}), inputs:{}, outputs:{sum}
  Node(Op(sum), inputs:{X[elementwise_add_0, data2, data3]}, outputs:{Out[sum_0.tmp_0]}), inputs:{elementwise_add_0, data2, data3}, outputs:{sum_0.tmp_0}.
  Node(Op(sum), inputs:{X[sum_0.tmp_0, data4]}, outputs:{Out[sum_1.tmp_0]}), inputs:{sum_0.tmp_0, data4}, outputs:{sum_1.tmp_0}.
    Node(data4{32x128}), inputs:{}, outputs:{sum}
}
I0228 04:12:10.145560 35030 fusion_group_pass.cc:60] after code gen: Graph: {
    Node(sum_1.tmp_0{32x128}), inputs:{sum}, outputs:{}
    Node(elementwise_add_0{32x128}), inputs:{elementwise_add}, outputs:{sum}
    Node(data0{32x128}), inputs:{}, outputs:{elementwise_add}
    Node(data1{32x128}), inputs:{}, outputs:{elementwise_add}
  Node(Op(elementwise_add), inputs:{X[data0], Y[data1]}, outputs:{Out[elementwise_add_0]}), inputs:{data0, data1}, outputs:{elementwise_add_0}.
    Node(sum_0.tmp_0{32x128}), inputs:{sum}, outputs:{sum}
    Node(data3{32x128}), inputs:{}, outputs:{sum}
    Node(data2{32x128}), inputs:{}, outputs:{sum}
  Node(Op(sum), inputs:{X[elementwise_add_0, data2, data3]}, outputs:{Out[sum_0.tmp_0]}), inputs:{elementwise_add_0, data2, data3}, outputs:{sum_0.tmp_0}.
  Node(Op(sum), inputs:{X[sum_0.tmp_0, data4]}, outputs:{Out[sum_1.tmp_0]}), inputs:{sum_0.tmp_0, data4}, outputs:{sum_1.tmp_0}.
    Node(data4{32x128}), inputs:{}, outputs:{sum}
}
I0228 04:12:10.145742 35030 fusion_group_pass.cc:63] fusion group:Graph: {
  Node(Op(fusion_group), inputs:{Inputs[data2, data3, data4, data0, data1]}, outputs:{Outs[elementwise_add_0, sum_0.tmp_0, sum_1.tmp_0]}), inputs:{data2, data3, data4, data0, data1}, outputs:{elementwise_add_0, sum_0.tmp_0, sum_1.tmp_0}.
    Node(sum_1.tmp_0{32x128}), inputs:{fusion_group}, outputs:{}
    Node(elementwise_add_0{32x128}), inputs:{fusion_group}, outputs:{}
    Node(data0{32x128}), inputs:{}, outputs:{fusion_group}
    Node(data1{32x128}), inputs:{}, outputs:{fusion_group}
    Node(sum_0.tmp_0{32x128}), inputs:{fusion_group}, outputs:{}
    Node(data3{32x128}), inputs:{}, outputs:{fusion_group}
    Node(data2{32x128}), inputs:{}, outputs:{fusion_group}
    Node(data4{32x128}), inputs:{}, outputs:{fusion_group}
}

Gen Code

__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }


extern "C" __global__ void fused_elementwise_0(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7) {
  for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
      idx < N;
      idx += gridDim.x * blockDim.x) {
    float tmp0 = arg0[idx];
    float tmp1 = arg1[idx];
    float tmp2 = arg2[idx];
    float tmp3 = arg3[idx];
    float tmp4 = arg4[idx];
    float tmp5 = tmp3 + tmp4;
    float tmp6 = tmp5 + tmp0 + tmp1;
    float tmp7 = tmp6 + tmp2;
    arg5[idx] = tmp5;
    arg6[idx] = tmp6;
    arg7[idx] = tmp7;
  }
}
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!22771
Source branch: github/fork/wangchaochaohu/sum_fusion
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7