Created by: wangchaochaohu
1添加fusion group 中的cast实现方式(通过添加fusion group多type 类型进行实现) 生成代码如下
I0305 09:32:55.299682 65282 subgraph_detector.cc:314] 6 node not a trt candidate.
Node(mul_2.tmp_0{32x128}), inputs:{mul}, outputs:{cast} {
Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{elementwise_add}
Node(Op(cast), inputs:{X[mul_2.tmp_0]}, outputs:{Out[cast_3.tmp_0]}), inputs:{mul_2.tmp_0}, outputs:{cast_3.tmp_0}.
Node(cast_3.tmp_0{32x128}), inputs:{cast}, outputs:{elementwise_add}
Node(Op(elementwise_add), inputs:{X[cast_2.tmp_0], Y[cast_3.tmp_0]}, outputs:{Out[tmp_9]}), inputs:{cast_2.tmp_0, cast_3.tmp_0}, outputs:{tmp_9}.
Node(tmp_9{32x128}), inputs:{elementwise_add}, outputs:{relu}
Node(Op(relu), inputs:{X[tmp_9]}, outputs:{Out[relu_3.tmp_0]}), inputs:{tmp_9}, outputs:{relu_3.tmp_0}.
Node(relu_3.tmp_0{32x128}), inputs:{relu}, outputs:{cast}
Node(Op(cast), inputs:{X[relu_3.tmp_0]}, outputs:{Out[cast_4.tmp_0]}), inputs:{relu_3.tmp_0}, outputs:{cast_4.tmp_0}.
Node(cast_4.tmp_0{32x128}), inputs:{cast}, outputs:{}
}
}0305 09:32:55.299979 65282 code_generator.cc:278] Encoding input names:mul_2.tmp_0, id:0
I0305 09:32:55.300010 65282 code_generator.cc:278] Encoding input names:cast_2.tmp_0, id:1
I0305 09:32:55.300045 65282 code_generator.cc:311] Ecoding output names:cast_3.tmp_0, id:2
I0305 09:32:55.300074 65282 code_generator.cc:311] Ecoding output names:tmp_9, id:3
I0305 09:32:55.300101 65282 code_generator.cc:311] Ecoding output names:relu_3.tmp_0, id:4
I0305 09:32:55.300135 65282 code_generator.cc:311] Ecoding output names:cast_4.tmp_0, id:5
I0305 09:32:55.300205 65282 code_generator.cc:263] Op(cast), inputs:{0}, outputs:{2}
I0305 09:32:55.300249 65282 code_generator.cc:263] Op(elementwise_add), inputs:{1,2}, outputs:{3}
I0305 09:32:55.300294 65282 code_generator.cc:263] Op(relu), inputs:{3}, outputs:{4}
I0305 09:32:55.300338 65282 code_generator.cc:263] Op(cast), inputs:{4}, outputs:{5}
I0305 09:32:55.300415 65282 fusion_group_pass.cc:70]
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
struct __align__(2) __half {
__device__ __half() { }
protected:
unsigned short __x;
};
__device__ __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val)
) : "f"(f));
return val;
}
__device__ float __half2float(const __half h) {
float val;
asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
#undef __HALF_TO_US
#undef __HALF_TO_CUS
typedef __half float16;
extern "C" __global__ void fused_elementwise_6(int N, float16* arg0, float* arg1, float16* arg2, float16* arg3, float16* arg4, float* arg5) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
float16 tmp0 = arg0[idx];
float tmp1 = arg1[idx];
float16 tmp2 = __float2half(tmp1);
float half2fp32_tmp0 = __half2float(tmp0);
float half2fp32_tmp2 = __half2float(tmp2);
float16 tmp3 = __float2half(half2fp32_tmp0 + half2fp32_tmp2);
float half2fp32_tmp3 = __half2float(tmp3);
float16 tmp4 = __float2half(half2fp32_tmp3 > 0 ? half2fp32_tmp3 : 0);
float half2fp32_tmp4 = __half2float(tmp4);
float tmp5 = half2fp32_tmp4;
arg2[idx] = tmp2;
arg3[idx] = tmp3;
arg4[idx] = tmp4;
arg5[idx] = tmp5;
}
}
I0305 09:32:55.480564 65282 fusion_group_pass.cc:54] subgraph: {
Node(data0{32x128}), inputs:{}, outputs:{elementwise_mul}
Node(data1{32x128}), inputs:{}, outputs:{elementwise_mul}
Node(Op(elementwise_mul), inputs:{X[data0], Y[data1]}, outputs:{Out[tmp_8]}), inputs:{data0, data1}, outputs:{tmp_8}.
Node(tmp_8{32x128}), inputs:{elementwise_mul}, outputs:{mul, cast}
Node(Op(cast), inputs:{X[tmp_8]}, outputs:{Out[cast_2.tmp_0]}), inputs:{tmp_8}, outputs:{cast_2.tmp_0}.
Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{fusion_group}
}
I0305 09:32:55.480564 65282 fusion_group_pass.cc:54] subgraph: {
Node(data0{32x128}), inputs:{}, outputs:{elementwise_mul}
Node(data1{32x128}), inputs:{}, outputs:{elementwise_mul}
Node(Op(elementwise_mul), inputs:{X[data0], Y[data1]}, outputs:{Out[tmp_8]}), inputs:{data0, data1}, outputs:{tmp_8}.
Node(tmp_8{32x128}), inputs:{elementwise_mul}, outputs:{mul, cast}
Node(Op(cast), inputs:{X[tmp_8]}, outputs:{Out[cast_2.tmp_0]}), inputs:{tmp_8}, outputs:{cast_2.tmp_0}.
Node(cast_2.tmp_0{32x128}), inputs:{cast}, outputs:{fusion_group}
}
I0305 09:32:55.480883 65282 code_generator.cc:278] Encoding input names:data0, id:0
I0305 09:32:55.480938 65282 code_generator.cc:278] Encoding input names:data1, id:1
I0305 09:32:55.480973 65282 code_generator.cc:311] Ecoding output names:tmp_8, id:2
I0305 09:32:55.481007 65282 code_generator.cc:311] Ecoding output names:cast_2.tmp_0, id:3
I0305 09:32:55.481086 65282 code_generator.cc:263] Op(elementwise_mul), inputs:{0,1}, outputs:{2}
I0305 09:32:55.481150 65282 code_generator.cc:263] Op(cast), inputs:{2}, outputs:{3}
I0305 09:32:55.481212 65282 fusion_group_pass.cc:70]
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
#define __HALF_TO_US(var) *(reinterpret_cast<unsigned short *>(&(var)))
#define __HALF_TO_CUS(var) *(reinterpret_cast<const unsigned short *>(&(var)))
struct __align__(2) __half {
__device__ __half() { }
protected:
unsigned short __x;
};
__device__ __half __float2half(const float f) {
__half val;
asm("{ cvt.rn.f16.f32 %0, %1; }\n" : "=h"(__HALF_TO_US(val)
) : "f"(f));
return val;
}
__device__ float __half2float(const __half h) {
float val;
asm("{ cvt.f32.f16 %0, %1; }\n" : "=f"(val) : "h"(__HALF_TO_CUS(h)));
return val;
}
#undef __HALF_TO_US
#undef __HALF_TO_CUS
typedef __half float16;
extern "C" __global__ void fused_elementwise_7(int N, float* arg0, float* arg1, float* arg2, float16* arg3) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
float tmp0 = arg0[idx];
float tmp1 = arg1[idx];
float tmp2 = tmp0 * tmp1;
float16 tmp3 = __float2half(tmp2);
arg2[idx] = tmp2;
arg3[idx] = tmp3;
}
}