Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
e2e1c57b
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
e2e1c57b
编写于
7月 12, 2021
作者:
Y
Yuang Liu
提交者:
GitHub
7月 12, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
softmax mask fuse upper triangle (#33981)
* softmax mask fuse upper triangle * cover not implemented cpu code
上级
bfbea8fd
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
860 addition
and
1 deletion
+860
-1
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc
...le/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc
+107
-0
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu
...le/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu
+546
-0
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h
+30
-0
python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py
...sts/unittests/test_softmax_mask_fuse_upper_triangle_op.py
+117
-0
python/paddle/incubate/__init__.py
python/paddle/incubate/__init__.py
+2
-1
python/paddle/incubate/operators/__init__.py
python/paddle/incubate/operators/__init__.py
+15
-0
python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py
...le/incubate/operators/softmax_mask_fuse_upper_triangle.py
+42
-0
python/setup.py.in
python/setup.py.in
+1
-0
未找到文件。
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cc
0 → 100644
浏览文件 @
e2e1c57b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
class
SoftmaxMaskFuseUpperTriangleOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"SoftmaxMaskFuseUpperTriangle"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"SoftmaxMaskFuseUpperTriangle"
);
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
x_dims
.
size
(),
4
,
platform
::
errors
::
InvalidArgument
(
"Input x must be in 4D dimension but "
"received the dimension of X is %d"
,
x_dims
.
size
()));
ctx
->
SetOutputDim
(
"Out"
,
x_dims
);
ctx
->
ShareLoD
(
"X"
,
"Out"
);
}
};
class
SoftmaxMaskFuseUpperTriangleOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input of softmax_mask_fuse_upper_triangle op, "
"which is the result of matmul(QK)/sqrt(dk)."
);
AddOutput
(
"Out"
,
"The result of softmax_mask_fuse_upper_triangle op."
);
AddComment
(
R"DOC(
Softmax Mask Fuse Operator.
product = matmul(QK)/sqrt(dk)
output = softmax_mask_fuse_upper_triangle(product)
to get the final output.
)DOC"
);
}
};
class
SoftmaxMaskFuseUpperTriangleOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
framework
::
GradVarName
(
"Out"
),
"SoftmaxMaskFuseUpperTriangleGrad"
);
auto
out_dims
=
ctx
->
GetInputDim
(
framework
::
GradVarName
(
"Out"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"X"
),
out_dims
);
ctx
->
ShareLoD
(
framework
::
GradVarName
(
"Out"
),
framework
::
GradVarName
(
"X"
));
}
};
template
<
typename
T
>
class
SoftmaxMaskFuseUpperTriangleGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"softmax_mask_fuse_upper_triangle_grad"
);
op
->
SetInput
(
"Softmax"
,
this
->
Output
(
"Out"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"X"
),
this
->
InputGrad
(
"X"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
softmax_mask_fuse_upper_triangle
,
ops
::
SoftmaxMaskFuseUpperTriangleOp
,
ops
::
SoftmaxMaskFuseUpperTriangleOpMaker
,
ops
::
SoftmaxMaskFuseUpperTriangleGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SoftmaxMaskFuseUpperTriangleGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
softmax_mask_fuse_upper_triangle_grad
,
ops
::
SoftmaxMaskFuseUpperTriangleOpGrad
);
REGISTER_OP_CPU_KERNEL
(
softmax_mask_fuse_upper_triangle
,
ops
::
SoftmaxMaskFuseUpperTriangleCPUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
float
>
,
ops
::
SoftmaxMaskFuseUpperTriangleCPUKernel
<
paddle
::
platform
::
CPUDeviceContext
,
double
>
);
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.cu
0 → 100644
浏览文件 @
e2e1c57b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
// this file is inspired by:
// https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/fused_kernels/scaled_upper_triang_masked_softmax.h
#ifdef PADDLE_WITH_CUDA
#include <cuda.h>
#include <curand_kernel.h>
#endif
#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
#include <hiprand_kernel.h>
#endif
#include <stdint.h>
#include <thrust/device_ptr.h>
#include <thrust/iterator/counting_iterator.h>
#include <thrust/transform.h>
#include <algorithm>
#include <string>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
using
framework
::
Tensor
;
#ifdef PADDLE_WITH_HIP
#define WARP_SIZE 64
#else
#define WARP_SIZE 32
#endif
#define MASK 0xffffffff
namespace
plat
=
paddle
::
platform
;
__device__
__inline__
void
load_data_upper_tri
(
plat
::
float16
*
dst
,
const
plat
::
float16
*
src
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
*
(
reinterpret_cast
<
const
float2
*>
(
src
));
}
__device__
__inline__
void
load_data_upper_tri
(
float
*
dst
,
const
float
*
src
)
{
*
(
reinterpret_cast
<
float4
*>
(
dst
))
=
*
(
reinterpret_cast
<
const
float4
*>
(
src
));
}
__device__
__inline__
void
load_zero_vector_upper_tri
(
plat
::
float16
*
dst
)
{
*
(
reinterpret_cast
<
float2
*>
(
dst
))
=
make_float2
(
0.0
f
,
0.0
f
);
}
__device__
__inline__
void
load_zero_vector_upper_tri
(
float
*
dst
)
{
*
(
reinterpret_cast
<
float4
*>
(
dst
))
=
make_float4
(
0.0
f
,
0.0
f
,
0.0
f
,
0.0
f
);
}
int
get_pow2_index_value
(
int
value
)
{
int
pow2_index
=
0
;
while
((
1
<<
pow2_index
)
<
value
)
{
++
pow2_index
;
}
return
pow2_index
;
}
template
<
typename
T
>
struct
AddOP_upper_tri
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
MaxOP_upper_tri
{
__device__
__forceinline__
T
operator
()(
T
a
,
T
b
)
const
{
return
a
<
b
?
b
:
a
;
}
};
template
<
typename
T
>
__device__
__forceinline__
T
warp_shfl_xor_upper_tri
(
T
value
,
int
laneMask
,
int
width
,
unsigned
int
mask
=
MASK
)
{
#if CUDA_VERSION >= 9000
return
__shfl_xor_sync
(
mask
,
value
,
laneMask
,
width
);
#else
return
__shfl_xor
(
value
,
laneMask
,
width
);
#endif
}
template
<
typename
T
,
int
batch
,
int
width
,
template
<
typename
>
class
ReduceOp
>
__device__
__forceinline__
void
warp_reduce_upper_tri
(
T
*
sum
)
{
ReduceOp
<
T
>
r
;
#pragma unroll
for
(
int
offset
=
width
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
batch
;
++
i
)
{
T
b
=
warp_shfl_xor_upper_tri
(
sum
[
i
],
offset
,
width
);
sum
[
i
]
=
r
(
sum
[
i
],
b
);
}
}
}
template
<
typename
T
,
int
pow2_index
>
__global__
void
SoftmaxMaskFuseUpperTriangleGPUKernel
(
const
T
*
src
,
T
*
dst
,
int
batch_count
,
int
key_seq_len
)
{
constexpr
int
next_pow2
=
1
<<
pow2_index
;
constexpr
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
constexpr
int
kLocalIterations
=
std
::
max
(
next_pow2
/
warp_size
,
4
);
constexpr
int
kLocalBatchSize
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
kOneLoadingCounts
=
4
;
int
key_seq_len_pow_2
=
key_seq_len
*
key_seq_len
;
int
first_idx
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
kLocalBatchSize
+
blockIdx
.
x
;
int
local_block_idx
=
blockIdx
.
x
+
1
;
int
warp_iter_upper_bound
=
(
local_block_idx
+
kOneLoadingCounts
*
warp_size
-
1
)
/
warp_size
;
int
local_batches
=
batch_count
-
first_idx
;
if
(
local_batches
>
kLocalBatchSize
)
local_batches
=
kLocalBatchSize
;
int
local_idx
=
threadIdx
.
x
;
src
+=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
dst
+=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
float
data
[
kLocalBatchSize
][
kLocalIterations
];
T
temp_in
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
int
batch_total_number
=
(
i
>=
local_batches
)
?
0
:
local_block_idx
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
int
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
batch_total_number
)
{
load_data_upper_tri
(
temp_in
,
src
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
((
element_index
+
counter
)
<
batch_total_number
)
{
data
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_in
[
counter
]);
}
else
{
data
[
i
][
ii
+
counter
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
}
}
else
{
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
data
[
i
][
ii
+
counter
]
=
-
std
::
numeric_limits
<
float
>::
infinity
();
}
}
}
}
float
max_value
[
kLocalBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
max_value
[
i
]
=
data
[
i
][
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
kLocalIterations
;
++
ii
)
{
max_value
[
i
]
=
(
max_value
[
i
]
>
data
[
i
][
ii
])
?
max_value
[
i
]
:
data
[
i
][
ii
];
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
MaxOP_upper_tri
>
(
max_value
);
float
sum
[
kLocalBatchSize
]{
0.0
f
};
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
++
ii
)
{
if
(
ii
<
warp_iter_upper_bound
)
{
data
[
i
][
ii
]
=
std
::
exp
((
data
[
i
][
ii
]
-
max_value
[
i
]));
sum
[
i
]
+=
data
[
i
][
ii
];
}
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
AddOP_upper_tri
>
(
sum
);
T
out
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
int
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
local_block_idx
)
{
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
local_block_idx
)
{
out
[
counter
]
=
data
[
i
][
ii
+
counter
]
/
sum
[
i
];
}
else
{
out
[
counter
]
=
0
;
}
}
load_data_upper_tri
(
dst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
,
out
);
}
else
if
(
element_index
<
key_seq_len
)
{
load_zero_vector_upper_tri
(
dst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
}
else
{
break
;
}
}
}
}
template
<
typename
T
,
int
pow2_index
>
__global__
void
SoftmaxMaskFuseUpperTriangleGradGPUKernel
(
const
T
*
grad_input
,
T
*
grad_output
,
const
T
*
softmax_rst
,
int
batch_count
,
int
key_seq_len
)
{
constexpr
int
next_pow2
=
1
<<
pow2_index
;
constexpr
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
constexpr
int
kLocalIterations
=
std
::
max
(
next_pow2
/
warp_size
,
4
);
constexpr
int
kLocalBatchSize
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
kOneLoadingCounts
=
4
;
int
key_seq_len_pow_2
=
key_seq_len
*
key_seq_len
;
int
first_idx
=
(
blockDim
.
y
*
blockIdx
.
y
+
threadIdx
.
y
)
*
gridDim
.
x
*
kLocalBatchSize
+
blockIdx
.
x
;
int
local_block_idx
=
blockIdx
.
x
+
1
;
// micro_batch_size might not be a multiple of WARP_BATCH. Check how
// many batches have to computed within this WARP.
int
local_batches
=
batch_count
-
first_idx
;
if
(
local_batches
>
kLocalBatchSize
)
local_batches
=
kLocalBatchSize
;
// there might be multiple batches per warp. compute the index within the
// batch
int
local_idx
=
threadIdx
.
x
;
// the first element to process by the current thread
int
offset
=
first_idx
*
key_seq_len
+
kOneLoadingCounts
*
local_idx
;
grad_input
+=
offset
;
grad_output
+=
offset
;
softmax_rst
+=
offset
;
// load data from global memory
float
grad_input_reg
[
kLocalBatchSize
][
kLocalIterations
]{
0.0
f
};
float
softmax_rst_reg
[
kLocalBatchSize
][
kLocalIterations
]{
0.0
f
};
T
temp_grad_input
[
kOneLoadingCounts
];
T
temp_softmax_rst
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
int
batch_total_number
=
(
i
>=
local_batches
)
?
0
:
local_block_idx
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
int
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
batch_total_number
)
{
load_data_upper_tri
(
temp_grad_input
,
grad_input
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
load_data_upper_tri
(
temp_softmax_rst
,
softmax_rst
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
);
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
batch_total_number
)
{
softmax_rst_reg
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_softmax_rst
[
counter
]);
}
}
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
if
(
element_index
+
counter
<
batch_total_number
)
{
grad_input_reg
[
i
][
ii
+
counter
]
=
static_cast
<
float
>
(
temp_grad_input
[
counter
])
*
softmax_rst_reg
[
i
][
ii
+
counter
];
}
}
}
}
}
float
sum
[
kLocalBatchSize
];
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
sum
[
i
]
=
grad_input_reg
[
i
][
0
];
#pragma unroll
for
(
int
ii
=
1
;
ii
<
kLocalIterations
;
++
ii
)
{
sum
[
i
]
+=
grad_input_reg
[
i
][
ii
];
}
}
warp_reduce_upper_tri
<
float
,
kLocalBatchSize
,
warp_size
,
AddOP_upper_tri
>
(
sum
);
#pragma unroll
for
(
int
i
=
0
;
i
<
kLocalBatchSize
;
++
i
)
{
if
(
i
>=
local_batches
)
break
;
#pragma unroll
for
(
int
ii
=
0
;
ii
<
kLocalIterations
;
ii
+=
kOneLoadingCounts
)
{
int
element_index
=
kOneLoadingCounts
*
local_idx
+
ii
*
warp_size
;
if
(
element_index
<
key_seq_len
)
{
// compute gradients
T
samples_out
[
kOneLoadingCounts
];
#pragma unroll
for
(
int
counter
=
0
;
counter
<
kOneLoadingCounts
;
++
counter
)
{
samples_out
[
counter
]
=
grad_input_reg
[
i
][
ii
+
counter
]
-
softmax_rst_reg
[
i
][
ii
+
counter
]
*
sum
[
i
];
}
load_data_upper_tri
(
grad_output
+
i
*
key_seq_len_pow_2
+
ii
*
warp_size
,
samples_out
);
}
}
}
}
template
<
typename
Place
,
typename
T
>
class
SoftmaxMaskFuseUpperTriangleKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
x
=
context
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
*
x_data
=
x
->
data
<
T
>
();
auto
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
x_dim
=
x
->
dims
();
auto
batches
=
x_dim
[
0
];
auto
attn_heads
=
x_dim
[
1
];
auto
attn_mul_batch
=
batches
*
attn_heads
;
auto
query_seq_len
=
x_dim
[
2
];
auto
key_seq_len
=
x_dim
[
3
];
PADDLE_ENFORCE_EQ
(
key_seq_len
,
query_seq_len
,
platform
::
errors
::
InvalidArgument
(
"Key seq len must be equal with query seq len "
"received key len: %d, query len: %d"
,
key_seq_len
,
query_seq_len
));
PADDLE_ENFORCE_EQ
(
key_seq_len
>=
32
&&
key_seq_len
<
8192
,
true
,
platform
::
errors
::
InvalidArgument
(
"Input x's last dim must be between [32, 8192) "
"received the last dimension of x is %d"
,
key_seq_len
));
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
int
pow2_index
=
get_pow2_index_value
(
key_seq_len
);
const
int
next_pow2
=
1
<<
pow2_index
;
int
batch_count
=
attn_mul_batch
*
query_seq_len
;
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
int
batches_per_warp
=
(
next_pow2
<=
128
)
?
2
:
1
;
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
PADDLE_ENFORCE_EQ
(
query_seq_len
%
batches_per_block
,
0
,
platform
::
errors
::
InvalidArgument
(
"The query seq len (third dim of input X) must can divide the "
"number of batches per block. The query seq len is %d, while "
"the number of batches per block is %d."
,
query_seq_len
,
batches_per_block
));
dim3
blocks
(
query_seq_len
,
(
attn_mul_batch
+
batches_per_block
)
/
batches_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
7
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
8
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
9
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
10
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
11
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
12
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseUpperTriangleGPUKernel
<
T
,
13
><<<
blocks
,
threads
,
0
,
stream
>>>
(
x_data
,
y_data
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
};
template
<
typename
Place
,
typename
T
>
class
SoftmaxMaskFuseUpperTriangleGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
grad_x
=
context
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
grad_y
=
context
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
softmax_rst
=
context
.
Input
<
Tensor
>
(
"Softmax"
);
auto
*
grad_x_data
=
grad_x
->
mutable_data
<
T
>
(
context
.
GetPlace
());
auto
*
grad_y_data
=
grad_y
->
data
<
T
>
();
auto
*
softmax_rst_data
=
softmax_rst
->
data
<
T
>
();
auto
y_dim
=
grad_y
->
dims
();
auto
batches
=
y_dim
[
0
];
auto
attn_heads
=
y_dim
[
1
];
auto
attn_mul_batch
=
batches
*
attn_heads
;
auto
query_seq_len
=
y_dim
[
2
];
auto
key_seq_len
=
y_dim
[
3
];
auto
&
place
=
*
context
.
template
device_context
<
Place
>().
eigen_device
();
auto
stream
=
context
.
cuda_device_context
().
stream
();
int
pow2_index
=
get_pow2_index_value
(
key_seq_len
);
const
int
next_pow2
=
1
<<
pow2_index
;
int
batch_count
=
attn_mul_batch
*
query_seq_len
;
int
warp_size
=
(
next_pow2
<
WARP_SIZE
)
?
next_pow2
:
WARP_SIZE
;
int
batches_per_warp
=
(
next_pow2
<=
128
)
?
2
:
1
;
// use 128 threads per block to maximum gpu utilization
constexpr
int
threads_per_block
=
128
;
int
warps_per_block
=
(
threads_per_block
/
warp_size
);
int
batches_per_block
=
warps_per_block
*
batches_per_warp
;
dim3
blocks
(
query_seq_len
,
(
attn_mul_batch
+
batches_per_block
)
/
batches_per_block
,
1
);
dim3
threads
(
warp_size
,
warps_per_block
,
1
);
switch
(
pow2_index
)
{
case
5
:
// 32
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
5
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
6
:
// 64
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
6
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
7
:
// 128
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
7
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
8
:
// 256
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
8
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
9
:
// 512
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
9
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
10
:
// 1024
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
10
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
11
:
// 2048
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
11
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
12
:
// 4096
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
12
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
case
13
:
// 8192
SoftmaxMaskFuseUpperTriangleGradGPUKernel
<
T
,
13
><<<
blocks
,
threads
,
0
,
stream
>>>
(
grad_y_data
,
grad_x_data
,
softmax_rst_data
,
batch_count
,
key_seq_len
);
break
;
default:
break
;
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_CUDA_KERNEL
(
softmax_mask_fuse_upper_triangle
,
ops
::
SoftmaxMaskFuseUpperTriangleKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
SoftmaxMaskFuseUpperTriangleKernel
<
plat
::
CUDADeviceContext
,
float
>
);
REGISTER_OP_CUDA_KERNEL
(
softmax_mask_fuse_upper_triangle_grad
,
ops
::
SoftmaxMaskFuseUpperTriangleGradKernel
<
plat
::
CUDADeviceContext
,
plat
::
float16
>
,
ops
::
SoftmaxMaskFuseUpperTriangleGradKernel
<
plat
::
CUDADeviceContext
,
float
>
);
paddle/fluid/operators/softmax_mask_fuse_upper_triangle_op.h
0 → 100644
浏览文件 @
e2e1c57b
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
SoftmaxMaskFuseUpperTriangleCPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
Unimplemented
(
"Softmax mask fuse op only supports GPU now."
));
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/tests/unittests/test_softmax_mask_fuse_upper_triangle_op.py
0 → 100644
浏览文件 @
e2e1c57b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
unittest
import
numpy
as
np
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
import
paddle
import
paddle.fluid
as
fluid
import
paddle.incubate
as
incubate
paddle
.
enable_static
()
def
_get_softmax_upper
(
x
,
fp16
=
True
):
x_lower
=
np
.
tril
(
x
)
masked_x
=
np
.
where
(
x_lower
==
0
,
-
10000.0
,
x_lower
).
astype
(
"float32"
)
max_value
=
np
.
max
(
masked_x
,
axis
=-
1
,
keepdims
=
True
)
before_exp
=
masked_x
-
max_value
exp
=
np
.
exp
(
before_exp
)
exp_sum
=
np
.
sum
(
exp
,
axis
=-
1
,
keepdims
=
True
)
rst
=
exp
/
exp_sum
if
fp16
:
rst
=
rst
.
astype
(
"float16"
)
return
rst
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestSoftmaxMaskFuseOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"softmax_mask_fuse_upper_triangle"
x
=
np
.
random
.
random
((
1
,
1
,
32
,
32
)).
astype
(
"float16"
)
self
.
inputs
=
{
'X'
:
x
}
rst
=
_get_softmax_upper
(
x
)
self
.
outputs
=
{
'Out'
:
rst
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
core
.
CUDAPlace
(
0
))
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
core
.
CUDAPlace
(
0
),
[
"X"
],
"Out"
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestSoftmaxMaskFuseOp1
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"softmax_mask_fuse_upper_triangle"
x
=
np
.
random
.
random
((
1
,
1
,
32
,
32
))
self
.
inputs
=
{
'X'
:
x
}
rst
=
_get_softmax_upper
(
x
)
self
.
outputs
=
{
'Out'
:
rst
}
def
test_check_output
(
self
):
try
:
self
.
check_output_with_place
(
core
.
CPUPlace
())
except
NotImplementedError
:
pass
def
test_check_grad
(
self
):
try
:
self
.
check_grad_with_place
(
core
.
CPUPlace
(),
[
"X"
],
"Out"
)
except
NotImplementedError
:
pass
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"core is not compiled with CUDA"
)
class
TestDropoutBiasFuseOp2
(
unittest
.
TestCase
):
# test the python side API for softmax_mask_fuse op
def
setUp
(
self
):
np
.
random
.
seed
(
123
)
self
.
dtypes
=
[
'float16'
,
'float32'
]
def
test_static
(
self
):
for
dtype
in
self
.
dtypes
:
with
fluid
.
program_guard
(
fluid
.
Program
(),
fluid
.
Program
()):
input_x
=
fluid
.
data
(
name
=
"x"
,
shape
=
[
1
,
1
,
32
,
32
],
dtype
=
dtype
)
rst
=
incubate
.
softmax_mask_fuse_upper_triangle
(
input_x
)
x_in_np
=
np
.
random
.
random
((
1
,
1
,
32
,
32
)).
astype
(
dtype
)
rst_np
=
_get_softmax_upper
(
x_in_np
,
dtype
==
'float16'
)
exe
=
fluid
.
Executor
(
fluid
.
CUDAPlace
(
0
))
fetches
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
"x"
:
x_in_np
},
fetch_list
=
[
rst
])
self
.
assertTrue
(
np
.
allclose
(
fetches
[
0
],
rst_np
))
def
test_dygraph
(
self
):
for
dtype
in
self
.
dtypes
:
with
fluid
.
dygraph
.
guard
(
fluid
.
CUDAPlace
(
0
)):
x_in_np
=
np
.
random
.
random
((
1
,
1
,
32
,
32
)).
astype
(
dtype
)
rst_np
=
_get_softmax_upper
(
x_in_np
,
dtype
==
'float16'
)
input_x
=
fluid
.
dygraph
.
to_variable
(
x_in_np
)
rst
=
incubate
.
softmax_mask_fuse_upper_triangle
(
input_x
)
self
.
assertTrue
(
np
.
allclose
(
rst
,
rst_np
))
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/incubate/__init__.py
浏览文件 @
e2e1c57b
...
...
@@ -16,7 +16,8 @@ from .optimizer import LookAhead # noqa: F401
from
.optimizer
import
ModelAverage
# noqa: F401
from
.checkpoint
import
auto_checkpoint
# noqa: F401
from
..fluid.layer_helper
import
LayerHelper
# noqa: F401
from
.operators
import
softmax_mask_fuse_upper_triangle
# noqa: F401
__all__
=
[
# noqa
'LookAhead'
,
'ModelAverage'
'LookAhead'
,
'ModelAverage'
,
'softmax_mask_fuse_upper_triangle'
]
python/paddle/incubate/operators/__init__.py
0 → 100644
浏览文件 @
e2e1c57b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
.softmax_mask_fuse_upper_triangle
import
softmax_mask_fuse_upper_triangle
# noqa: F401
python/paddle/incubate/operators/softmax_mask_fuse_upper_triangle.py
0 → 100644
浏览文件 @
e2e1c57b
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
from
paddle.fluid.layer_helper
import
LayerHelper
from
paddle.fluid.framework
import
in_dygraph_mode
from
paddle.fluid
import
core
def
softmax_mask_fuse_upper_triangle
(
x
):
"""
Fuse softmax mask together without even give a mask.
Under GPT model, the mask is always be a upper triangle
so we can simply mask the upper triangle part of x to get the mask result
:param x: the input x (rst of QK)
:return: the result of softmax mask fuse (upper triangle)
"""
if
in_dygraph_mode
():
out
=
core
.
ops
.
softmax_mask_fuse_upper_triangle
(
x
)
return
out
helper
=
LayerHelper
(
'softmax_mask_fuse_upper_triangle'
,
**
locals
())
out
=
helper
.
create_variable_for_type_inference
(
dtype
=
x
.
dtype
)
helper
.
append_op
(
type
=
'softmax_mask_fuse_upper_triangle'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Out'
:
[
out
]})
return
out
python/setup.py.in
浏览文件 @
e2e1c57b
...
...
@@ -146,6 +146,7 @@ packages=['paddle',
'paddle.incubate',
'paddle.incubate.optimizer',
'paddle.incubate.checkpoint',
'paddle.incubate.operators',
'paddle.distributed.fleet',
'paddle.distributed.fleet.base',
'paddle.distributed.fleet.meta_optimizers',
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录