Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • 合并请求
  • !22876

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板

Cast fusion for fusion group !22876

  • Report abuse
!22876 已合并 3月 05, 2020 由 saxon_zh@saxon_zh 创建
#<User:0x00007fede20052a8>
  • 概览 38
  • 提交 12
  • 变更 15

Created by: wangchaochaohu

1添加fusion group 中的cast实现方式(通过添加fusion group多type 类型进行实现) 生成代码如下

I0305 09:32:55.299682 65282 subgraph_detector.cc:314] 6 node not a trt candidate.
    Node(mul_2.tmp_0{32x128}), inputs:{mul}, outputs:{cast}    {
    Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{elementwise_add}
  Node(Op(cast), inputs:{X[mul_2.tmp_0]}, outputs:{Out[cast_3.tmp_0]}), inputs:{mul_2.tmp_0}, outputs:{cast_3.tmp_0}.
    Node(cast_3.tmp_0{32x128}), inputs:{cast}, outputs:{elementwise_add}
  Node(Op(elementwise_add), inputs:{X[cast_2.tmp_0], Y[cast_3.tmp_0]}, outputs:{Out[tmp_9]}), inputs:{cast_2.tmp_0, cast_3.tmp_0}, outputs:{tmp_9}.
    Node(tmp_9{32x128}), inputs:{elementwise_add}, outputs:{relu}
  Node(Op(relu), inputs:{X[tmp_9]}, outputs:{Out[relu_3.tmp_0]}), inputs:{tmp_9}, outputs:{relu_3.tmp_0}.
    Node(relu_3.tmp_0{32x128}), inputs:{relu}, outputs:{cast}
  Node(Op(cast), inputs:{X[relu_3.tmp_0]}, outputs:{Out[cast_4.tmp_0]}), inputs:{relu_3.tmp_0}, outputs:{cast_4.tmp_0}.
    Node(cast_4.tmp_0{32x128}), inputs:{cast}, outputs:{}
}
}0305 09:32:55.299979 65282 code_generator.cc:278] Encoding input names:mul_2.tmp_0, id:0
I0305 09:32:55.300010 65282 code_generator.cc:278] Encoding input names:cast_2.tmp_0, id:1
I0305 09:32:55.300045 65282 code_generator.cc:311] Ecoding output names:cast_3.tmp_0, id:2
I0305 09:32:55.300074 65282 code_generator.cc:311] Ecoding output names:tmp_9, id:3
I0305 09:32:55.300101 65282 code_generator.cc:311] Ecoding output names:relu_3.tmp_0, id:4
I0305 09:32:55.300135 65282 code_generator.cc:311] Ecoding output names:cast_4.tmp_0, id:5
I0305 09:32:55.300205 65282 code_generator.cc:263] Op(cast), inputs:{0}, outputs:{2}
I0305 09:32:55.300249 65282 code_generator.cc:263] Op(elementwise_add), inputs:{1,2}, outputs:{3}
I0305 09:32:55.300294 65282 code_generator.cc:263] Op(relu), inputs:{3}, outputs:{4}
I0305 09:32:55.300338 65282 code_generator.cc:263] Op(cast), inputs:{4}, outputs:{5}
I0305 09:32:55.300415 65282 fusion_group_pass.cc:70]
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }

#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))

struct __align__(2) __half {
  __device__ __half() { }

 protected:
  unsigned short __x;
};

__device__ __half __float2half(const float f) {
  __half val;
  asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val)

) : "f"(f));
  return val;
}

__device__ float __half2float(const __half h) {
  float val;
  asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
  return val;
}

#undef __HALF_TO_US
#undef __HALF_TO_CUS

typedef __half float16;


