Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
39210ed0
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
39210ed0
编写于
1月 10, 2023
作者:
MarDino
提交者:
GitHub
1月 10, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Refine name style and MoeKernel (#49432)
上级
c0d6ec63
变更
15
隐藏空白更改
内联
并排
Showing
15 changed file
with
338 addition
and
497 deletion
+338
-497
paddle/fluid/framework/details/nan_inf_utils_detail.cu
paddle/fluid/framework/details/nan_inf_utils_detail.cu
+3
-3
paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu
...id/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu
+2
-2
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
...d/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
+2
-36
paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu
...ference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu
+2
-2
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+20
-97
paddle/fluid/operators/optimizers/lars_momentum_op.cu
paddle/fluid/operators/optimizers/lars_momentum_op.cu
+6
-6
paddle/phi/kernels/funcs/math_cuda_utils.h
paddle/phi/kernels/funcs/math_cuda_utils.h
+92
-12
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
+13
-36
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
+183
-288
paddle/phi/kernels/gpu/dist_kernel.cu
paddle/phi/kernels/gpu/dist_kernel.cu
+3
-3
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
+6
-6
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
+1
-1
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
+2
-2
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
+1
-1
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
+2
-2
未找到文件。
paddle/fluid/framework/details/nan_inf_utils_detail.cu
浏览文件 @
39210ed0
...
...
@@ -216,9 +216,9 @@ __device__ void BlockReduceMaxMinAndWrite(const T max_value,
if
(
max_ptr
&&
min_ptr
&&
mean_ptr
)
{
__syncthreads
();
T
block_max_value
=
phi
::
funcs
::
b
lockReduceMax
<
T
>
(
max_value
,
FINAL_MASK
);
T
block_min_value
=
phi
::
funcs
::
b
lockReduceMin
<
T
>
(
min_value
,
FINAL_MASK
);
T
block_mean_value
=
phi
::
funcs
::
b
lockReduceSum
<
T
>
(
mean_value
,
FINAL_MASK
);
T
block_max_value
=
phi
::
funcs
::
B
lockReduceMax
<
T
>
(
max_value
,
FINAL_MASK
);
T
block_min_value
=
phi
::
funcs
::
B
lockReduceMin
<
T
>
(
min_value
,
FINAL_MASK
);
T
block_mean_value
=
phi
::
funcs
::
B
lockReduceSum
<
T
>
(
mean_value
,
FINAL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
max_ptr
[
offset
]
=
block_max_value
;
...
...
paddle/fluid/inference/tensorrt/plugin/merge_layernorm_op_plugin.cu
浏览文件 @
39210ed0
...
...
@@ -68,7 +68,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
mean
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
sum
,
FINAL_MASK
);
mean
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
sum
,
FINAL_MASK
);
if
(
tid
==
0
)
{
s_mean
=
mean
/
n
;
}
...
...
@@ -84,7 +84,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
variance
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
var
,
FINAL_MASK
);
variance
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
var
,
FINAL_MASK
);
if
(
tid
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
n
+
layernorm_eps
);
}
...
...
paddle/fluid/inference/tensorrt/plugin/preln_residual_bias_plugin.cu
浏览文件 @
39210ed0
...
...
@@ -26,6 +26,7 @@
#include "paddle/fluid/operators/fused/fused_layernorm_residual_dropout_bias.h"
#include "paddle/fluid/operators/layer_norm_kernel.cu.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -33,41 +34,6 @@ namespace tensorrt {
namespace
plugin
{
#ifdef TRT_PLUGIN_FP16_AVALIABLE
#define FINAL_MASK 0xffffffff
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
int
UNROLL_FACTOR
>
__global__
void
generalAddBiasResidualLayerNormOpt2
(
...
...
@@ -119,7 +85,7 @@ __global__ void generalAddBiasResidualLayerNormOpt2(
float
sums
[
2
];
sums
[
0
]
=
x_sum
;
sums
[
1
]
=
x2_sum
;
b
lockReduceSumV2
<
float
,
2
>
(
sums
);
phi
::
funcs
::
B
lockReduceSumV2
<
float
,
2
>
(
sums
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sums
[
0
]
/
n
/
2
;
...
...
paddle/fluid/inference/tensorrt/plugin/skip_merge_layernorm_op_plugin.cu
浏览文件 @
39210ed0
...
...
@@ -70,7 +70,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
mean
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
sum
,
FINAL_MASK
);
mean
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
sum
,
FINAL_MASK
);
if
(
tid
==
0
)
{
s_mean
=
mean
/
n
;
}
...
...
@@ -86,7 +86,7 @@ __global__ void merge_layernorm_v2(T *out,
}
}
variance
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
var
,
FINAL_MASK
);
variance
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
var
,
FINAL_MASK
);
if
(
tid
==
0
)
{
s_variance
=
rsqrtf
(
variance
/
n
+
layernorm_eps
);
}
...
...
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
39210ed0
...
...
@@ -269,10 +269,10 @@ __global__ void SoftmaxKernelWithEltadd(T *qk_buf_,
?
static_cast
<
float
>
(
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
qk_offset
])
:
-
1e20
f
;
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
tmp
,
mask
);
float
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
tmp
,
mask
);
float
qk_tmp
=
threadIdx
.
x
<
seq_len
?
__expf
(
tmp
-
max_val
)
:
0.0
f
;
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
qk_tmp
,
mask
);
float
sum_val
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
qk_tmp
,
mask
);
if
(
threadIdx
.
x
<
seq_len
)
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
=
(
T
)(
qk_tmp
/
sum_val
);
...
...
@@ -295,10 +295,10 @@ __global__ void SoftmaxKernelWithEltadd<half>(half *qk_buf_,
?
static_cast
<
float
>
(
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
+
bias_qk_
[
threadIdx
.
x
+
qk_offset
])
:
-
1e20
f
;
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
tmp
,
mask
);
float
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
tmp
,
mask
);
float
qk_tmp
=
threadIdx
.
x
<
seq_len
?
__expf
(
tmp
-
max_val
)
:
0.0
f
;
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
qk_tmp
,
mask
);
float
sum_val
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
qk_tmp
,
mask
);
if
(
threadIdx
.
x
<
seq_len
)
qk_buf_
[
threadIdx
.
x
+
qk_offset
]
=
(
half
)(
qk_tmp
/
sum_val
);
...
...
@@ -321,12 +321,12 @@ __global__ void SoftmaxKernelWithEltadd2(T *qk_buf_,
?
phi
::
funcs
::
ToFloat2
<
T
>
(
qk_buf_
[
idx
+
qk_offset
]
+
bias_qk_
[
idx
+
qk_offset
])
:
make_float2
(
-
1e20
f
,
-
1e20
f
);
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
max
(
tmp
.
x
,
tmp
.
y
),
mask
);
float
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
max
(
tmp
.
x
,
tmp
.
y
),
mask
);
float2
qk_tmp
=
idx
<
seq_len
?
make_float2
(
__expf
(
tmp
.
x
-
max_val
),
__expf
(
tmp
.
y
-
max_val
))
:
make_float2
(
0.
f
,
0.
f
);
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
qk_tmp
.
x
+
qk_tmp
.
y
,
mask
)
+
1e-6
f
;
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
qk_tmp
.
x
+
qk_tmp
.
y
,
mask
)
+
1e-6
f
;
if
(
idx
<
seq_len
)
{
qk_buf_
[
idx
+
qk_offset
]
=
...
...
@@ -353,12 +353,12 @@ __global__ void SoftmaxKernelWithEltadd2<half2>(half2 *qk_buf_,
?
phi
::
funcs
::
ToFloat2
<
half2
>
(
qk_buf_
[
idx
+
qk_offset
]
+
bias_qk_
[
idx
+
qk_offset
])
:
make_float2
(
-
1e20
f
,
-
1e20
f
);
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
max
(
tmp
.
x
,
tmp
.
y
),
mask
);
float
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
max
(
tmp
.
x
,
tmp
.
y
),
mask
);
float2
qk_tmp
=
idx
<
seq_len
?
make_float2
(
__expf
(
tmp
.
x
-
max_val
),
__expf
(
tmp
.
y
-
max_val
))
:
make_float2
(
0.
f
,
0.
f
);
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
qk_tmp
.
x
+
qk_tmp
.
y
,
mask
)
+
1e-6
f
;
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
qk_tmp
.
x
+
qk_tmp
.
y
,
mask
)
+
1e-6
f
;
if
(
idx
<
seq_len
)
{
qk_buf_
[
idx
+
qk_offset
]
=
...
...
@@ -386,14 +386,14 @@ __global__ void SoftmaxKernelWithEltaddForLarge(T *qk_buf,
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
:
stride_max
;
}
T
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
T
>
(
stride_max
,
mask
);
T
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
T
>
(
stride_max
,
mask
);
T
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
stride_sum
+=
__expf
(
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
+
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]
-
max_val
);
}
T
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
T
>
(
stride_sum
,
mask
);
T
sum_val
=
phi
::
funcs
::
B
lockReduceSum
<
T
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
qk_buf
[
threadIdx
.
x
+
i
+
qk_offset
]
=
...
...
@@ -422,7 +422,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_max
=
tmp
>
stride_max
?
tmp
:
stride_max
;
}
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
stride_max
,
mask
);
float
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
stride_max
,
mask
);
float
stride_sum
=
0.
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -430,7 +430,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge(half *qk_buf,
bias_qk
[
threadIdx
.
x
+
i
+
qk_offset
]);
stride_sum
+=
__expf
(
tmp
-
max_val
);
}
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
stride_sum
,
mask
);
float
sum_val
=
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
stride_sum
,
mask
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
float
tmp
=
...
...
@@ -461,7 +461,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
stride_max
.
y
=
max
(
stride_max
.
y
,
cur
.
y
);
}
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -472,7 +472,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(T *qk_buf_,
}
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -507,7 +507,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
stride_max
.
y
=
max
(
stride_max
.
y
,
cur
.
y
);
}
float
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
phi
::
funcs
::
B
lockReduceMax
<
float
>
(
max
(
stride_max
.
x
,
stride_max
.
y
),
mask
);
float2
stride_sum
=
make_float2
(
0.
f
,
0.
f
);
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -519,7 +519,7 @@ __global__ void SoftmaxKernelWithEltaddForLarge2(half2 *qk_buf_,
}
float
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
phi
::
funcs
::
B
lockReduceSum
<
float
>
(
stride_sum
.
x
+
stride_sum
.
y
,
mask
)
+
1e-6
f
;
for
(
int
i
=
0
;
threadIdx
.
x
+
i
<
seq_len
;
i
+=
blockDim
.
x
)
{
...
...
@@ -573,83 +573,6 @@ inline __device__ T hadd2(T a, T b) {
return
__hadd2
(
a
,
b
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceMaxV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
));
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceMaxV2
(
T
*
val
)
{
static
__shared__
T
shared
[
32
][
NUM
];
int
lane
=
threadIdx
.
x
&
0x1f
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
5
;
// warp idx
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
// get maxx in each warp
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
wid
][
i
]
=
val
[
i
];
}
}
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
lane
][
i
]
:
(
T
)
-
1e20
f
;
}
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
typename
T
,
int
ITEMS_PER_THREAD
,
int
NUM
>
__global__
void
softmax_kernel_with_mask
(
T
*
qk_buf_
,
const
T
*
attr_mask
,
...
...
@@ -715,9 +638,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
if
(
blockDim
.
x
<=
32
)
{
w
arpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
phi
::
funcs
::
W
arpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
else
{
b
lockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
phi
::
funcs
::
B
lockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
if
(
threadIdx
.
x
==
0
)
{
...
...
@@ -750,9 +673,9 @@ __global__ void softmax_kernel_with_mask(T *qk_buf_,
}
if
(
blockDim
.
x
<=
32
)
{
w
arpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
phi
::
funcs
::
W
arpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
else
{
b
lockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
phi
::
funcs
::
B
lockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
if
(
threadIdx
.
x
==
0
)
{
...
...
paddle/fluid/operators/optimizers/lars_momentum_op.cu
浏览文件 @
39210ed0
...
...
@@ -187,8 +187,8 @@ __global__ void L2NormKernel(
g_tmp
+=
(
tmp1
*
tmp1
);
tid
+=
grid_stride
;
}
p_tmp
=
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
p_tmp
,
FINAL_MASK
);
g_tmp
=
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
g_tmp
,
FINAL_MASK
);
p_tmp
=
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
p_tmp
,
FINAL_MASK
);
g_tmp
=
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
g_tmp
,
FINAL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
p_buffer
[
blockIdx
.
x
]
=
p_tmp
;
...
...
@@ -198,8 +198,8 @@ __global__ void L2NormKernel(
cg
->
sync
();
// Grid sync for writring partial result to gloabl memory
MT
p_part_sum
=
threadIdx
.
x
<
gridDim
.
x
?
p_buffer
[
threadIdx
.
x
]
:
0
;
MT
g_part_sum
=
threadIdx
.
x
<
gridDim
.
x
?
g_buffer
[
threadIdx
.
x
]
:
0
;
MT
tmp0
=
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
p_part_sum
,
FINAL_MASK
);
MT
tmp1
=
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
g_part_sum
,
FINAL_MASK
);
MT
tmp0
=
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
p_part_sum
,
FINAL_MASK
);
MT
tmp1
=
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
g_part_sum
,
FINAL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
s_buffer
[
0
]
=
tmp0
;
s_buffer
[
1
]
=
tmp1
;
...
...
@@ -393,8 +393,8 @@ __global__ void MomentumLarsKernel(const T* param,
MT
grad_part_norm
=
threadIdx
.
x
<
thresh
?
g_buffer
[
threadIdx
.
x
]
:
0
;
__syncthreads
();
MT
param_norm
=
Sqrt
(
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
param_part_norm
,
FINAL_MASK
));
MT
grad_norm
=
Sqrt
(
rescale_grad_pow
*
phi
::
funcs
::
b
lockReduceSum
<
MT
>
(
Sqrt
(
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
param_part_norm
,
FINAL_MASK
));
MT
grad_norm
=
Sqrt
(
rescale_grad_pow
*
phi
::
funcs
::
B
lockReduceSum
<
MT
>
(
grad_part_norm
,
FINAL_MASK
));
#endif
MomentumUpdate
<
T
,
MT
>
(
param
,
...
...
paddle/phi/kernels/funcs/math_cuda_utils.h
浏览文件 @
39210ed0
...
...
@@ -168,7 +168,7 @@ struct KeyValuePair<half> {
#define WARP_SIZE 32
template
<
typename
T
>
__inline__
__device__
T
w
arpReduceSum
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
W
arpReduceSum
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
+=
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
);
...
...
@@ -180,12 +180,12 @@ __inline__ __device__ T warpReduceSum(T val, unsigned lane_mask) {
/* Calculate the sum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
b
lockReduceSum
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
B
lockReduceSum
(
T
val
,
unsigned
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
w
arpReduceSum
<
T
>
(
val
,
mask
);
val
=
W
arpReduceSum
<
T
>
(
val
,
mask
);
__syncthreads
();
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
...
...
@@ -195,13 +195,53 @@ __inline__ __device__ T blockReduceSum(T val, unsigned mask) {
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
static_cast
<
T
>
(
0.0
f
);
val
=
w
arpReduceSum
<
T
>
(
val
,
mask
);
val
=
W
arpReduceSum
<
T
>
(
val
,
mask
);
return
val
;
}
/*
WarpReduce multi values.
*/
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
WarpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
BlockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
WarpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
WarpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
typename
T
>
__inline__
__device__
T
w
arpReduceMax
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
W
arpReduceMax
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
max
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
...
...
@@ -211,8 +251,19 @@ __inline__ __device__ T warpReduceMax(T val, unsigned lane_mask) {
return
val
;
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
WarpReduceMaxV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
));
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
>
__inline__
__device__
T
w
arpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
__inline__
__device__
T
W
arpReduceMin
(
T
val
,
unsigned
lane_mask
)
{
for
(
int
mask
=
HALF_WARP
;
mask
>
0
;
mask
>>=
1
)
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val
=
min
(
val
,
__shfl_xor_sync
(
lane_mask
,
val
,
mask
,
warpSize
));
...
...
@@ -246,12 +297,12 @@ __inline__ __device__ T PartialWarpReduceMin(T val, unsigned lane_mask) {
/* Calculate the maximum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
b
lockReduceMax
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
B
lockReduceMax
(
T
val
,
unsigned
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
w
arpReduceMax
(
val
,
mask
);
val
=
W
arpReduceMax
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
...
...
@@ -260,26 +311,55 @@ __inline__ __device__ T blockReduceMax(T val, unsigned mask) {
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
-
1e10
f
;
val
=
w
arpReduceMax
(
val
,
mask
);
val
=
W
arpReduceMax
(
val
,
mask
);
return
val
;
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
BlockReduceMaxV2
(
T
*
val
)
{
static
__shared__
T
shared
[
32
][
NUM
];
int
lane
=
threadIdx
.
x
&
0x1f
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
5
;
// warp idx
WarpReduceMaxV2
<
T
,
NUM
>
(
val
);
// get maxx in each warp
if
(
lane
==
0
)
{
// record in-warp maxx by warp Idx
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
wid
][
i
]
=
val
[
i
];
}
}
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
lane
][
i
]
:
(
T
)
-
1e20
f
;
}
WarpReduceMaxV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
/* Calculate the minimum of all elements in a block */
template
<
typename
T
>
__inline__
__device__
T
b
lockReduceMin
(
T
val
,
unsigned
mask
)
{
__inline__
__device__
T
B
lockReduceMin
(
T
val
,
unsigned
mask
)
{
static
__shared__
T
shared
[
WARP_SIZE
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
val
=
w
arpReduceMin
(
val
,
mask
);
val
=
W
arpReduceMin
(
val
,
mask
);
if
(
lane
==
0
)
shared
[
wid
]
=
val
;
__syncthreads
();
// align block_span to warpSize
int
block_span
=
(
blockDim
.
x
+
warpSize
-
1
)
>>
5
;
val
=
(
lane
<
block_span
)
?
shared
[
lane
]
:
1e10
f
;
val
=
w
arpReduceMin
(
val
,
mask
);
val
=
W
arpReduceMin
(
val
,
mask
);
return
val
;
}
...
...
paddle/phi/kernels/fusion/cutlass/moe_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -160,30 +160,17 @@ void InitExpertChoiceRouteKernelLauncher(
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
seq_len); \
} else { \
softmax_kernel_v4_half2<__half, ITEMS_PER_THREAD> \
<<<grid, block, 0, stream>>>(reinterpret_cast<half*>(buffer), \
(const half*)attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
(const half)scalar); \
seq_len); \
} \
} else { \
softmax_kernel_v4<ITEMS_PER_THREAD, T> \
<<<grid, block, 0, stream>>>(buffer, \
buffer_src, \
attr_mask, \
batch_size, \
head_num, \
seq_len_1, \
seq_len_2, \
scalar); \
softmax_kernel_v4<ITEMS_PER_THREAD, T><<<grid, block, 0, stream>>>( \
buffer, buffer_src, attr_mask, batch_size, seq_len); \
}
template
<
typename
T
>
...
...
@@ -191,19 +178,16 @@ void invokeMaskedSoftMax(T* buffer,
const
T
*
buffer_src
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
int
head_num
,
const
T
scalar
,
const
int
seq_len
,
cudaStream_t
stream
)
{
// NOTE: attention scores shape (batch_size,
head_num, seq_len_1, seq_len_2
)
dim3
grid
(
seq_len_1
,
batch_size
,
head_num
);
if
(
batch_size
*
head_num
>
360
)
{
grid
.
x
=
ceil
(
static_cast
<
float
>
(
seq_len_
1
)
/
32.0
f
);
// NOTE: attention scores shape (batch_size,
seq_len
)
dim3
grid
(
1
,
batch_size
,
1
);
if
(
batch_size
>
360
)
{
grid
.
x
=
ceil
(
static_cast
<
float
>
(
1
)
/
32.0
f
);
}
bool
is_half2
=
sizeof
(
T
)
==
2
&&
sizeof
(
T
)
==
2
&&
seq_len
_2
%
2
==
0
;
dim3
block
((
seq_len
_2
/
(
is_half2
?
2
:
1
)
+
31
)
/
32
*
32
);
bool
is_half2
=
sizeof
(
T
)
==
2
&&
sizeof
(
T
)
==
2
&&
seq_len
%
2
==
0
;
dim3
block
((
seq_len
/
(
is_half2
?
2
:
1
)
+
31
)
/
32
*
32
);
if
(
block
.
x
>
2048
&&
block
.
x
<=
4096
)
{
SOFTMAX_KERNEL
(
4
)
...
...
@@ -766,26 +750,19 @@ void MoeKernel(const Context& ctx,
k
,
batch_size
,
ctx
.
stream
());
T
scalar
=
(
T
)
1.0
f
;
if
(
IS_FP16
)
{
invokeMaskedSoftMax
<
__half
>
(
reinterpret_cast
<
__half
*>
(
gating_output
),
reinterpret_cast
<
const
__half
*>
(
gating_output
),
reinterpret_cast
<
const
__half
*>
(
attr_mask
),
/*batch_size=*/
num_rows
,
/*seq_len_1=*/
1
,
/*seq_len_2=*/
num_experts
,
/*head_num=*/
1
,
*
reinterpret_cast
<
const
__half
*>
(
&
scalar
),
/*seq_len=*/
num_experts
,
ctx
.
stream
());
}
else
{
invokeMaskedSoftMax
<
float
>
(
reinterpret_cast
<
float
*>
(
gating_output
),
reinterpret_cast
<
const
float
*>
(
gating_output
),
reinterpret_cast
<
const
float
*>
(
attr_mask
),
/*batch_size=*/
num_rows
,
/*seq_len_1=*/
1
,
/*seq_len_2=*/
num_experts
,
/*head_num=*/
1
,
*
reinterpret_cast
<
const
float
*>
(
&
scalar
),
/*seq_len=*/
num_experts
,
ctx
.
stream
());
}
InvokeTransposeAxis01
(
...
...
paddle/phi/kernels/fusion/cutlass/moe_kernel_impl.h
浏览文件 @
39210ed0
...
...
@@ -26,87 +26,6 @@ static inline size_t AlignTo16(const size_t& input) {
return
ALIGNMENT
*
((
input
+
ALIGNMENT
-
1
)
/
ALIGNMENT
);
}
/*
WarpReduce multi values.
TODO(zhengzekang): Add blocksize templates to reduce shared memory usage.
*/
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceSumV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
+=
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
);
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceSumV2
(
T
*
val
)
{
static
__shared__
T
shared
[
NUM
][
33
];
int
lane
=
threadIdx
.
x
&
0x1f
;
int
wid
=
threadIdx
.
x
>>
5
;
warpReduceSumV2
<
T
,
NUM
>
(
val
);
if
(
lane
==
0
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
i
][
wid
]
=
val
[
i
];
}
}
__syncthreads
();
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
i
][
lane
]
:
(
T
)(
0.0
f
);
}
warpReduceSumV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
warpReduceMaxV2
(
T
*
val
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
#pragma unroll
for
(
int
mask
=
16
;
mask
>
0
;
mask
>>=
1
)
val
[
i
]
=
max
(
val
[
i
],
__shfl_xor_sync
(
FINAL_MASK
,
val
[
i
],
mask
,
32
));
}
return
(
T
)(
0.0
f
);
}
template
<
typename
T
,
int
NUM
>
__inline__
__device__
T
blockReduceMaxV2
(
T
*
val
)
{
static
__shared__
T
shared
[
32
][
NUM
];
int
lane
=
threadIdx
.
x
&
0x1f
;
// in-warp idx
int
wid
=
threadIdx
.
x
>>
5
;
// warp idx
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
// get maxx in each warp
if
(
lane
==
0
)
{
// record in-warp maxx by warp Idx
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
shared
[
wid
][
i
]
=
val
[
i
];
}
}
__syncthreads
();
// Modify from blockDim.x << 5 to blockDim.x / 32. to prevent
// blockDim.x is not divided by 32
bool
is_mask
=
threadIdx
.
x
<
(
blockDim
.
x
/
32.
f
);
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM
;
i
++
)
{
val
[
i
]
=
is_mask
?
shared
[
lane
][
i
]
:
(
T
)
-
1e20
f
;
}
warpReduceMaxV2
<
T
,
NUM
>
(
val
);
return
(
T
)
0.0
f
;
}
class
CubKeyValueSorter
{
public:
CubKeyValueSorter
();
...
...
@@ -311,65 +230,57 @@ __global__ void initialize_expert_choice_route_kernel(
template
<
int
ITEMS_PER_THREAD
,
typename
T
>
__global__
void
softmax_kernel_v4
(
T
*
qk_buf_
,
const
T
*
qk_buf_src
,
// shape [batch_size,
head_num, seq_len_1, seq_len_2
]
const
T
*
attr_mask
,
// shape [batch_size, seq_len
_1, seq_len_2
]
const
T
*
qk_buf_src
,
// shape [batch_size,
seq_len
]
const
T
*
attr_mask
,
// shape [batch_size, seq_len]
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
const
int
seq_len
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
)
{
float
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
float
qk
=
static_cast
<
float
>
(
qk_buf_src
[
qk_offset
]);
float
mask_val
=
static_cast
<
float
>
(
__ldg
(
&
attr_mask
[
mask_offset
]));
mask_val
=
(
1.0
f
-
mask_val
)
*
-
10000.0
f
;
data
[
i
]
=
qk
*
static_cast
<
float
>
(
scalar
)
+
mask_val
;
local_max
=
fmax
(
local_max
,
data
[
i
]);
}
float
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
+
blockIdx
.
z
))
*
seq_len
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
(
blockIdx
.
y
)
*
seq_len
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
float
qk
=
static_cast
<
float
>
(
qk_buf_src
[
qk_offset
]);
float
mask_val
=
static_cast
<
float
>
(
__ldg
(
&
attr_mask
[
mask_offset
]));
mask_val
=
(
1.0
f
-
mask_val
)
*
-
10000.0
f
;
data
[
i
]
=
qk
+
mask_val
;
local_max
=
fmax
(
local_max
,
data
[
i
]);
}
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
w
arpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceMax
<
float
>
(
local_max
,
0xffffffff
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
W
arpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
BlockReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
data
[
i
]
=
__expf
(
data
[
i
]
-
s_max
);
local_sum
+=
data
[
i
];
}
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceSum
<
float
>
(
local_sum
,
0xffffffff
)
:
phi
::
funcs
::
blockReduceSum
<
float
>
(
local_sum
,
0xffffffff
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len
;
i
++
)
{
data
[
i
]
=
__expf
(
data
[
i
]
-
s_max
);
local_sum
+=
data
[
i
];
}
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
WarpReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
)
:
phi
::
funcs
::
BlockReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len_2
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
seq_len_2
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_
[
qk_offset
]
=
(
T
)(
data
[
i
]
*
s_mean
);
}
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
seq_len
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
+
blockIdx
.
z
))
*
seq_len
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_
[
qk_offset
]
=
(
T
)(
data
[
i
]
*
s_mean
);
}
#endif
}
...
...
@@ -378,77 +289,69 @@ template <typename T, int ITEMS_PER_THREAD>
__global__
void
softmax_kernel_v4_half2
(
T
*
qk_buf_
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
const
int
seq_len
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using
T2
=
half2
;
T2
*
qk_buf_half2
=
reinterpret_cast
<
T2
*>
(
qk_buf_
);
const
T2
*
attr_mask_half2
=
(
const
T2
*
)
attr_mask
;
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
)
{
T2
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
T2
qk
=
qk_buf_half2
[
qk_offset
];
T2
mask_val
=
__ldg
(
&
attr_mask_half2
[
mask_offset
]);
mask_val
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
),
__float2half2_rn
(
-
10000.0
f
));
data
[
i
]
=
__hadd2
(
__hmul2
(
qk
,
__half2half2
(
scalar
)),
mask_val
);
local_max
=
fmax
(
local_max
,
fmax
(
static_cast
<
float
>
(
data
[
i
].
x
),
static_cast
<
float
>
(
data
[
i
].
y
)));
}
T2
data
[
ITEMS_PER_THREAD
];
int
qk_offset
;
__shared__
float
s_mean
,
s_max
;
float
local_max
=
-
1e20
f
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
+
blockIdx
.
z
))
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
int
mask_offset
=
blockIdx
.
y
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
T2
qk
=
qk_buf_half2
[
qk_offset
];
T2
mask_val
=
__ldg
(
&
attr_mask_half2
[
mask_offset
]);
mask_val
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
),
__float2half2_rn
(
-
10000.0
f
));
data
[
i
]
=
__hadd2
(
qk
,
mask_val
);
local_max
=
fmax
(
local_max
,
fmax
(
static_cast
<
float
>
(
data
[
i
].
x
),
static_cast
<
float
>
(
data
[
i
].
y
)));
}
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
data
[
i
]
=
h2exp
(
__hsub2
(
data
[
i
],
__float2half2_rn
(
s_max
)));
local_sum
+=
static_cast
<
float
>
(
data
[
i
].
x
+
data
[
i
].
y
);
}
float
max_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
WarpReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
)
:
phi
::
funcs
::
BlockReduceMax
<
float
>
(
local_max
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_max
=
max_val
;
}
__syncthreads
();
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
warpReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
)
:
phi
::
funcs
::
blockReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
);
float
local_sum
=
0
;
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
data
[
i
]
=
h2exp
(
__hsub2
(
data
[
i
],
__float2half2_rn
(
s_max
)));
local_sum
+=
static_cast
<
float
>
(
data
[
i
].
x
+
data
[
i
].
y
);
}
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_half2
[
qk_offset
]
=
__hmul2
(
data
[
i
],
__float2half2_rn
(
s_mean
));
}
float
sum_val
=
blockDim
.
x
<=
32
?
phi
::
funcs
::
WarpReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
)
:
phi
::
funcs
::
BlockReduceSum
<
float
>
(
local_sum
,
0xFFFFFFFF
);
if
(
threadIdx
.
x
==
0
)
{
s_mean
=
sum_val
+
1e-6
f
;
s_mean
=
__fdividef
(
1.0
f
,
s_mean
);
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
qk_offset
=
((
blockIdx
.
y
+
blockIdx
.
z
))
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
qk_buf_half2
[
qk_offset
]
=
__hmul2
(
data
[
i
],
__float2half2_rn
(
s_mean
));
}
#endif
}
...
...
@@ -457,131 +360,123 @@ template <typename T, int ITEMS_PER_THREAD, int NUM>
__global__
void
softmax_kernel_v5_half2
(
T
*
qk_buf_
,
const
T
*
attr_mask
,
const
int
batch_size
,
const
int
head_num
,
const
int
seq_len_1
,
const
int
seq_len_2
,
const
T
scalar
)
{
const
int
seq_len
)
{
#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700
using
T2
=
half2
;
T2
*
qk_buf_half2
=
reinterpret_cast
<
T2
*>
(
qk_buf_
);
const
T2
*
attr_mask_half2
=
(
const
T2
*
)
attr_mask
;
for
(
int
seq_id
=
blockIdx
.
x
;
seq_id
<
seq_len_1
;
seq_id
+=
gridDim
.
x
*
NUM
)
{
T2
data
[
NUM
][
ITEMS_PER_THREAD
];
T2
data
[
NUM
][
ITEMS_PER_THREAD
];
int
qk_offset
[
NUM
];
int
qk_offset
[
NUM
];
__shared__
float
s_sum
[
NUM
],
s_max
[
NUM
];
float
local_max
[
NUM
];
__shared__
float
s_sum
[
NUM
],
s_max
[
NUM
];
float
local_max
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_max
[
j
]
=
-
1e20
f
;
}
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_max
[
j
]
=
-
1e20
f
;
}
const
int
MAX_NUM
=
min
((
seq_len_1
-
seq_id
+
gridDim
.
x
-
1
)
/
gridDim
.
x
,
NUM
);
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
int
mask_offset
[
NUM
];
const
int
MAX_NUM
=
min
((
1
+
gridDim
.
x
-
1
)
/
gridDim
.
x
,
NUM
);
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
int
mask_offset
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
+
blockIdx
.
z
)
+
j
*
gridDim
.
x
)
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
mask_offset
[
j
]
=
(
blockIdx
.
y
+
j
*
gridDim
.
x
)
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
mask_offset
[
j
]
=
(
blockIdx
.
y
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
}
}
T2
mask_val
[
NUM
];
T2
mask_val
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__ldg
(
&
attr_mask_half2
[
mask_offset
[
j
]]);
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__ldg
(
&
attr_mask_half2
[
mask_offset
[
j
]]);
}
T2
qk
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk
[
j
]
=
qk_buf_half2
[
qk_offset
[
j
]];
}
T2
qk
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
[
j
]),
__float2half2_rn
(
-
10000.0
f
));
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
data
[
j
][
i
]
=
__hadd2
(
__hmul2
(
qk
[
j
],
__half2half2
(
scalar
)),
mask_val
[
j
]);
local_max
[
j
]
=
fmax
(
local_max
[
j
],
fmax
(
static_cast
<
float
>
(
data
[
j
][
i
].
x
),
static_cast
<
float
>
(
data
[
j
][
i
].
y
)));
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk
[
j
]
=
qk_buf_half2
[
qk_offset
[
j
]];
}
if
(
blockDim
.
x
<=
32
)
{
warpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
else
{
blockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
mask_val
[
j
]
=
__hmul2
(
__hsub2
(
__float2half2_rn
(
1.0
f
),
mask_val
[
j
]),
__float2half2_rn
(
-
10000.0
f
)
);
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
s_max
[
j
]
=
local_max
[
j
];
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
data
[
j
][
i
]
=
__hadd2
(
qk
[
j
],
mask_val
[
j
]);
local_max
[
j
]
=
fmax
(
local_max
[
j
],
fmax
(
static_cast
<
float
>
(
data
[
j
][
i
].
x
),
static_cast
<
float
>
(
data
[
j
][
i
].
y
)));
}
__syncthreads
();
float
local_sum
[
NUM
];
}
if
(
blockDim
.
x
<=
32
)
{
phi
::
funcs
::
WarpReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
else
{
phi
::
funcs
::
BlockReduceMaxV2
<
float
,
NUM
>
(
local_max
);
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_sum
[
j
]
=
{
0.
f
}
;
s_max
[
j
]
=
local_max
[
j
]
;
}
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
}
__syncthreads
();
float
local_sum
[
NUM
];
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_
NUM
;
j
++
)
{
data
[
j
][
i
]
=
h2exp
(
__hsub2
(
data
[
j
][
i
],
__float2half2_rn
(
s_max
[
j
])))
;
}
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
local_sum
[
j
]
=
{
0.
f
}
;
}
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
local_sum
[
j
]
+=
static_cast
<
float
>
(
data
[
j
][
i
].
x
+
data
[
j
][
i
].
y
);
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
data
[
j
][
i
]
=
h2exp
(
__hsub2
(
data
[
j
][
i
],
__float2half2_rn
(
s_max
[
j
])));
}
if
(
blockDim
.
x
<=
32
)
{
warpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
else
{
blockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
local_sum
[
j
]
+=
static_cast
<
float
>
(
data
[
j
][
i
].
x
+
data
[
j
][
i
].
y
);
}
}
if
(
threadIdx
.
x
==
0
)
{
if
(
blockDim
.
x
<=
32
)
{
phi
::
funcs
::
WarpReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
else
{
phi
::
funcs
::
BlockReduceSumV2
<
float
,
NUM
>
(
local_sum
);
}
if
(
threadIdx
.
x
==
0
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
s_sum
[
j
]
=
__fdividef
(
1.0
f
,
local_sum
[
j
]
+
1e-6
f
);
}
for
(
int
j
=
0
;
j
<
NUM
;
j
++
)
{
s_sum
[
j
]
=
__fdividef
(
1.0
f
,
local_sum
[
j
]
+
1e-6
f
);
}
__syncthreads
();
}
__syncthreads
();
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len_2
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
for
(
int
i
=
0
;
blockDim
.
x
*
i
+
threadIdx
.
x
<
(
seq_len
/
2
)
&&
i
<
ITEMS_PER_THREAD
;
i
++
)
{
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
*
head_num
+
blockIdx
.
z
)
*
seq_len_1
+
seq_id
+
j
*
gridDim
.
x
)
*
(
seq_len_2
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_offset
[
j
]
=
((
blockIdx
.
y
+
blockIdx
.
z
)
+
j
*
gridDim
.
x
)
*
(
seq_len
/
2
)
+
blockDim
.
x
*
i
+
threadIdx
.
x
;
}
#pragma unroll
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_buf_half2
[
qk_offset
[
j
]]
=
__hmul2
(
data
[
j
][
i
],
__float2half2_rn
(
s_sum
[
j
]));
}
for
(
int
j
=
0
;
j
<
MAX_NUM
;
j
++
)
{
qk_buf_half2
[
qk_offset
[
j
]]
=
__hmul2
(
data
[
j
][
i
],
__float2half2_rn
(
s_sum
[
j
]));
}
}
#endif
...
...
paddle/phi/kernels/gpu/dist_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -62,7 +62,7 @@ __global__ void ReduceSumWithSubtract(
}
__syncthreads
();
sum_val
=
phi
::
funcs
::
b
lockReduceSum
<
T
>
(
sum_val
,
FULL_MASK
);
sum_val
=
phi
::
funcs
::
B
lockReduceSum
<
T
>
(
sum_val
,
FULL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
out
[
blockIdx
.
x
]
=
sum_val
;
}
...
...
@@ -80,7 +80,7 @@ __global__ void ReduceMaxWithSubtract(const T* x,
}
__syncthreads
();
max_val
=
phi
::
funcs
::
b
lockReduceMax
<
T
>
(
max_val
,
FULL_MASK
);
max_val
=
phi
::
funcs
::
B
lockReduceMax
<
T
>
(
max_val
,
FULL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
out
[
blockIdx
.
x
]
=
max_val
;
}
...
...
@@ -98,7 +98,7 @@ __global__ void ReduceMinWithSubtract(const T* x,
}
__syncthreads
();
min_val
=
phi
::
funcs
::
b
lockReduceMin
(
min_val
,
FULL_MASK
);
min_val
=
phi
::
funcs
::
B
lockReduceMin
(
min_val
,
FULL_MASK
);
if
(
threadIdx
.
x
==
0
)
{
out
[
blockIdx
.
x
]
=
min_val
;
}
...
...
paddle/phi/kernels/gpu/interpolate_grad_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -211,7 +211,7 @@ __inline__ __device__ T PartialBlockMin(T val,
if
(
threadIdx
.
x
<
threshold
)
{
shared_last_idx
=
(
threshold
>>
5
)
-
1
;
val
=
phi
::
funcs
::
w
arpReduceMin
(
val
,
mask
);
val
=
phi
::
funcs
::
W
arpReduceMin
(
val
,
mask
);
if
(
lane
==
0
)
{
shared
[
wid
]
=
val
;
}
...
...
@@ -226,7 +226,7 @@ __inline__ __device__ T PartialBlockMin(T val,
if
(
threadIdx
.
x
<
threshold
)
{
val
=
(
lane
<=
shared_last_idx
)
?
shared
[
lane
]
:
std
::
numeric_limits
<
T
>::
max
();
val
=
phi
::
funcs
::
w
arpReduceMin
(
val
,
mask
);
val
=
phi
::
funcs
::
W
arpReduceMin
(
val
,
mask
);
shared_last_val
=
val
;
}
__syncthreads
();
...
...
@@ -292,13 +292,13 @@ __global__ void KeBilinearInterpBwShareMemory(T* in,
s_data
[
1
][
threadIdx
.
x
]
=
static_cast
<
MT
>
(
0
);
int
remain
=
nthreads
-
(
tid
&
(
-
blockDim
.
x
));
int
in_top_max_index
=
phi
::
funcs
::
b
lockReduceMax
(
top_right_index
,
FINAL_MASK
);
phi
::
funcs
::
B
lockReduceMax
(
top_right_index
,
FINAL_MASK
);
int
in_bot_max_index
=
phi
::
funcs
::
b
lockReduceMax
(
bot_right_index
,
FINAL_MASK
);
phi
::
funcs
::
B
lockReduceMax
(
bot_right_index
,
FINAL_MASK
);
if
(
remain
>
blockDim
.
x
)
{
in_top_min_index
=
phi
::
funcs
::
b
lockReduceMin
(
input_index
,
FINAL_MASK
);
in_bot_min_index
=
phi
::
funcs
::
b
lockReduceMin
(
bot_left_index
,
FINAL_MASK
);
in_top_min_index
=
phi
::
funcs
::
B
lockReduceMin
(
input_index
,
FINAL_MASK
);
in_bot_min_index
=
phi
::
funcs
::
B
lockReduceMin
(
bot_left_index
,
FINAL_MASK
);
}
else
{
in_top_min_index
=
PartialBlockMin
(
input_index
,
remain
,
FINAL_MASK
);
in_bot_min_index
=
PartialBlockMin
(
bot_left_index
,
remain
,
FINAL_MASK
);
...
...
paddle/phi/kernels/sparse/gpu/fused_attention_grad_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -47,7 +47,7 @@ __global__ void AttnSoftmaxGpuGradKernel(const int64_t* out_crows,
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
mul
+=
out_values
[
row_first
+
idx
]
*
dout_values
[
row_first
+
idx
];
}
T
mul_sum
=
phi
::
funcs
::
w
arpReduceSum
<
T
>
(
mul
,
0xFFFFFFFF
);
T
mul_sum
=
phi
::
funcs
::
W
arpReduceSum
<
T
>
(
mul
,
0xFFFFFFFF
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
dx_values
[
row_first
+
idx
]
=
(
dout_values
[
row_first
+
idx
]
-
mul_sum
)
*
...
...
paddle/phi/kernels/sparse/gpu/fused_attention_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -72,7 +72,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
out_values
[
row_first
+
idx
]
=
-
std
::
numeric_limits
<
T
>::
infinity
();
}
}
T
row_max_val
=
phi
::
funcs
::
w
arpReduceMax
<
T
>
(
max_val
,
0xFFFFFFFF
);
T
row_max_val
=
phi
::
funcs
::
W
arpReduceMax
<
T
>
(
max_val
,
0xFFFFFFFF
);
T
exp_sum
=
0
;
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
...
...
@@ -81,7 +81,7 @@ __global__ void AttnSoftmaxGpuKernel(const int64_t* x_crows,
exp_sum
+=
exp
;
out_values
[
row_first
+
idx
]
=
exp
;
}
T
row_exp_sum
=
phi
::
funcs
::
w
arpReduceSum
<
T
>
(
exp_sum
,
0xFFFFFFFF
);
T
row_exp_sum
=
phi
::
funcs
::
W
arpReduceSum
<
T
>
(
exp_sum
,
0xFFFFFFFF
);
for
(
int
idx
=
threadIdx
.
x
;
idx
<
row_nnz
;
idx
+=
blockDim
.
x
)
{
out_values
[
row_first
+
idx
]
=
out_values
[
row_first
+
idx
]
/
row_exp_sum
;
...
...
paddle/phi/kernels/sparse/gpu/softmax_grad_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -53,7 +53,7 @@ __global__ void SoftmaxGradGpuKernel(const IntT* out_crows,
mul_result
+=
out_values
[
row_first
+
idx
]
*
dout_values
[
row_first
+
idx
];
}
T
sum
=
phi
::
funcs
::
w
arpReduceSum
<
T
>
(
mul_result
,
0xFFFFFFFF
);
T
sum
=
phi
::
funcs
::
W
arpReduceSum
<
T
>
(
mul_result
,
0xFFFFFFFF
);
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
non_zero_idx
+
i
*
warpSize
;
...
...
paddle/phi/kernels/sparse/gpu/softmax_kernel.cu
浏览文件 @
39210ed0
...
...
@@ -57,7 +57,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
max_val
=
val
;
}
}
T
row_max_val
=
phi
::
funcs
::
w
arpReduceMax
<
T
>
(
max_val
,
0xFFFFFFFF
);
T
row_max_val
=
phi
::
funcs
::
W
arpReduceMax
<
T
>
(
max_val
,
0xFFFFFFFF
);
T
exp_sum
=
0
;
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
...
...
@@ -69,7 +69,7 @@ __global__ void SoftmaxGpuKernel(const IntT* x_crows,
exp_sum
+=
exp
;
out_values
[
row_first
+
idx
]
=
exp
;
}
T
row_exp_sum
=
phi
::
funcs
::
w
arpReduceSum
<
T
>
(
exp_sum
,
0xFFFFFFFF
);
T
row_exp_sum
=
phi
::
funcs
::
W
arpReduceSum
<
T
>
(
exp_sum
,
0xFFFFFFFF
);
for
(
int
i
=
0
;
i
<
kIteration
;
++
i
)
{
int
idx
=
non_zero_idx
+
i
*
warpSize
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录