未验证 提交 2fe896df 编写于 作者: W wenbin 提交者: GitHub

Compile fix (#49690)

* compile fix

* fix compile

* compile fix
上级 6578da51
......@@ -155,7 +155,14 @@ __global__ void prelnGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
// int64_t offsetY = static_cast<int64_t>(ni) * params.c + ci;
__half2 y = *reinterpret_cast<__half2 const *>(&params.srcY[offset]);
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
#if __CUDA_ARCH__ >= 530
h2 = __hadd2(h2, y);
#else
float2 out{};
out.x = __half2float(h2.x) + __half2float(y.x);
out.y = __half2float(h2.y) + __half2float(y.y);
h2 = __float22half2_rn(out);
#endif
// elementwise_add
*reinterpret_cast<__half2 *>(&params.eleOut[offset]) = h2;
}
......
......@@ -167,7 +167,14 @@ __global__ void skipGroupNormNHWCSumKernel(GroupNormNHWCParams params) {
int64_t offsetY = static_cast<int64_t>(ni) * params.c + ci;
__half2 y = *reinterpret_cast<__half2 const *>(&params.srcY[offsetY]);
h2 = *reinterpret_cast<__half2 const *>(&params.srcX[offset]);
#if __CUDA_ARCH__ >= 530
h2 = __hadd2(h2, y);
#else
float2 out{};
out.x = __half2float(h2.x) + __half2float(y.x);
out.y = __half2float(h2.y) + __half2float(y.y);
h2 = __float22half2_rn(out);
#endif
// elementwise_add
*reinterpret_cast<__half2 *>(&params.dst[offset]) = h2;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册