Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
fc880209
P
Paddle
项目概览
PaddlePaddle
/
Paddle
接近 2 年 前同步成功
通知
2323
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
fc880209
编写于
6月 08, 2023
作者:
C
cmeng
提交者:
GitHub
6月 08, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fuse vit attention for faster-rcnn on BML (#54139)
上级
25409dcc
变更
11
隐藏空白更改
内联
并排
Showing
11 changed file
with
1015 addition
and
0 deletion
+1015
-0
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+7
-0
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+75
-0
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+27
-0
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc
+150
-0
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h
+41
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/fluid/operators/fused/CMakeLists.txt
paddle/fluid/operators/fused/CMakeLists.txt
+9
-0
paddle/fluid/operators/fused/scaled_dp_attention.h
paddle/fluid/operators/fused/scaled_dp_attention.h
+466
-0
paddle/fluid/operators/fused/self_dp_attention_op.cc
paddle/fluid/operators/fused/self_dp_attention_op.cc
+124
-0
paddle/fluid/operators/fused/self_dp_attention_op.h
paddle/fluid/operators/fused/self_dp_attention_op.h
+41
-0
test/mkldnn/test_fused_vit_attention.py
test/mkldnn/test_fused_vit_attention.py
+74
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
fc880209
...
...
@@ -200,6 +200,13 @@ if(WITH_MKLDNN)
pass_library
(
multi_gru_seq_fuse_pass inference DIR mkldnn
)
pass_library
(
quant_dequant_mkldnn_pass inference DIR mkldnn
)
pass_library
(
compute_propagate_scales_mkldnn_pass inference DIR mkldnn
)
pass_library
(
self_attention_fuse_pass inference DIR mkldnn
)
if
(
WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG
)
set_target_properties
(
self_attention_fuse_pass
PROPERTIES COMPILE_FLAGS
"-mfma
${
AVX512F_FLAG
}
"
)
endif
()
endif
()
if
(
WITH_IPU
)
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
fc880209
...
...
@@ -2615,6 +2615,81 @@ PDNode *patterns::VitAttention::operator()(PDNode *in) {
return
reshape2_out
;
}
PDNode
*
patterns
::
SelfAttention
::
operator
()(
PDNode
*
in
)
{
in
->
AsInput
();
std
::
unordered_set
<
std
::
string
>
matmul_ops
{
"matmul"
,
"matmul_v2"
};
auto
transpose2_0_op
=
pattern
->
NewNode
(
transpose2_0_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
transpose2_0_out
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_op_input
(
"slice"
,
"Input"
)
->
AsIntermediate
();
auto
slice_0_op
=
pattern
->
NewNode
(
slice_0_op_repr
())
->
assert_is_op
(
"slice"
);
auto
slice_0_out
=
pattern
->
NewNode
(
slice_0_out_repr
())
->
assert_is_op_output
(
"slice"
,
"Out"
)
->
assert_is_ops_input
(
matmul_ops
,
"X"
)
->
AsIntermediate
();
auto
slice_1_op
=
pattern
->
NewNode
(
slice_1_op_repr
())
->
assert_is_op
(
"slice"
);
auto
slice_1_out
=
pattern
->
NewNode
(
slice_1_out_repr
())
->
assert_is_op_output
(
"slice"
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
)
->
AsIntermediate
();
auto
slice_2_op
=
pattern
->
NewNode
(
slice_2_op_repr
())
->
assert_is_op
(
"slice"
);
auto
slice_2_out
=
pattern
->
NewNode
(
slice_2_out_repr
())
->
assert_is_op_output
(
"slice"
,
"Out"
)
->
assert_is_ops_input
(
matmul_ops
,
"Y"
)
->
AsIntermediate
();
auto
matmul_0_op
=
pattern
->
NewNode
(
matmul_0_op_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
matmul_0_out
=
pattern
->
NewNode
(
matmul_0_out_repr
())
->
assert_is_ops_output
(
matmul_ops
,
"Out"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
)
->
AsIntermediate
();
auto
matmul_1_op
=
pattern
->
NewNode
(
matmul_1_op_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
matmul_1_out
=
pattern
->
NewNode
(
matmul_1_out_repr
())
->
assert_is_ops_output
(
matmul_ops
,
"Out"
)
->
assert_is_op_input
(
"softmax"
,
"X"
)
->
AsIntermediate
();
auto
transpose2_1_op
=
pattern
->
NewNode
(
transpose2_1_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
transpose2_1_out
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
assert_is_ops_input
(
matmul_ops
,
"Y"
)
->
AsIntermediate
();
auto
softmax_op
=
pattern
->
NewNode
(
softmax_op_repr
())
->
assert_is_op
(
"softmax"
);
auto
softmax_out
=
pattern
->
NewNode
(
softmax_out_repr
())
->
assert_is_op_output
(
"softmax"
,
"Out"
)
->
assert_is_ops_input
(
matmul_ops
,
"X"
)
->
AsIntermediate
();
auto
transpose2_2_op
=
pattern
->
NewNode
(
transpose2_2_op_repr
())
->
assert_is_op
(
"transpose2"
);
auto
transpose2_2_out
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
,
"Out"
)
->
AsOutput
();
transpose2_0_op
->
LinksFrom
({
in
});
transpose2_0_out
->
LinksFrom
({
transpose2_0_op
});
slice_0_op
->
LinksFrom
({
transpose2_0_out
});
slice_0_out
->
LinksFrom
({
slice_0_op
});
slice_1_op
->
LinksFrom
({
transpose2_0_out
});
slice_1_out
->
LinksFrom
({
slice_1_op
});
slice_2_op
->
LinksFrom
({
transpose2_0_out
});
slice_2_out
->
LinksFrom
({
slice_2_op
});
transpose2_1_op
->
LinksFrom
({
slice_1_out
});
transpose2_1_out
->
LinksFrom
({
transpose2_1_op
});
matmul_1_op
->
LinksFrom
({
slice_0_out
,
transpose2_1_out
});
matmul_1_out
->
LinksFrom
({
matmul_1_op
});
softmax_op
->
LinksFrom
({
matmul_1_out
});
softmax_out
->
LinksFrom
({
softmax_op
});
matmul_0_op
->
LinksFrom
({
softmax_out
,
slice_2_out
});
matmul_0_out
->
LinksFrom
({
matmul_0_op
});
transpose2_2_op
->
LinksFrom
({
matmul_0_out
});
transpose2_2_out
->
LinksFrom
({
transpose2_2_op
});
return
transpose2_2_out
;
}
PDNode
*
patterns
::
ConvElementwiseadd2Act
::
operator
()(
PDNode
*
conv_in
,
const
std
::
unordered_set
<
std
::
string
>
&
conv_act_set
)
{
auto
conv_op
=
pattern
->
NewNode
(
conv_op_repr
())
->
assert_is_op
(
"conv2d"
);
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
fc880209
...
...
@@ -1491,6 +1491,33 @@ struct VitAttention : public PatternBase {
PATTERN_DECL_NODE
(
reshape2_out
);
};
// self_attention in vit
struct
SelfAttention
:
public
PatternBase
{
SelfAttention
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"vit_block"
)
{}
PDNode
*
operator
()(
PDNode
*
in
);
PATTERN_DECL_NODE
(
transpose2_0_op
);
PATTERN_DECL_NODE
(
transpose2_0_out
);
PATTERN_DECL_NODE
(
transpose2_1_op
);
PATTERN_DECL_NODE
(
transpose2_1_out
);
PATTERN_DECL_NODE
(
transpose2_2_op
);
PATTERN_DECL_NODE
(
transpose2_2_out
);
PATTERN_DECL_NODE
(
matmul_0_op
);
PATTERN_DECL_NODE
(
matmul_0_out
);
PATTERN_DECL_NODE
(
matmul_1_op
);
PATTERN_DECL_NODE
(
matmul_1_out
);
PATTERN_DECL_NODE
(
slice_0_op
);
PATTERN_DECL_NODE
(
slice_0_out
);
PATTERN_DECL_NODE
(
slice_1_op
);
PATTERN_DECL_NODE
(
slice_1_out
);
PATTERN_DECL_NODE
(
slice_2_op
);
PATTERN_DECL_NODE
(
slice_2_out
);
PATTERN_DECL_NODE
(
softmax_op
);
PATTERN_DECL_NODE
(
softmax_out
);
};
// Conv + ElementwiseAdd + an activation
// This pattern can further fuse the conv related ops after the conv+bn fusion.
struct
ConvElementwiseaddAct
:
public
PatternBase
{
...
...
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.cc
0 → 100644
浏览文件 @
fc880209
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h"
#include <string>
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/string/pretty_log.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#define GET_IR_NODE(node__) GET_IR_NODE_FROM_SUBGRAPH(node__, node__, pattern);
#define GET_NODES \
GET_IR_NODE(transpose2_0_op); \
GET_IR_NODE(transpose2_0_out); \
GET_IR_NODE(slice_0_op); \
GET_IR_NODE(slice_0_out); \
GET_IR_NODE(slice_1_op); \
GET_IR_NODE(slice_1_out); \
GET_IR_NODE(slice_2_op); \
GET_IR_NODE(slice_2_out); \
GET_IR_NODE(matmul_0_op); \
GET_IR_NODE(matmul_0_out); \
GET_IR_NODE(matmul_1_op); \
GET_IR_NODE(matmul_1_out); \
GET_IR_NODE(transpose2_1_op); \
GET_IR_NODE(transpose2_1_out); \
GET_IR_NODE(softmax_op); \
GET_IR_NODE(softmax_out); \
GET_IR_NODE(transpose2_2_op); \
GET_IR_NODE(transpose2_2_out);
namespace
paddle
{
namespace
framework
{
namespace
ir
{
using
string
::
PrettyLogDetail
;
void
SelfAttentionFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
#if !defined(__AVX512F__) || !defined(PADDLE_WITH_MKLML) || \
!defined(PADDLE_WITH_MKLDNN)
LOG
(
WARNING
)
<<
"No-avx512 or MKL supported!"
;
return
;
#endif
// do something;
GraphPatternDetector
gpd
;
const
std
::
string
pattern_name
=
"self_attention_fuse"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
// pattern
PDNode
*
x
=
gpd
.
mutable_pattern
()
->
NewNode
(
"x"
)
->
assert_is_op_input
(
"transpose2"
,
"X"
)
->
AsInput
();
patterns
::
SelfAttention
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
x
);
int
fusion_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_NODES
;
// do something;
OpDesc
desc
(
transpose2_0_op
->
Op
()
->
Block
());
desc
.
SetType
(
"self_dp_attention"
);
desc
.
SetInput
(
"X"
,
{
subgraph
.
at
(
x
)
->
Name
()});
desc
.
SetOutput
(
"Out"
,
{
transpose2_2_out
->
Name
()});
std
::
vector
<
int64_t
>
in_shape
=
subgraph
.
at
(
x
)
->
Var
()
->
GetShape
();
std
::
vector
<
int64_t
>
shape
=
transpose2_0_out
->
Var
()
->
GetShape
();
// in shape should be [batch_size, seq_len, 3, num_heads, head_size]
if
(
in_shape
.
size
()
!=
5
||
in_shape
[
2
]
!=
3
||
shape
.
size
()
!=
5
||
shape
[
0
]
!=
3
||
shape
[
2
]
!=
in_shape
[
3
])
{
LOG
(
WARNING
)
<<
"Self-attention shape mismatch!"
;
return
;
}
desc
.
SetAttr
(
"head_number"
,
static_cast
<
int
>
(
shape
[
2
]));
float
alpha
=
1.0
;
if
(
matmul_1_op
->
Op
()
->
HasAttr
(
"alpha"
))
alpha
=
PADDLE_GET_CONST
(
float
,
matmul_1_op
->
Op
()
->
GetAttr
(
"alpha"
));
desc
.
SetAttr
(
"alpha"
,
alpha
);
// Create a new node for the fused op.
auto
self_attention_node
=
graph
->
CreateOpNode
(
&
desc
);
// Link inputs and outputs.
PADDLE_ENFORCE_NE
(
subgraph
.
count
(
x
),
0
,
platform
::
errors
::
NotFound
(
"Detector did not find input x of self attention."
));
IR_NODE_LINK_TO
(
subgraph
.
at
(
x
),
self_attention_node
);
// Input
IR_NODE_LINK_TO
(
self_attention_node
,
transpose2_2_out
);
// Output
// Delete the unneeded nodes.
std
::
unordered_set
<
const
Node
*>
marked_nodes
({
transpose2_0_op
,
transpose2_0_out
,
slice_0_op
,
slice_0_out
,
slice_1_op
,
slice_1_out
,
slice_2_op
,
slice_2_out
,
matmul_0_op
,
matmul_0_out
,
matmul_1_op
,
matmul_1_out
,
transpose2_1_op
,
transpose2_1_out
,
softmax_op
,
softmax_out
,
transpose2_2_op
});
GraphSafeRemoveNodes
(
graph
,
marked_nodes
);
++
fusion_count
;
};
gpd
(
graph
,
handler
);
AddStatis
(
fusion_count
);
if
(
!
Has
(
"disable_logs"
)
||
!
Get
<
bool
>
(
"disable_logs"
))
{
PrettyLogDetail
(
"--- fused %d self attention (of scaled_dp_attention) with %s"
,
fusion_count
,
pattern_name
);
}
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
self_attention_fuse_pass
,
paddle
::
framework
::
ir
::
SelfAttentionFusePass
);
REGISTER_PASS_CAPABILITY
(
self_attention_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
EQ
(
"transpose2"
,
0
)
.
EQ
(
"slice"
,
0
)
.
EQ
(
"scale"
,
0
)
.
EQ
(
"softmax"
,
0
)
.
EQ
(
"matmul_v2"
,
0
));
paddle/fluid/framework/ir/mkldnn/self_attention_fuse_pass.h
0 → 100644
浏览文件 @
fc880209
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
// Fusing of self-attetion structure
class
Graph
;
class
SelfAttentionFusePass
:
public
FusePassBase
{
public:
virtual
~
SelfAttentionFusePass
()
{}
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
};
}
// namespace ir
}
// namespace framework
}
// namespace paddle
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
fc880209
...
...
@@ -367,6 +367,7 @@ void CpuPassStrategy::EnableMKLDNN() {
"fc_mkldnn_pass"
,
"fc_act_mkldnn_fuse_pass"
,
"fc_elementwise_add_mkldnn_fuse_pass"
,
//
"self_attention_fuse_pass"
,
//
"batch_norm_act_fuse_pass"
,
//
"softplus_activation_onednn_fuse_pass"
,
//
"shuffle_channel_mkldnn_detect_pass"
,
//
...
...
paddle/fluid/operators/fused/CMakeLists.txt
浏览文件 @
fc880209
...
...
@@ -11,6 +11,7 @@ register_operators(
fusion_conv_inception_op
fused_fc_elementwise_layernorm_op
multihead_matmul_op
self_dp_attention_op
skip_layernorm_op
yolo_box_head_op
yolo_box_post_op
...
...
@@ -33,6 +34,14 @@ register_operators(
# fusion_gru_op does not have CUDA kernel
op_library
(
fusion_gru_op
)
op_library
(
fusion_lstm_op
)
if
(
WITH_AVX
AND AVX512F_FOUND
AND AVX512F_FLAG
AND WITH_MKL
)
op_library
(
self_dp_attention_op
)
set_target_properties
(
self_dp_attention_op PROPERTIES COMPILE_FLAGS
"-mfma
${
AVX512F_FLAG
}
"
)
endif
()
if
(
WITH_XPU
)
op_library
(
resnet_basic_block_op
)
...
...
paddle/fluid/operators/fused/scaled_dp_attention.h
0 → 100644
浏览文件 @
fc880209
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include <assert.h>
#include <immintrin.h>
#include <math.h>
#include <omp.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <iostream>
#include <new>
#include <string>
#ifdef PADDLE_WITH_MKLDNN
#include "dnnl.hpp" //NOLINT
#endif
namespace
paddle
{
namespace
operators
{
template
<
typename
T
,
typename
Tt
>
void
arraycpy
(
T
*
dst
,
const
Tt
*
src
,
int
n
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for
(
int
i
=
0
;
i
<
n
;
i
++
)
{
dst
[
i
]
=
static_cast
<
T
>
(
src
[
i
]);
}
}
// batchs x tokens x 3 x head x heads -> 3 x batchs x head x tokens x heads (2
// 0 3 1 4)
template
<
typename
T
,
typename
Tt
>
void
transpose_before_bmm1
(
const
T
*
qkvBuffer
,
Tt
*
qkvTransBuffer
,
int
batchSize
,
int
tokenSize
,
int
headNum
,
int
headSize
)
{
int
hiddenSize
=
headNum
*
headSize
;
int
blocksize
=
tokenSize
*
hiddenSize
;
// dst buffer stride in each batch
const
T
*
qBuffer
=
qkvBuffer
;
const
T
*
kBuffer
=
qkvBuffer
+
hiddenSize
;
const
T
*
vBuffer
=
qkvBuffer
+
hiddenSize
*
2
;
Tt
*
q_buffer
=
qkvTransBuffer
;
Tt
*
k_buffer
=
qkvTransBuffer
+
batchSize
*
blocksize
;
Tt
*
v_buffer
=
qkvTransBuffer
+
batchSize
*
blocksize
*
2
;
int
bmHead
=
headNum
;
int
cols_per_bmHead
=
hiddenSize
/
headNum
;
// 768/12 = 64
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
int
k
=
0
;
k
<
bmHead
;
k
++
)
{
for
(
int
j
=
0
;
j
<
tokenSize
;
j
++
)
{
const
T
*
q_src_each_batch
=
reinterpret_cast
<
const
T
*>
(
qBuffer
)
+
blocksize
*
3
*
i
;
const
T
*
k_src_each_batch
=
reinterpret_cast
<
const
T
*>
(
kBuffer
)
+
blocksize
*
3
*
i
;
const
T
*
v_src_each_batch
=
reinterpret_cast
<
const
T
*>
(
vBuffer
)
+
blocksize
*
3
*
i
;
int
dst_offset_each_bmHead
=
k
*
tokenSize
*
cols_per_bmHead
;
int
src_offset_each_line
=
k
*
cols_per_bmHead
;
int
dst_offset_each_line
=
j
*
cols_per_bmHead
;
int
src_offset_each_bmHead
=
j
*
hiddenSize
*
3
;
Tt
*
q_dst_each_line
=
q_buffer
+
i
*
blocksize
+
dst_offset_each_bmHead
+
dst_offset_each_line
;
const
T
*
q_src_each_line
=
q_src_each_batch
+
src_offset_each_bmHead
+
src_offset_each_line
;
Tt
*
k_dst_each_line
=
k_buffer
+
i
*
blocksize
+
dst_offset_each_bmHead
+
dst_offset_each_line
;
const
T
*
k_src_each_line
=
k_src_each_batch
+
src_offset_each_bmHead
+
src_offset_each_line
;
Tt
*
v_dst_each_line
=
v_buffer
+
i
*
blocksize
+
dst_offset_each_bmHead
+
dst_offset_each_line
;
const
T
*
v_src_each_line
=
v_src_each_batch
+
src_offset_each_bmHead
+
src_offset_each_line
;
arraycpy
<
Tt
,
T
>
(
q_dst_each_line
,
q_src_each_line
,
cols_per_bmHead
);
arraycpy
<
Tt
,
T
>
(
k_dst_each_line
,
k_src_each_line
,
cols_per_bmHead
);
arraycpy
<
Tt
,
T
>
(
v_dst_each_line
,
v_src_each_line
,
cols_per_bmHead
);
}
}
}
}
// batchs x head x tokens x heads -> batchs x tokens x head x heads (0 2 1 3)
template
<
typename
T
,
typename
Tt
>
void
transpose_after_bmm2
(
T
*
Buffer
,
Tt
*
TransBuffer
,
int
batchSize
,
int
tokenSize
,
int
headNum
,
int
headSize
)
{
int
hiddenSize
=
headNum
*
headSize
;
int
blocksize
=
tokenSize
*
hiddenSize
;
// dst buffer stride in each batch
int
bmHead
=
headNum
;
int
cols_per_bmHead
=
hiddenSize
/
headNum
;
// 768/12 = 64
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(2)
#endif
for
(
int
i
=
0
;
i
<
batchSize
;
i
++
)
{
for
(
int
k
=
0
;
k
<
tokenSize
;
k
++
)
{
int
src_offset_each_head
=
k
*
cols_per_bmHead
;
int
dst_offset_each_line
=
k
*
hiddenSize
;
for
(
int
j
=
0
;
j
<
bmHead
;
j
++
)
{
int
src_offset_each_line
=
j
*
tokenSize
*
cols_per_bmHead
;
int
dst_offset_each_head
=
j
*
cols_per_bmHead
;
Tt
*
q_dst_each_line
=
TransBuffer
+
dst_offset_each_head
+
dst_offset_each_line
+
i
*
blocksize
;
const
T
*
q_src_each_line
=
Buffer
+
src_offset_each_line
+
src_offset_each_head
+
i
*
blocksize
;
arraycpy
<
Tt
,
T
>
(
q_dst_each_line
,
q_src_each_line
,
cols_per_bmHead
);
}
}
}
}
// C = A * B
// bTranspose: B need to be transposed or not
void
sgemm
(
const
float
*
A
,
const
float
*
B
,
float
*
C
,
int
m
,
int
n
,
int
k
,
bool
transa
,
bool
transb
)
{
#ifdef PADDLE_WITH_MKLDNN
int
lda
=
(
transa
?
m
:
k
);
int
ldb
=
(
transb
?
k
:
n
);
int
ldc
=
n
;
float
alpha
=
1
;
float
beta
=
0
;
char
ta
[]
=
"N"
;
char
tb
[]
=
"N"
;
if
(
transa
)
ta
[
0
]
=
'T'
;
if
(
transb
)
tb
[
0
]
=
'T'
;
dnnl_sgemm
(
ta
[
0
],
tb
[
0
],
m
,
n
,
k
,
alpha
,
A
,
lda
,
B
,
ldb
,
beta
,
C
,
ldc
);
#else
LOG
(
ERROR
)
<<
"scaled_dp_atten not supported without WITH_MKL!"
;
#endif
}
#if defined(__AVX512F__)
// exp based-on jit code
static
inline
__m512
vexp
(
const
__m512
&
_x
)
{
__m512
p16f_1
=
_mm512_set1_ps
(
1.0
f
);
__m512
p16f_half
=
_mm512_set1_ps
(
0.5
f
);
__m512
p16f_127
=
_mm512_set1_ps
(
127.
f
);
__m512
p16f_exp_hi
=
_mm512_set1_ps
(
88.3762626647950
f
);
__m512
p16f_exp_lo
=
_mm512_set1_ps
(
-
88.3762626647949
f
);
__m512
p16f_cephes_LOG2EF
=
_mm512_set1_ps
(
1.44269504088896341
f
);
__m512
p16f_cephes_exp_p0
=
_mm512_set1_ps
(
1.9875691500E-4
f
);
__m512
p16f_cephes_exp_p1
=
_mm512_set1_ps
(
1.3981999507E-3
f
);
__m512
p16f_cephes_exp_p2
=
_mm512_set1_ps
(
8.3334519073E-3
f
);
__m512
p16f_cephes_exp_p3
=
_mm512_set1_ps
(
4.1665795894E-2
f
);
__m512
p16f_cephes_exp_p4
=
_mm512_set1_ps
(
1.6666665459E-1
f
);
__m512
p16f_cephes_exp_p5
=
_mm512_set1_ps
(
5.0000001201E-1
f
);
// Clamp x.
__m512
x
=
_mm512_max_ps
(
_mm512_min_ps
(
_x
,
p16f_exp_hi
),
p16f_exp_lo
);
// Express exp(x) as exp(m*ln(2) + r), start by extracting
// m = floor(x/ln(2) + 0.5).
__m512
m
=
_mm512_floor_ps
(
_mm512_fmadd_ps
(
x
,
p16f_cephes_LOG2EF
,
p16f_half
));
// Get r = x - m*ln(2).
__m512
p16f_nln2
=
_mm512_set1_ps
(
-
0.6931471805599453
f
);
__m512
r
=
_mm512_fmadd_ps
(
m
,
p16f_nln2
,
x
);
__m512
r2
=
_mm512_mul_ps
(
r
,
r
);
__m512
y
=
p16f_cephes_exp_p0
;
y
=
_mm512_fmadd_ps
(
y
,
r
,
p16f_cephes_exp_p1
);
y
=
_mm512_fmadd_ps
(
y
,
r
,
p16f_cephes_exp_p2
);
y
=
_mm512_fmadd_ps
(
y
,
r
,
p16f_cephes_exp_p3
);
y
=
_mm512_fmadd_ps
(
y
,
r
,
p16f_cephes_exp_p4
);
y
=
_mm512_fmadd_ps
(
y
,
r
,
p16f_cephes_exp_p5
);
y
=
_mm512_fmadd_ps
(
y
,
r2
,
r
);
y
=
_mm512_add_ps
(
y
,
p16f_1
);
// Build emm0 = 2^m.
__m512i
emm0
=
_mm512_cvttps_epi32
(
_mm512_add_ps
(
m
,
p16f_127
));
emm0
=
_mm512_slli_epi32
(
emm0
,
23
);
// Return 2^m * exp(r).
return
_mm512_max_ps
(
_mm512_mul_ps
(
y
,
_mm512_castsi512_ps
(
emm0
)),
_x
);
}
// need to do for res.
void
softmax_sum_max
(
float
*
AB
,
float
*
sum
,
float
*
max
,
float
*
pre_sum
,
float
*
pre_max
,
float
refac
,
int
m
,
int
k
)
{
assert
(
k
%
16
==
0
);
float
max_val
=
std
::
numeric_limits
<
float
>::
lowest
();
__m512
vrefac
=
_mm512_set1_ps
(
refac
);
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
float
*
buf
=
AB
+
i
*
k
;
// max val for avoiding inf and nan
__m512
vmax
=
_mm512_set1_ps
(
max_val
);
for
(
int
off
=
0
;
off
<
k
;
off
+=
16
)
{
int
remain
=
k
-
off
;
__mmask16
mask
=
(
remain
>=
16
?
0xffff
:
(
1
<<
remain
)
-
1
);
__m512
vx
=
_mm512_maskz_loadu_ps
(
mask
,
buf
+
off
);
vmax
=
_mm512_mask_max_ps
(
vmax
,
mask
,
vmax
,
vx
);
}
float
_max
=
_mm512_reduce_max_ps
(
vmax
);
_max
*=
refac
;
_max
=
_max
>
max
[
i
]
?
_max
:
max
[
i
];
__m512
merr
=
_mm512_set1_ps
(
max
[
i
]
-
_max
);
merr
=
vexp
(
merr
);
max
[
i
]
=
_max
;
// exp and get sum
__m512
vsum
=
_mm512_set1_ps
(
0
);
vmax
=
_mm512_set1_ps
(
_max
);
for
(
int
off
=
0
;
off
<
k
;
off
+=
16
)
{
int
remain
=
k
-
off
;
__mmask16
mask
=
(
remain
>=
16
?
0xffff
:
(
1
<<
remain
)
-
1
);
__m512
vx
=
_mm512_maskz_loadu_ps
(
mask
,
buf
+
off
);
vx
=
vexp
(
vx
*
vrefac
-
vmax
);
_mm512_mask_storeu_ps
(
buf
+
off
,
mask
,
vx
);
vsum
=
_mm512_mask_add_ps
(
vsum
,
mask
,
vsum
,
vx
);
}
float
_sum
=
_mm512_reduce_add_ps
(
vsum
);
float
fac
=
_mm512_cvtss_f32
(
merr
);
sum
[
i
]
=
sum
[
i
]
*
fac
+
_sum
;
_sum
=
sum
[
i
];
// Compute exp/sum(exp) and store
__m512
vrsum
=
_mm512_set1_ps
(
1.0
f
/
_sum
);
for
(
int
off
=
0
;
off
<
k
;
off
+=
16
)
{
int
remain
=
k
-
off
;
__mmask16
mask
=
(
remain
>=
16
?
0xffff
:
(
1
<<
remain
)
-
1
);
__m512
vx
=
_mm512_maskz_loadu_ps
(
mask
,
buf
+
off
);
vx
=
vx
*
vrsum
;
_mm512_mask_storeu_ps
(
buf
+
off
,
mask
,
vx
);
}
}
}
void
update_out_blk
(
float
*
output
,
const
float
*
exp_ABC
,
float
*
pre_sum
,
float
*
sum
,
float
*
pre_max
,
float
*
max
,
int
m
,
int
n
)
{
assert
(
n
%
16
==
0
);
for
(
int
i
=
0
;
i
<
m
;
++
i
)
{
const
float
*
buf
=
exp_ABC
+
i
*
n
;
float
*
outbuf
=
output
+
i
*
n
;
__m512
merr
=
_mm512_set1_ps
(
pre_max
[
i
]
-
max
[
i
]);
merr
=
vexp
(
merr
);
__m512
vfac
=
_mm512_set1_ps
(
pre_sum
[
i
]
/
sum
[
i
]);
for
(
int
off
=
0
;
off
<
n
;
off
+=
16
)
{
__m512
vout
=
_mm512_loadu_ps
(
outbuf
+
off
);
__m512
vabc
=
_mm512_loadu_ps
(
buf
+
off
);
__m512
vupt
=
vout
*
merr
*
vfac
+
vabc
;
_mm512_storeu_ps
(
outbuf
+
off
,
vupt
);
}
pre_sum
[
i
]
=
sum
[
i
];
pre_max
[
i
]
=
max
[
i
];
}
}
#endif
// hard code: axis = 1
// sum += sum(exp(A[i]))
// output = output * pre_sum / sum + (exp(A) / sum) x B
// pre_sum = sum
void
incremental_tile_attention
(
const
float
*
A
,
const
float
*
B
,
const
float
*
C
,
int
m
,
int
n
,
int
k
,
float
*
pre_sum
,
float
*
sum
,
float
*
pre_max
,
float
*
max
,
float
refac
,
float
*
AB
,
float
*
exp_ABC
,
float
*
output
)
{
sgemm
(
A
,
B
,
AB
,
m
,
k
,
n
,
false
,
true
);
softmax_sum_max
(
AB
,
sum
,
max
,
pre_sum
,
pre_max
,
refac
,
m
,
k
);
sgemm
(
AB
,
C
,
exp_ABC
,
m
,
n
,
k
,
false
,
false
);
update_out_blk
(
output
,
exp_ABC
,
pre_sum
,
sum
,
pre_max
,
max
,
m
,
n
);
}
// scaled dot-product attention: bmm1 + softmax + bmm2
void
scaled_dp_attention
(
const
float
*
query
,
const
float
*
key
,
const
float
*
value
,
float
scale
,
int
batch_size
,
int
itsize
,
int
otsize
,
int
num_head
,
int
head_size
,
float
*
output
)
{
// output = trans(softmax(query * trans(key)) * value)
int
iblk
=
std
::
min
(
512
,
itsize
/
1
);
int
oblk
=
std
::
min
(
512
,
otsize
/
1
);
float
refac
=
scale
;
assert
(
itsize
%
iblk
==
0
);
assert
(
otsize
%
oblk
==
0
);
#ifdef PADDLE_WITH_MKLML
int
nth
=
omp_get_max_threads
();
#else
int
nth
=
1
;
#endif
float
**
pre_sum
;
float
**
sum
;
float
**
pre_max
;
float
**
max
;
float
**
qk_arr
;
float
**
exp_qkv_arr
;
pre_sum
=
new
float
*
[
nth
];
sum
=
new
float
*
[
nth
];
pre_max
=
new
float
*
[
nth
];
max
=
new
float
*
[
nth
];
qk_arr
=
new
float
*
[
nth
];
exp_qkv_arr
=
new
float
*
[
nth
];
for
(
int
i
=
0
;
i
<
nth
;
++
i
)
{
pre_sum
[
i
]
=
new
float
[
iblk
];
sum
[
i
]
=
new
float
[
iblk
];
pre_max
[
i
]
=
new
float
[
iblk
];
max
[
i
]
=
new
float
[
iblk
];
qk_arr
[
i
]
=
new
float
[
iblk
*
oblk
];
exp_qkv_arr
[
i
]
=
new
float
[
iblk
*
head_size
];
}
#ifdef PADDLE_WITH_MKLML
#pragma omp parallel for collapse(3)
#endif
for
(
int
i
=
0
;
i
<
batch_size
;
++
i
)
{
for
(
int
j
=
0
;
j
<
num_head
;
++
j
)
{
for
(
int
m
=
0
;
m
<
itsize
;
m
+=
iblk
)
{
#ifdef PADDLE_WITH_MKLML
int
tid
=
omp_get_thread_num
();
#else
int
tid
=
0
;
#endif
int
ooffset
=
i
*
num_head
*
otsize
*
head_size
+
j
*
otsize
*
head_size
;
const
float
*
k
=
key
+
ooffset
;
const
float
*
v
=
value
+
ooffset
;
int
q_rblk
=
std
::
min
(
iblk
,
itsize
-
m
);
int
ioffset
=
i
*
num_head
*
otsize
*
head_size
+
j
*
otsize
*
head_size
;
const
float
*
q
=
query
+
ioffset
+
m
*
head_size
;
float
*
out
=
output
+
ioffset
+
m
*
head_size
;
// reset out
for
(
int
ii
=
0
;
ii
<
q_rblk
;
++
ii
)
{
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for
(
int
jj
=
0
;
jj
<
head_size
;
++
jj
)
{
out
[
ii
*
head_size
+
jj
]
=
0
;
// reset output
}
}
// reset sum
#ifdef PADDLE_WITH_MKLML
#pragma omp simd
#endif
for
(
int
ii
=
0
;
ii
<
q_rblk
;
++
ii
)
{
pre_sum
[
tid
][
ii
]
=
0
;
sum
[
tid
][
ii
]
=
0
;
pre_max
[
tid
][
ii
]
=
std
::
numeric_limits
<
float
>::
lowest
();
max
[
tid
][
ii
]
=
std
::
numeric_limits
<
float
>::
lowest
();
}
//
for
(
int
b
=
0
;
b
<
otsize
;
b
+=
oblk
)
{
int
kv_rblk
=
std
::
min
(
oblk
,
otsize
-
b
);
const
float
*
blk_k
=
k
+
b
*
head_size
;
const
float
*
blk_v
=
v
+
b
*
head_size
;
incremental_tile_attention
(
q
,
blk_k
,
blk_v
,
q_rblk
,
head_size
,
kv_rblk
,
pre_sum
[
tid
],
sum
[
tid
],
pre_max
[
tid
],
max
[
tid
],
refac
,
qk_arr
[
tid
],
exp_qkv_arr
[
tid
],
out
);
}
}
}
}
for
(
int
i
=
0
;
i
<
nth
;
++
i
)
{
delete
[]
pre_sum
[
i
];
delete
[]
sum
[
i
];
delete
[]
pre_max
[
i
];
delete
[]
max
[
i
];
delete
[]
qk_arr
[
i
];
delete
[]
exp_qkv_arr
[
i
];
}
delete
[]
pre_sum
;
delete
[]
sum
;
delete
[]
pre_max
;
delete
[]
max
;
delete
[]
qk_arr
;
delete
[]
exp_qkv_arr
;
return
;
}
}
// namespace operators
}
// namespace paddle
paddle/fluid/operators/fused/self_dp_attention_op.cc
0 → 100644
浏览文件 @
fc880209
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#include "paddle/fluid/operators/fused/self_dp_attention_op.h"
#include "paddle/fluid/operators/fused/scaled_dp_attention.h"
namespace
paddle
{
namespace
operators
{
void
SelfDPAttenOp
::
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
{
OP_INOUT_CHECK
(
ctx
->
HasInput
(
"X"
),
"Input"
,
"X"
,
"SelfDPAtten"
);
OP_INOUT_CHECK
(
ctx
->
HasOutput
(
"Out"
),
"Output"
,
"Out"
,
"SelfDPAtten"
);
auto
dim_input
=
ctx
->
GetInputDim
(
"X"
);
PADDLE_ENFORCE_EQ
(
dim_input
.
size
(),
5
,
platform
::
errors
::
InvalidArgument
(
"The size of input X dims should be 5, "
"[batchsize, tokensize, 3, nhead, headsize] "
", but now Input X dim is:[%s] "
,
dim_input
));
PADDLE_ENFORCE_EQ
(
dim_input
[
4
]
%
16
,
0
,
platform
::
errors
::
InvalidArgument
(
"The last dim of input X should be a multiple of 16, "
", but now the dim is:[%d] "
"Please remove self_attention_fuse_pass from the lists"
,
dim_input
[
4
]));
framework
::
DDim
out_dims
(
{
dim_input
[
0
],
dim_input
[
1
],
dim_input
[
3
],
dim_input
[
4
]});
ctx
->
SetOutputDim
(
"Out"
,
out_dims
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
phi
::
KernelKey
SelfDPAttenOp
::
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
{
return
phi
::
KernelKey
(
OperatorWithKernel
::
IndicateVarDataType
(
ctx
,
"X"
),
ctx
.
GetPlace
());
}
void
SelfDPAttenOpMaker
::
Make
()
{
AddInput
(
"X"
,
"(LoDTensor) Input tensors of this operator."
);
AddOutput
(
"Out"
,
"(LoDTensor) Output tensor of this operator."
);
AddAttr
<
float
>
(
"alpha"
,
"The scale of Out"
).
SetDefault
(
1.0
f
);
AddAttr
<
int
>
(
"head_number"
,
"The number of heads of the matrix"
)
.
SetDefault
(
1
);
AddComment
(
R"DOC(
Multihead Self-scaled-dp-Attention Operator.
)DOC"
);
}
template
<
typename
T
>
class
SelfDPAttenKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
using
DeviceContext
=
phi
::
CPUContext
;
auto
*
in
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
out
=
ctx
.
Output
<
LoDTensor
>
(
"Out"
);
auto
place
=
ctx
.
GetPlace
();
auto
*
input_d
=
in
->
data
<
T
>
();
auto
*
output_d
=
out
->
mutable_data
<
T
>
(
place
);
float
scale
=
static_cast
<
float
>
(
ctx
.
Attr
<
float
>
(
"alpha"
));
int
head_number
=
ctx
.
Attr
<
int
>
(
"head_number"
);
auto
input_dims
=
in
->
dims
();
// in shouble be (batch * seq * 3 * head_num * head_size)
// out shouble be (batch * seq * head_num * head_size)
int
batch_size
=
input_dims
[
0
];
int
seq_len
=
input_dims
[
1
];
int
head_size
=
input_dims
[
4
];
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
phi
::
DenseTensor
temp1
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
input_dims
,
dev_ctx
);
float
*
trans_input
=
temp1
.
mutable_data
<
float
>
(
place
);
phi
::
DenseTensor
temp2
=
ctx
.
AllocateTmpTensor
<
T
,
DeviceContext
>
(
input_dims
,
dev_ctx
);
float
*
trans_output
=
temp2
.
mutable_data
<
float
>
(
place
);
transpose_before_bmm1
<
T
,
float
>
(
input_d
,
trans_input
,
batch_size
,
seq_len
,
head_number
,
head_size
);
float
*
query
=
trans_input
;
float
*
key
=
trans_input
+
batch_size
*
head_number
*
seq_len
*
head_size
;
float
*
value
=
trans_input
+
batch_size
*
head_number
*
seq_len
*
head_size
*
2
;
scaled_dp_attention
(
query
,
key
,
value
,
scale
,
batch_size
,
seq_len
,
seq_len
,
head_number
,
head_size
,
trans_output
);
transpose_after_bmm2
<
float
,
T
>
(
trans_output
,
output_d
,
batch_size
,
seq_len
,
head_number
,
head_size
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
self_dp_attention
,
ops
::
SelfDPAttenOp
,
ops
::
SelfDPAttenOpMaker
);
REGISTER_OP_KERNEL
(
self_dp_attention
,
CPU
,
phi
::
CPUPlace
,
ops
::
SelfDPAttenKernel
<
float
>
,
ops
::
SelfDPAttenKernel
<
double
>
);
paddle/fluid/operators/fused/self_dp_attention_op.h
0 → 100644
浏览文件 @
fc880209
/* Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#pragma once
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
LoDTensor
=
phi
::
DenseTensor
;
using
Tensor
=
phi
::
DenseTensor
;
class
SelfDPAttenOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
;
protected:
phi
::
KernelKey
GetExpectedKernelType
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
;
};
class
SelfDPAttenOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
;
};
}
// namespace operators
}
// namespace paddle
test/mkldnn/test_fused_vit_attention.py
0 → 100644
浏览文件 @
fc880209
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
numpy
as
np
import
paddle
import
paddle.incubate
from
paddle.fluid
import
core
paddle
.
enable_static
()
np
.
random
.
seed
(
0
)
def
test_fuse_resenet_unit
():
place
=
paddle
.
CPUPlace
()
program
=
paddle
.
static
.
Program
()
startup_program
=
paddle
.
static
.
Program
()
batch_size
=
1
token_size
=
4097
hidden_size
=
768
num_heads
=
12
dtype
=
np
.
float32
with
paddle
.
static
.
program_guard
(
program
,
startup_program
):
x
=
paddle
.
static
.
data
(
"x"
,
[
batch_size
,
token_size
,
hidden_size
*
3
],
dtype
=
dtype
)
qkv
=
x
.
reshape
(
(
batch_size
,
token_size
,
3
,
num_heads
,
hidden_size
//
num_heads
)
).
transpose
((
2
,
0
,
3
,
1
,
4
))
q
,
k
,
v
=
qkv
[
0
],
qkv
[
1
],
qkv
[
2
]
attn
=
q
.
matmul
(
k
.
transpose
((
0
,
1
,
3
,
2
)))
attn
=
paddle
.
nn
.
functional
.
softmax
(
attn
,
axis
=-
1
)
out
=
(
(
attn
.
matmul
(
v
))
.
transpose
((
0
,
2
,
1
,
3
))
.
reshape
((
-
1
,
token_size
,
hidden_size
))
)
graph
=
core
.
Graph
(
program
.
desc
)
core
.
get_pass
(
"self_attention_fuse_pass"
).
apply
(
graph
)
after_program
=
paddle
.
fluid
.
framework
.
IrGraph
(
graph
).
to_program
()
exe
=
paddle
.
static
.
Executor
(
place
)
exe
.
run
(
startup_program
)
feed
=
{
"x"
:
np
.
random
.
randn
(
batch_size
,
token_size
,
hidden_size
*
3
).
astype
(
dtype
)
}
before_out
=
exe
.
run
(
program
,
feed
=
feed
,
fetch_list
=
[
out
.
name
])
after_out
=
exe
.
run
(
after_program
,
feed
=
feed
,
fetch_list
=
[
out
.
name
])
np
.
testing
.
assert_allclose
(
before_out
[
0
],
after_out
[
0
],
rtol
=
1e-05
,
atol
=
0.005
)
if
__name__
==
'__main__'
:
test_fuse_resenet_unit
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录