Created by: wangchaochaohu
fusion group now support simple elementwise OP. In this PR, we add sum op support for fusion group.
- add codegen for sum op in a special way
- add support of sum fusion in fusion group detection pattern
- add some unittest for sum fusion
W0228 04:12:07.934985 35030 device_context.cc:237] Please NOTE: device: 0, CUDA Capability: 70, Driver API Version: 10.1, Runtime API Version: 9.0
W0228 04:12:07.940497 35030 device_context.cc:245] device: 0, cuDNN Version: 7.5.
I0228 04:12:09.940064 35030 fusion_group_pass.cc:56] origin: Graph: {
Node(sum_1.tmp_0{32x128}), inputs:{sum}, outputs:{}
Node(elementwise_add_0{32x128}), inputs:{elementwise_add}, outputs:{sum}
Node(data0{32x128}), inputs:{}, outputs:{elementwise_add}
Node(data1{32x128}), inputs:{}, outputs:{elementwise_add}
Node(Op(elementwise_add), inputs:{X[data0], Y[data1]}, outputs:{Out[elementwise_add_0]}), inputs:{data0, data1}, outputs:{elementwise_add_0}.
Node(sum_0.tmp_0{32x128}), inputs:{sum}, outputs:{sum}
Node(data3{32x128}), inputs:{}, outputs:{sum}
Node(data2{32x128}), inputs:{}, outputs:{sum}
Node(Op(sum), inputs:{X[elementwise_add_0, data2, data3]}, outputs:{Out[sum_0.tmp_0]}), inputs:{elementwise_add_0, data2, data3}, outputs:{sum_0.tmp_0}.
Node(Op(sum), inputs:{X[sum_0.tmp_0, data4]}, outputs:{Out[sum_1.tmp_0]}), inputs:{sum_0.tmp_0, data4}, outputs:{sum_1.tmp_0}.
Node(data4{32x128}), inputs:{}, outputs:{sum}
}
I0228 04:12:10.145560 35030 fusion_group_pass.cc:60] after code gen: Graph: {
Node(sum_1.tmp_0{32x128}), inputs:{sum}, outputs:{}
Node(elementwise_add_0{32x128}), inputs:{elementwise_add}, outputs:{sum}
Node(data0{32x128}), inputs:{}, outputs:{elementwise_add}
Node(data1{32x128}), inputs:{}, outputs:{elementwise_add}
Node(Op(elementwise_add), inputs:{X[data0], Y[data1]}, outputs:{Out[elementwise_add_0]}), inputs:{data0, data1}, outputs:{elementwise_add_0}.
Node(sum_0.tmp_0{32x128}), inputs:{sum}, outputs:{sum}
Node(data3{32x128}), inputs:{}, outputs:{sum}
Node(data2{32x128}), inputs:{}, outputs:{sum}
Node(Op(sum), inputs:{X[elementwise_add_0, data2, data3]}, outputs:{Out[sum_0.tmp_0]}), inputs:{elementwise_add_0, data2, data3}, outputs:{sum_0.tmp_0}.
Node(Op(sum), inputs:{X[sum_0.tmp_0, data4]}, outputs:{Out[sum_1.tmp_0]}), inputs:{sum_0.tmp_0, data4}, outputs:{sum_1.tmp_0}.
Node(data4{32x128}), inputs:{}, outputs:{sum}
}
I0228 04:12:10.145742 35030 fusion_group_pass.cc:63] fusion group:Graph: {
Node(Op(fusion_group), inputs:{Inputs[data2, data3, data4, data0, data1]}, outputs:{Outs[elementwise_add_0, sum_0.tmp_0, sum_1.tmp_0]}), inputs:{data2, data3, data4, data0, data1}, outputs:{elementwise_add_0, sum_0.tmp_0, sum_1.tmp_0}.
Node(sum_1.tmp_0{32x128}), inputs:{fusion_group}, outputs:{}
Node(elementwise_add_0{32x128}), inputs:{fusion_group}, outputs:{}
Node(data0{32x128}), inputs:{}, outputs:{fusion_group}
Node(data1{32x128}), inputs:{}, outputs:{fusion_group}
Node(sum_0.tmp_0{32x128}), inputs:{fusion_group}, outputs:{}
Node(data3{32x128}), inputs:{}, outputs:{fusion_group}
Node(data2{32x128}), inputs:{}, outputs:{fusion_group}
Node(data4{32x128}), inputs:{}, outputs:{fusion_group}
}
Gen Code
__device__ inline float real_exp(float x) { return ::expf(x); }
__device__ inline float real_log(float x) { return ::logf(x); }
extern "C" __global__ void fused_elementwise_0(int N, float* arg0, float* arg1, float* arg2, float* arg3, float* arg4, float* arg5, float* arg6, float* arg7) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
float tmp0 = arg0[idx];
float tmp1 = arg1[idx];
float tmp2 = arg2[idx];
float tmp3 = arg3[idx];
float tmp4 = arg4[idx];
float tmp5 = tmp3 + tmp4;
float tmp6 = tmp5 + tmp0 + tmp1;
float tmp7 = tmp6 + tmp2;
arg5[idx] = tmp5;
arg6[idx] = tmp6;
arg7[idx] = tmp7;
}
}