未验证 提交 a8aa7913 编写于 作者: W wopeizl 提交者: GitHub

Merge pull request #15453 from wopeizl/fix15313

fix pr 15313
......@@ -21,20 +21,20 @@ namespace operators {
enum GroupNormKernelFlags { kHasScale = 1, kHasBias = 2 };
#define CHECK_CASE(i, flags, kernel_name, args...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(args); \
#define CHECK_CASE(i, flags, kernel_name, ...) \
if (i == flags) { \
kernel_name<T, i><<<grid, threads, 0, dev_ctx.stream()>>>(__VA_ARGS__); \
}
// 0 for no scale, no bias
// 1 for has scale, no bias
// 2 for no scale, has bias
// 3 for has scale, has bias
#define UNROLL_ALL_CASES(flags, kernel_name, args...) \
CHECK_CASE(0, flags, kernel_name, args) \
CHECK_CASE(1, flags, kernel_name, args) \
CHECK_CASE(2, flags, kernel_name, args) \
CHECK_CASE(3, flags, kernel_name, args)
#define UNROLL_ALL_CASES(flags, kernel_name, ...) \
CHECK_CASE(0, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(1, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(2, flags, kernel_name, __VA_ARGS__) \
CHECK_CASE(3, flags, kernel_name, __VA_ARGS__)
template <typename T>
__device__ __inline__ void CudaAtomicAddWithWarp(T* sum, T value) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册