Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
79bfb184
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
79bfb184
编写于
9月 04, 2023
作者:
Y
Yuanle Liu
提交者:
GitHub
9月 04, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
multihead_matmul op support codegen and kernel remove to phi (#56846)
上级
7fd6ffb8
变更
13
展开全部
隐藏空白更改
内联
并排
Showing
13 changed file
with
1035 addition
and
1026 deletion
+1035
-1026
paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu
...rence/tensorrt/plugin/multihead_matmul_roformer_plugin.cu
+3
-3
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
.../fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
+3
-3
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+0
-3
paddle/fluid/operators/fused/multihead_matmul_op.cc
paddle/fluid/operators/fused/multihead_matmul_op.cc
+0
-116
paddle/fluid/operators/math/bert_encoder_functor.cu
paddle/fluid/operators/math/bert_encoder_functor.cu
+0
-698
paddle/fluid/operators/math/bert_encoder_functor.h
paddle/fluid/operators/math/bert_encoder_functor.h
+0
-29
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+10
-0
paddle/phi/api/yaml/op_compat.yaml
paddle/phi/api/yaml/op_compat.yaml
+8
-0
paddle/phi/infermeta/multiary.cc
paddle/phi/infermeta/multiary.cc
+33
-0
paddle/phi/infermeta/multiary.h
paddle/phi/infermeta/multiary.h
+11
-0
paddle/phi/kernels/funcs/multihead_matmul_functor.cu
paddle/phi/kernels/funcs/multihead_matmul_functor.cu
+745
-0
paddle/phi/kernels/funcs/multihead_matmul_functor.h
paddle/phi/kernels/funcs/multihead_matmul_functor.h
+51
-0
paddle/phi/kernels/fusion/gpu/multihead_matmul_kernel.cu
paddle/phi/kernels/fusion/gpu/multihead_matmul_kernel.cu
+171
-174
未找到文件。
paddle/fluid/inference/tensorrt/plugin/multihead_matmul_roformer_plugin.cu
浏览文件 @
79bfb184
...
...
@@ -22,9 +22,9 @@
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -254,7 +254,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
platform
::
CUDAPlace
(
device_id
)));
const
phi
::
GPUContext
&
dev_ctx
=
*
device_ctx
;
operators
::
math
::
MultiH
eadGPUComputeFunctor
<
float
>
multihead_compute_func
;
phi
::
funcs
::
Multih
eadGPUComputeFunctor
<
float
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
...
...
@@ -341,7 +341,7 @@ int MultiheadMatmulRoformerPlugin::enqueue(
tptr
,
static_cast
<
half
>
(
scale_
),
n_q
);
const
phi
::
GPUContext
&
dev_ctx
=
*
device_ctx
;
operators
::
math
::
MultiH
eadGPUComputeFunctor
<
half
>
multihead_compute_func
;
phi
::
funcs
::
Multih
eadGPUComputeFunctor
<
half
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
...
...
paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.cu
浏览文件 @
79bfb184
...
...
@@ -24,9 +24,9 @@
#include "paddle/fluid/inference/tensorrt/plugin/common/common.cuh"
#include "paddle/fluid/inference/tensorrt/plugin/qkv_to_context_plugin.h"
#include "paddle/fluid/inference/tensorrt/plugin/trt_plugin_utils.h"
#include "paddle/fluid/operators/math/bert_encoder_functor.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace
paddle
{
namespace
inference
{
...
...
@@ -396,7 +396,7 @@ int QkvToContextPluginDynamic::enqueue(
platform
::
CUDAPlace
(
device_id
)));
const
phi
::
GPUContext
&
dev_ctx
=
*
device_ctx
;
operators
::
math
::
MultiH
eadGPUComputeFunctor
<
float
>
multihead_compute_func
;
phi
::
funcs
::
Multih
eadGPUComputeFunctor
<
float
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
...
...
@@ -506,7 +506,7 @@ int QkvToContextPluginDynamic::enqueue(
tptr
,
static_cast
<
half
>
(
scale_
),
n_q
);
const
phi
::
GPUContext
&
dev_ctx
=
*
device_ctx
;
operators
::
math
::
MultiH
eadGPUComputeFunctor
<
half
>
multihead_compute_func
;
phi
::
funcs
::
Multih
eadGPUComputeFunctor
<
half
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
79bfb184
...
...
@@ -10,7 +10,6 @@ register_operators(
fusion_transpose_flatten_concat_op
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
...
...
@@ -74,8 +73,6 @@ if(WITH_GPU OR WITH_ROCM)
endif
()
# fused_fc_elementwise_layernorm_op
op_library
(
fused_fc_elementwise_layernorm_op
)
# multihead_matmul_op
op_library
(
multihead_matmul_op
)
op_library
(
skip_layernorm_op
)
op_library
(
yolo_box_head_op
)
op_library
(
yolo_box_post_op
)
...
...
paddle/fluid/operators/fused/multihead_matmul_op.cc
已删除
100644 → 0
浏览文件 @
7fd6ffb8
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/errors.h"
namespace
paddle
{
namespace
operators
{
class
MultiHeadMatMulV2Op
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
protected:
void
InferShape
(
framework
::
InferShapeContext
*
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
context
->
HasInput
(
"Input"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Input) of MultiHeadMatMul should not be null."
));
PADDLE_ENFORCE_EQ
(
context
->
HasInput
(
"W"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(W) of MultiHeadMatMul should not be null."
));
PADDLE_ENFORCE_EQ
(
context
->
HasInput
(
"Bias"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Input(Bias) of MultiHeadMatMul should not be null."
));
PADDLE_ENFORCE_EQ
(
context
->
HasOutput
(
"Out"
),
true
,
platform
::
errors
::
InvalidArgument
(
"Output(Out) of MultiHeadMatMul should not be null."
));
auto
dim_w
=
context
->
GetInputDim
(
"W"
);
PADDLE_ENFORCE_GT
(
dim_w
.
size
(),
2
,
platform
::
errors
::
InvalidArgument
(
"Multihead input is expected at least a 3-D tensor, but "
"it's %d-D tensor now."
,
dim_w
.
size
()));
auto
dim_bias_q
=
context
->
GetInputDim
(
"Bias"
);
PADDLE_ENFORCE_GT
(
dim_bias_q
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"Multihead input should be at least 2-D tensor, but it's "
"%d-D tensor now."
,
dim_bias_q
.
size
()));
auto
dim_input
=
context
->
GetInputDim
(
"Input"
);
context
->
SetOutputDim
(
"Out"
,
dim_input
);
context
->
ShareLoD
(
"Input"
,
/*->*/
"Out"
);
}
};
class
MultiHeadMatMulV2OpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"Input"
,
"The input of MultiHeadMatMul op"
);
AddInput
(
"W"
,
"The weight input of MultiHeadMatMul op"
);
AddInput
(
"Bias"
,
"The bias input of MultiHeadMatMul op"
);
AddInput
(
"BiasQK"
,
"The QK bias input of MultiHeadMatMul op"
)
.
AsDispensable
();
AddOutput
(
"Out"
,
"The output of MultiHeadMatMul op"
);
AddAttr
<
bool
>
(
"transpose_Q"
,
R"DOC(If true, use the transpose of `Q`.
)DOC"
)
.
SetDefault
(
false
);
AddAttr
<
bool
>
(
"transpose_K"
,
R"DOC(If true, use the transpose of `K`.
)DOC"
)
.
SetDefault
(
true
);
AddAttr
<
bool
>
(
"transpose_V"
,
R"DOC(If true, use the transpose of `V`.
)DOC"
)
.
SetDefault
(
false
);
AddAttr
<
float
>
(
"alpha"
,
"The scale of Out"
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
.
SetDefault
(
1
);
AddComment
(
R"DOC(
MultiHeadMatMul Operator.
This op is used for optimize multi head calculation in ernie model.
Not suggest to use in other case except has same structure as ernie.
Example of matrix multiplication with head_number of B
- X: [B, M, K], Y: [B, K, N] => Out: [B, M, N]
)DOC"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_WITHOUT_GRADIENT
(
multihead_matmul
,
ops
::
MultiHeadMatMulV2Op
,
ops
::
MultiHeadMatMulV2OpMaker
);
paddle/fluid/operators/math/bert_encoder_functor.cu
浏览文件 @
79bfb184
此差异已折叠。
点击以展开。
paddle/fluid/operators/math/bert_encoder_functor.h
浏览文件 @
79bfb184
...
...
@@ -77,35 +77,6 @@ class EmbEltwiseLayerNormFunctor {
gpuStream_t
stream
);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template
<
typename
T
>
class
MultiHeadGPUComputeFunctor
{
public:
void
operator
()(
const
phi
::
GPUContext
&
dev_ctx
,
int
batch
,
int
seq_len
,
int
head_num
,
int
head_size
,
T
*
qkptr
,
const
T
*
bias_qk_ptr
,
bool
bias_is_mask
,
T
*
tptr
,
T
alpha
,
T
beta
);
};
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
79bfb184
...
...
@@ -189,6 +189,16 @@
data_type
:
x
optional
:
mask, seq_lod, max_seq_len, x_fp16, out_fp16
-
op
:
multihead_matmul
args
:
(Tensor input, Tensor w, Tensor bias, Tensor bias_qk, bool transpose_q =
false
, bool transpose_k =
true
, bool transpose_v =
false
, float alpha = 1.0f, int head_number = 1)
output
:
Tensor(out)
infer_meta
:
func
:
MultiheadMatmulInferMeta
kernel
:
func
:
multihead_matmul
data_type
:
input
optional
:
bias_qk
-
op
:
yolo_box_xpu
args
:
(Tensor x, Tensor x_max, Tensor grid, Tensor stride, Tensor anchor_grid, float offset)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/api/yaml/op_compat.yaml
浏览文件 @
79bfb184
...
...
@@ -1935,6 +1935,14 @@
outputs
:
{
out
:
Out
,
index
:
Index
,
nms_rois_num
:
NmsRoisNum
}
-
op
:
multihead_matmul
inputs
:
{
input
:
Input
,
w
:
W
,
bias
:
Bias
,
bias_qk
:
BiasQK
}
outputs
:
out
:
Out
attrs
:
{
transpose_q
:
transpose_Q
,
transpose_k
:
transpose_K
,
transpose_v
:
transpose_V
}
-
op
:
multinomial
inputs
:
{
x
:
X
}
...
...
paddle/phi/infermeta/multiary.cc
浏览文件 @
79bfb184
...
...
@@ -4126,6 +4126,39 @@ void WeightedSampleNeighborsInferMeta(const MetaTensor& row,
out_count
->
set_dtype
(
DataType
::
INT32
);
}
void
MultiheadMatmulInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
w
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
bias_qk
,
const
bool
transpose_q
,
const
bool
transpose_k
,
const
bool
transpose_v
,
const
float
alpha
,
const
int
head_number
,
MetaTensor
*
out
)
{
auto
w_dims
=
w
.
dims
();
PADDLE_ENFORCE_GT
(
w_dims
.
size
(),
2
,
errors
::
InvalidArgument
(
"MultiheadMatmul's w is expected at least a 3-D tensor, but "
"it's %d-D tensor now."
,
w_dims
.
size
()));
auto
bias_dims
=
bias
.
dims
();
PADDLE_ENFORCE_GT
(
bias_dims
.
size
(),
1
,
errors
::
InvalidArgument
(
"MultiheadMatmul's bias should be at least 2-D tensor, but it's "
"%d-D tensor now."
,
bias_dims
.
size
()));
out
->
set_dims
(
input
.
dims
());
out
->
set_dtype
(
input
.
dtype
());
out
->
share_lod
(
input
);
}
void
MaskedMultiheadAttentionInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
cache_kv
,
const
MetaTensor
&
bias
,
...
...
paddle/phi/infermeta/multiary.h
浏览文件 @
79bfb184
...
...
@@ -811,6 +811,17 @@ void FusedRopeInferMeta(const MetaTensor& q,
MetaTensor
*
out_k
,
MetaTensor
*
out_v
);
void
MultiheadMatmulInferMeta
(
const
MetaTensor
&
input
,
const
MetaTensor
&
w
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
bias_qk
,
const
bool
transpose_q
,
const
bool
transpose_k
,
const
bool
transpose_v
,
const
float
alpha
,
const
int
head_number
,
MetaTensor
*
out
);
void
MaskedMultiheadAttentionInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
cache_kv
,
const
MetaTensor
&
bias
,
...
...
paddle/phi/kernels/funcs/multihead_matmul_functor.cu
0 → 100644
浏览文件 @
79bfb184
此差异已折叠。
点击以展开。
paddle/phi/kernels/funcs/multihead_matmul_functor.h
0 → 100644
浏览文件 @
79bfb184
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/backends/gpu/gpu_context.h"
namespace
phi
{
namespace
funcs
{
// This functor involves a fusion calculation in Ernie or Bert.
// The fusion mode is as follows:
//
// | |
// matmul
// |
// eltwise_add
// |
// softmax /
// \ /
// matmul
// |
template
<
typename
T
>
class
MultiheadGPUComputeFunctor
{
public:
void
operator
()(
const
phi
::
GPUContext
&
dev_ctx
,
int
batch
,
int
seq_len
,
int
head_num
,
int
head_size
,
T
*
qkptr
,
const
T
*
bias_qk_ptr
,
bool
bias_is_mask
,
T
*
tptr
,
T
alpha
,
T
beta
);
};
}
// namespace funcs
}
// namespace phi
paddle/
fluid/operators/fused/multihead_matmul_op
.cu
→
paddle/
phi/kernels/fusion/gpu/multihead_matmul_kernel
.cu
浏览文件 @
79bfb184
// Copyright (c) 20
19
PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 20
23
PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
...
...
@@ -12,20 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <paddle/fluid/platform/device_context.h>
#include <algorithm>
#include <type_traits>
#include "paddle/
fluid/framework/op_registry
.h"
#include "paddle/
fluid/memory/malloc
.h"
#include "paddle/
fluid/operators/math/bert_encoder_functor
.h"
#include "paddle/
fluid/platform/float16
.h"
#include "paddle/
phi/common/float16
.h"
#include "paddle/
phi/core/enforce
.h"
#include "paddle/
phi/core/errors
.h"
#include "paddle/
phi/core/kernel_registry
.h"
#include "paddle/phi/core/tensor_utils.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/multihead_matmul_functor.h"
namespace
p
addle
{
namespace
operators
{
namespace
p
hi
{
namespace
fusion
{
template
<
typename
T
>
__global__
void
transpose
(
T
*
src
,
...
...
@@ -149,7 +148,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE
(
h
*
head_num
,
1024
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"head_num (%d) * head_size (%d) should <= %d"
,
head_num
,
head_size
,
...
...
@@ -165,7 +164,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE
(
h
*
head_num
,
1024
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"head_num (%d) * head_size (%d) should <= %d"
,
head_num
,
head_size
,
...
...
@@ -177,7 +176,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE
(
head_size
*
head_num
,
1024
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"head_num (%d) * head_size (%d) should <= %d"
,
head_num
,
head_size
,
...
...
@@ -193,9 +192,9 @@ void TransQKVWithBias(const int batch,
const
int
seq_len
,
const
int
head_size
,
const
int
head_num
,
const
p
latform
::
float16
*
input
,
const
p
latform
::
float16
*
bias
,
p
latform
::
float16
*
output
,
const
p
hi
::
dtype
::
float16
*
input
,
const
p
hi
::
dtype
::
float16
*
bias
,
p
hi
::
dtype
::
float16
*
output
,
gpuStream_t
stream
)
{
// BxSx3xNxH + 3xNxH -> 3xBxNxSxH
int
scratch_size
=
batch
*
head_num
*
seq_len
*
seq_len
;
...
...
@@ -209,7 +208,7 @@ void TransQKVWithBias(const int batch,
// limit h * head_num to max block size(1024).
PADDLE_ENFORCE_LE
(
h
*
head_num
,
1024
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"head_num (%d) * head_size (%d) should <= %d"
,
head_num
,
head_size
,
...
...
@@ -225,7 +224,7 @@ void TransQKVWithBias(const int batch,
// limit head_size * head_num to max block size(1024).
PADDLE_ENFORCE_LE
(
head_size
*
head_num
,
1024
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"head_num (%d) * head_size (%d) should <= %d"
,
head_num
,
head_size
,
...
...
@@ -240,7 +239,7 @@ inline int round_up(int seq_len, int multiple = 32) {
PADDLE_ENFORCE_GT
(
multiple
,
0
,
p
latform
::
errors
::
InvalidArgument
(
p
hi
::
errors
::
InvalidArgument
(
"multiple should be a positive number, but it's (%d)"
,
multiple
));
return
((
seq_len
+
multiple
-
1
)
/
multiple
)
*
multiple
;
}
...
...
@@ -270,168 +269,166 @@ __global__ void broadcast_batch_head_number(const T *src,
}
}
template
<
typename
T
,
typename
DeviceContext
>
class
MultiHeadMatMulV2Kernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
input
=
context
.
Input
<
phi
::
DenseTensor
>
(
"Input"
);
auto
*
w
=
context
.
Input
<
phi
::
DenseTensor
>
(
"W"
);
auto
*
bias
=
context
.
Input
<
phi
::
DenseTensor
>
(
"Bias"
);
auto
*
bias_qk
=
context
.
Input
<
phi
::
DenseTensor
>
(
"BiasQK"
);
auto
*
input_d
=
input
->
data
<
T
>
();
auto
*
w_d
=
w
->
data
<
T
>
();
auto
*
bias_d
=
bias
->
data
<
T
>
();
auto
*
bias_qk_d
=
bias_qk
?
bias_qk
->
data
<
T
>
()
:
nullptr
;
T
scale
=
static_cast
<
T
>
(
context
.
Attr
<
float
>
(
"alpha"
));
int
head_number
=
context
.
Attr
<
int
>
(
"head_number"
);
// compute q*k with eltadd
auto
&
device_ctx
=
context
.
template
device_context
<
DeviceContext
>();
auto
stream
=
device_ctx
.
stream
();
// should be (B * S * hidden)
auto
input_dims
=
input
->
dims
();
// shouble be (hidden * 3 * all_head_size)
auto
w_dims
=
w
->
dims
();
int
batch
=
input_dims
[
0
];
int
seq_len
=
input_dims
[
1
];
int
hidden
=
input_dims
[
2
];
phi
::
DenseTensor
temp_bias_tensor
;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
batch
*
seq_len
))
{
VLOG
(
4
)
<<
"Do broadcasted bias_qk from [batch, 1, 1, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
1
*
seq_len
*
seq_len
))
{
VLOG
(
4
)
<<
"do broadcasted bias_qk from [1, 1, seq_len, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast_batch_head_number
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
batch
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
if
(
!
bias_qk
)
{
int
size
=
batch
*
head_number
*
seq_len
*
seq_len
;
temp_bias_tensor
.
Resize
({
size
});
auto
*
temp_qk_bias
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
template
<
typename
T
,
typename
Context
>
void
MultiheadMatmulKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
input
,
const
DenseTensor
&
w
,
const
DenseTensor
&
bias
,
const
paddle
::
optional
<
DenseTensor
>
&
bias_qk
,
const
bool
transpose_q
,
const
bool
transpose_k
,
const
bool
transpose_v
,
const
float
alpha
,
const
int
head_number
,
DenseTensor
*
out
)
{
auto
*
input_d
=
input
.
data
<
T
>
();
auto
*
w_d
=
w
.
data
<
T
>
();
auto
*
bias_d
=
bias
.
data
<
T
>
();
auto
*
bias_qk_d
=
bias_qk
?
bias_qk
->
data
<
T
>
()
:
nullptr
;
T
scale
=
static_cast
<
T
>
(
alpha
);
// compute q*k with eltadd
auto
stream
=
dev_ctx
.
stream
();
// should be (B * S * hidden)
auto
input_dims
=
input
.
dims
();
// shouble be (hidden * 3 * all_head_size)
auto
w_dims
=
w
.
dims
();
int
batch
=
input_dims
[
0
];
int
seq_len
=
input_dims
[
1
];
int
hidden
=
input_dims
[
2
];
phi
::
DenseTensor
temp_bias_tensor
;
// if bias_qk is[batch, 1, 1, seq_len], the bias_qk_d need to be broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
batch
*
seq_len
))
{
VLOG
(
4
)
<<
"Do broadcasted bias_qk from [batch, 1, 1, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
dev_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
// if bias_qk is[1, 1, seq_len, seq_len], the bias_qk_d need to be
// broadcasted
if
(
bias_qk
&&
bias_qk
->
numel
()
==
(
1
*
seq_len
*
seq_len
))
{
VLOG
(
4
)
<<
"do broadcasted bias_qk from [1, 1, seq_len, seq_len]"
;
temp_bias_tensor
.
Resize
({
batch
*
head_number
*
seq_len
*
seq_len
});
auto
*
temp_qk_bias
=
dev_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
round_up
(
seq_len
);
broadcast_batch_head_number
<<<
grid
,
block
,
0
,
stream
>>>
(
bias_qk_d
,
temp_qk_bias
,
batch
,
seq_len
,
head_number
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
if
(
!
bias_qk
)
{
int
size
=
batch
*
head_number
*
seq_len
*
seq_len
;
temp_bias_tensor
.
Resize
({
size
});
auto
*
temp_qk_bias
=
dev_ctx
.
template
Alloc
<
T
>(
&
temp_bias_tensor
,
temp_bias_tensor
.
numel
()
*
sizeof
(
T
));
#ifdef PADDLE_WITH_HIP
hipMemset
(
temp_qk_bias
,
0
,
sizeof
(
float
)
*
size
);
hipMemset
(
temp_qk_bias
,
0
,
sizeof
(
float
)
*
size
);
#else
cudaMemset
(
temp_qk_bias
,
0
,
sizeof
(
float
)
*
size
);
cudaMemset
(
temp_qk_bias
,
0
,
sizeof
(
float
)
*
size
);
#endif
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
int
all_head_size
=
w_dims
[
2
];
int
head_size
=
all_head_size
/
head_number
;
auto
*
out
=
context
.
Output
<
phi
::
DenseTensor
>
(
"Out"
);
out
->
Resize
({
batch
,
seq_len
,
all_head_size
});
auto
*
output_d
=
device_ctx
.
template
Alloc
<
T
>(
out
,
out
->
numel
()
*
sizeof
(
T
));
// (B*S, hidden)
const
phi
::
DenseTensor
input_matrix
=
phi
::
ReshapeToMatrix
(
*
input
,
2
/*x_num_col_dims */
);
// (hidden, 3 * all_head_size)
const
phi
::
DenseTensor
w_matrix
=
phi
::
ReshapeToMatrix
(
*
w
,
1
/*y_num_col_dims*/
);
phi
::
DenseTensor
temp_out_tensor
;
auto
temp_out_dims
=
phi
::
make_ddim
({
batch
,
seq_len
,
3
,
head_number
,
head_size
});
temp_out_tensor
.
Resize
(
{
batch
*
seq_len
,
phi
::
product
(
temp_out_dims
)
/
(
batch
*
seq_len
)});
auto
*
temp_out_data
=
device_ctx
.
template
Alloc
<
T
>(
&
temp_out_tensor
,
temp_out_tensor
.
numel
()
*
sizeof
(
T
));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
GPUContext
,
T
>
(
device_ctx
);
blas
.
MatMul
(
input_matrix
,
w_matrix
,
&
temp_out_tensor
);
VLOG
(
2
)
<<
"(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)"
;
VLOG
(
2
)
<<
temp_out_tensor
;
// temp_out_tensor.Resize(temp_out_dims);
phi
::
DenseTensor
multihead_temp_tensor
;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int
scratch_size
=
batch
*
head_number
*
seq_len
*
seq_len
*
1
;
multihead_temp_tensor
.
Resize
({
scratch_size
+
temp_out_tensor
.
numel
()});
auto
*
multihead_temp_data
=
device_ctx
.
template
Alloc
<
T
>(
&
multihead_temp_tensor
,
multihead_temp_tensor
.
numel
()
*
sizeof
(
T
));
auto
*
qkptr
=
multihead_temp_data
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias
(
batch
,
seq_len
,
head_size
,
head_number
,
temp_out_data
,
bias_d
,
tptr
,
stream
);
if
(
std
::
is_same
<
T
,
platform
::
float16
>::
value
)
{
math
::
MultiHeadGPUComputeFunctor
<
half
>
multihead_compute_func
;
multihead_compute_func
(
device_ctx
,
batch
,
seq_len
,
head_number
,
head_size
,
reinterpret_cast
<
half
*>
(
qkptr
),
reinterpret_cast
<
const
half
*>
(
bias_qk_d
),
false
,
reinterpret_cast
<
half
*>
(
tptr
),
__float2half
(
static_cast
<
float
>
(
scale
)),
__float2half
(
0.0
));
}
else
{
math
::
MultiHeadGPUComputeFunctor
<
T
>
multihead_compute_func
;
multihead_compute_func
(
device_ctx
,
batch
,
seq_len
,
head_number
,
head_size
,
qkptr
,
bias_qk_d
,
false
,
tptr
,
scale
,
T
(
0.0
));
}
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
head_size
;
transpose
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
tptr
,
output_d
,
batch
,
seq_len
,
head_number
,
head_size
);
bias_qk_d
=
static_cast
<
const
T
*>
(
temp_qk_bias
);
}
int
all_head_size
=
w_dims
[
2
];
int
head_size
=
all_head_size
/
head_number
;
out
->
Resize
({
batch
,
seq_len
,
all_head_size
});
auto
*
output_d
=
dev_ctx
.
template
Alloc
<
T
>(
out
,
out
->
numel
()
*
sizeof
(
T
));
// (B*S, hidden)
const
phi
::
DenseTensor
input_matrix
=
phi
::
ReshapeToMatrix
(
input
,
2
/*x_num_col_dims */
);
// (hidden, 3 * all_head_size)
const
phi
::
DenseTensor
w_matrix
=
phi
::
ReshapeToMatrix
(
w
,
1
/*y_num_col_dims*/
);
phi
::
DenseTensor
temp_out_tensor
;
auto
temp_out_dims
=
phi
::
make_ddim
({
batch
,
seq_len
,
3
,
head_number
,
head_size
});
temp_out_tensor
.
Resize
(
{
batch
*
seq_len
,
phi
::
product
(
temp_out_dims
)
/
(
batch
*
seq_len
)});
auto
*
temp_out_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
temp_out_tensor
,
temp_out_tensor
.
numel
()
*
sizeof
(
T
));
// (B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)
auto
blas
=
phi
::
funcs
::
GetBlas
<
phi
::
GPUContext
,
T
>
(
dev_ctx
);
blas
.
MatMul
(
input_matrix
,
w_matrix
,
&
temp_out_tensor
);
VLOG
(
2
)
<<
"(B * S, hidden) * (hidden, 3 * N * H) -> (B * S * 3 * N * H)"
;
// temp_out_tensor.Resize(temp_out_dims);
phi
::
DenseTensor
multihead_temp_tensor
;
// B * head_number * S * S * 1 + B * S * 3 * N * H
int
scratch_size
=
batch
*
head_number
*
seq_len
*
seq_len
*
1
;
multihead_temp_tensor
.
Resize
({
scratch_size
+
temp_out_tensor
.
numel
()});
auto
*
multihead_temp_data
=
dev_ctx
.
template
Alloc
<
T
>(
&
multihead_temp_tensor
,
multihead_temp_tensor
.
numel
()
*
sizeof
(
T
));
auto
*
qkptr
=
multihead_temp_data
;
auto
*
tptr
=
multihead_temp_data
+
scratch_size
;
// Do the transpose with bias.
// BxSx3xNxH => tptr: 3xBxNxSxH.
TransQKVWithBias
(
batch
,
seq_len
,
head_size
,
head_number
,
temp_out_data
,
bias_d
,
tptr
,
stream
);
if
(
std
::
is_same
<
T
,
phi
::
dtype
::
float16
>::
value
)
{
phi
::
funcs
::
MultiheadGPUComputeFunctor
<
half
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
head_number
,
head_size
,
reinterpret_cast
<
half
*>
(
qkptr
),
reinterpret_cast
<
const
half
*>
(
bias_qk_d
),
false
,
reinterpret_cast
<
half
*>
(
tptr
),
__float2half
(
static_cast
<
float
>
(
scale
)),
__float2half
(
0.0
));
}
else
{
phi
::
funcs
::
MultiheadGPUComputeFunctor
<
T
>
multihead_compute_func
;
multihead_compute_func
(
dev_ctx
,
batch
,
seq_len
,
head_number
,
head_size
,
qkptr
,
bias_qk_d
,
false
,
tptr
,
scale
,
T
(
0.0
));
}
};
}
// namespace operators
}
// namespace paddle
int
grid
=
batch
*
head_number
*
seq_len
;
int
block
=
head_size
;
transpose
<
T
><<<
grid
,
block
,
0
,
stream
>>>
(
tptr
,
output_d
,
batch
,
seq_len
,
head_number
,
head_size
);
}
}
// namespace fusion
}
// namespace phi
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
PD_REGISTER_
STRUCT_
KERNEL
(
multihead_matmul
,
GPU
,
ALL_LAYOUT
,
ops
::
MultiHeadMatMulV2
Kernel
,
float
,
plat
::
float16
)
{}
PD_REGISTER_KERNEL
(
multihead_matmul
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
MultiheadMatmul
Kernel
,
float
,
phi
::
dtype
::
float16
)
{}
#else
PD_REGISTER_STRUCT_KERNEL
(
multihead_matmul
,
GPU
,
ALL_LAYOUT
,
ops
::
MultiHeadMatMulV2Kernel
,
float
)
{}
PD_REGISTER_KERNEL
(
multihead_matmul
,
GPU
,
ALL_LAYOUT
,
phi
::
fusion
::
MultiheadMatmulKernel
,
float
)
{}
#endif
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录