Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !21126

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Enable generating code for a given subgraph. !21126

  • Report abuse
!21126 已合并 11月 12, 2019 由 saxon_zh@saxon_zh 创建
#<User:0x00007ff81b45d998>
  • 概览 1
  • 提交 10
  • 变更 15

Created by: Xreki

Support generating code for a given subgraph.

Examples of generated codes

  • example 1, generating code for a given list of expressions, each expression represents a forward operation.
extern "C" __global__ void elementwise_kernel_0(int N, float* arg0, float* arg1, float* arg3, float* arg5, float* arg2, float* arg4, float* arg6, float* arg7, float* arg8) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp3 = arg3[idx];
     float tmp5 = arg5[idx];
     float tmp2 = tmp0 * tmp1;
     float tmp4 = tmp2 + tmp3;
     float tmp6 = tmp4 - tmp5;
     float tmp7 = real_max(tmp6, 0);
     float tmp8 = 1.0 / (1.0 + real_exp(- tmp7));
     arg2[idx] = tmp2;
     arg4[idx] = tmp4;
     arg6[idx] = tmp6;
     arg7[idx] = tmp7;
     arg8[idx] = tmp8;
   }
 }
  • example 2, generating code for a given list of expressions, each expression represents a backward operation.
extern "C" __global__ void elementwise_grad_kernel_0(int N, float* arg0, float* arg1, float* arg2, float* arg7, float* arg4, float* arg5, float* arg6) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp2 = arg2[idx];
     float tmp7 = arg7[idx];
     float tmp6 = tmp2 > 0 ? tmp7 : 0;
     float tmp4 = tmp6 * tmp1;
     float tmp5 = tmp6 * tmp0;
     arg4[idx] = tmp4;
     arg5[idx] = tmp5;
     arg6[idx] = tmp6;
   }
 }
  • example 3, generating code for a given subgraph, each expression represents a forward operation.
extern "C" __global__ void elementwise_kernel_1(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7, float* arg8) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp0 = arg0[idx];
     float tmp1 = arg1[idx];
     float tmp2 = arg2[idx];
     float tmp3 = arg3[idx];
     float tmp4 = 1.0 / (1.0 + real_exp(- tmp0));
     float tmp7 = tmp4 * tmp1;
     float tmp5 = 2.0 / (1.0 + real_exp(-2.0 * tmp2)) - 1.0;
     float tmp6 = tmp3 * tmp5;
     float tmp8 = tmp7 + tmp6;
     arg4[idx] = tmp4;
     arg5[idx] = tmp5;
     arg6[idx] = tmp6;
     arg7[idx] = tmp7;
     arg8[idx] = tmp8;
   }
 }
  • example 4, generating code for a given subgraph, each expression represents a backward operation.
extern "C" __global__ void elementwise_grad_kernel_1(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7, float* arg8, float* arg9, float* arg10, float* arg11, float* arg12, float* arg13, float* arg14, float* arg15, float* arg16, float* arg17) {
   for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
       idx < N;
       idx += gridDim.x * blockDim.x) {
     float tmp2 = arg2[idx];
     float tmp4 = arg4[idx];
     float tmp5 = arg5[idx];
     float tmp6 = arg6[idx];
     float tmp7 = arg7[idx];
     float tmp11 = tmp2;
     float tmp10 = tmp2;
     float tmp16 = tmp10 * tmp5;
     float tmp12 = tmp10 * tmp4;
     float tmp13 = tmp11 * tmp6;
     float tmp17 = tmp11 * tmp7;
     float tmp15 = tmp12 * (1.0 - tmp5 * tmp5);
     float tmp14 = tmp13 * tmp7 * (1.0 - tmp7);
     arg10[idx] = tmp10;
     arg11[idx] = tmp11;
     arg12[idx] = tmp12;
     arg13[idx] = tmp13;
     arg14[idx] = tmp14;
     arg15[idx] = tmp15;
     arg16[idx] = tmp16;
     arg17[idx] = tmp17;
   }
 }
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!21126
Source branch: github/fork/Xreki/pass_subgraph_generate
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7