提交 621ae0a1 编写于 作者: M Megvii Engine Team

fix(dnn): replace kernel launch syntax with macro for hcc

GitOrigin-RevId: f9e69d48257e7cd90f96ee75b8c0428cff470025
上级 78fff72a
......@@ -204,7 +204,7 @@ namespace megdnn {
DEF_KERN_FLOAT(ATAN2, atan2f(x, y));
DEF_KERN_FLOAT(H_SWISH_GRAD,
x < -3.f ? 0.f : (x > 3.f ? y : (2.f * x + 3.f) / 6.f * y));
x < -3.f ? (ctype)0.f : (ctype)(x > 3.f ? (ctype)y : (ctype)((2.f * x + 3.f) / 6.f * y)));
DEF_KERN_FLOAT(FUSE_ADD_H_SWISH, fuse_add_hswish(x, y));
#undef KERN_SIG
......
......@@ -147,7 +147,7 @@ void chanwise::run_bwd_data(T* src_grad, const T* dst_grad, const T* flt,
dim3 nr_block(param.src_chl,
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
kern<<<nr_block, nr_thread, shared, stream>>>(src_grad, dst_grad, flt,
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, src_grad, dst_grad, flt,
param);
after_kernel_launch();
}
......
......@@ -105,7 +105,7 @@ void chanwise::run_fwd(T* dst, const T* src, const T* flt, const Param& param,
dim3 nr_block(param.src_chl,
std::min(512, max(nr_out_dimx / (nr_thread * 4), 1)));
uint32_t shared = param.chl_mul * param.flt_h * param.flt_w * sizeof(T);
kern<<<nr_block, nr_thread, shared, stream>>>(dst, src, flt, param);
hipLaunchKernelGGL(kern, nr_block, nr_thread, shared, stream, dst, src, flt, param);
after_kernel_launch();
}
......
......@@ -314,7 +314,7 @@ void convolution::exec_inplace_matmul_fwd(
} else { \
kptr = conv_kernel<BY, BX, false, BufferFetcherTexture>; \
} \
kptr<<<blocks, threads, 0, stream>>>( \
hipLaunchKernelGGL(kptr, blocks, threads, 0, stream, \
src_tex.val, filter_tex.val, dst, INP_BS, OUT_BS, IC, IH, \
IW, OC, OH, OW, FH, FW, SH, SW, PH, PW); \
} else { \
......@@ -324,7 +324,7 @@ void convolution::exec_inplace_matmul_fwd(
} else { \
kptr = conv_kernel<BY, BX, false, BufferFetcherRaw>; \
} \
kptr<<<blocks, threads, 0, stream>>>( \
hipLaunchKernelGGL(kptr, blocks, threads, 0, stream, \
src_buf, filter_buf, dst, INP_BS, OUT_BS, IC, IH, IW, OC, \
OH, OW, FH, FW, SH, SW, PH, PW); \
} \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册