提交 e6a3a3a3 编写于 作者: P peizhilin

fix pr 15313

test=develop
上级 59ab98c9
...@@ -21,20 +21,20 @@ namespace operators { ...@@ -21,20 +21,20 @@ namespace operators {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 }; enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define CHECK_CASE(i, flags, kernel_name, args...) \ #define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \ if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \ kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
} }
// 0 for no scale, no bias // 0 for no scale, no bias
// 1 for has scale, no bias // 1 for has scale, no bias
// 2 for no scale, has bias // 2 for no scale, has bias
// 3 for has scale, has bias // 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \ #define UNROLL_ALL_CASES(flags, kernel_name, ...) \
CHECK_CASE(0, flags, kernel_name, args) \ CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(1, flags, kernel_name, args) \ CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(2, flags, kernel_name, args) \ CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(3, flags, kernel_name, args) CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
template <typename T> template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) { __device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册