Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
6b587e93
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
6b587e93
编写于
9月 28, 2021
作者:
L
Liu-xiandong
提交者:
GitHub
9月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add sparse_attention api, test=develop (#35676)
Add sparse_attention OPs, python api will be added in next pr
上级
bc7e2b92
变更
6
显示空白变更内容
内联
并排
Showing
6 changed file
with
942 addition
and
2 deletion
+942
-2
cmake/operators.cmake
cmake/operators.cmake
+1
-1
paddle/fluid/operators/CMakeLists.txt
paddle/fluid/operators/CMakeLists.txt
+5
-1
paddle/fluid/operators/sparse_attention_op.cc
paddle/fluid/operators/sparse_attention_op.cc
+193
-0
paddle/fluid/operators/sparse_attention_op.cu
paddle/fluid/operators/sparse_attention_op.cu
+537
-0
python/paddle/fluid/tests/unittests/test_sparse_attention_op.py
.../paddle/fluid/tests/unittests/test_sparse_attention_op.py
+205
-0
python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py
...uid/tests/unittests/white_list/op_threshold_white_list.py
+1
-0
未找到文件。
cmake/operators.cmake
浏览文件 @
6b587e93
...
@@ -214,7 +214,7 @@ function(op_library TARGET)
...
@@ -214,7 +214,7 @@ function(op_library TARGET)
foreach
(
manual_pybind_op
"compare_all_op"
"compare_op"
"logical_op"
"bitwise_op"
"nccl_op"
foreach
(
manual_pybind_op
"compare_all_op"
"compare_op"
"logical_op"
"bitwise_op"
"nccl_op"
"tensor_array_read_write_op"
"tensorrt_engine_op"
"conv_fusion_op"
"tensor_array_read_write_op"
"tensorrt_engine_op"
"conv_fusion_op"
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"fusion_transpose_flatten_concat_op"
"fusion_conv_inception_op"
"sync_batch_norm_op"
"dgc_op"
"fused_fc_elementwise_layernorm_op"
"sync_batch_norm_op"
"
sparse_attention_op"
"
dgc_op"
"fused_fc_elementwise_layernorm_op"
"skip_layernorm_op"
"multihead_matmul_op"
"fusion_group_op"
"fused_bn_activation_op"
"fused_embedding_eltwise_layernorm_op"
"fusion_gru_op"
"fusion_lstm_op"
"skip_layernorm_op"
"multihead_matmul_op"
"fusion_group_op"
"fused_bn_activation_op"
"fused_embedding_eltwise_layernorm_op"
"fusion_gru_op"
"fusion_lstm_op"
"fused_bn_add_activation_op"
)
"fused_bn_add_activation_op"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
if
(
"
${
TARGET
}
"
STREQUAL
"
${
manual_pybind_op
}
"
)
...
...
paddle/fluid/operators/CMakeLists.txt
浏览文件 @
6b587e93
...
@@ -78,7 +78,7 @@ if(WITH_UNITY_BUILD)
...
@@ -78,7 +78,7 @@ if(WITH_UNITY_BUILD)
include
(
unity_build_rule.cmake
)
include
(
unity_build_rule.cmake
)
endif
()
endif
()
register_operators
(
EXCLUDES py_layer_op py_func_op warpctc_op dgc_op lstm_op run_program_op eye_op recurrent_op
register_operators
(
EXCLUDES py_layer_op py_func_op warpctc_op dgc_op
sparse_attention_op
lstm_op run_program_op eye_op recurrent_op
sync_batch_norm_op spectral_op
${
OP_MKL_DEPS
}
DEPS
${
OP_HEADER_DEPS
}
)
sync_batch_norm_op spectral_op
${
OP_MKL_DEPS
}
DEPS
${
OP_HEADER_DEPS
}
)
op_library
(
run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache
${
OP_HEADER_DEPS
}
)
op_library
(
run_program_op SRCS run_program_op.cc run_program_op.cu.cc DEPS executor_cache
${
OP_HEADER_DEPS
}
)
...
@@ -94,6 +94,10 @@ if (WITH_GPU OR WITH_ROCM)
...
@@ -94,6 +94,10 @@ if (WITH_GPU OR WITH_ROCM)
endif
()
endif
()
op_library
(
sync_batch_norm_op
)
op_library
(
sync_batch_norm_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(sync_batch_norm);
\n
"
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(sync_batch_norm);
\n
"
)
if
((
NOT WIN32
)
AND
(
NOT WITH_ROCM
)
AND
(
NOT
${
CMAKE_CUDA_COMPILER_VERSION
}
VERSION_LESS 11.2
)
)
op_library
(
sparse_attention_op
)
file
(
APPEND
${
pybind_file
}
"USE_CUDA_ONLY_OP(sparse_attention);
\n
"
)
endif
()
else
()
else
()
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
op_library
(
warpctc_op DEPS dynload_warpctc sequence_padding sequence_scale
)
endif
()
endif
()
...
...
paddle/fluid/operators/sparse_attention_op.cc
0 → 100644
浏览文件 @
6b587e93
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
class
SparseAttentionOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Q"
,
"(Tensor), The input tensor of query in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`."
);
AddInput
(
"K"
,
"(Tensor), The input tensor of key in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`."
);
AddInput
(
"V"
,
"(Tensor), The input tensor of value in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`."
);
AddInput
(
"Offset"
,
"(Tensor, default: Tensor<int32>), The input tensor of offset in "
"CSR sparse format, "
"whose dimension : `[batch_size, num_heads, target_len + 1]`."
);
AddInput
(
"Columns"
,
"(Tensor, default: Tensor<int32>), The input tensor of columns in "
"CSR sparse format, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_num]`."
);
AddOutput
(
"Out"
,
"(Tensor), The output tensor of result in attention, "
"whose dimension : `[batch_size, num_heads, target_len, head_dim]`."
);
AddOutput
(
"SparseDotSdd"
,
"(Tensor), The output tensor of result in SparseDotSdd step, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`."
)
.
AsIntermediate
();
AddOutput
(
"Softmax"
,
"(Tensor), The output tensor of result in Softmax step, "
"whose dimension : `[batch_size, num_heads, sparse_nnz_dim]`."
)
.
AsIntermediate
();
AddComment
(
R"DOC(
Compute the value of the sparse attention module. Its input value includes five tensors.
Q, K, and V represent query, key, and value in the Attention module, respectively.
The CSR format is used to represent the sparsity feature in the Attention module.
The CSR format contains two tensors, offset and columns.
)DOC"
);
}
};
class
SparseAttentionOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Q"
),
"Input"
,
"Q"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"K"
),
"Input"
,
"K"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"V"
),
"Input"
,
"V"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Offset"
),
"Input"
,
"Offset"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Columns"
),
"Input"
,
"Columns"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SparseDotSdd"
),
"Output"
,
"SparseDotSdd"
,
"sparse_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Softmax"
),
"Output"
,
"Softmax"
,
"sparse_attention"
);
auto
dims_q
=
ctx
->
GetInputDim
(
"Q"
);
auto
dims_k
=
ctx
->
GetInputDim
(
"K"
);
auto
dims_v
=
ctx
->
GetInputDim
(
"V"
);
auto
dims_columns
=
ctx
->
GetInputDim
(
"Columns"
);
PADDLE_ENFORCE_EQ
(
dims_q
.
size
(),
static_cast
<
size_t
>
(
4
),
platform
::
errors
::
InvalidArgument
(
"Dimension in query' shapes should be 4."
));
PADDLE_ENFORCE_EQ
(
dims_k
.
size
(),
static_cast
<
size_t
>
(
4
),
platform
::
errors
::
InvalidArgument
(
"Dimension in key' shapes should be 4."
));
PADDLE_ENFORCE_EQ
(
dims_v
.
size
(),
static_cast
<
size_t
>
(
4
),
platform
::
errors
::
InvalidArgument
(
"Dimension in value' shapes should be 4."
));
auto
batch_size
=
dims_q
[
0
];
auto
num_heads
=
dims_q
[
1
];
auto
M
=
dims_q
[
2
];
auto
N
=
dims_q
[
3
];
auto
sparse_nnz
=
dims_columns
[
2
];
ctx
->
SetOutputDim
(
"Out"
,
{
batch_size
,
num_heads
,
M
,
N
});
ctx
->
SetOutputDim
(
"SparseDotSdd"
,
{
batch_size
,
num_heads
,
sparse_nnz
});
ctx
->
SetOutputDim
(
"Softmax"
,
{
batch_size
,
num_heads
,
sparse_nnz
});
ctx
->
ShareLoD
(
"Q"
,
"Out"
);
}
protected:
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
input_data_type
=
OperatorWithKernel
::
IndicateOrPromoteVarDataTypes
(
ctx
,
"Q"
,
"K"
);
return
framework
::
OpKernelType
(
input_data_type
,
ctx
.
GetPlace
());
}
};
class
SparseAttentionOpGrad
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Q"
),
"Input"
,
"Q"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"K"
),
"Input"
,
"K"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"V"
),
"Input"
,
"V"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Offset"
),
"Input"
,
"Offset"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Columns"
),
"Input"
,
"Columns"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"SparseDotSdd"
),
"Input"
,
"SparseDotSdd"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Softmax"
),
"Input"
,
"Softmax"
,
"sparse_attention_grad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input"
,
"Out@GRAD"
,
"sparse_attention_grad"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"Q"
);
auto
y_grad_name
=
framework
::
GradVarName
(
"K"
);
auto
z_grad_name
=
framework
::
GradVarName
(
"V"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
SetOutputDim
(
x_grad_name
,
ctx
->
GetInputDim
(
"Q"
));
}
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
SetOutputDim
(
y_grad_name
,
ctx
->
GetInputDim
(
"K"
));
}
if
(
ctx
->
HasOutput
(
z_grad_name
))
{
ctx
->
SetOutputDim
(
z_grad_name
,
ctx
->
GetInputDim
(
"V"
));
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
return
framework
::
OpKernelType
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
framework
::
GradVarName
(
"Out"
)),
ctx
.
GetPlace
());
}
};
template
<
typename
T
>
class
SparseAttentionGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"sparse_attention_grad"
);
op
->
SetInput
(
"Q"
,
this
->
Input
(
"Q"
));
op
->
SetInput
(
"K"
,
this
->
Input
(
"K"
));
op
->
SetInput
(
"V"
,
this
->
Input
(
"V"
));
op
->
SetInput
(
"Offset"
,
this
->
Input
(
"Offset"
));
op
->
SetInput
(
"Columns"
,
this
->
Input
(
"Columns"
));
op
->
SetInput
(
"SparseDotSdd"
,
this
->
Output
(
"SparseDotSdd"
));
op
->
SetInput
(
"Softmax"
,
this
->
Output
(
"Softmax"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Q"
),
this
->
InputGrad
(
"Q"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"K"
),
this
->
InputGrad
(
"K"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"V"
),
this
->
InputGrad
(
"V"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
sparse_attention
,
ops
::
SparseAttentionOp
,
ops
::
SparseAttentionOpMaker
,
ops
::
SparseAttentionGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
SparseAttentionGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
sparse_attention_grad
,
ops
::
SparseAttentionOpGrad
);
paddle/fluid/operators/sparse_attention_op.cu
0 → 100644
浏览文件 @
6b587e93
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <math.h>
#include <limits>
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#if defined(PADDLE_WITH_CUDA)
#include "paddle/fluid/platform/dynload/cusparse.h"
#endif
namespace
ops
=
paddle
::
operators
;
namespace
plf
=
paddle
::
platform
;
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
__forceinline__
__device__
T
CudaShuffleXorSync
(
unsigned
mask
,
T
val
,
int
width
=
warpSize
)
{
return
__shfl_xor_sync
(
mask
,
val
,
width
);
}
template
<
typename
T
,
int
batch_size
,
int
warp_size
>
__device__
__forceinline__
void
WarpReduceSum
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
sum_val
=
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
sum
[
i
]
+
sum_val
;
}
}
}
template
<
typename
T
,
int
batch_size
,
int
warp_size
>
__device__
__forceinline__
void
WarpReduceMax
(
T
*
sum
)
{
#pragma unroll
for
(
int
offset
=
warp_size
/
2
;
offset
>
0
;
offset
/=
2
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
T
max_val
=
CudaShuffleXorSync
(
0xFFFFFFFF
,
sum
[
i
],
offset
);
sum
[
i
]
=
max
(
sum
[
i
],
max_val
);
}
}
}
template
<
typename
T
,
int
BlockSize
,
int
BlockNnzMax
>
__global__
void
BlockSparseSoftmaxForward
(
T
*
softmax
,
const
T
*
src
,
T
scale
,
const
T
*
kp_mask
,
const
T
*
attn_mask
,
const
int
*
layout_rowptr
,
const
int
*
layout_colindex
,
int
num_rows
)
{
// current thread related info
const
int
WarpSize
=
32
;
const
int
cur_row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
cur_row
<
num_rows
)
{
const
int
cur_block_row
=
cur_row
/
BlockSize
;
const
int
cur_block_nnz
=
layout_rowptr
[
cur_block_row
+
1
]
-
layout_rowptr
[
cur_block_row
];
T
srcdata
[(
BlockSize
*
BlockNnzMax
+
WarpSize
-
1
)
/
WarpSize
];
T
attndata
[(
BlockSize
*
BlockNnzMax
+
WarpSize
-
1
)
/
WarpSize
];
// read kp mask
T
cur_kp_mask
=
(
kp_mask
==
nullptr
)
?
0
:
kp_mask
[
cur_row
];
// read tensor data, attn mask
const
int
iter
=
(
cur_block_nnz
+
WarpSize
-
1
)
/
WarpSize
;
const
T
*
srcptr
=
src
+
layout_rowptr
[
cur_block_row
];
T
*
attnptr
=
nullptr
;
if
(
attn_mask
!=
nullptr
)
{
const
T
*
attnptr
=
attn_mask
+
cur_block_row
*
num_rows
;
}
const
int
*
colindex
=
layout_colindex
+
layout_rowptr
[
cur_block_row
];
for
(
int
j
=
0
;
j
<
iter
;
j
++
)
{
int
cur_block_col
=
j
*
WarpSize
+
threadIdx
.
x
;
int
cur_reg_index
=
j
;
if
(
cur_block_col
<
cur_block_nnz
)
{
if
((
attnptr
!=
nullptr
)
&&
std
::
abs
(
attnptr
[
colindex
[
cur_block_col
]])
<
std
::
numeric_limits
<
T
>::
epsilon
())
{
srcdata
[
cur_reg_index
]
=
-
std
::
numeric_limits
<
T
>::
infinity
()
*
scale
+
cur_kp_mask
;
}
else
{
srcdata
[
cur_reg_index
]
=
scale
*
srcptr
[
cur_block_col
]
+
cur_kp_mask
;
}
}
else
{
srcdata
[
cur_reg_index
]
=
-
std
::
numeric_limits
<
T
>::
infinity
();
}
}
// max value
T
max_value
=
srcdata
[
0
];
const
int
kIteration
=
(
cur_block_nnz
*
BlockSize
+
WarpSize
-
1
)
/
WarpSize
;
#pragma unroll
for
(
int
it
=
1
;
it
<
kIteration
;
++
it
)
{
max_value
=
(
max_value
>
srcdata
[
it
])
?
max_value
:
srcdata
[
it
];
}
WarpReduceMax
<
T
,
1
,
WarpSize
>
(
&
max_value
);
// exp sum
T
sum
=
0
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIteration
;
++
it
)
{
srcdata
[
it
]
=
std
::
exp
(
srcdata
[
it
]
-
max_value
);
sum
+=
srcdata
[
it
];
}
WarpReduceSum
<
T
,
1
,
WarpSize
>
(
&
sum
);
// compute softmax and write out
T
*
softmaxptr
=
softmax
+
layout_rowptr
[
cur_block_row
];
for
(
int
j
=
0
;
j
<
iter
;
j
++
)
{
int
cur_block_col
=
j
*
WarpSize
+
threadIdx
.
x
;
int
cur_reg_index
=
j
;
if
(
cur_block_col
<
cur_block_nnz
)
{
softmaxptr
[
cur_block_col
]
=
srcdata
[
cur_reg_index
]
/
sum
;
}
}
}
}
template
<
typename
T
,
int
BlockSize
,
int
BlockNnzMax
>
__global__
void
BlockSparseSoftmaxBackward
(
T
*
dst
,
const
T
*
grad
,
const
T
*
src
,
T
scale
,
const
int
*
layout_rowptr
,
const
int
*
layout_colindex
,
int
num_rows
)
{
// current thread related info
const
int
WarpSize
=
32
;
const
int
cur_row
=
blockIdx
.
x
*
blockDim
.
y
+
threadIdx
.
y
;
if
(
cur_row
<
num_rows
)
{
const
int
cur_block_row
=
cur_row
/
BlockSize
;
const
int
cur_block_nnz
=
layout_rowptr
[
cur_block_row
+
1
]
-
layout_rowptr
[
cur_block_row
];
T
srcdata
[(
BlockSize
*
BlockNnzMax
+
WarpSize
-
1
)
/
WarpSize
];
T
graddata
[(
BlockSize
*
BlockNnzMax
+
WarpSize
-
1
)
/
WarpSize
];
// read tensor data, attn mask
const
int
iter
=
(
cur_block_nnz
+
WarpSize
-
1
)
/
WarpSize
;
const
T
*
srcptr
=
src
+
layout_rowptr
[
cur_block_row
];
const
T
*
gradptr
=
grad
+
layout_rowptr
[
cur_block_row
];
for
(
int
j
=
0
;
j
<
iter
;
j
++
)
{
int
cur_block_col
=
j
*
WarpSize
+
threadIdx
.
x
;
int
cur_reg_index
=
j
;
if
(
cur_block_col
<
cur_block_nnz
)
{
srcdata
[
cur_reg_index
]
=
srcptr
[
cur_block_col
];
graddata
[
cur_reg_index
]
=
gradptr
[
cur_block_col
];
}
else
{
srcdata
[
cur_reg_index
]
=
0
;
graddata
[
cur_reg_index
]
=
0
;
}
}
T
sum
=
0
;
const
int
kIteration
=
(
cur_block_nnz
*
BlockSize
+
WarpSize
-
1
)
/
WarpSize
;
#pragma unroll
for
(
int
it
=
0
;
it
<
kIteration
;
++
it
)
{
sum
+=
srcdata
[
it
]
*
graddata
[
it
];
}
WarpReduceSum
<
T
,
1
,
WarpSize
>
(
&
sum
);
// compute softmax and write out
T
*
dstptr
=
dst
+
layout_rowptr
[
cur_block_row
];
for
(
int
j
=
0
;
j
<
iter
;
j
++
)
{
int
cur_block_col
=
j
*
WarpSize
+
threadIdx
.
x
;
int
cur_reg_index
=
j
;
if
(
cur_block_col
<
cur_block_nnz
)
{
dstptr
[
cur_block_col
]
=
scale
*
srcdata
[
cur_reg_index
]
*
(
graddata
[
cur_reg_index
]
-
sum
);
}
}
}
}
using
Tensor
=
framework
::
Tensor
;
/*
input: sparse C in CSR format (num_rows,num_rows)
output: sparse C after softmax operation
*/
template
<
typename
DeviceContext
,
typename
T
>
void
SparseSoftmaxForward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
*
offset
,
const
Tensor
*
columns
,
Tensor
*
input
,
Tensor
*
output
,
const
int
blocksize
,
const
int
num_rows
,
const
int
num_cols
)
{
const
int
*
offset_data
=
offset
->
data
<
int
>
();
const
int
*
columns_data
=
columns
->
data
<
int
>
();
T
*
input_data
=
input
->
data
<
T
>
();
T
*
output_data
=
output
->
data
<
T
>
();
const
int
block_size
=
1
;
dim3
blocks
(
32
,
4
,
1
);
int
grid
=
(
num_rows
*
block_size
+
3
)
/
4
;
T
scaling
=
static_cast
<
T
>
(
1.0
)
/
sqrt
(
static_cast
<
T
>
(
num_cols
));
const
int
block_nnz_max
=
256
;
BlockSparseSoftmaxForward
<
T
,
block_size
,
block_nnz_max
><<<
grid
,
blocks
>>>
(
output_data
,
input_data
,
scaling
,
nullptr
,
nullptr
,
offset_data
,
columns_data
,
num_rows
);
}
template
<
typename
DeviceContext
,
typename
T
>
void
SparseSoftmaxBackward
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
*
offset
,
const
Tensor
*
columns
,
Tensor
*
dx
,
const
Tensor
*
dout
,
const
Tensor
*
out
,
const
int
blocksize
,
const
int
num_rows
,
const
int
num_cols
)
{
const
int
*
offset_data
=
offset
->
data
<
int
>
();
const
int
*
columns_data
=
columns
->
data
<
int
>
();
T
*
dx_data
=
dx
->
data
<
T
>
();
const
T
*
dout_data
=
dout
->
data
<
T
>
();
const
T
*
out_data
=
out
->
data
<
T
>
();
const
int
block_size
=
1
;
dim3
blocks
(
32
,
4
,
1
);
int
grid
=
(
num_rows
*
block_size
+
3
)
/
4
;
T
scaling
=
static_cast
<
T
>
(
1.0
)
/
sqrt
(
static_cast
<
T
>
(
num_cols
));
const
int
block_nnz_max
=
256
;
BlockSparseSoftmaxBackward
<
T
,
block_size
,
block_nnz_max
><<<
grid
,
blocks
>>>
(
dx_data
,
dout_data
,
out_data
,
scaling
,
offset_data
,
columns_data
,
num_rows
);
}
using
VarType
=
framework
::
proto
::
VarType
;
inline
cudaDataType_t
GetGpuType
(
const
VarType
::
Type
data_type
)
{
if
(
data_type
==
VarType
::
FP32
)
{
return
CUDA_R_32F
;
}
else
if
(
data_type
==
VarType
::
FP64
)
{
return
CUDA_R_64F
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Not support tensor type in sparse_attention OP: %s"
,
framework
::
DataTypeToString
(
data_type
)));
}
}
inline
cusparseOperation_t
GetTransposeOperation
(
const
bool
transpose
)
{
if
(
transpose
)
{
return
CUSPARSE_OPERATION_TRANSPOSE
;
}
else
{
return
CUSPARSE_OPERATION_NON_TRANSPOSE
;
}
}
void
CusparseDestroy
(
cusparseDnMatDescr_t
*
dn_mat_first
,
cusparseDnMatDescr_t
*
dn_mat_second
,
cusparseSpMatDescr_t
*
sp_mat
)
{
platform
::
dynload
::
cusparseDestroyDnMat
(
*
dn_mat_first
);
platform
::
dynload
::
cusparseDestroyDnMat
(
*
dn_mat_second
);
platform
::
dynload
::
cusparseDestroySpMat
(
*
sp_mat
);
}
/*
input: dense A (num_rows,num_cols), dense B (num_rows,num_cols)
output: sparse C in CSR format (num_rows,num_rows)
*/
template
<
typename
DeviceContext
,
typename
T
>
void
DotSdd
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
*
a
,
const
Tensor
*
b
,
const
Tensor
*
c_offset
,
const
Tensor
*
c_columns
,
Tensor
*
c_value
,
const
int
num_rows
,
const
int
num_cols
,
const
bool
a_transpose
,
const
bool
b_transpose
)
{
const
T
*
a_data
=
a
->
data
<
T
>
();
const
T
*
b_data
=
b
->
data
<
T
>
();
const
int
*
c_offset_data
=
c_offset
->
data
<
int
>
();
const
int
*
c_columns_data
=
c_columns
->
data
<
int
>
();
T
*
c_value_data
=
c_value
->
data
<
T
>
();
cudaDataType_t
gpu_type
=
GetGpuType
(
c_value
->
type
());
cusparseHandle_t
handle
=
nullptr
;
cusparseDnMatDescr_t
mat_a
,
mat_b
;
cusparseSpMatDescr_t
mat_c
;
platform
::
dynload
::
cusparseCreate
(
&
handle
);
// Create dense matrix A
platform
::
dynload
::
cusparseCreateDnMat
(
&
mat_a
,
num_rows
,
num_cols
,
num_cols
,
const_cast
<
T
*>
(
a_data
),
gpu_type
,
CUSPARSE_ORDER_ROW
);
// Create dense matrix B
platform
::
dynload
::
cusparseCreateDnMat
(
&
mat_b
,
num_rows
,
num_cols
,
num_cols
,
const_cast
<
T
*>
(
b_data
),
gpu_type
,
CUSPARSE_ORDER_ROW
);
// Create sparse matrix C in CSR format
int
c_nnz
=
c_columns
->
dims
()[
1
];
platform
::
dynload
::
cusparseCreateCsr
(
&
mat_c
,
num_rows
,
num_rows
,
c_nnz
,
const_cast
<
int
*>
(
c_offset_data
),
const_cast
<
int
*>
(
c_columns_data
),
c_value_data
,
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_BASE_ZERO
,
gpu_type
);
T
alpha
=
1
;
T
beta
=
0
;
size_t
buffer_size
=
0
;
platform
::
dynload
::
cusparseSDDMM_bufferSize
(
handle
,
GetTransposeOperation
(
a_transpose
),
GetTransposeOperation
(
b_transpose
),
&
alpha
,
mat_a
,
mat_b
,
&
beta
,
mat_c
,
gpu_type
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
&
buffer_size
);
auto
d_buffer_ptr
=
paddle
::
memory
::
Alloc
(
ctx
,
buffer_size
);
void
*
d_buffer
=
static_cast
<
void
*>
(
d_buffer_ptr
->
ptr
());
platform
::
dynload
::
cusparseSDDMM
(
handle
,
GetTransposeOperation
(
a_transpose
),
GetTransposeOperation
(
b_transpose
),
&
alpha
,
mat_a
,
mat_b
,
&
beta
,
mat_c
,
gpu_type
,
CUSPARSE_SDDMM_ALG_DEFAULT
,
d_buffer
);
CusparseDestroy
(
&
mat_a
,
&
mat_b
,
&
mat_c
);
platform
::
dynload
::
cusparseDestroy
(
handle
);
}
/*
input: sparse A in CSR format (num_rows,num_rows), dense B (num_rows,num_cols)
output: dense C (num_rows,num_cols)
*/
template
<
typename
DeviceContext
,
typename
T
>
void
DotDsd
(
const
platform
::
CUDADeviceContext
&
ctx
,
const
Tensor
*
a_offset
,
const
Tensor
*
a_columns
,
const
Tensor
*
a_value
,
const
Tensor
*
b
,
Tensor
*
c
,
const
int
num_rows
,
const
int
num_cols
,
const
bool
a_transpose
,
const
bool
b_transpose
)
{
const
int
*
a_offset_data
=
a_offset
->
data
<
int
>
();
const
int
*
a_columns_data
=
a_columns
->
data
<
int
>
();
const
T
*
a_value_data
=
a_value
->
data
<
T
>
();
const
T
*
b_data
=
b
->
data
<
T
>
();
T
*
c_data
=
c
->
data
<
T
>
();
cudaDataType_t
gpu_type
=
GetGpuType
(
c
->
type
());
cusparseHandle_t
handle
=
nullptr
;
cusparseSpMatDescr_t
mat_a
;
cusparseDnMatDescr_t
mat_b
,
mat_c
;
platform
::
dynload
::
cusparseCreate
(
&
handle
);
// Create sparse matrix A in CSR format
int
a_nnz
=
a_columns
->
dims
()[
1
];
platform
::
dynload
::
cusparseCreateCsr
(
&
mat_a
,
num_rows
,
num_rows
,
a_nnz
,
const_cast
<
int
*>
(
a_offset_data
),
const_cast
<
int
*>
(
a_columns_data
),
const_cast
<
T
*>
(
a_value_data
),
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_32I
,
CUSPARSE_INDEX_BASE_ZERO
,
gpu_type
);
// Create dense matrix B
platform
::
dynload
::
cusparseCreateDnMat
(
&
mat_b
,
num_rows
,
num_cols
,
num_cols
,
const_cast
<
T
*>
(
b_data
),
gpu_type
,
CUSPARSE_ORDER_ROW
);
// Create dense matrix C
platform
::
dynload
::
cusparseCreateDnMat
(
&
mat_c
,
num_rows
,
num_cols
,
num_cols
,
c_data
,
gpu_type
,
CUSPARSE_ORDER_ROW
);
T
alpha
=
1
;
T
beta
=
0
;
size_t
buffer_size
=
0
;
// allocate an external buffer if needed
platform
::
dynload
::
cusparseSpMM_bufferSize
(
handle
,
GetTransposeOperation
(
a_transpose
),
GetTransposeOperation
(
b_transpose
),
&
alpha
,
mat_a
,
mat_b
,
&
beta
,
mat_c
,
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
&
buffer_size
);
auto
d_buffer_ptr
=
paddle
::
memory
::
Alloc
(
ctx
,
buffer_size
);
void
*
d_buffer
=
static_cast
<
void
*>
(
d_buffer_ptr
->
ptr
());
platform
::
dynload
::
cusparseSpMM
(
handle
,
GetTransposeOperation
(
a_transpose
),
GetTransposeOperation
(
b_transpose
),
&
alpha
,
mat_a
,
mat_b
,
&
beta
,
mat_c
,
gpu_type
,
CUSPARSE_SPMM_ALG_DEFAULT
,
d_buffer
);
CusparseDestroy
(
&
mat_b
,
&
mat_c
,
&
mat_a
);
platform
::
dynload
::
cusparseDestroy
(
handle
);
}
std
::
vector
<
Tensor
>
GetSplitTensor
(
Tensor
*
input
)
{
auto
dims
=
input
->
dims
();
int
batch_size
=
dims
[
0
];
int
num_heads
=
dims
[
1
];
std
::
vector
<
int
>
new_dims
(
dims
.
size
()
-
1
);
new_dims
[
0
]
=
batch_size
*
num_heads
;
for
(
int
i
=
1
;
i
<
new_dims
.
size
();
i
++
)
{
new_dims
[
i
]
=
dims
[
i
+
1
];
}
input
->
Resize
(
framework
::
make_ddim
(
new_dims
));
return
input
->
Split
(
1
,
0
);
}
template
<
typename
DeviceContext
,
typename
T
>
class
SparseAttentionCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
query
=
*
ctx
.
Input
<
Tensor
>
(
"Q"
);
auto
key
=
*
ctx
.
Input
<
Tensor
>
(
"K"
);
auto
value
=
*
ctx
.
Input
<
Tensor
>
(
"V"
);
auto
offset
=
*
ctx
.
Input
<
Tensor
>
(
"Offset"
);
auto
columns
=
*
ctx
.
Input
<
Tensor
>
(
"Columns"
);
auto
output_ptr
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
output_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
sparse_dot_sdd_ptr
=
ctx
.
Output
<
Tensor
>
(
"SparseDotSdd"
);
sparse_dot_sdd_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
softmax_ptr
=
ctx
.
Output
<
Tensor
>
(
"Softmax"
);
softmax_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
output
=
*
output_ptr
;
auto
result_sdd
=
*
sparse_dot_sdd_ptr
;
auto
result_softmax
=
*
softmax_ptr
;
auto
query_dims
=
query
.
dims
();
int
batch_size
=
query_dims
[
0
];
int
num_heads
=
query_dims
[
1
];
int
M
=
query_dims
[
2
];
int
N
=
query_dims
[
3
];
std
::
vector
<
Tensor
>
query_lists
=
GetSplitTensor
(
&
query
);
std
::
vector
<
Tensor
>
key_lists
=
GetSplitTensor
(
&
key
);
std
::
vector
<
Tensor
>
value_lists
=
GetSplitTensor
(
&
value
);
std
::
vector
<
Tensor
>
offset_lists
=
GetSplitTensor
(
&
offset
);
std
::
vector
<
Tensor
>
columns_lists
=
GetSplitTensor
(
&
columns
);
std
::
vector
<
Tensor
>
result_sdd_lists
=
GetSplitTensor
(
&
result_sdd
);
std
::
vector
<
Tensor
>
result_softmax_lists
=
GetSplitTensor
(
&
result_softmax
);
std
::
vector
<
Tensor
>
output_lists
=
GetSplitTensor
(
&
output
);
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
const
int
iter_num
=
batch_size
*
num_heads
;
for
(
int
i
=
0
;
i
<
iter_num
;
i
++
)
{
DotSdd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
query_lists
[
i
],
&
key_lists
[
i
],
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
result_sdd_lists
[
i
],
M
,
N
,
false
,
true
);
SparseSoftmaxForward
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
result_sdd_lists
[
i
],
&
result_softmax_lists
[
i
],
1
,
M
,
N
);
DotDsd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
result_softmax_lists
[
i
],
&
value_lists
[
i
],
&
output_lists
[
i
],
M
,
N
,
false
,
false
);
}
}
};
template
<
typename
DeviceContext
,
typename
T
>
class
SparseAttentionGradCUDAKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
query
=
*
ctx
.
Input
<
Tensor
>
(
"Q"
);
auto
key
=
*
ctx
.
Input
<
Tensor
>
(
"K"
);
auto
value
=
*
ctx
.
Input
<
Tensor
>
(
"V"
);
auto
offset
=
*
ctx
.
Input
<
Tensor
>
(
"Offset"
);
auto
columns
=
*
ctx
.
Input
<
Tensor
>
(
"Columns"
);
auto
sparse_dot_sdd
=
*
ctx
.
Input
<
Tensor
>
(
"SparseDotSdd"
);
auto
softmax
=
*
ctx
.
Input
<
Tensor
>
(
"Softmax"
);
auto
dout
=
*
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dquery_ptr
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Q"
));
auto
*
dkey_ptr
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"K"
));
auto
*
dvalue_ptr
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"V"
));
dquery_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dkey_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
dvalue_ptr
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
dquery
=
*
dquery_ptr
;
auto
dkey
=
*
dkey_ptr
;
auto
dvalue
=
*
dvalue_ptr
;
auto
query_dims
=
query
.
dims
();
int
batch_size
=
query_dims
[
0
];
int
num_heads
=
query_dims
[
1
];
int
M
=
query_dims
[
2
];
int
N
=
query_dims
[
3
];
std
::
vector
<
Tensor
>
query_lists
=
GetSplitTensor
(
&
query
);
std
::
vector
<
Tensor
>
key_lists
=
GetSplitTensor
(
&
key
);
std
::
vector
<
Tensor
>
value_lists
=
GetSplitTensor
(
&
value
);
std
::
vector
<
Tensor
>
offset_lists
=
GetSplitTensor
(
&
offset
);
std
::
vector
<
Tensor
>
columns_lists
=
GetSplitTensor
(
&
columns
);
std
::
vector
<
Tensor
>
sparse_dot_sdd_lists
=
GetSplitTensor
(
&
sparse_dot_sdd
);
std
::
vector
<
Tensor
>
softmax_lists
=
GetSplitTensor
(
&
softmax
);
std
::
vector
<
Tensor
>
dout_lists
=
GetSplitTensor
(
&
dout
);
std
::
vector
<
Tensor
>
dquery_lists
=
GetSplitTensor
(
&
dquery
);
std
::
vector
<
Tensor
>
dkey_lists
=
GetSplitTensor
(
&
dkey
);
std
::
vector
<
Tensor
>
dvalue_lists
=
GetSplitTensor
(
&
dvalue
);
const
int
iter_num
=
batch_size
*
num_heads
;
const
auto
&
dev_ctx
=
ctx
.
cuda_device_context
();
for
(
int
i
=
0
;
i
<
iter_num
;
i
++
)
{
// dValue = transpose(result_softmax) * dOut
DotDsd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
softmax_lists
[
i
],
&
dout_lists
[
i
],
&
dvalue_lists
[
i
],
M
,
N
,
true
,
false
);
// dSoftmax = dOut * transpose(Value)
int
nnz_num
=
columns
.
dims
()[
0
];
Tensor
dsoftmax
;
dsoftmax
.
Resize
({
nnz_num
});
dsoftmax
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
DotSdd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
dout_lists
[
i
],
&
value_lists
[
i
],
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
dsoftmax
,
M
,
N
,
false
,
true
);
// dSparseDotSdd = dSoftmax * softmax'(SparseDotSdd)
Tensor
dsparse_dot_sdd
;
dsparse_dot_sdd
.
Resize
({
nnz_num
});
dsparse_dot_sdd
.
mutable_data
<
T
>
(
ctx
.
GetPlace
());
SparseSoftmaxBackward
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
dsparse_dot_sdd
,
&
dsoftmax
,
&
softmax_lists
[
i
],
1
,
M
,
N
);
// dQuery = dSparseDotSdd * Key
DotDsd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
dsparse_dot_sdd
,
&
key_lists
[
i
],
&
dquery_lists
[
i
],
M
,
N
,
false
,
false
);
// dKey = transpose(dSparseDotSdd) * Query
DotDsd
<
DeviceContext
,
T
>
(
dev_ctx
,
&
offset_lists
[
i
],
&
columns_lists
[
i
],
&
dsparse_dot_sdd
,
&
query_lists
[
i
],
&
dkey_lists
[
i
],
M
,
N
,
true
,
false
);
}
}
};
}
// namespace operators
}
// namespace paddle
REGISTER_OP_CUDA_KERNEL
(
sparse_attention
,
ops
::
SparseAttentionCUDAKernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
SparseAttentionCUDAKernel
<
plf
::
CUDADeviceContext
,
double
>
);
REGISTER_OP_CUDA_KERNEL
(
sparse_attention_grad
,
ops
::
SparseAttentionGradCUDAKernel
<
plf
::
CUDADeviceContext
,
float
>
,
ops
::
SparseAttentionGradCUDAKernel
<
plf
::
CUDADeviceContext
,
double
>
);
python/paddle/fluid/tests/unittests/test_sparse_attention_op.py
0 → 100644
浏览文件 @
6b587e93
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
import
numpy
as
np
from
op_test
import
OpTest
import
paddle.fluid.core
as
core
import
paddle
import
os
import
re
import
platform
def
get_cuda_version
():
result
=
os
.
popen
(
"nvcc --version"
).
read
()
regex
=
r
'release (\S+),'
match
=
re
.
search
(
regex
,
result
)
if
match
:
num
=
str
(
match
.
group
(
1
))
integer
,
decimal
=
num
.
split
(
'.'
)
return
int
(
integer
)
*
1000
+
int
(
float
(
decimal
)
*
10
)
else
:
return
-
1
def
get_linux_platform
():
if
platform
.
system
().
lower
()
==
'windows'
:
return
0
elif
platform
.
system
().
lower
()
==
'linux'
:
return
1
else
:
return
-
1
def
get_suitable_env
():
if
get_cuda_version
()
>=
11020
and
get_linux_platform
()
==
1
:
return
True
else
:
return
False
def
softmax
(
x
):
max
=
np
.
max
(
x
,
axis
=
1
,
keepdims
=
True
)
e_x
=
np
.
exp
(
x
-
max
)
sum
=
np
.
sum
(
e_x
,
axis
=
1
,
keepdims
=
True
)
f_x
=
e_x
/
sum
return
f_x
def
get_csr_value
(
mat
,
layout
,
nnz
):
row
,
col
=
mat
.
shape
[
0
],
mat
.
shape
[
1
]
value
=
np
.
zeros
(
nnz
)
ptr
=
0
for
i
in
range
(
row
):
for
j
in
range
(
col
):
if
layout
[
i
][
j
]
==
1
:
value
[
ptr
]
=
mat
[
i
][
j
]
ptr
+=
1
return
value
def
ref_sparse_attention
(
q
,
k
,
v
,
offset
,
columns
):
row
,
col
,
nnz
=
q
.
shape
[
0
],
q
.
shape
[
1
],
columns
.
shape
[
0
]
mat
=
np
.
zeros
((
row
,
row
))
for
cur_row
in
range
(
row
):
start_ptr
=
int
(
offset
[
cur_row
])
end_ptr
=
int
(
offset
[
cur_row
+
1
])
for
ptr
in
range
(
start_ptr
,
end_ptr
):
cur_col
=
int
(
columns
[
ptr
])
mat
[
cur_row
][
cur_col
]
=
1
a
=
np
.
dot
(
q
,
k
.
T
)
*
mat
a_value
=
get_csr_value
(
a
,
mat
,
nnz
)
scaling
=
float
(
col
)
**-
0.5
a
=
scaling
*
a
for
i
in
range
(
row
):
for
j
in
range
(
row
):
if
mat
[
i
][
j
]
==
0
:
a
[
i
][
j
]
=
float
(
'-inf'
)
b
=
softmax
(
a
)
b_value
=
get_csr_value
(
b
,
mat
,
nnz
)
result
=
np
.
dot
(
b
,
v
)
return
result
,
a_value
,
b_value
def
ref_batch_sparse_attention
(
q
,
k
,
v
,
offset
,
columns
):
batch_size
,
num_heads
,
row
,
col
=
q
.
shape
nnz
=
columns
.
shape
[
2
]
result
=
np
.
zeros
((
batch_size
,
num_heads
,
row
,
col
))
result_sdd
=
np
.
zeros
((
batch_size
,
num_heads
,
nnz
))
result_softmax
=
np
.
zeros
((
batch_size
,
num_heads
,
nnz
))
for
i
in
range
(
batch_size
):
for
j
in
range
(
num_heads
):
cur_q
,
cur_k
,
cur_v
,
=
q
[
i
][
j
],
k
[
i
][
j
],
v
[
i
][
j
]
cur_offset
,
cur_columns
=
offset
[
i
][
j
],
columns
[
i
][
j
]
cur_result
,
cur_sdd
,
cur_softmax
=
ref_sparse_attention
(
cur_q
,
cur_k
,
cur_v
,
cur_offset
,
cur_columns
)
result
[
i
][
j
]
=
cur_result
result_sdd
[
i
][
j
],
result_softmax
[
i
][
j
]
=
cur_sdd
,
cur_softmax
return
result
,
result_sdd
,
result_softmax
def
init_csr_format
(
batch_size
,
num_heads
,
rows
,
blocksize
):
block_num
,
block_last
=
rows
/
blocksize
,
rows
%
blocksize
nnz_num
=
block_num
*
blocksize
*
blocksize
+
block_last
*
block_last
offset
=
np
.
zeros
(
rows
+
1
)
columns
=
np
.
zeros
(
int
(
nnz_num
))
mat
=
np
.
zeros
((
rows
,
rows
))
for
i
in
range
(
0
,
rows
,
blocksize
):
for
x
in
range
(
blocksize
):
for
y
in
range
(
blocksize
):
p_x
,
p_y
=
i
+
x
,
i
+
y
if
(
p_x
<
rows
)
and
(
p_y
<
rows
):
mat
[
p_x
][
p_y
]
=
1
p_offset
,
p_column
,
count
=
0
,
0
,
0
for
i
in
range
(
rows
):
for
j
in
range
(
rows
):
if
mat
[
i
][
j
]
!=
0
:
count
+=
1
columns
[
p_column
]
=
j
p_column
+=
1
p_offset
+=
1
offset
[
p_offset
]
=
count
offset
=
np
.
expand_dims
(
np
.
expand_dims
(
offset
,
0
),
0
)
offset
=
offset
.
repeat
(
num_heads
,
axis
=
1
)
offset
=
offset
.
repeat
(
batch_size
,
axis
=
0
)
columns
=
np
.
expand_dims
(
np
.
expand_dims
(
columns
,
0
),
0
)
columns
=
columns
.
repeat
(
num_heads
,
axis
=
1
)
columns
=
columns
.
repeat
(
batch_size
,
axis
=
0
)
return
offset
,
columns
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
get_suitable_env
()
==
False
,
"core is not compiled with CUDA and cuda version need >= 11.2 in windows"
)
class
TestSparseAttentionOp
(
OpTest
):
def
config
(
self
):
self
.
shape
=
(
1
,
1
,
16
,
8
)
self
.
blocksize
=
2
self
.
dtype
=
"float64"
def
setUp
(
self
):
paddle
.
enable_static
()
self
.
config
()
self
.
op_type
=
"sparse_attention"
self
.
place
=
paddle
.
CUDAPlace
(
0
)
self
.
q
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
self
.
k
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
self
.
v
=
np
.
random
.
random
(
self
.
shape
).
astype
(
self
.
dtype
)
offset
,
columns
=
init_csr_format
(
self
.
shape
[
0
],
self
.
shape
[
1
],
self
.
shape
[
2
],
self
.
blocksize
)
self
.
offset
=
offset
.
astype
(
'int32'
)
self
.
columns
=
columns
.
astype
(
'int32'
)
result
,
result_sdd
,
result_softmax
=
ref_batch_sparse_attention
(
self
.
q
,
self
.
k
,
self
.
v
,
self
.
offset
,
self
.
columns
)
self
.
inputs
=
{
'Q'
:
self
.
q
,
'K'
:
self
.
k
,
'V'
:
self
.
v
,
'offset'
:
self
.
offset
,
'columns'
:
self
.
columns
}
self
.
outputs
=
{
'Out'
:
result
.
astype
(
self
.
dtype
),
'ResultSdd'
:
result_sdd
.
astype
(
self
.
dtype
),
'ResultSoftmax'
:
result_softmax
.
astype
(
self
.
dtype
)
}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
)
def
test_check_grad
(
self
):
self
.
check_grad_with_place
(
self
.
place
,
[
'Q'
],
'Out'
)
self
.
check_grad_with_place
(
self
.
place
,
[
'K'
],
'Out'
)
self
.
check_grad_with_place
(
self
.
place
,
[
'V'
],
'Out'
)
class
TestSparseAttentionOpFp32Test
(
TestSparseAttentionOp
):
def
config
(
self
):
self
.
shape
=
(
1
,
1
,
8
,
16
)
self
.
blocksize
=
2
self
.
dtype
=
"float32"
class
TestSparseAttentionOpShapeTest
(
TestSparseAttentionOp
):
def
config
(
self
):
self
.
shape
=
(
2
,
2
,
32
,
8
)
self
.
blocksize
=
8
self
.
dtype
=
"float64"
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/white_list/op_threshold_white_list.py
浏览文件 @
6b587e93
...
@@ -46,6 +46,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
...
@@ -46,6 +46,7 @@ NEED_FIX_FP64_CHECK_GRAD_THRESHOLD_OP_LIST = [
'cudnn_lstm'
,
\
'cudnn_lstm'
,
\
'rnn'
,
\
'rnn'
,
\
'lgamma'
,
\
'lgamma'
,
\
'sparse_attention'
,
\
'svd'
,
\
'svd'
,
\
'matrix_power'
,
\
'matrix_power'
,
\
'solve'
,
\
'solve'
,
\
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录