Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
88ea8e6f
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
88ea8e6f
编写于
9月 23, 2021
作者:
L
Li Min
提交者:
GitHub
9月 23, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Add fused_attention_op: add impl wrappers. (#35903)
上级
7bf84e2d
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
487 addition
and
8 deletion
+487
-8
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
+2
-1
paddle/fluid/operators/fused/attention_layer_norm.h
paddle/fluid/operators/fused/attention_layer_norm.h
+1
-1
paddle/fluid/operators/fused/attn_bias_add.cu.h
paddle/fluid/operators/fused/attn_bias_add.cu.h
+1
-5
paddle/fluid/operators/fused/attn_gemm.h
paddle/fluid/operators/fused/attn_gemm.h
+159
-0
paddle/fluid/operators/fused/fmha_ref.h
paddle/fluid/operators/fused/fmha_ref.h
+324
-0
paddle/fluid/operators/layer_norm_kernel.cu.h
paddle/fluid/operators/layer_norm_kernel.cu.h
+0
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h
浏览文件 @
88ea8e6f
...
...
@@ -108,7 +108,8 @@ struct ElementwisePrimitiveCaller<InT, OutT, VecSize, Functor, 2, false> {
template
<
typename
InT
,
typename
OutT
,
int
VecSize
,
typename
Functor
>
struct
ElementwisePrimitiveCaller
<
InT
,
OutT
,
VecSize
,
Functor
,
3
,
false
>
{
__device__
inline
void
operator
()(
Functor
func
,
InT
**
args
,
OutT
*
result
)
{
__device__
inline
void
operator
()(
Functor
func
,
InT
(
*
args
)[
VecSize
],
OutT
*
result
)
{
kps
::
ElementwiseTernary
<
InT
,
OutT
,
VecSize
,
1
,
1
,
Functor
>
(
result
,
args
[
0
],
args
[
1
],
args
[
2
],
func
);
}
...
...
paddle/fluid/operators/fused/attention_layer_norm.h
浏览文件 @
88ea8e6f
...
...
@@ -50,7 +50,7 @@ class AttnLayerNorm {
}
}
void
ComputeBackward
(
const
T
*
x_data
,
const
T
*
y_data
,
void
ComputeBackward
(
const
T
*
x_data
,
const
T
*
d_
y_data
,
const
LayerNormParamType
<
T
>*
scale_data
,
const
LayerNormParamType
<
T
>*
mean_data
,
const
LayerNormParamType
<
T
>*
var_data
,
T
*
d_x_data
,
...
...
paddle/fluid/operators/fused/attn_bias_add.cu.h
浏览文件 @
88ea8e6f
...
...
@@ -34,6 +34,7 @@ namespace cub = hipcub;
#define LAUNCH_BOUNDS(BlockDim)
#endif
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_functor_op.h"
...
...
@@ -51,11 +52,6 @@ using CudnnDataType = platform::CudnnDataType<T>;
template
<
typename
T
>
using
ReduceParamType
=
typename
CudnnDataType
<
T
>::
BatchNormParamType
;
template
<
typename
T
>
struct
AddFunctor
{
inline
HOSTDEVICE
T
operator
()(
const
T
&
a
,
const
T
&
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
InT
,
typename
OutT
,
int
ShapeSize
,
int
VecSize
,
int
DATA_PER_THREAD
,
typename
Functor
>
__global__
void
BroadcastKernelBinary
(
...
...
paddle/fluid/operators/fused/attn_gemm.h
0 → 100644
浏览文件 @
88ea8e6f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/attn_bias_add.cu.h"
#include "paddle/fluid/operators/math/blas.h"
#include "paddle/fluid/platform/float16.h"
namespace
paddle
{
namespace
operators
{
// support gemm-nt and gemm-nn, which is used in fused_attention_op.
template
<
typename
T
>
class
AttnMatMul
{
public:
// (m, n, k) = bsz_seq, output_size, input_size
AttnMatMul
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
bool
transA
,
bool
transB
,
int
bsz_seq
,
int
output_size
,
int
input_size
,
bool
compute_bias
)
:
dev_ctx_
(
dev_ctx
),
transA_
(
transA
),
transB_
(
transB
),
bsz_seq_
(
bsz_seq
),
output_size_
(
output_size
),
input_size_
(
input_size
),
compute_bias_
(
compute_bias
)
{}
~
AttnMatMul
()
{}
void
ComputeForward
(
const
T
*
weight_data
,
const
T
*
input_data
,
const
T
*
bias_data
,
T
*
output_data
,
T
*
bias_out_data
)
{
// Note: for blas.GEMM API in Paddle, it treats all inputs as row-major.
// here: (transa, transb): nt, input * weight.
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
transB
=
CblasNoTrans
;
if
(
transA_
)
{
transA
=
CblasTrans
;
}
if
(
transB_
)
{
transB
=
CblasTrans
;
}
T
alpha
=
static_cast
<
T
>
(
1.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
// here: (m, n, k) = bsz_seq, output_size, input_size, (input, weight, out)
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
blas
.
GEMM
(
transA
,
transB
,
bsz_seq_
,
output_size_
,
input_size_
,
alpha
,
input_data
,
weight_data
,
beta
,
output_data
);
if
(
compute_bias_
)
{
// compute output + bias
LaunchBiasAddFwKernel
(
dev_ctx_
,
bsz_seq_
,
output_size_
,
output_data
,
bias_data
,
bias_out_data
);
}
}
void
ComputeBackward
(
const
T
*
input
,
const
T
*
weight
,
const
T
*
d_output
,
T
*
d_input
,
T
*
d_weight
,
T
*
d_bias
)
{
T
alpha
=
static_cast
<
T
>
(
1.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
CBLAS_TRANSPOSE
dB_transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dB_transB
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dA_transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
dA_transB
=
CblasNoTrans
;
int
dB_m
=
1
;
int
dB_n
=
1
;
int
dB_k
=
1
;
int
dA_m
=
1
;
int
dA_n
=
1
;
int
dA_k
=
1
;
T
*
dB_input_1_ptr
=
nullptr
;
T
*
dB_input_2_ptr
=
nullptr
;
T
*
dB_output_ptr
=
d_weight
;
T
*
dA_input_1_ptr
=
nullptr
;
T
*
dA_input_2_ptr
=
nullptr
;
T
*
dA_output_ptr
=
d_input
;
if
(
!
transA_
)
{
// fw: gemm-nt
if
(
transB_
)
{
// bw: gemm-tn, dB = (dC)^t * A
dB_transA
=
CblasTrans
;
dB_transB
=
CblasNoTrans
;
dB_m
=
output_size_
;
dB_n
=
input_size_
;
dB_k
=
bsz_seq_
;
// bw: gemm-nn, dA = dC * B
dA_transA
=
CblasNoTrans
;
dA_transB
=
CblasNoTrans
;
dA_m
=
bsz_seq_
;
dA_n
=
input_size_
;
dA_k
=
output_size_
;
blas
.
GEMM
(
dB_transA
,
dB_transB
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
d_output
,
input
,
beta
,
dB_output_ptr
);
blas
.
GEMM
(
dA_transA
,
dA_transB
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
,
weight
,
beta
,
dA_output_ptr
);
}
else
{
// fw: gemm-nn
// bw: gemm-tn, dB = A^t * dC
dB_transA
=
CblasTrans
;
dB_transB
=
CblasNoTrans
;
dB_m
=
input_size_
;
dB_n
=
output_size_
;
dB_k
=
bsz_seq_
;
// bw: gemm-nt, dA = dC * B^t
dA_transA
=
CblasNoTrans
;
dA_transB
=
CblasTrans
;
dA_m
=
bsz_seq_
;
dA_n
=
input_size_
;
dA_k
=
output_size_
;
blas
.
GEMM
(
dB_transA
,
dB_transB
,
dB_m
,
dB_n
,
dB_k
,
alpha
,
input
,
d_output
,
beta
,
dB_output_ptr
);
blas
.
GEMM
(
dA_transA
,
dA_transB
,
dA_m
,
dA_n
,
dA_k
,
alpha
,
d_output
,
weight
,
beta
,
dA_output_ptr
);
}
}
else
if
(
transB_
)
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"AttnMatMul wrapper do not support (transA=T, transB=T)"
"parameters."
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"AttnMatMul wrapper do not support (transA=T, transB=N)"
"parameters."
));
}
if
(
compute_bias_
)
{
LaunchBiasAddBwKernel
(
dev_ctx_
,
bsz_seq_
,
output_size_
,
d_output
,
d_bias
);
}
}
private:
const
platform
::
CUDADeviceContext
&
dev_ctx_
;
bool
transA_
;
bool
transB_
;
int
bsz_seq_
;
int
output_size_
;
int
input_size_
;
int
compute_bias_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/fmha_ref.h
0 → 100644
浏览文件 @
88ea8e6f
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/operators/dropout_impl.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_add_op.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/softmax_cudnn_op.cu.h"
#include "paddle/fluid/operators/transpose_op.cu.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
class
AttnDropoutParam
{
public:
AttnDropoutParam
()
{
is_test_
=
false
;
dropout_implementation_
=
"downgrade_in_infer"
;
dropout_prob_
=
0.5
;
is_upscale_in_train_
=
false
;
is_fix_seed_
=
false
;
seed_val_
=
0
;
seed_
=
nullptr
;
}
AttnDropoutParam
(
bool
is_test
,
const
std
::
string
dropout_implementation
,
float
dropout_prob
,
bool
is_upscale_in_train
,
bool
is_fix_seed
,
int
seed_val
,
const
Tensor
*
seed
)
{
is_test_
=
is_test
;
dropout_implementation_
=
dropout_implementation
;
dropout_prob_
=
dropout_prob
;
is_upscale_in_train_
=
is_upscale_in_train
;
is_fix_seed_
=
is_fix_seed
;
seed_val_
=
seed_val
;
seed_
=
seed
;
}
bool
is_test_
;
std
::
string
dropout_implementation_
;
float
dropout_prob_
;
bool
is_upscale_in_train_
;
bool
is_fix_seed_
;
int
seed_val_
;
const
Tensor
*
seed_
;
};
template
<
typename
T
>
class
FMHARef
{
public:
FMHARef
(
const
platform
::
CUDADeviceContext
&
dev_ctx
,
int64_t
batch_size
,
int64_t
seq_len
,
int64_t
num_head
,
int64_t
head_dim
,
AttnDropoutParam
param
)
:
dev_ctx_
(
dev_ctx
),
batch_size_
(
batch_size
),
seq_len_
(
seq_len
),
num_head_
(
num_head
),
head_dim_
(
head_dim
),
dropout_param_
(
param
)
{}
~
FMHARef
()
{}
void
ComputeForward
(
const
Tensor
&
qkv_input_tensor
,
const
Tensor
&
src_mask_tensor
,
Tensor
*
transpose_2_out_tensor
,
Tensor
*
qk_out_tensor
,
Tensor
*
src_mask_out_tensor
,
Tensor
*
softmax_out_tensor
,
Tensor
*
dropout_mask_out_tensor
,
Tensor
*
dropout_out_tensor
,
Tensor
*
qktv_out_tensor
,
Tensor
*
fmha_out_tensor
)
{
// input shape: [bs, seq_len, 3, num_head, head_dim]
// transpose with perm [2, 0, 1, 3, 4],
// output_shape: [3, bs, num_head, seq_len, head_dim]
int
ndims
=
5
;
std
::
vector
<
int
>
perm_1
=
{
2
,
0
,
3
,
1
,
4
};
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
qkv_input_tensor
,
perm_1
,
transpose_2_out_tensor
);
T
*
qkv_data
=
transpose_2_out_tensor
->
data
<
T
>
();
T
*
qk_out_data
=
qk_out_tensor
->
data
<
T
>
();
T
*
qktv_out_data
=
qktv_out_tensor
->
data
<
T
>
();
T
*
softmax_out_data
=
softmax_out_tensor
->
data
<
T
>
();
T
*
dropout_out_data
=
dropout_out_tensor
->
data
<
T
>
();
T
*
fmha_out_data
=
fmha_out_tensor
->
data
<
T
>
();
int
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
int
k_size
=
q_size
;
T
*
q_ptr
=
qkv_data
;
T
*
k_ptr
=
q_ptr
+
q_size
;
T
*
v_ptr
=
k_ptr
+
k_size
;
// q*k^t, batched_gemm
CBLAS_TRANSPOSE
transA
=
CblasNoTrans
;
CBLAS_TRANSPOSE
transB
=
CblasTrans
;
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
int
gemm_batch_size
=
batch_size_
*
num_head_
;
int
gemm_m
=
seq_len_
;
int
gemm_n
=
seq_len_
;
int
gemm_k
=
head_dim_
;
T
alpha
=
static_cast
<
T
>
(
1.0
/
sqrt
(
head_dim_
));
T
beta
=
static_cast
<
T
>
(
0.0
);
int64_t
stride_a
=
gemm_m
*
gemm_k
;
int64_t
stride_b
=
gemm_k
*
gemm_n
;
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
q_ptr
,
k_ptr
,
beta
,
qk_out_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
std
::
vector
<
const
Tensor
*>
ins
;
std
::
vector
<
Tensor
*>
outs
;
ins
.
emplace_back
(
qk_out_tensor
);
ins
.
emplace_back
(
&
src_mask_tensor
);
outs
.
emplace_back
(
src_mask_out_tensor
);
int
elewise_add_axis
=
-
1
;
int
softmax_axis
=
-
1
;
if
(
&
src_mask_tensor
!=
nullptr
)
{
LaunchElementwiseCudaKernel
<
ElementwiseType
::
kBinary
,
T
,
T
>
(
dev_ctx_
,
ins
,
&
outs
,
elewise_add_axis
,
AddFunctor
<
T
>
());
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
src_mask_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
}
else
{
SoftmaxForwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
*
qk_out_tensor
,
softmax_axis
,
softmax_out_tensor
);
}
transB
=
CblasNoTrans
;
gemm_m
=
seq_len_
;
gemm_n
=
head_dim_
;
gemm_k
=
seq_len_
;
alpha
=
static_cast
<
T
>
(
1.0
);
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
if
(
dropout_param_
.
dropout_prob_
)
{
DropoutFwGPUKernelDriver
<
T
>
(
dev_ctx_
,
dropout_param_
.
is_test_
,
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
dropout_param_
.
dropout_prob_
,
dropout_param_
.
is_upscale_in_train_
,
dropout_param_
.
is_fix_seed_
,
dropout_param_
.
seed_val_
,
static_cast
<
const
Tensor
&>
(
*
softmax_out_tensor
),
dropout_param_
.
seed_
,
dropout_mask_out_tensor
,
dropout_out_tensor
);
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
dropout_out_data
,
v_ptr
,
beta
,
qktv_out_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
else
{
// softmax_out * v, batched_gemm
// output shape: [batch_size, num_heads, seq_len, head_dim]
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
softmax_out_data
,
v_ptr
,
beta
,
qktv_out_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
// transpose: [0, 2, 1, 3]
// output shape: [batch_size, seq_len, num_heads, head_dim]
std
::
vector
<
int
>
perm_3
=
{
0
,
2
,
1
,
3
};
ndims
=
4
;
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
*
qktv_out_tensor
,
perm_3
,
fmha_out_tensor
);
}
void
ComputeBackward
(
const
Tensor
&
transpose_2_out_tensor
,
const
Tensor
&
src_mask_tensor
,
const
Tensor
&
softmax_out_tensor
,
const
Tensor
&
dropout_mask_out_tensor
,
const
Tensor
&
dropout_out_tensor
,
const
Tensor
&
qk_out_tensor
,
const
Tensor
&
src_mask_out_tensor
,
const
Tensor
&
fmha_out_grad_tensor
,
Tensor
*
qktv_out_grad_tensor
,
Tensor
*
dropout_out_grad_tensor
,
Tensor
*
softmax_out_grad_tensor
,
Tensor
*
src_mask_out_grad_tensor
,
Tensor
*
qk_out_grad_tensor
,
Tensor
*
transpose_2_out_grad_tensor
,
Tensor
*
src_mask_grad_tensor
,
Tensor
*
qkv_input_grad_tensor
)
{
auto
blas
=
math
::
GetBlas
<
platform
::
CUDADeviceContext
,
T
>
(
dev_ctx_
);
int
q_size
=
batch_size_
*
seq_len_
*
num_head_
*
head_dim_
;
int
k_size
=
q_size
;
int
softmax_axis
=
-
1
;
T
*
qkv_grad_data
=
transpose_2_out_grad_tensor
->
data
<
T
>
();
T
*
q_grad_ptr
=
qkv_grad_data
;
T
*
k_grad_ptr
=
q_grad_ptr
+
q_size
;
T
*
v_grad_ptr
=
k_grad_ptr
+
k_size
;
const
T
*
qkv_data
=
transpose_2_out_tensor
.
data
<
T
>
();
const
T
*
q_ptr
=
qkv_data
;
const
T
*
k_ptr
=
q_ptr
+
q_size
;
const
T
*
v_ptr
=
k_ptr
+
k_size
;
const
T
*
softmax_out_data
=
softmax_out_tensor
.
data
<
T
>
();
T
*
softmax_out_grad_data
=
softmax_out_grad_tensor
->
data
<
T
>
();
const
T
*
dropout_out_data
=
dropout_out_tensor
.
data
<
T
>
();
T
*
dropout_out_grad_data
=
dropout_out_grad_tensor
->
data
<
T
>
();
T
*
qktv_out_grad_data
=
qktv_out_grad_tensor
->
data
<
T
>
();
// transpose bw
int
ndims
=
4
;
std
::
vector
<
int
>
perm_3
=
{
0
,
2
,
1
,
3
};
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
fmha_out_grad_tensor
,
perm_3
,
qktv_out_grad_tensor
);
// recall batchedgemm(nn) fw: softmax_out_data(x) * v_ptr(y) =
// qktv_out_data(out)
CBLAS_TRANSPOSE
transA
=
CblasTrans
;
CBLAS_TRANSPOSE
transB
=
CblasNoTrans
;
int
gemm_batch_size
=
batch_size_
*
num_head_
;
int
gemm_m
=
seq_len_
;
int
gemm_n
=
head_dim_
;
int
gemm_k
=
seq_len_
;
T
alpha
=
static_cast
<
T
>
(
1.0
);
T
beta
=
static_cast
<
T
>
(
0.0
);
int64_t
stride_a
=
gemm_m
*
gemm_k
;
int64_t
stride_b
=
gemm_k
*
gemm_n
;
// bw: dy = x^t * dout
if
(
dropout_param_
.
dropout_prob_
)
{
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
dropout_out_data
,
qktv_out_grad_data
,
beta
,
v_grad_ptr
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
else
{
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
softmax_out_data
,
qktv_out_grad_data
,
beta
,
v_grad_ptr
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
// bw: dx = dout * y^t
transA
=
CblasNoTrans
;
transB
=
CblasTrans
;
gemm_m
=
seq_len_
;
gemm_n
=
seq_len_
;
gemm_k
=
head_dim_
;
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
if
(
dropout_param_
.
dropout_prob_
)
{
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
qktv_out_grad_data
,
v_ptr
,
beta
,
dropout_out_grad_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
else
{
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
qktv_out_grad_data
,
v_ptr
,
beta
,
softmax_out_grad_data
,
gemm_batch_size
,
stride_a
,
stride_b
);
}
// dropout bw
if
(
dropout_param_
.
dropout_prob_
)
{
DropoutGradGPUKernelDriver
<
T
>
(
dev_ctx_
,
static_cast
<
const
std
::
string
>
(
dropout_param_
.
dropout_implementation_
),
dropout_param_
.
dropout_prob_
,
static_cast
<
const
Tensor
&>
(
*
dropout_out_grad_tensor
),
dropout_mask_out_tensor
,
softmax_out_grad_tensor
->
numel
(),
softmax_out_grad_tensor
);
}
if
(
&
src_mask_tensor
!=
nullptr
)
{
SoftmaxBackwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
softmax_out_tensor
,
*
softmax_out_grad_tensor
,
softmax_axis
,
src_mask_out_grad_tensor
);
// recall LaunchElementwiseCudaKernel fw: src_mask_out = qk_out +
// src_mask
// Special case when dy is not needed and dx doesn't reduce
if
(
qk_out_grad_tensor
!=
nullptr
&&
src_mask_grad_tensor
==
nullptr
&&
qk_out_tensor
.
dims
()
==
src_mask_out_tensor
.
dims
())
{
VLOG
(
4
)
<<
"Special case when dy is not needed and dx doesn't "
"reduce"
;
framework
::
TensorCopy
(
*
src_mask_out_grad_tensor
,
dev_ctx_
.
GetPlace
(),
dev_ctx_
,
qk_out_grad_tensor
);
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"Only used for the backward elementwise_add op when"
"dy is not needed and dx is not reduce"
));
return
;
}
}
else
{
SoftmaxBackwardCUDAKernelDriver
<
T
>
(
dev_ctx_
,
softmax_out_tensor
,
*
softmax_out_grad_tensor
,
softmax_axis
,
qk_out_grad_tensor
);
}
T
*
qk_out_grad_data
=
qk_out_grad_tensor
->
data
<
T
>
();
alpha
=
static_cast
<
T
>
(
1.0
/
sqrt
(
head_dim_
));
// recall batchedgemm(nt) fw: q_ptr * (k_ptr)^t = qk_out
// bw: dy (seq_len * head_dim) = (dout)^t * x
transA
=
CblasTrans
;
transB
=
CblasNoTrans
;
gemm_m
=
seq_len_
;
gemm_n
=
head_dim_
;
gemm_k
=
seq_len_
;
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
qk_out_grad_data
,
q_ptr
,
beta
,
k_grad_ptr
,
gemm_batch_size
,
stride_a
,
stride_b
);
// dx (seq_len * head_dim) = dout * y
transA
=
CblasNoTrans
;
transB
=
CblasNoTrans
;
gemm_m
=
seq_len_
;
gemm_n
=
head_dim_
;
gemm_k
=
seq_len_
;
stride_a
=
gemm_m
*
gemm_k
;
stride_b
=
gemm_k
*
gemm_n
;
blas
.
BatchedGEMM
(
transA
,
transB
,
gemm_m
,
gemm_n
,
gemm_k
,
alpha
,
qk_out_grad_data
,
k_ptr
,
beta
,
q_grad_ptr
,
gemm_batch_size
,
stride_a
,
stride_b
);
// transpose bw
ndims
=
5
;
std
::
vector
<
int
>
perm_1
=
{
1
,
3
,
0
,
2
,
4
};
TransposeGPUKernelDriver
<
T
>
(
dev_ctx_
,
ndims
,
*
transpose_2_out_grad_tensor
,
perm_1
,
qkv_input_grad_tensor
);
}
private:
const
platform
::
CUDADeviceContext
&
dev_ctx_
;
int64_t
batch_size_
;
int64_t
seq_len_
;
int64_t
num_head_
;
int64_t
head_dim_
;
AttnDropoutParam
dropout_param_
;
};
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/layer_norm_kernel.cu.h
浏览文件 @
88ea8e6f
...
...
@@ -35,7 +35,6 @@ namespace paddle {
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
DataLayout
=
framework
::
DataLayout
;
template
<
typename
T
>
using
CudnnDataType
=
platform
::
CudnnDataType
<
T
>
;
template
<
typename
T
>
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录