Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
effebd41
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
effebd41
编写于
6月 01, 2023
作者:
R
ronnywang
提交者:
GitHub
6月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[ROCM] fix multihead_matmul (#54108)
* [ROCM] fix multihead_matmul * skip bf16 uts * update
上级
2186fe16
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
135 addition
and
91 deletion
+135
-91
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+65
-54
paddle/fluid/pybind/place.cc
paddle/fluid/pybind/place.cc
+10
-2
paddle/phi/kernels/funcs/math_cuda_utils.h
paddle/phi/kernels/funcs/math_cuda_utils.h
+47
-30
test/legacy_test/test_activation_op.py
test/legacy_test/test_activation_op.py
+8
-4
test/legacy_test/test_scale_op.py
test/legacy_test/test_scale_op.py
+3
-0
test/legacy_test/test_softmax_op.py
test/legacy_test/test_softmax_op.py
+2
-1
未找到文件。
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
effebd41
...
...
@@ -261,9 +261,9 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
const
phi
::
funcs
::
warp_mask_t
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float
tmp
=
threadIdx
.
x
<
seq_len
?
static_cast
<
float
>
(
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
+
...
...
@@ -281,15 +281,16 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template
<
>
__global__
void
SoftmaxKernelWithEltadd
<
half
>
(
half
*
qk_buf_
,
const
half
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltadd
<
half
>
(
half
*
qk_buf_
,
const
half
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float
tmp
=
threadIdx
.
x
<
seq_len
?
static_cast
<
float
>
(
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
+
...
...
@@ -312,10 +313,10 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
const
phi
::
funcs
::
warp_mask_t
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
int
idx
=
threadIdx
.
x
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float2
tmp
=
idx
<
seq_len
?
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
idx
+
qk_offset
]
+
...
...
@@ -335,19 +336,20 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
}
template
<
>
__global__
void
SoftmaxKernelWithEltadd2
<
half2
>
(
half2
*
qk_buf_
,
const
half2
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltadd2
<
half2
>
(
half2
*
qk_buf_
,
const
half2
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
int
idx
=
threadIdx
.
x
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float2
tmp
=
idx
<
seq_len
?
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
idx
+
qk_offset
]
+
...
...
@@ -368,14 +370,15 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
}
template
<
typename
T
>
__global__
void
SoftmaxKernelWithEltaddForLarge
(
T
*
qk_buf
,
const
T
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltaddForLarge
(
T
*
qk_buf
,
const
T
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
T
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -406,15 +409,16 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
// HIP defined __HIP_NO_HALF_CONVERSIONS__
#ifndef __HIPCC__ // @{ Half kernel: SoftmaxKernelWithEltadd
template
<
>
__global__
void
SoftmaxKernelWithEltaddForLarge
(
half
*
qk_buf
,
const
half
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltaddForLarge
(
half
*
qk_buf
,
const
half
*
bias_qk
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
#if CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float
stride_max
=
-
1e20
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -444,14 +448,15 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
#endif // @} End Half kernel: SoftmaxKernelWithEltadd
template
<
typename
T
>
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
T
*
qk_buf_
,
const
T
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
T
*
qk_buf_
,
const
T
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -484,19 +489,20 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
}
template
<
>
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
half2
*
qk_buf_
,
const
half2
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
unsigned
mask
)
{
__global__
void
SoftmaxKernelWithEltaddForLarge2
(
half2
*
qk_buf_
,
const
half2
*
bias_qk_
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len
,
const
phi
::
funcs
::
warp_mask_t
mask
)
{
// operator "+" of half only suppotted after cuda version 10.0
// HIP defined __HIP_NO_HALF_CONVERSIONS__ in hip.cmake
#if defined(PADDLE_WITH_CUDA) && \
(CUDA_ARCH_FP16_SUPPORTED(__CUDA_ARCH__) && CUDA_VERSION >= 10000)
int
qk_offset
=
blockIdx
.
x
*
seq_len
;
assert
(
blockDim
.
x
%
32
==
0
);
assert
(
blockDim
.
x
%
WARP_SIZE
==
0
);
float2
stride_max
=
make_float2
(
-
1e20
f
,
-
1e20
f
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -637,7 +643,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
}
if
(
blockDim
.
x
<=
32
)
{
if
(
blockDim
.
x
<=
WARP_SIZE
)
{
phi
::
funcs
::
WarpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
else
{
phi
::
funcs
::
BlockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
...
...
@@ -672,7 +678,7 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
}
if
(
blockDim
.
x
<=
32
)
{
if
(
blockDim
.
x
<=
WARP_SIZE
)
{
phi
::
funcs
::
WarpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
else
{
phi
::
funcs
::
BlockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
...
...
@@ -761,7 +767,10 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
// Align block to 32, also limit seq_len to max block size.
if
(
seq_len
%
2
==
0
)
{
block
=
(
seq_len
<=
64
)
?
32
:
((
seq_len
+
63
)
/
64
)
*
32
;
block
=
(
seq_len
<=
(
2
*
WARP_SIZE
))
?
WARP_SIZE
:
((
seq_len
+
(
2
*
WARP_SIZE
-
1
))
/
(
2
*
WARP_SIZE
))
*
WARP_SIZE
;
if
(
std
::
is_same
<
T
,
float
>::
value
)
{
SoftmaxKernelWithEltadd2
<
float2
><<<
grid
,
block
,
0
,
stream
>>>
(
reinterpret_cast
<
float2
*>
(
qk_buf_
),
...
...
@@ -780,7 +789,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"cuda_arch<700"
));
#else
dim3
grid
(
seq_len
,
batch_size
,
head_num
);
dim3
block
((
seq_len
/
2
+
31
)
/
32
*
32
);
dim3
block
((
seq_len
/
2
+
WARP_SIZE
-
1
)
/
WARP_SIZE
*
WARP_SIZE
);
SOFTMAX_KERNEL_WITH_MASK
(
1
);
#endif
}
else
{
...
...
@@ -794,7 +803,9 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
}
}
}
else
{
block
=
(
seq_len
<=
32
)
?
32
:
((
seq_len
+
31
)
/
32
)
*
32
;
block
=
(
seq_len
<=
WARP_SIZE
)
?
WARP_SIZE
:
((
seq_len
+
WARP_SIZE
-
1
)
/
WARP_SIZE
)
*
WARP_SIZE
;
SoftmaxKernelWithEltadd
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
qk_buf_
,
bias_qk
,
batch_size
,
head_num
,
seq_len
,
FINAL_MASK
);
}
...
...
@@ -820,7 +831,7 @@ inline void MatMulWithHeadQK(const phi::GPUContext &context,
"cuda_arch<700"
));
#else
dim3
grid
(
seq_len
,
batch_size
,
head_num
);
dim3
block
((
seq_len
/
2
+
31
)
/
32
*
32
);
dim3
block
((
seq_len
/
2
+
WARP_SIZE
-
1
)
/
WARP_SIZE
*
WARP_SIZE
);
if
(
block
.
x
>
0
&&
block
.
x
<=
1024
)
{
SOFTMAX_KERNEL_WITH_MASK
(
1
);
}
else
if
(
block
.
x
<=
2048
)
{
...
...
@@ -1176,8 +1187,8 @@ void SkipLayerNormFunctor<T>::operator()(const int num,
float
eps
,
gpuStream_t
stream
)
{
int
block
=
num
/
hidden
;
if
(
hidden
<=
32
)
{
const
int
threads
=
32
;
if
(
hidden
<=
WARP_SIZE
)
{
const
int
threads
=
WARP_SIZE
;
SkipLayerNormSmallKernel
<
T
,
threads
><<<
block
,
threads
,
0
,
stream
>>>
(
num
,
hidden
,
input1
,
input2
,
output
,
scale
,
bias
,
eps
);
}
else
if
(
hidden
<=
128
)
{
...
...
paddle/fluid/pybind/place.cc
浏览文件 @
effebd41
...
...
@@ -374,12 +374,20 @@ void BindPlace(pybind11::module &m) { // NOLINT
.
def
(
"__str__"
,
string
::
to_string
<
const
platform
::
CUDAPlace
&>
);
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
m
.
def
(
"is_float16_supported"
,
[](
const
platform
::
CUDAPlace
&
place
)
->
bool
{
// Only GPUs with Compute Capability >= 53 support float16
// Only GPUs with Compute Capability >= 53 support float16
#ifdef PADDLE_WITH_HIP
return
true
;
#else
return
platform
::
GetGPUComputeCapability
(
place
.
device
)
>=
53
;
#endif
});
m
.
def
(
"is_bfloat16_supported"
,
[](
const
platform
::
CUDAPlace
&
place
)
->
bool
{
// Only GPUs with Compute Capability >= 80 support bfloat16
// Only GPUs with Compute Capability >= 80 support bfloat16
#ifdef PADDLE_WITH_HIP
return
false
;
#else
return
platform
::
GetGPUComputeCapability
(
place
.
device
)
>=
80
;
#endif
});
#endif
py
::
class_
<
platform
::
XPUPlace
>
xpuplace
(
m
,
"XPUPlace"
,
R"DOC(
...
...
paddle/phi/kernels/funcs/math_cuda_utils.h
浏览文件 @
effebd41
...
...
@@ -163,12 +163,28 @@ struct KeyValuePair<half> {
}
};
// NOTE(wangran16): The warpSize variable is of type int and contains the warp
// size (in threads) for the target device. Note that all current NVIDIA devices
// return 32 for this variable, and all current AMD devices return 64. Device
// code should use the warpSize built-in to develop portable wave-aware code.
#ifdef PADDLE_WITH_HIP
#define FINAL_MASK 0xffffffffffffffffUL
#define HALF_WARP 32
#define WARP_SIZE 64
#define WARP_SIZE_WIDTH 6
#define WARP_SIZE_WIDTH_MASK 0x3f
typedef
u_int64_t
warp_mask_t
;
#else
#define FINAL_MASK 0xffffffff
#define HALF_WARP 16
#define WARP_SIZE 32
#define WARP_SIZE_WIDTH 5
#define WARP_SIZE_WIDTH_MASK 0x1f
typedef
unsigned
warp_mask_t
;
#endif
template
<
typename
T
>
__inline__
__device__
T
WarpReduceSum
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
WarpReduceSum
(
T
val
,
warp_mask_t
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
+=
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
);
...
...
@@ -180,10 +196,10 @@ __inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
BlockReduceSum
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
BlockReduceSum
(
T
val
,
warp_mask_t
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
val
=
WarpReduceSum
<
T
>
(
val
,
mask
);
...
...
@@ -193,7 +209,7 @@ __inline__ __device__ T BlockReduceSum(T val, unsigned mask) {
__syncthreads
();
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
WARP_SIZE_WIDTH
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
static_cast
<
T
>
(
0.0
f
);
val
=
WarpReduceSum
<
T
>
(
val
,
mask
);
...
...
@@ -208,8 +224,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) {
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
WARP_SIZE
);
}
return
(
T
)(
0.0
f
);
}
...
...
@@ -217,8 +233,8 @@ __inline__ __device__ T WarpReduceSumV2(T *val) {
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
BlockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
WarpReduceSumV2
<
T
,
NUM
>
(
val
);
...
...
@@ -231,7 +247,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
static_cast
<
float
>
(
WARP_SIZE
)
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
...
...
@@ -241,7 +257,7 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
}
template
<
typename
T
>
__inline__
__device__
T
WarpReduceMax
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
WarpReduceMax
(
T
val
,
warp_mask_t
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
max
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
...
...
@@ -256,14 +272,15 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) {
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
));
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
WARP_SIZE
));
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
>
__inline__
__device__
T
WarpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
WarpReduceMin
(
T
val
,
warp_mask_t
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
min
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
...
...
@@ -276,7 +293,7 @@ __inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the minimum of all elements in a warp when actual quantity of
* threads are less than warpSize.*/
template
<
typename
T
>
__inline__
__device__
T
PartialWarpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
PartialWarpReduceMin
(
T
val
,
warp_mask_t
lane_mask
)
{
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
T
warp_val
=
__shfl_sync
(
lane_mask
,
val
,
0
,
warpSize
);
#else
...
...
@@ -297,10 +314,10 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the maximum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
BlockReduceMax
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
BlockReduceMax
(
T
val
,
warp_mask_t
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
val
=
WarpReduceMax
(
val
,
mask
);
...
...
@@ -309,7 +326,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
__syncthreads
();
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
WARP_SIZE_WIDTH
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
-
1e10
f
;
val
=
WarpReduceMax
(
val
,
mask
);
...
...
@@ -318,9 +335,9 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
BlockReduceMaxV2
(
T
*
val
)
{
static
__shared__
T
shared
[
32
][
NUM
];
int
lane
=
threadIdx
.
x
&
0x1f
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
5
;
// warp idx
static
__shared__
T
shared
[
WARP_SIZE
][
NUM
];
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
// warp idx
WarpReduceMaxV2
<
T
,
NUM
>
(
val
);
// get maxx in each warp
...
...
@@ -335,7 +352,7 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) {
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
static_cast
<
float
>
(
WARP_SIZE
)
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
lane
][
i
]
:
(
T
)
-
1e20
f
;
...
...
@@ -347,17 +364,17 @@ __inline__ __device__ T BlockReduceMaxV2(T *val) {
/* Calculate the minimum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
BlockReduceMin
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
BlockReduceMin
(
T
val
,
warp_mask_t
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
val
=
WarpReduceMin
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
WARP_SIZE_WIDTH
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
1e10
f
;
val
=
WarpReduceMin
(
val
,
mask
);
...
...
@@ -367,11 +384,11 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) {
/* Calculate the minimum of all elements in a warp when actual quantity of
* threads are less than warpSize.*/
template
<
typename
T
>
__inline__
__device__
T
PartialBlockReduceMin
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
PartialBlockReduceMin
(
T
val
,
warp_mask_t
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
static
__shared__
T
min_value
;
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
int
lane
=
threadIdx
.
x
&
WARP_SIZE_WIDTH_MASK
;
int
wid
=
threadIdx
.
x
>>
WARP_SIZE_WIDTH
;
val
=
PartialWarpReduceMin
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
...
...
test/legacy_test/test_activation_op.py
浏览文件 @
effebd41
...
...
@@ -278,7 +278,8 @@ class TestSigmoid_ZeroDim(TestSigmoid):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
,
)
class
TestSigmoidBF16
(
OpTest
):
def
setUp
(
self
):
...
...
@@ -1237,7 +1238,8 @@ class TestSqrt_ZeroDim(TestSqrt):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
,
)
class
TestSqrtBF16
(
OpTest
):
def
setUp
(
self
):
...
...
@@ -3060,7 +3062,8 @@ class TestSquare_ZeroDim(TestSquare):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
,
)
class
TestSquareBF16
(
OpTest
):
def
setUp
(
self
):
...
...
@@ -3350,7 +3353,8 @@ class TestSoftplus_ZeroDim(TestSoftplus):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
,
)
class
TestSoftplusBF16
(
OpTest
):
def
setUp
(
self
):
...
...
test/legacy_test/test_scale_op.py
浏览文件 @
effebd41
...
...
@@ -154,6 +154,9 @@ class TestScaleFp16Op(TestScaleOp):
self
.
check_grad
([
"X"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
)
class
TestScaleBF16Op
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"scale"
...
...
test/legacy_test/test_softmax_op.py
浏览文件 @
effebd41
...
...
@@ -392,7 +392,8 @@ class TestSoftmaxFP16CUDNNOp2(TestSoftmaxFP16CUDNNOp):
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
not
core
.
is_compiled_with_cuda
()
or
core
.
is_compiled_with_rocm
(),
"core is not compiled with CUDA"
,
)
class
TestSoftmaxBF16Op
(
OpTest
):
def
setUp
(
self
):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录