extern "C" __global__ void fused_elementwise_6(int N, float16* arg0, float* arg1, float16* arg2, float16* arg3, float16* arg4, float* arg5) {
  for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
      idx < N;
      idx += gridDim.x * blockDim.x) {
    float16 tmp0 = arg0[idx];
    float tmp1 = arg1[idx];
    float16 tmp2 = __float2half(tmp1);
    float half2fp32_tmp0 = __half2float(tmp0);
    float half2fp32_tmp2 = __half2float(tmp2);
    float16 tmp3 = __float2half(half2fp32_tmp0 + half2fp32_tmp2);
    float half2fp32_tmp3 = __half2float(tmp3);
    float16 tmp4 = __float2half(half2fp32_tmp3 > 0 ? half2fp32_tmp3 : 0);
    float half2fp32_tmp4 = __half2float(tmp4);
    float tmp5 = half2fp32_tmp4;
    arg2[idx] = tmp2;
    arg3[idx] = tmp3;
    arg4[idx] = tmp4;
    arg5[idx] = tmp5;
  }
}


I0305 09:32:55.480564 65282 fusion_group_pass.cc:54] subgraph: {
    Node(data0{32x128}), inputs:{}, outputs:{elementwise_mul}
    Node(data1{32x128}), inputs:{}, outputs:{elementwise_mul}
  Node(Op(elementwise_mul), inputs:{X[data0], Y[data1]}, outputs:{Out[tmp_8]}), inputs:{data0, data1}, outputs:{tmp_8}.
    Node(tmp_8{32x128}), inputs:{elementwise_mul}, outputs:{mul, cast}
  Node(Op(cast), inputs:{X[tmp_8]}, outputs:{Out[cast_2.tmp_0]}), inputs:{tmp_8}, outputs:{cast_2.tmp_0}.
    Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{fusion_group}
}


I0305 09:32:55.480564 65282 fusion_group_pass.cc:54] subgraph: {
    Node(data0{32x128}), inputs:{}, outputs:{elementwise_mul}
    Node(data1{32x128}), inputs:{}, outputs:{elementwise_mul}
  Node(Op(elementwise_mul), inputs:{X[data0], Y[data1]}, outputs:{Out[tmp_8]}), inputs:{data0, data1}, outputs:{tmp_8}.
    Node(tmp_8{32x128}), inputs:{elementwise_mul}, outputs:{mul, cast}
  Node(Op(cast), inputs:{X[tmp_8]}, outputs:{Out[cast_2.tmp_0]}), inputs:{tmp_8}, outputs:{cast_2.tmp_0}.
    Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{fusion_group}
}
I0305 09:32:55.480883 65282 code_generator.cc:278] Encoding input names:data0, id:0
I0305 09:32:55.480938 65282 code_generator.cc:278] Encoding input names:data1, id:1
I0305 09:32:55.480973 65282 code_generator.cc:311] Ecoding output names:tmp_8, id:2
I0305 09:32:55.481007 65282 code_generator.cc:311] Ecoding output names:cast_2.tmp_0, id:3
I0305 09:32:55.481086 65282 code_generator.cc:263] Op(elementwise_mul), inputs:{0,1}, outputs:{2}
I0305 09:32:55.481150 65282 code_generator.cc:263] Op(cast), inputs:{2}, outputs:{3}
I0305 09:32:55.481212 65282 fusion_group_pass.cc:70]
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }

#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))

struct __align__(2) __half {
  __device__ __half() { }

 protected:
  unsigned short __x;
};

__device__ __half __float2half(const float f) {
  __half val;
  asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val)

) : "f"(f));
  return val;
}

__device__ float __half2float(const __half h) {
  float val;
  asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
  return val;
}

#undef __HALF_TO_US
#undef __HALF_TO_CUS

typedef __half float16;


extern "C" __global__ void fused_elementwise_7(int N, float* arg0, float* arg1, float* arg2, float16* arg3) {
  for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
      idx < N;
      idx += gridDim.x * blockDim.x) {
    float tmp0 = arg0[idx];
    float tmp1 = arg1[idx];
    float tmp2 = tmp0 * tmp1;
    float16 tmp3 = __float2half(tmp2);
    arg2[idx] = tmp2;
    arg3[idx] = tmp3;
  }
}
指派人
分配到
审核者
Request review from
无
里程碑
无
分配里程碑
工时统计
标识: paddlepaddle/Paddle!22876
Source branch: github/fork/wangchaochaohu/cast_fusion
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7