Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6e40fc1d
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6e40fc1d
编写于
8月 14, 2023
作者:
S
Sonder
提交者:
GitHub
8月 14, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Fluid] Move fused_softmax_mask_upper_triangle to phi (#55769)
上级
f54078e8
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
759 addition
and
660 deletion
+759
-660
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc
...e/fluid/operators/fused_softmax_mask_upper_triangle_op.cc
+0
-8
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
...e/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
+0
-621
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h
...le/fluid/operators/fused_softmax_mask_upper_triangle_op.h
+0
-31
paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h
...le/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h
+32
-0
paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc
...ls/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc
+41
-0
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu
...sion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu
+266
-0
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu
...ls/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu
+261
-0
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h
...nels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h
+117
-0
paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc
...e/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc
+40
-0
test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py
test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py
+2
-0
未找到文件。
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cc
浏览文件 @
6e40fc1d
...
...
@@ -10,7 +10,6 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/core/generator.h"
...
...
@@ -102,10 +101,3 @@ REGISTER_OPERATOR(
ops
::
SoftmaxMaskFuseUpperTriangleGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
fused_softmax_mask_upper_triangle_grad
,
ops
::
SoftmaxMaskFuseUpperTriangleOpGrad
);
PD_REGISTER_STRUCT_KERNEL
(
fused_softmax_mask_upper_triangle
,
CPU
,
ALL_LAYOUT
,
ops
::
SoftmaxMaskFuseUpperTriangleCPUKernel
,
float
,
double
)
{}
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.cu
已删除
100644 → 0
浏览文件 @
f54078e8
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused_softmax_mask_upper_triangle_op.h
已删除
100644 → 0
浏览文件 @
f54078e8
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
DeviceContext
>
class
SoftmaxMaskFuseUpperTriangleCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"Softmax mask fuse op only supports GPU now."
));
}
};
}
// namespace operators
}
// namespace paddle
paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
FusedSoftmaxMaskFuseUpperTriangleKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
);
template
<
typename
T
,
typename
Context
>
void
FusedSoftmaxMaskFuseUpperTriangleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
);
}
// namespace phi
paddle/phi/kernels/fusion/cpu/fused_softmax_mask_upper_triangle_kernel.cc
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
FusedSoftmaxMaskFuseUpperTriangleKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
bool
is_gpu_place
=
dev_ctx
.
GetPlace
().
GetType
()
==
phi
::
AllocationType
::
GPU
;
PADDLE_ENFORCE_EQ
(
is_gpu_place
,
true
,
phi
::
errors
::
Unimplemented
(
"Softmax mask fuse op only supports GPU now."
));
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_softmax_mask_upper_triangle
,
CPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusedSoftmaxMaskFuseUpperTriangleKernel
,
float
,
double
)
{}
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_grad_kernel.cu
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
int
pow2_index
>
__global__
void
SoftmaxMaskFuseUpperTriangleGradGPUKernel
(
const
T
*
grad_input
,
T
*
grad_output
,
const
T
*
softmax_rst
,
int64_t
batch_count
,
int64_t
key_seq_len
)
{
constexpr
int
next_pow2
=
1
<<
pow2_index
;
constexpr
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
constexpr
int
kLocalIterations
=
std
::
max
(
next_pow2
/
warp_size
,
4
);
constexpr
int
kLocalBatchSize
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
kOneLoadingCounts
=
4
;
int64_t
key_seq_len_pow_2
=
key_seq_len
*
key_seq_len
;
int64_t
first_idx
=
(
static_cast
<
int64_t
>
(
blockDim
.
y
)
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
kLocalBatchSize
+
blockIdx
.
x
;
int64_t
local_block_idx
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int64_t
local_batches
=
batch_count
-
first_idx
;
if
(
local_batches
>
kLocalBatchSize
)
local_batches
=
kLocalBatchSize
;
// there might be multiple batches per warp. compute the index within the
// batch
int64_t
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int64_t
offset
=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
grad_input
+=
offset
;
grad_output
+=
offset
;
softmax_rst
+=
offset
;
// load data from global memory
float
grad_input_reg
[
kLocalBatchSize
][
kLocalIterations
]{
0.0
f
};
float
softmax_rst_reg
[
kLocalBatchSize
][
kLocalIterations
]{
0.0
f
};
T
temp_grad_input
[
kOneLoadingCounts
];
T
temp_softmax_rst
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
auto
batch_total_number
=
(
i
>=
local_batches
)
?
0
:
local_block_idx
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
auto
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
batch_total_number
)
{
load_data_upper_tri
(
temp_grad_input
,
grad_input
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
load_data_upper_tri
(
temp_softmax_rst
,
softmax_rst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
batch_total_number
)
{
softmax_rst_reg
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_softmax_rst
[
counter
]);
}
}
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
batch_total_number
)
{
grad_input_reg
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_grad_input
[
counter
])
*
softmax_rst_reg
[
i
][
ii
+
counter
];
}
}
}
}
}
float
sum
[
kLocalBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
sum
[
i
]
=
grad_input_reg
[
i
][
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
kLocalIterations
;
++
ii
)
{
sum
[
i
]
+=
grad_input_reg
[
i
][
ii
];
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
AddOP_upper_tri
>
(
sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
auto
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
key_seq_len
)
{
// compute gradients
T
samples_out
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
samples_out
[
counter
]
=
grad_input_reg
[
i
][
ii
+
counter
]
-
softmax_rst_reg
[
i
][
ii
+
counter
]
*
sum
[
i
];
}
load_data_upper_tri
(
grad_output
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
,
samples_out
);
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
FusedSoftmaxMaskFuseUpperTriangleGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
out
,
const
DenseTensor
&
out_grad
,
DenseTensor
*
x_grad
)
{
auto
*
grad_y
=
&
out_grad
;
auto
*
softmax_rst
=
&
out
;
auto
*
x_grad_data
=
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
auto
*
grad_y_data
=
grad_y
->
data
<
T
>
();
auto
*
softmax_rst_data
=
softmax_rst
->
data
<
T
>
();
auto
y_dim
=
grad_y
->
dims
();
auto
batches
=
y_dim
[
0
];
auto
attn_heads
=
y_dim
[
1
];
auto
attn_mul_batch
=
batches
*
attn_heads
;
auto
query_seq_len
=
y_dim
[
2
];
auto
key_seq_len
=
y_dim
[
3
];
auto
stream
=
dev_ctx
.
stream
();
int
pow2_index
=
get_pow2_index_value
(
key_seq_len
);
const
int
next_pow2
=
1
<<
pow2_index
;
int64_t
batch_count
=
attn_mul_batch
*
query_seq_len
;
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
int
batches_per_warp
=
(
next_pow2
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximum gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
dim3
blocks
(
query_seq_len
,
(
attn_mul_batch
+
batches_per_block
)
/
batches_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
5
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
6
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
7
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
8
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
9
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
10
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
11
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
12
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
13
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
14
:
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
14
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
x_grad_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Too large sequence length."
));
break
;
}
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_softmax_mask_upper_triangle_grad
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusedSoftmaxMaskFuseUpperTriangleGradKernel
,
float
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_kernel.cu
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/fused_softmax_mask_upper_triangle_kernel.h"
#include "paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
int
pow2_index
>
__global__
void
SoftmaxMaskFuseUpperTriangleGPUKernel
(
const
T
*
src
,
T
*
dst
,
int64_t
batch_count
,
int64_t
key_seq_len
)
{
constexpr
int
next_pow2
=
1
<<
pow2_index
;
constexpr
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
constexpr
int
kLocalIterations
=
std
::
max
(
next_pow2
/
warp_size
,
4
);
constexpr
int
kLocalBatchSize
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
kOneLoadingCounts
=
4
;
int64_t
key_seq_len_pow_2
=
key_seq_len
*
key_seq_len
;
int64_t
first_idx
=
(
static_cast
<
int64_t
>
(
blockDim
.
y
)
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
kLocalBatchSize
+
blockIdx
.
x
;
int64_t
local_block_idx
=
blockIdx
.
x
+
1
;
int64_t
warp_iter_upper_bound
=
(
local_block_idx
+
kOneLoadingCounts
*
warp_size
-
1
)
/
warp_size
;
int64_t
local_batches
=
batch_count
-
first_idx
;
if
(
local_batches
>
kLocalBatchSize
)
local_batches
=
kLocalBatchSize
;
int64_t
local_idx
=
threadIdx
.
x
;
src
+=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
dst
+=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
float
data
[
kLocalBatchSize
][
kLocalIterations
];
T
temp_in
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
auto
batch_total_number
=
(
i
>=
local_batches
)
?
0
:
local_block_idx
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
auto
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
batch_total_number
)
{
load_data_upper_tri
(
temp_in
,
src
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
((
element_index
+
counter
)
<
batch_total_number
)
{
data
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_in
[
counter
]);
}
else
{
data
[
i
][
ii
+
counter
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
data
[
i
][
ii
+
counter
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
}
}
}
float
max_value
[
kLocalBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
max_value
[
i
]
=
data
[
i
][
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
kLocalIterations
;
++
ii
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
data
[
i
][
ii
])
?
max_value
[
i
]
:
data
[
i
][
ii
];
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
MaxOP_upper_tri
>
(
max_value
);
float
sum
[
kLocalBatchSize
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
++
ii
)
{
if
(
ii
<
warp_iter_upper_bound
)
{
data
[
i
][
ii
]
=
std
::
exp
((
data
[
i
][
ii
]
-
max_value
[
i
]));
sum
[
i
]
+=
data
[
i
][
ii
];
}
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
AddOP_upper_tri
>
(
sum
);
T
out
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
auto
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
local_block_idx
)
{
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
local_block_idx
)
{
out
[
counter
]
=
data
[
i
][
ii
+
counter
]
/
sum
[
i
];
}
else
{
out
[
counter
]
=
0
;
}
}
load_data_upper_tri
(
dst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
,
out
);
}
else
if
(
element_index
<
key_seq_len
)
{
load_zero_vector_upper_tri
(
dst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
}
else
{
break
;
}
}
}
}
template
<
typename
T
,
typename
Context
>
void
FusedSoftmaxMaskFuseUpperTriangleKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
auto
*
x_ptr
=
&
x
;
auto
*
x_data
=
x_ptr
->
data
<
T
>
();
auto
*
y_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
x_dim
=
x_ptr
->
dims
();
auto
batches
=
x_dim
[
0
];
auto
attn_heads
=
x_dim
[
1
];
auto
attn_mul_batch
=
batches
*
attn_heads
;
auto
query_seq_len
=
x_dim
[
2
];
auto
key_seq_len
=
x_dim
[
3
];
PADDLE_ENFORCE_EQ
(
key_seq_len
,
query_seq_len
,
phi
::
errors
::
InvalidArgument
(
"Key seq len must be equal with query seq len "
"received key len: %d, query len: %d"
,
key_seq_len
,
query_seq_len
));
PADDLE_ENFORCE_EQ
(
key_seq_len
>=
32
&&
key_seq_len
<=
16384
,
true
,
phi
::
errors
::
InvalidArgument
(
"Input x's last dim must be between [32, 16384] "
"received the last dimension of x is %d"
,
key_seq_len
));
auto
stream
=
dev_ctx
.
stream
();
int
pow2_index
=
get_pow2_index_value
(
key_seq_len
);
const
int
next_pow2
=
1
<<
pow2_index
;
int64_t
batch_count
=
attn_mul_batch
*
query_seq_len
;
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
int
batches_per_warp
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
PADDLE_ENFORCE_EQ
(
query_seq_len
%
batches_per_block
,
0
,
phi
::
errors
::
InvalidArgument
(
"The query seq len (third dim of input X) must can divide the "
"number of batches per block. The query seq len is %d, while "
"the number of batches per block is %d."
,
query_seq_len
,
batches_per_block
));
dim3
blocks
(
query_seq_len
,
(
attn_mul_batch
+
batches_per_block
)
/
batches_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
5
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
6
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
7
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
8
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
9
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
10
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
11
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
12
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
13
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
14
:
// 16384
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
14
>
<<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
default:
PADDLE_THROW
(
phi
::
errors
::
Unimplemented
(
"Too large sequence length."
));
break
;
}
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fused_softmax_mask_upper_triangle
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FusedSoftmaxMaskFuseUpperTriangleKernel
,
float
,
phi
::
dtype
::
float16
,
phi
::
dtype
::
bfloat16
)
{}
paddle/phi/kernels/fusion/gpu/fused_softmax_mask_upper_triangle_utils.h
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#endif
#include <stdint.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>
#include <algorithm>
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace
phi
{
namespace
fusion
{
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define MASK 0xffffffff
__device__
__inline__
void
load_data_upper_tri
(
phi
::
float16
*
dst
,
const
phi
::
float16
*
src
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
*
(
reinterpret_cast
<
const
float2
*>
(
src
));
}
__device__
__inline__
void
load_data_upper_tri
(
phi
::
bfloat16
*
dst
,
const
phi
::
bfloat16
*
src
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
*
(
reinterpret_cast
<
const
float2
*>
(
src
));
}
__device__
__inline__
void
load_data_upper_tri
(
float
*
dst
,
const
float
*
src
)
{
*
(
reinterpret_cast
<
float4
*>
(
dst
))
=
*
(
reinterpret_cast
<
const
float4
*>
(
src
));
}
__device__
__inline__
void
load_zero_vector_upper_tri
(
phi
::
float16
*
dst
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
make_float2
(
0.0
f
,
0.0
f
);
}
__device__
__inline__
void
load_zero_vector_upper_tri
(
phi
::
bfloat16
*
dst
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
make_float2
(
0.0
f
,
0.0
f
);
}
__device__
__inline__
void
load_zero_vector_upper_tri
(
float
*
dst
)
{
*
(
reinterpret_cast
<
float4
*>
(
dst
))
=
make_float4
(
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
);
}
__inline__
int
get_pow2_index_value
(
int
value
)
{
int
pow2_index
=
0
;
while
((
1
<<
pow2_index
)
<
value
)
{
++
pow2_index
;
}
return
pow2_index
;
}
template
<
typename
T
>
struct
AddOP_upper_tri
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
MaxOP_upper_tri
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
warp_shfl_xor_upper_tri
(
T
value
,
int
laneMask
,
int
width
,
unsigned
int
mask
=
MASK
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
T
,
int
batch
,
int
width
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce_upper_tri
(
T
*
sum
)
{
ReduceOp
<
T
>
r
;
#pragma unroll
for
(
int
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
T
b
=
warp_shfl_xor_upper_tri
(
sum
[
i
],
offset
,
width
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
}
// namespace fusion
}
// namespace phi
paddle/phi/ops/compat/fused_softmax_mask_upper_triangle_sig.cc
0 → 100644
浏览文件 @
6e40fc1d
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace
phi
{
KernelSignature
FusedSoftmaxMaskUpperTriangleOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"fused_softmax_mask_upper_triangle"
,
{
"X"
},
{},
{
"Out"
});
}
KernelSignature
FusedSoftmaxMaskUpperTriangleGradOpArgumentMapping
(
const
ArgumentMappingContext
&
ctx
UNUSED
)
{
return
KernelSignature
(
"fused_softmax_mask_upper_triangle_grad"
,
{
"Out"
,
"Out@GRAD"
},
{},
{
"X@GRAD"
});
}
}
// namespace phi
PD_REGISTER_ARG_MAPPING_FN
(
fused_softmax_mask_upper_triangle
,
phi
::
FusedSoftmaxMaskUpperTriangleOpArgumentMapping
);
PD_REGISTER_ARG_MAPPING_FN
(
fused_softmax_mask_upper_triangle_grad
,
phi
::
FusedSoftmaxMaskUpperTriangleGradOpArgumentMapping
);
test/legacy_test/test_softmax_mask_fuse_upper_triangle_op.py
浏览文件 @
6e40fc1d
...
...
@@ -43,6 +43,7 @@ def _get_softmax_upper(x, fp16=True):
class
TestSoftmaxMaskFuseOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fused_softmax_mask_upper_triangle"
self
.
python_api
=
paddle
.
incubate
.
softmax_mask_fuse_upper_triangle
x
=
np
.
random
.
random
((
1
,
4
,
32
,
32
)).
astype
(
"float16"
)
self
.
inputs
=
{
'X'
:
x
}
rst
=
_get_softmax_upper
(
x
)
...
...
@@ -61,6 +62,7 @@ class TestSoftmaxMaskFuseOp(OpTest):
class
TestSoftmaxMaskFuseOp1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"fused_softmax_mask_upper_triangle"
self
.
python_api
=
paddle
.
incubate
.
softmax_mask_fuse_upper_triangle
x
=
np
.
random
.
random
((
1
,
4
,
32
,
32
))
self
.
inputs
=
{
'X'
:
x
}
rst
=
_get_softmax_upper
(
x
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录