Created by: Xreki
PR types
Others
PR changes
Others
Describe
支持子图中,有stop_gradient=True
的情况。目前已经可以成功匹配如下子图:
I0417 09:27:16.836542 8037 fusion_group_pass.cc:56] subgraph: {
Node(mean_0.tmp_0{1}), inputs:{mean}, outputs:{elementwise_mul, elementwise_mul_grad}
Node(loss_scaling_0{1}), inputs:{}, outputs:{elementwise_mul, elementwise_mul_grad, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, elementwise_div, conditional_block}
Node(Op(fill_constant), inputs:{}, outputs:{Out[tmp_3@GRAD]}), inputs:{}, outputs:{tmp_3@GRAD}.
Node(tmp_3@GRAD{1}), inputs:{fill_constant}, outputs:{elementwise_mul_grad}
Node(Op(elementwise_mul_grad), inputs:{Out@GRAD[tmp_3@GRAD], X[mean_0.tmp_0], Y[loss_scaling_0]}, outputs:{X@GRAD[mean_0.tmp_0@GRAD], Y@GRAD[]}), inputs:{tmp_3@GRAD, mean_0.tmp_0, loss_scaling_0}, outputs:{mean_0.tmp_0@GRAD, __control_var@4160}.
Node(mean_0.tmp_0@GRAD{1}), inputs:{elementwise_mul_grad}, outputs:{mean_grad}
Node(__control_var@4160), inputs:{elementwise_mul_grad}, outputs:{conditional_block}
}
但发现fill_constant生成的代码不对:
__device__ inline float Max(float x, float y) { return fmaxf(x, y); }
__device__ inline float Exp(float x) { return expf(x); }
__device__ inline float Log(float x) { return logf(x); }
__device__ inline float Sqrt(float x) { return sqrtf(x); }
extern "C" __global__ void FusedElementwise106(int N, float* arg0, float* arg1, float* arg2, float* arg3) {
for(int idx = blockIdx.x * blockDim.x + threadIdx.x;
idx < N;
idx += gridDim.x * blockDim.x) {
float tmp1 = arg1[idx];
float tmp2 = static_cast<float>();
float tmp3 = tmp2 * tmp1;
arg2[idx] = tmp2;
arg3[idx] = tmp3;
}
}
TODO:stop_gradient=True
时单测中计算结果对不上,需要进一步排查。