Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fdcdbec5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fdcdbec5
编写于
5月 30, 2022
作者:
C
crystal
提交者:
GitHub
5月 30, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Implement fused_gate_attention operator for AlphaFold. (#42018)
上级
17b8446d
变更
11
展开全部
显示空白变更内容
内联
并排
Showing
11 changed file
with
1821 addition
and
109 deletion
+1821
-109
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+3
-1
paddle/fluid/operators/fused/attn_gemm.h
paddle/fluid/operators/fused/attn_gemm.h
+80
-84
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+3
-1
paddle/fluid/operators/fused/fused_gate_attention.h
paddle/fluid/operators/fused/fused_gate_attention.h
+647
-0
paddle/fluid/operators/fused/fused_gate_attention_op.cc
paddle/fluid/operators/fused/fused_gate_attention_op.cc
+317
-0
paddle/fluid/operators/fused/fused_gate_attention_op.cu
paddle/fluid/operators/fused/fused_gate_attention_op.cu
+488
-0
paddle/fluid/platform/device/gpu/gpu_info.cc
paddle/fluid/platform/device/gpu/gpu_info.cc
+2
-2
paddle/fluid/pybind/op_function_generator.h
paddle/fluid/pybind/op_function_generator.h
+7
-0
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
+21
-21
python/paddle/fluid/tests/unittests/CMakeLists.txt
python/paddle/fluid/tests/unittests/CMakeLists.txt
+1
-0
python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py
...dle/fluid/tests/unittests/test_fused_gate_attention_op.py
+252
-0
未找到文件。
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
fdcdbec5
...
...
@@ -23,7 +23,8 @@ register_operators(EXCLUDES
fused_feedforward_op
fused_multi_transformer_op
resnet_unit_op
fused_gemm_epilogue_op
)
fused_gemm_epilogue_op
fused_gate_attention_op
)
# fusion_gru_op does not have CUDA kernel
op_library
(
fusion_gru_op
)
...
...
@@ -58,6 +59,7 @@ if (WITH_GPU OR WITH_ROCM)
op_library
(
yolo_box_head_op
)
op_library
(
yolo_box_post_op
)
op_library
(
fused_embedding_eltwise_layernorm_op
)
op_library
(
fused_gate_attention_op
)
# fusion_group
if
(
NOT APPLE AND NOT WIN32
)
op_library
(
fusion_group_op DEPS device_code
)
...
...
paddle/fluid/operators/fused/attn_gemm.h
浏览文件 @
fdcdbec5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -13,6 +16,7 @@ limitations under the License. */
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
...
...
@@ -21,6 +25,8 @@ limitations under the License. */
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template
<
typename
T
>
class
AttnMatMul
{
...
...
@@ -45,31 +51,21 @@ class AttnMatMul {
framework
::
Tensor
*
bias_out
)
{
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
transB
=
CblasNoTrans
;
if
(
transA_
)
{
transA
=
CblasTrans
;
}
if
(
transB_
)
{
transB
=
CblasTrans
;
}
CBLAS_TRANSPOSE
transA
=
transA_
?
CblasTrans
:
CblasNoTrans
;
CBLAS_TRANSPOSE
transB
=
transB_
?
CblasTrans
:
CblasNoTrans
;
T
alpha
=
static_cast
<
T
>
(
1.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
//
here:
(m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
// (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
blas
.
GEMM
(
transA
,
transB
,
bsz_seq_
,
output_size_
,
input_size_
,
alpha
,
input
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
beta
,
output
->
data
<
T
>
());
if
(
compute_bias_
)
{
// compute output + bias
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
Tensor
*>
outs
;
ins
.
emplace_back
(
output
);
ins
.
emplace_back
(
bias
);
outs
.
emplace_back
(
bias_out
);
int
elewise_add_axis
=
-
1
;
// bias_out = output + bias
std
::
vector
<
const
Tensor
*>
ins
=
{
output
,
bias
};
std
::
vector
<
Tensor
*>
outs
=
{
bias_out
};
phi
::
funcs
::
BroadcastKernel
<
phi
::
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
elewise_add_axis
,
phi
::
funcs
::
AddFunctor
<
T
>
());
dev_ctx_
,
ins
,
&
outs
,
-
1
,
phi
::
funcs
::
AddFunctor
<
T
>
());
}
}
...
...
@@ -77,82 +73,71 @@ class AttnMatMul {
const
framework
::
Tensor
*
weight
,
const
framework
::
Tensor
*
d_output
,
framework
::
Tensor
*
d_input
,
framework
::
Tensor
*
d_weight
,
framework
::
Tensor
*
d_bias
)
{
framework
::
Tensor
*
d_bias
,
bool
use_addto
=
false
)
{
T
alpha
=
static_cast
<
T
>
(
1.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
T
beta_dA
=
use_addto
?
static_cast
<
T
>
(
1.0
)
:
static_cast
<
T
>
(
0.0
);
T
beta_dB
=
static_cast
<
T
>
(
0.0
);
auto
blas
=
phi
::
funcs
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
if
(
!
transA_
)
{
// forward: gemm-nt
if
(
transB_
)
{
// backward: gemm-tn, dB = (dC)^T * A
if
(
d_weight
)
{
int
dB_m
=
output_size_
;
int
dB_n
=
input_size_
;
int
dB_k
=
bsz_seq_
;
CBLAS_TRANSPOSE
dB_transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dB_transB
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dA_transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dA_transB
=
CblasNoTrans
;
int
dB_m
=
1
;
int
dB_n
=
1
;
int
dB_k
=
1
;
int
dA_m
=
1
;
int
dA_n
=
1
;
int
dA_k
=
1
;
T
*
dB_input_1_ptr
=
nullptr
;
T
*
dB_input_2_ptr
=
nullptr
;
T
*
dB_output_ptr
=
d_weight
->
data
<
T
>
();
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
d_output
->
data
<
T
>
(),
input
->
data
<
T
>
(),
beta_dB
,
dB_output_ptr
);
}
T
*
dA_input_1_ptr
=
nullptr
;
T
*
dA_input_2_ptr
=
nullptr
;
T
*
dA_output_ptr
=
d_input
->
data
<
T
>
();
// backward: gemm-nn, dA = dC * B
if
(
d_input
)
{
int
dA_m
=
bsz_seq_
;
int
dA_n
=
input_size_
;
int
dA_k
=
output_size_
;
if
(
!
transA_
)
{
// fw: gemm-nt
if
(
transB_
)
{
// bw: gemm-tn, dB = (dC)^t * A
dB_transA
=
CblasTrans
;
dB_transB
=
CblasNoTrans
;
dB_m
=
output_size_
;
dB_n
=
input_size_
;
dB_k
=
bsz_seq_
;
// bw: gemm-nn, dA = dC * B
dA_transA
=
CblasNoTrans
;
dA_transB
=
CblasNoTrans
;
dA_m
=
bsz_seq_
;
dA_n
=
input_size_
;
dA_k
=
output_size_
;
blas
.
GEMM
(
dB_transA
,
dB_transB
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
d_output
->
data
<
T
>
(),
input
->
data
<
T
>
(),
beta
,
dB_output_ptr
);
blas
.
GEMM
(
dA_transA
,
dA_transB
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
beta
,
dA_output_ptr
);
T
*
dA_output_ptr
=
d_input
->
data
<
T
>
();
blas
.
GEMM
(
CblasNoTrans
,
CblasNoTrans
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
beta_dA
,
dA_output_ptr
);
}
}
else
{
// fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA
=
CblasTrans
;
dB_transB
=
CblasNoTrans
;
dB_m
=
input_size_
;
dB_n
=
output_size_
;
dB_k
=
bsz_seq_
;
// bw: gemm-nt, dA = dC * B^t
dA_transA
=
CblasNoTrans
;
dA_transB
=
CblasTrans
;
dA_m
=
bsz_seq_
;
dA_n
=
input_size_
;
dA_k
=
output_size_
;
blas
.
GEMM
(
dB_transA
,
dB_transB
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
input
->
data
<
T
>
(),
d_output
->
data
<
T
>
(),
beta
,
dB_output_ptr
);
blas
.
GEMM
(
dA_transA
,
dA_transB
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
beta
,
dA_output_ptr
);
// backward: gemm-tn, dB = A^T * dC
if
(
d_weight
)
{
int
dB_m
=
input_size_
;
int
dB_n
=
output_size_
;
int
dB_k
=
bsz_seq_
;
T
*
dB_output_ptr
=
d_weight
->
data
<
T
>
();
blas
.
GEMM
(
CblasTrans
,
CblasNoTrans
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
input
->
data
<
T
>
(),
d_output
->
data
<
T
>
(),
beta_dB
,
dB_output_ptr
);
}
// backward: gemm-nt, dA = dC * B^T
if
(
d_input
)
{
int
dA_m
=
bsz_seq_
;
int
dA_n
=
input_size_
;
int
dA_k
=
output_size_
;
T
*
dA_output_ptr
=
d_input
->
data
<
T
>
();
blas
.
GEMM
(
CblasNoTrans
,
CblasTrans
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
->
data
<
T
>
(),
weight
->
data
<
T
>
(),
beta_dA
,
dA_output_ptr
);
}
}
}
else
if
(
transB_
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"AttnMatMul wrapper do not support (transA=T, transB=T)"
"parameters."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"AttnMatMul wrapper do not support (transA=T, transB=N)"
"AttnMatMul wrapper do not support (transA=T, transB=
T/
N)"
"parameters."
));
}
if
(
compute_bias_
)
{
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2}
if
(
compute_bias_
&&
d_bias
)
{
// reduce: {0, 1, 2, 3, 4} -> {2, 3, 4} or {0, 1, 2} -> {2} or {0,1,2,3}
// -> {3} or {0,1,2,3,4} -> {3,4}
const
auto
input_dims
=
d_output
->
dims
();
const
auto
output_dims
=
d_bias
->
dims
();
bool
support_case_1
=
...
...
@@ -163,11 +148,22 @@ class AttnMatMul {
bool
support_case_2
=
(
input_dims
.
size
()
==
3
&&
output_dims
.
size
()
==
1
&&
(
input_dims
[
2
]
==
output_dims
[
0
]));
if
(
support_case_1
||
support_case_2
)
{
bool
support_case_3
=
(
input_dims
.
size
()
==
4
&&
output_dims
.
size
()
==
1
&&
input_dims
[
3
]
==
output_dims
[
0
]);
bool
support_case_4
=
(
input_dims
.
size
()
==
5
&&
output_dims
.
size
()
==
2
&&
input_dims
[
3
]
==
output_dims
[
0
]
&&
input_dims
[
4
]
==
output_dims
[
1
]);
gpuStream_t
stream
=
dev_ctx_
.
stream
();
if
(
support_case_1
||
support_case_2
)
{
TensorReduceImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dev_ctx_
,
*
d_output
,
d_bias
,
kps
::
IdentityFunctor
<
T
>
(),
{
0
,
1
},
stream
);
}
else
if
(
support_case_3
||
support_case_4
)
{
TensorReduceImpl
<
T
,
T
,
kps
::
AddFunctor
,
kps
::
IdentityFunctor
<
T
>>
(
dev_ctx_
,
*
d_output
,
d_bias
,
kps
::
IdentityFunctor
<
T
>
(),
{
0
,
1
,
2
},
stream
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only support reduce when the input dims are [0,1,2,3,4] and "
...
...
paddle/fluid/operators/fused/fmha_ref.h
浏览文件 @
fdcdbec5
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
...
...
@@ -297,7 +300,6 @@ class FMHARef {
phi
::
SoftmaxBackwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
softmax_out_tensor
,
*
softmax_out_grad_tensor
,
softmax_axis
,
src_mask_out_grad_tensor
);
// recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out +
// src_mask
// Special case when dy is not needed and dx doesn't reduce
...
...
paddle/fluid/operators/fused/fused_gate_attention.h
0 → 100644
浏览文件 @
fdcdbec5
此差异已折叠。
点击以展开。
paddle/fluid/operators/fused/fused_gate_attention_op.cc
0 → 100644
浏览文件 @
fdcdbec5
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
DDim
=
framework
::
DDim
;
class
FusedGateAttentionOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Query"
),
"Input"
,
"Query"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearWeight"
),
"Input"
,
"OutLinearWeight"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearBias"
),
"Input"
,
"OutLinearBias"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"SoftmaxOut"
),
"Output"
,
"SoftmaxOut"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"FMHAOut"
),
"Output"
,
"FMHAOut"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"fused_gate_attention"
);
auto
input_q_dims
=
ctx
->
GetInputDim
(
"Query"
);
int
batch_size
=
input_q_dims
[
0
];
int
seq_len_m
=
input_q_dims
[
1
];
int
seq_len_r
=
input_q_dims
[
2
];
int
num_head
,
m_size
,
key_dim
;
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"merge_qkv"
))
{
// QKV's input: [batch_size, seq_len_m, seq_len_r, qkv_dim]
// QKV's weight: [3, num_head, key_dim, qkv_dim]
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVWeight"
),
"Input"
,
"QKVWeight"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"QKVTransposeOut"
),
"Output"
,
"QKVTransposeOut"
,
"fused_gate_attention"
);
auto
qkv_w_dims
=
ctx
->
GetInputDim
(
"QKVWeight"
);
num_head
=
qkv_w_dims
[
1
];
key_dim
=
qkv_w_dims
[
2
];
m_size
=
seq_len_r
;
ctx
->
SetOutputDim
(
"QKVTransposeOut"
,
{
3
,
batch_size
,
seq_len_m
,
num_head
,
seq_len_r
,
key_dim
});
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QueryWeight"
),
"Input"
,
"QueryWeight"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"KeyWeight"
),
"Input"
,
"KeyWeight"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ValueWeight"
),
"Input"
,
"ValueWeight"
,
"fused_gate_attention"
);
auto
input_k_dims
=
ctx
->
GetInputDim
(
"Key"
);
auto
q_w_dims
=
ctx
->
GetInputDim
(
"QueryWeight"
);
num_head
=
q_w_dims
[
1
];
key_dim
=
q_w_dims
[
2
];
m_size
=
input_k_dims
[
2
];
ctx
->
SetOutputDim
(
"QueryTransposeOut"
,
{
batch_size
,
seq_len_m
,
num_head
,
seq_len_r
,
key_dim
});
ctx
->
SetOutputDim
(
"KeyTransposeOut"
,
{
batch_size
,
seq_len_m
,
num_head
,
m_size
,
key_dim
});
ctx
->
SetOutputDim
(
"ValueTransposeOut"
,
{
batch_size
,
seq_len_m
,
num_head
,
m_size
,
key_dim
});
}
ctx
->
SetOutputDim
(
"SoftmaxOut"
,
{
batch_size
,
seq_len_m
,
num_head
,
seq_len_r
,
m_size
});
ctx
->
SetOutputDim
(
"FMHAOut"
,
{
batch_size
,
seq_len_m
,
seq_len_r
,
num_head
,
key_dim
});
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"has_gating"
))
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GateWeight"
),
"Input"
,
"GateWeight"
,
"fused_gate_attention"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"GateBias"
),
"Input"
,
"GateBias"
,
"fused_gate_attention"
);
ctx
->
SetOutputDim
(
"GateOut"
,
{
batch_size
,
seq_len_m
,
seq_len_r
,
num_head
,
key_dim
});
}
ctx
->
SetOutputDim
(
"Out"
,
ctx
->
GetInputDim
(
"Query"
));
}
};
class
FusedGateAttentionOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Query"
,
"The query tensor."
);
AddInput
(
"Key"
,
"The key tensor."
).
AsDispensable
();
AddInput
(
"QueryWeight"
,
"(optional) The query weight tensor."
)
.
AsDispensable
();
AddInput
(
"KeyWeight"
,
"(optional) The key weight tensor."
).
AsDispensable
();
AddInput
(
"ValueWeight"
,
"(optional) The value weight tensor."
)
.
AsDispensable
();
AddInput
(
"QKVWeight"
,
"(optional) The qkv weight tensor."
).
AsDispensable
();
AddInput
(
"NonbatchedBias"
,
"(optional) The nonbatchedBias tensor."
)
.
AsDispensable
();
AddInput
(
"SrcMask"
,
"The attention mask tensor in fmha."
);
AddInput
(
"GateWeight"
,
"(optional) The gate weight tensor."
)
.
AsDispensable
();
AddInput
(
"GateBias"
,
"(optional) The gate bias tensor."
).
AsDispensable
();
AddInput
(
"OutLinearWeight"
,
"The out_linear weight tensor."
);
AddInput
(
"OutLinearBias"
,
"The out_linear bias tensor."
);
AddOutput
(
"QueryTransposeOut"
,
"The transposed result of query matmul."
)
.
AsIntermediate
()
.
AsDispensable
();
AddOutput
(
"KeyTransposeOut"
,
"The transposed result of key matmul."
)
.
AsIntermediate
()
.
AsDispensable
();
AddOutput
(
"ValueTransposeOut"
,
"The transposed result of value matmul."
)
.
AsIntermediate
()
.
AsDispensable
();
AddOutput
(
"QKVTransposeOut"
,
"The transposed result of merged QKV matmul."
)
.
AsIntermediate
()
.
AsDispensable
();
AddOutput
(
"SoftmaxOut"
,
"Result in fmha."
).
AsIntermediate
();
AddOutput
(
"FMHAOut"
,
"Result in fmha."
).
AsIntermediate
();
AddOutput
(
"GateOut"
,
"Result of the gating module."
)
.
AsIntermediate
()
.
AsDispensable
();
AddOutput
(
"Out"
,
"Result after attention."
);
AddAttr
<
bool
>
(
"has_gating"
,
"if true, the attention op uses gate architecure, "
"[default true]."
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"merge_qkv"
,
"if true, calculation with merged qkv, "
"[default true]."
)
.
SetDefault
(
true
);
AddComment
(
R"DOC(
Add fused attention op whose logic is as follows:
{
q = paddle.einsum('nbqa,ahc->nbqhc', q_data, self.query_w)
k = paddle.einsum('nbka,ahc->nbkhc', m_data, self.key_w)
v = paddle.einsum('nbka,ahc->nbkhc', m_data, self.value_w)
logits = paddle.einsum('nbqhc,nbkhc->nbhqk', q * c , k) + bias
weights = nn.functional.softmax(logits)
weighted_avg = paddle.einsum('nbhqk,nbkhc->nbqhc', weights, v)
if nonbatched_bias is not None:
logits += paddle.unsqueeze(nonbatched_bias, axis=1)
if self.gating:
gate_values = paddle.einsum('nbqc,chv->nbqhv', q_data,
self.gating_w) + self.gating_b
gate_values_1 = nn.functional.sigmoid(gate_values)
weighted_avg *= gate_values_1
output = paddle.einsum('nbqhc,hco->nbqo', weighted_avg,
self.output_w) + self.output_b
}
)DOC"
);
}
};
class
FusedGateAttentionGradOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"Query"
),
"Input"
,
"Query"
,
"fused_gate_attention_grad"
);
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Query"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Query"
),
ctx
->
GetInputDim
(
"Query"
));
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"Key"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"Key"
),
ctx
->
GetInputDim
(
"Key"
));
}
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"merge_qkv"
))
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QKVWeight"
),
"Input"
,
"QKVWeight"
,
"fused_gate_attention_arad"
);
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"QKVWeight"
),
ctx
->
GetInputDim
(
"QKVWeight"
));
}
else
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"QueryWeight"
),
"Input"
,
"QueryWeight"
,
"fused_aate_attention_arad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"KeyWeight"
),
"Input"
,
"KeyWeight"
,
"fused_aate_attention_arad"
);
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"ValueWeight"
),
"Input"
,
"ValueWeight"
,
"fused_aate_attention_arad"
);
for
(
auto
&
name
:
{
"QueryWeight"
,
"KeyWeight"
,
"ValueWeight"
})
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
name
),
ctx
->
GetInputDim
(
name
));
}
}
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"OutLinearWeight"
),
"Input"
,
"OutLinearWeight"
,
"fused_aate_attention_arad"
);
if
(
ctx
->
Attrs
().
Get
<
bool
>
(
"has_gating"
))
{
for
(
auto
&
name
:
{
"GateWeight"
,
"GateBias"
,
"GateOut"
})
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
name
),
ctx
->
GetInputDim
(
name
));
}
}
if
(
ctx
->
HasOutput
(
framework
::
GradVarName
(
"NonbatchedBias"
)))
{
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"NonbatchedBias"
),
ctx
->
GetInputDim
(
"NonbatchedBias"
));
}
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"FMHAOut"
),
ctx
->
GetInputDim
(
"FMHAOut"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearWeight"
),
ctx
->
GetInputDim
(
"OutLinearWeight"
));
ctx
->
SetOutputDim
(
framework
::
GradVarName
(
"OutLinearBias"
),
ctx
->
GetInputDim
(
"OutLinearBias"
));
}
};
template
<
typename
T
>
class
FusedGateAttentionGradOpMaker
:
public
framework
::
SingleGradOpMaker
<
T
>
{
public:
using
framework
::
SingleGradOpMaker
<
T
>::
SingleGradOpMaker
;
protected:
void
Apply
(
GradOpPtr
<
T
>
op
)
const
override
{
op
->
SetType
(
"fused_gate_attention_grad"
);
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
op
->
SetInput
(
"Query"
,
this
->
Input
(
"Query"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Query"
),
this
->
InputGrad
(
"Query"
));
op
->
SetAttrMap
(
this
->
Attrs
());
bool
merge_qkv
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"merge_qkv"
));
if
(
merge_qkv
)
{
op
->
SetInput
(
"QKVWeight"
,
this
->
Input
(
"QKVWeight"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"QKVWeight"
),
this
->
InputGrad
(
"QKVWeight"
));
op
->
SetInput
(
"QKVTransposeOut"
,
this
->
Output
(
"QKVTransposeOut"
));
}
else
{
op
->
SetInput
(
"Key"
,
this
->
Input
(
"Key"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"Key"
),
this
->
InputGrad
(
"Key"
));
for
(
auto
&
name
:
{
"QueryWeight"
,
"KeyWeight"
,
"ValueWeight"
})
{
op
->
SetInput
(
name
,
this
->
Input
(
name
));
op
->
SetOutput
(
framework
::
GradVarName
(
name
),
this
->
InputGrad
(
name
));
}
for
(
auto
&
name
:
{
"QueryTransposeOut"
,
"KeyTransposeOut"
,
"ValueTransposeOut"
})
{
op
->
SetInput
(
name
,
this
->
Output
(
name
));
}
}
op
->
SetInput
(
"FMHAOut"
,
this
->
Output
(
"FMHAOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"FMHAOut"
),
this
->
OutputGrad
(
"FMHAOut"
));
if
(
this
->
HasInput
(
"NonbatchedBias"
))
{
op
->
SetInput
(
"NonbatchedBias"
,
this
->
Input
(
"NonbatchedBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"NonbatchedBias"
),
this
->
InputGrad
(
"NonbatchedBias"
));
}
op
->
SetInput
(
"SoftmaxOut"
,
this
->
Output
(
"SoftmaxOut"
));
bool
has_gating
=
BOOST_GET_CONST
(
bool
,
op
->
GetAttr
(
"has_gating"
));
if
(
has_gating
)
{
op
->
SetInput
(
"GateWeight"
,
this
->
Input
(
"GateWeight"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"GateWeight"
),
this
->
InputGrad
(
"GateWeight"
));
op
->
SetInput
(
"GateBias"
,
this
->
Input
(
"GateBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"GateBias"
),
this
->
InputGrad
(
"GateBias"
));
op
->
SetInput
(
"GateOut"
,
this
->
Output
(
"GateOut"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"GateOut"
),
this
->
OutputGrad
(
"GateOut"
));
}
op
->
SetInput
(
"OutLinearWeight"
,
this
->
Input
(
"OutLinearWeight"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearWeight"
),
this
->
InputGrad
(
"OutLinearWeight"
));
op
->
SetInput
(
"OutLinearBias"
,
this
->
Input
(
"OutLinearBias"
));
op
->
SetOutput
(
framework
::
GradVarName
(
"OutLinearBias"
),
this
->
InputGrad
(
"OutLinearBias"
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
fused_gate_attention
,
ops
::
FusedGateAttentionOp
,
ops
::
FusedGateAttentionOpMaker
,
ops
::
FusedGateAttentionGradOpMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
FusedGateAttentionGradOpMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
fused_gate_attention_grad
,
ops
::
FusedGateAttentionGradOp
);
paddle/fluid/operators/fused/fused_gate_attention_op.cu
0 → 100644
浏览文件 @
fdcdbec5
此差异已折叠。
点击以展开。
paddle/fluid/platform/device/gpu/gpu_info.cc
浏览文件 @
fdcdbec5
...
...
@@ -225,9 +225,9 @@ class RecordedGpuMallocHelper {
if
(
UNLIKELY
(
malloc_managed_memory
))
{
result
=
cudaMallocManaged
(
ptr
,
size
);
}
else
{
VLOG
(
10
)
<<
"[cudaMalloc] size="
<<
static_cast
<
double
>
(
size
)
/
(
1
<<
20
)
<<
" MB"
;
result
=
cudaMalloc
(
ptr
,
size
);
VLOG
(
10
)
<<
"[cudaMalloc] size="
<<
static_cast
<
double
>
(
size
)
/
(
1
<<
20
)
<<
" MB, result="
<<
result
;
}
#endif
if
(
result
==
gpuSuccess
)
{
...
...
paddle/fluid/pybind/op_function_generator.h
浏览文件 @
fdcdbec5
...
...
@@ -32,6 +32,10 @@ std::map<std::string, std::set<std::string>> op_ins_map = {
{
"fused_attention"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"CacheKV"
,
"SrcMask"
,
"OutLinearW"
,
"OutLinearBias"
,
"Ln2Scale"
,
"Ln2Bias"
}},
{
"fused_gate_attention"
,
{
"Query"
,
"Key"
,
"QueryWeight"
,
"KeyWeight"
,
"ValueWeight"
,
"QKVWeight"
,
"NonbatchedBias"
,
"SrcMask"
,
"GateWeight"
,
"GateBias"
,
"OutLinearWeight"
,
"OutLinearBias"
}},
{
"fused_multi_transformer"
,
{
"X"
,
"LnScale"
,
"LnBias"
,
"QKVW"
,
"QKVBias"
,
"CacheKV"
,
"TimeStep"
,
"SrcMask"
,
"OutLinearW"
,
"OutLinearBias"
,
"FFNLnScale"
,
"FFNLnBias"
,
...
...
@@ -148,6 +152,9 @@ std::map<std::string, std::set<std::string>> op_outs_map = {
"DropoutMaskOut"
,
"Ln2Mean"
,
"Ln2Variance"
,
"BiasDropoutResidualOut"
,
"CacheKVOut"
,
"Y"
}},
{
"fused_gate_attention"
,
{
"QueryTransposeOut"
,
"KeyTransposeOut"
,
"ValueTransposeOut"
,
"QKVTransposeOut"
,
"SoftmaxOut"
,
"FMHAOut"
,
"GateOut"
,
"Out"
}},
{
"sync_batch_norm"
,
{
"Y"
,
"MeanOut"
,
"VarianceOut"
,
"SavedMean"
,
"SavedVariance"
,
"ReserveSpace"
}},
...
...
paddle/phi/kernels/gpudnn/softmax_gpudnn.h
浏览文件 @
fdcdbec5
...
...
@@ -888,19 +888,6 @@ void SoftmaxBackwardCudnnKernel(const GPUContext& dev_ctx,
#endif
}
template
<
typename
T
>
static
bool
CanUseCudnnSoftmax
(
const
GPUContext
&
dev_ctx
)
{
if
(
dev_ctx
.
cudnn_handle
()
!=
nullptr
)
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
#if CUDNN_VERSION < 8100
return
false
;
#endif
}
return
true
;
}
return
false
;
}
#if CUDNN_VERSION < 8100
template
<
>
inline
void
SoftmaxForwardCudnnKernel
<
phi
::
dtype
::
bfloat16
>
(
...
...
@@ -927,6 +914,25 @@ inline void SoftmaxBackwardCudnnKernel<phi::dtype::bfloat16>(
}
#endif
template
<
typename
T
>
bool
UseCudnnSoftmax
(
const
GPUContext
&
ctx
,
int
softmax_dim
,
bool
last_dim
)
{
bool
cudnn_available
=
ctx
.
cudnn_handle
();
if
(
!
ctx
.
cudnn_handle
())
{
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
bfloat16
>::
value
)
{
#if CUDNN_VERSION < 8100
cudnn_available
=
false
;
#endif
}
}
constexpr
int
max_dim
=
512
;
if
(
!
cudnn_available
||
!
last_dim
||
(
softmax_dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
))
{
return
false
;
}
else
{
return
true
;
}
}
template
<
typename
T
,
bool
LogMode
=
false
>
void
SoftmaxForwardCUDAKernelDriver
(
const
GPUContext
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
@@ -941,10 +947,7 @@ void SoftmaxForwardCUDAKernelDriver(const GPUContext& dev_ctx,
int
dim
=
tensor_dims
[
1
];
int
D
=
tensor_dims
[
2
];
constexpr
int
max_dim
=
512
;
if
(
D
==
1
&&
(
!
CanUseCudnnSoftmax
<
T
>
(
dev_ctx
)
||
(
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)))
{
if
(
D
==
1
&&
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
int
dim_log2
=
static_cast
<
int
>
(
Log2Ceil
(
dim
));
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
...
...
@@ -1016,10 +1019,7 @@ void SoftmaxBackwardCUDAKernelDriver(const GPUContext& dev_ctx,
int
dim
=
tensor_dims
[
1
];
int
D
=
tensor_dims
[
2
];
constexpr
int
max_dim
=
512
;
if
(
D
==
1
&&
(
!
CanUseCudnnSoftmax
<
T
>
(
dev_ctx
)
||
(
dim
<=
max_dim
&&
sizeof
(
T
)
<=
4
)))
{
if
(
D
==
1
&&
!
UseCudnnSoftmax
<
T
>
(
dev_ctx
,
dim
,
true
))
{
int
dim_log2
=
Log2Ceil
(
dim
);
int
dim_ceil
=
1
<<
dim_log2
;
int
warp_size
=
(
dim_ceil
<
32
)
?
dim_ceil
:
32
;
...
...
python/paddle/fluid/tests/unittests/CMakeLists.txt
浏览文件 @
fdcdbec5
...
...
@@ -327,6 +327,7 @@ if ((NOT WITH_NCCL) AND (NOT WITH_RCCL))
endif
()
if
(((
NOT WITH_ROCM
)
AND
(
NOT WITH_GPU
))
OR WIN32
)
LIST
(
REMOVE_ITEM TEST_OPS test_fused_gate_attention_op
)
LIST
(
REMOVE_ITEM TEST_OPS test_boxps
)
endif
()
list
(
REMOVE_ITEM TEST_OPS test_seq_concat_op
)
# FIXME(helin): https://github.com/PaddlePaddle/Paddle/issues/8290
...
...
python/paddle/fluid/tests/unittests/test_fused_gate_attention_op.py
0 → 100644
浏览文件 @
fdcdbec5
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.nn
as
nn
from
paddle
import
tensor
import
unittest
from
op_test
import
OpTest
,
convert_float_to_uint16
from
test_sparse_attention_op
import
get_cuda_version
from
paddle
import
_C_ops
from
paddle.fluid.framework
import
default_main_program
from
paddle.fluid
import
core
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
(),
"Paddle is not compiled with CUDA"
)
class
TestFusedGateAttentionOp
(
OpTest
):
def
setUp
(
self
):
self
.
__class__
.
op_type
=
"fused_gate_attention"
# use autograd to check grad in this unittest.
self
.
__class__
.
no_need_check_grad
=
True
self
.
config
()
self
.
merge_qkv
=
self
.
q_dim
==
self
.
kv_dim
self
.
generate_input_data
()
def
config
(
self
):
self
.
dtype
=
"float32"
self
.
has_gating
=
True
self
.
batch_size
=
1
self
.
msa_len
=
3
self
.
res_len
=
5
self
.
q_dim
=
6
self
.
num_heads
=
2
self
.
key_dim
=
4
self
.
m_size
=
self
.
res_len
self
.
kv_dim
=
self
.
q_dim
self
.
out_dim
=
self
.
q_dim
self
.
bias_attr
=
True
def
generate_input_data
(
self
):
def
_random
(
shape
):
if
self
.
dtype
==
"bfloat16"
:
data
=
np
.
random
.
random
(
shape
).
astype
(
"float32"
)
return
convert_float_to_uint16
(
data
)
else
:
return
np
.
random
.
random
(
shape
).
astype
(
self
.
dtype
)
np
.
random
.
seed
(
123
)
self
.
query
=
_random
(
(
self
.
batch_size
,
self
.
msa_len
,
self
.
res_len
,
self
.
q_dim
))
self
.
q_weight
=
_random
((
self
.
q_dim
,
self
.
num_heads
,
self
.
key_dim
))
self
.
k_weight
=
_random
((
self
.
kv_dim
,
self
.
num_heads
,
self
.
key_dim
))
self
.
v_weight
=
_random
((
self
.
kv_dim
,
self
.
num_heads
,
self
.
key_dim
))
if
self
.
merge_qkv
:
self
.
key
=
None
# (3, self.num_heads, self.key_dim, self.q_dim)
q_weight_t
=
np
.
transpose
(
self
.
q_weight
,
axes
=
[
1
,
2
,
0
])
k_weight_t
=
np
.
transpose
(
self
.
k_weight
,
axes
=
[
1
,
2
,
0
])
v_weight_t
=
np
.
transpose
(
self
.
v_weight
,
axes
=
[
1
,
2
,
0
])
self
.
qkv_weight
=
np
.
stack
([
q_weight_t
,
k_weight_t
,
v_weight_t
])
else
:
self
.
key
=
_random
(
(
self
.
batch_size
,
self
.
msa_len
,
self
.
m_size
,
self
.
kv_dim
))
self
.
qkv_weight
=
None
self
.
attn_mask
=
_random
(
(
self
.
batch_size
,
self
.
msa_len
,
1
,
1
,
self
.
m_size
))
if
self
.
bias_attr
:
self
.
nonbatched_bias
=
_random
(
(
self
.
batch_size
,
1
,
self
.
num_heads
,
self
.
res_len
,
self
.
m_size
))
if
self
.
has_gating
:
self
.
gating_w
=
_random
((
self
.
q_dim
,
self
.
num_heads
,
self
.
key_dim
))
self
.
gating_b
=
_random
((
self
.
num_heads
,
self
.
key_dim
))
self
.
output_w
=
_random
((
self
.
num_heads
,
self
.
key_dim
,
self
.
out_dim
))
self
.
output_b
=
_random
((
self
.
out_dim
))
self
.
dout
=
_random
(
(
self
.
batch_size
,
self
.
msa_len
,
self
.
res_len
,
self
.
q_dim
))
def
get_reference_out
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
key
=
query
if
self
.
merge_qkv
else
paddle
.
to_tensor
(
self
.
key
,
stop_gradient
=
False
)
q_weight
=
paddle
.
to_tensor
(
self
.
q_weight
,
stop_gradient
=
False
)
k_weight
=
paddle
.
to_tensor
(
self
.
k_weight
,
stop_gradient
=
False
)
v_weight
=
paddle
.
to_tensor
(
self
.
v_weight
,
stop_gradient
=
False
)
src_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
True
)
c
=
self
.
key_dim
**
(
-
0.5
)
# [batch_size, msa_len, num_heads, res_len, key_dim]
q
=
paddle
.
einsum
(
'nbqa,ahc->nbqhc'
,
query
,
q_weight
)
*
c
# [batch_size, msa_len, num_heads, m_size, key_dim]
k
=
paddle
.
einsum
(
'nbka,ahc->nbkhc'
,
key
,
k_weight
)
# [batch_size, msa_len, num_heads, m_size, key_dim]
v
=
paddle
.
einsum
(
'nbka,ahc->nbkhc'
,
key
,
v_weight
)
# [batch_size, msa_len, num_heads, res_len, m_size]
logits
=
paddle
.
einsum
(
'nbqhc,nbkhc->nbhqk'
,
q
,
k
)
# qk_out
logits
=
logits
+
src_mask
if
self
.
bias_attr
:
nonbatched_bias
=
paddle
.
to_tensor
(
self
.
nonbatched_bias
,
stop_gradient
=
False
)
logits
=
logits
+
nonbatched_bias
weights
=
nn
.
functional
.
softmax
(
logits
)
# softmax_out
weighted_avg
=
paddle
.
einsum
(
'nbhqk,nbkhc->nbqhc'
,
weights
,
v
)
if
self
.
has_gating
:
gating_w
=
paddle
.
to_tensor
(
self
.
gating_w
,
stop_gradient
=
False
)
gating_b
=
paddle
.
to_tensor
(
self
.
gating_b
,
stop_gradient
=
False
)
gate_values
=
paddle
.
einsum
(
'nbqc,chv->nbqhv'
,
query
,
gating_w
)
+
gating_b
gate_values
=
nn
.
functional
.
sigmoid
(
gate_values
)
weighted_avg
=
weighted_avg
*
gate_values
output_b
=
paddle
.
to_tensor
(
self
.
output_b
,
stop_gradient
=
False
)
output_w
=
paddle
.
to_tensor
(
self
.
output_w
,
stop_gradient
=
False
)
out
=
paddle
.
einsum
(
'nbqhc,hco->nbqo'
,
weighted_avg
,
output_w
)
+
output_b
paddle
.
autograd
.
backward
(
[
out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
if
self
.
merge_qkv
:
return
out
,
query
.
grad
,
None
else
:
return
out
,
query
.
grad
,
key
.
grad
def
get_fused_gate_attention_out
(
self
):
paddle
.
disable_static
(
place
=
paddle
.
CUDAPlace
(
0
))
query
=
paddle
.
to_tensor
(
self
.
query
,
stop_gradient
=
False
)
if
self
.
merge_qkv
:
key
=
None
q_weight
=
None
k_weight
=
None
v_weight
=
None
qkv_weight
=
paddle
.
to_tensor
(
self
.
qkv_weight
,
stop_gradient
=
False
)
else
:
key
=
paddle
.
to_tensor
(
self
.
key
,
stop_gradient
=
False
)
q_weight
=
paddle
.
to_tensor
(
self
.
q_weight
,
stop_gradient
=
False
)
k_weight
=
paddle
.
to_tensor
(
self
.
k_weight
,
stop_gradient
=
False
)
v_weight
=
paddle
.
to_tensor
(
self
.
v_weight
,
stop_gradient
=
False
)
qkv_weight
=
None
src_mask
=
paddle
.
to_tensor
(
self
.
attn_mask
,
stop_gradient
=
True
)
if
self
.
bias_attr
:
nonbatched_bias
=
paddle
.
to_tensor
(
self
.
nonbatched_bias
,
stop_gradient
=
False
)
else
:
nonbatched_bias
=
None
if
self
.
has_gating
:
gating_w
=
paddle
.
to_tensor
(
self
.
gating_w
,
stop_gradient
=
False
)
gating_b
=
paddle
.
to_tensor
(
self
.
gating_b
,
stop_gradient
=
False
)
else
:
gating_w
=
None
gating_b
=
None
output_w
=
paddle
.
to_tensor
(
self
.
output_w
,
stop_gradient
=
False
)
output_b
=
paddle
.
to_tensor
(
self
.
output_b
,
stop_gradient
=
False
)
_
,
_
,
_
,
_
,
_
,
_
,
_
,
out
=
_C_ops
.
fused_gate_attention
(
query
,
key
,
q_weight
,
k_weight
,
v_weight
,
qkv_weight
,
nonbatched_bias
,
src_mask
,
gating_w
,
gating_b
,
output_w
,
output_b
,
'has_gating'
,
self
.
has_gating
,
'merge_qkv'
,
self
.
merge_qkv
)
paddle
.
autograd
.
backward
(
[
out
],
[
paddle
.
to_tensor
(
self
.
dout
)],
retain_graph
=
True
)
if
key
is
not
None
:
return
out
,
query
.
grad
,
key
.
grad
else
:
return
out
,
query
.
grad
,
None
def
check_output_and_grad
(
self
,
atol
,
rtol
):
out_ref
,
query_grad_ref
,
key_grad_ref
=
self
.
get_reference_out
()
out
,
query_grad
,
key_grad
=
self
.
get_fused_gate_attention_out
()
np
.
testing
.
assert_allclose
(
out_ref
,
out
.
numpy
(),
atol
=
atol
,
rtol
=
rtol
)
np
.
testing
.
assert_allclose
(
query_grad_ref
,
query_grad
.
numpy
(),
atol
=
atol
,
rtol
=
rtol
)
if
key_grad_ref
is
not
None
and
key_grad
is
not
None
:
np
.
testing
.
assert_allclose
(
key_grad_ref
,
key_grad
.
numpy
(),
atol
=
atol
,
rtol
=
rtol
)
def
test_output_and_grad
(
self
):
self
.
check_output_and_grad
(
atol
=
1e-5
,
rtol
=
1e-5
)
class
TestSeparatedQKVCase
(
TestFusedGateAttentionOp
):
def
config
(
self
):
self
.
dtype
=
"float32"
self
.
has_gating
=
False
self
.
batch_size
=
1
self
.
msa_len
=
3
self
.
res_len
=
5
self
.
q_dim
=
6
self
.
num_heads
=
2
self
.
key_dim
=
4
self
.
m_size
=
4
self
.
kv_dim
=
2
self
.
out_dim
=
self
.
q_dim
self
.
bias_attr
=
False
class
TestMergeQKVNoBiasGatingCase
(
TestFusedGateAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
has_gating
=
False
self
.
bias_attr
=
False
class
TestMergeQKVFp16Case
(
TestFusedGateAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
dtype
=
"float16"
def
test_output_and_grad
(
self
):
self
.
check_output_and_grad
(
atol
=
1e-1
,
rtol
=
1e-5
)
@
unittest
.
skipIf
(
not
core
.
is_compiled_with_cuda
()
or
get_cuda_version
()
<
11000
,
"core is not compiled with CUDA and cuda version need larger than or equal to 11.3"
)
class
TestMergeQKVBF16Case
(
TestFusedGateAttentionOp
):
def
config
(
self
):
super
().
config
()
self
.
dtype
=
"bfloat16"
def
test_output_and_grad
(
self
):
self
.
check_output_and_grad
(
atol
=
1e-1
,
rtol
=
1e-3
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录