Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0f59d4e6
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
0f59d4e6
编写于
6月 28, 2021
作者:
王
王明冬
提交者:
GitHub
6月 28, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add compat precondition for multihead_matmul_fuse_pass_v2,v3, test=develop (#33786)
上级
7f9b8f06
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
498 addition
and
245 deletion
+498
-245
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
+458
-211
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h
+10
-12
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
...e/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
+19
-19
paddle/fluid/framework/ir/pass_tester_helper.h
paddle/fluid/framework/ir/pass_tester_helper.h
+5
-1
paddle/fluid/operators/compat/matmul.pbtxt
paddle/fluid/operators/compat/matmul.pbtxt
+4
-0
paddle/fluid/operators/compat/softmax.pbtxt
paddle/fluid/operators/compat/softmax.pbtxt
+2
-2
未找到文件。
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.cc
浏览文件 @
0f59d4e6
...
@@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() {
...
@@ -422,13 +422,335 @@ PDNode* MultiHeadMatmulPattern::operator()() {
return
transpose2_2_out_var
;
return
transpose2_2_out_var
;
}
}
static
int
BuildFusionV2
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
PDNode
*
MultiHeadMatmulV3Pattern
::
operator
()()
{
Scope
*
scope
)
{
std
::
unordered_set
<
std
::
string
>
matmul_ops
{
"matmul"
,
"matmul_v2"
};
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
input0
->
assert_is_op_input
(
"matmul"
);
// First path with scale
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
mul0_w_var
=
pattern
->
NewNode
(
mul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul0_out_var
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul0
)
eltadd0
;
decltype
(
mul0
)
eltadd0_b_var
;
decltype
(
mul0
)
eltadd0_out_var
;
mul0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_qkv_out_var
->
assert_is_op_input
(
"matmul"
);
// Second path to matmul
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
mul1_w_var
=
pattern
->
NewNode
(
mul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul1_out_var
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul1
)
eltadd1
;
decltype
(
mul1
)
eltadd1_b_var
;
decltype
(
mul1
)
eltadd1_out_var
;
mul1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
// link to matmul qk
// Third path to matmul
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
mul2_w_var
=
pattern
->
NewNode
(
mul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul2_out_var
=
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul2
)
eltadd2
;
decltype
(
mul2
)
eltadd2_b_var
;
decltype
(
mul2
)
eltadd2_out_var
;
mul2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
// link to matmul qkv
// Q path
mul0
->
LinksFrom
({
input0
,
mul0_w_var
}).
LinksTo
({
mul0_out_var
});
eltadd0
->
LinksFrom
({
mul0_out_var
,
eltadd0_b_var
}).
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
// K path
mul1
->
LinksFrom
({
input0
,
mul1_w_var
}).
LinksTo
({
mul1_out_var
});
eltadd1
->
LinksFrom
({
mul1_out_var
,
eltadd1_b_var
}).
LinksTo
({
eltadd1_out_var
});
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// compute q*k
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// V path
mul2
->
LinksFrom
({
input0
,
mul2_w_var
}).
LinksTo
({
mul2_out_var
});
eltadd2
->
LinksFrom
({
mul2_out_var
,
eltadd2_b_var
}).
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// compute q*k*v
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
return
transpose2_2_out_var
;
}
}
// namespace patterns
void
MultiHeadMatmulFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
fusion_count
=
patterns
::
BuildFusion
(
graph
,
name_scope_
);
AddStatis
(
fusion_count
);
}
MultiHeadMatmulV2FusePass
::
MultiHeadMatmulV2FusePass
()
{
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumEQ
(
2
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
AddOpCompat
(
OpCompat
(
"scale"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"scale"
)
.
IsType
<
float
>
()
// copy to new op. so unconstrained.
.
End
()
.
AddAttr
(
"bias"
)
.
IsNumEQ
(
0.
f
)
.
End
()
.
AddAttr
(
"bias_after_scale"
)
// bias is 0, so unconstrained.
.
IsType
<
bool
>
()
.
End
();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsNumEQ
(
1.0
f
)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
// QK(true) QKV(false)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
}
int
MultiHeadMatmulV2FusePass
::
BuildFusionV2
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
// Create pattern.
MultiHeadMatmulPattern
multihead_pattern
(
pattern
,
name_scope
);
patterns
::
MultiHeadMatmulPattern
multihead_pattern
(
pattern
,
name_scope
);
multihead_pattern
();
multihead_pattern
();
// Create New OpDesc
// Create New OpDesc
...
@@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
...
@@ -580,6 +902,11 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
int
fusion_count
{
0
};
int
fusion_count
{
0
};
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
Graph
*
g
)
{
if
(
!
IsCompat
(
subgraph
,
g
))
{
LOG
(
WARNING
)
<<
"Op compat check in multihead_matmul_fuse_pass_v2 failed."
;
return
;
}
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
// GET_IR_NODE_FROM_SUBGRAPH(dropout_out, dropout_out, multihead_pattern);
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
multihead_pattern
);
GET_IR_NODE_FROM_SUBGRAPH
(
input0
,
input0
,
multihead_pattern
);
...
@@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
...
@@ -714,197 +1041,141 @@ static int BuildFusionV2(Graph* graph, const std::string& name_scope,
return
fusion_count
;
return
fusion_count
;
}
}
PDNode
*
MultiHeadMatmulV3Pattern
::
operator
()()
{
void
MultiHeadMatmulV2FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
std
::
unordered_set
<
std
::
string
>
matmul_ops
{
"matmul"
,
"matmul_v2"
};
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
input0
=
pattern
->
NewNode
(
input0_repr
());
auto
*
scope
=
param_scope
();
input0
->
assert_is_op_input
(
"matmul"
);
PADDLE_ENFORCE_NOT_NULL
(
scope
,
// First path with scale
platform
::
errors
::
Fatal
(
auto
*
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
(
"matmul"
);
"During the multiheadMatmul pass, The scope should not be null."
));
auto
*
mul0_w_var
=
pattern
->
NewNode
(
mul0_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul0_out_var
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul0
)
eltadd0
;
decltype
(
mul0
)
eltadd0_b_var
;
decltype
(
mul0
)
eltadd0_out_var
;
mul0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd0
=
pattern
->
NewNode
(
eltadd0_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd0_b_var
=
pattern
->
NewNode
(
eltadd0_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd0_out_var
=
pattern
->
NewNode
(
eltadd0_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_0
=
pattern
->
NewNode
(
reshape2_0_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_0_out_var
=
pattern
->
NewNode
(
reshape2_0_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_0
=
pattern
->
NewNode
(
transpose2_0_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_0_out_var
=
pattern
->
NewNode
(
transpose2_0_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_0_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"X"
);
auto
*
matmul_qk
=
pattern
->
NewNode
(
matmul_qk_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
matmul_qk_out_var
=
pattern
->
NewNode
(
matmul_qk_out_repr
())
->
assert_is_op_output
(
"matmul"
);
matmul_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
auto
*
eltadd_qk
=
pattern
->
NewNode
(
eltadd_qk_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
eltadd_qk_b_var
=
pattern
->
NewNode
(
eltadd_qk_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
eltadd_qk_out_var
=
pattern
->
NewNode
(
eltadd_qk_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd_qk_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"softmax"
);
auto
*
softmax_qk
=
pattern
->
NewNode
(
softmax_qk_repr
())
->
assert_is_op
(
"softmax"
);
auto
*
softmax_qk_out_var
=
pattern
->
NewNode
(
softmax_qk_out_repr
())
->
assert_is_op_output
(
"softmax"
);
softmax_qk_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
auto
*
matmul_qkv
=
pattern
->
NewNode
(
matmul_qkv_repr
())
->
assert_is_ops
(
matmul_ops
);
auto
*
matmul_qkv_out_var
=
pattern
->
NewNode
(
matmul_qkv_out_repr
())
->
assert_is_ops_output
(
matmul_ops
);
matmul_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_qkv
=
pattern
->
NewNode
(
transpose2_qkv_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_qkv_out_var
=
pattern
->
NewNode
(
transpose2_qkv_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_qkv_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_qkv
=
pattern
->
NewNode
(
reshape2_qkv_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_qkv_out_var
=
pattern
->
NewNode
(
reshape2_qkv_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_qkv_out_var
->
assert_is_op_input
(
"matmul"
);
// Second path to matmul
auto
*
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
mul1_w_var
=
pattern
->
NewNode
(
mul1_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul1_out_var
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul1
)
eltadd1
;
decltype
(
mul1
)
eltadd1_b_var
;
decltype
(
mul1
)
eltadd1_out_var
;
mul1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd1
=
pattern
->
NewNode
(
eltadd1_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd1_b_var
=
pattern
->
NewNode
(
eltadd1_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd1_out_var
=
pattern
->
NewNode
(
eltadd1_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_1
=
pattern
->
NewNode
(
reshape2_1_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_1_out_var
=
pattern
->
NewNode
(
reshape2_1_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_1
=
pattern
->
NewNode
(
transpose2_1_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_1_out_var
=
pattern
->
NewNode
(
transpose2_1_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_1_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
// link to matmul qk
// Third path to matmul
auto
*
mul2
=
pattern
->
NewNode
(
mul2_repr
())
->
assert_is_op
(
"matmul"
);
auto
*
mul2_w_var
=
pattern
->
NewNode
(
mul2_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"matmul"
,
"Y"
);
auto
*
mul2_out_var
=
pattern
->
NewNode
(
mul2_out_repr
())
->
assert_is_op_output
(
"matmul"
);
decltype
(
mul2
)
eltadd2
;
decltype
(
mul2
)
eltadd2_b_var
;
decltype
(
mul2
)
eltadd2_out_var
;
mul2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"elementwise_add"
);
eltadd2
=
pattern
->
NewNode
(
eltadd2_repr
())
->
assert_is_op
(
"elementwise_add"
);
eltadd2_b_var
=
pattern
->
NewNode
(
eltadd2_b_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
eltadd2_out_var
=
pattern
->
NewNode
(
eltadd2_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
);
eltadd2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"reshape2"
);
auto
*
reshape2_2
=
pattern
->
NewNode
(
reshape2_2_repr
())
->
assert_is_op
(
"reshape2"
);
auto
*
reshape2_2_out_var
=
pattern
->
NewNode
(
reshape2_2_out_repr
())
->
assert_is_op_output
(
"reshape2"
);
reshape2_2_out_var
->
AsIntermediate
()
->
assert_is_op_input
(
"transpose2"
);
auto
*
transpose2_2
=
pattern
->
NewNode
(
transpose2_2_repr
())
->
assert_is_op
(
"transpose2"
);
auto
*
transpose2_2_out_var
=
pattern
->
NewNode
(
transpose2_2_out_repr
())
->
assert_is_op_output
(
"transpose2"
);
transpose2_2_out_var
->
AsIntermediate
()
->
assert_is_ops_input
(
matmul_ops
);
// link to matmul qkv
// Q path
mul0
->
LinksFrom
({
input0
,
mul0_w_var
}).
LinksTo
({
mul0_out_var
});
eltadd0
->
LinksFrom
({
mul0_out_var
,
eltadd0_b_var
}).
LinksTo
({
eltadd0_out_var
});
reshape2_0
->
LinksFrom
({
eltadd0_out_var
}).
LinksTo
({
reshape2_0_out_var
});
int
fusion_count
=
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
transpose2_0
->
LinksFrom
({
reshape2_0_out_var
}).
LinksTo
({
transpose2_0_out_var
});
if
(
fusion_count
>
0
)
{
// K path
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
mul1
->
LinksFrom
({
input0
,
mul1_w_var
}).
LinksTo
({
mul1_out_var
});
}
eltadd1
->
LinksFrom
({
mul1_out_var
,
eltadd1_b_var
}).
LinksTo
({
eltadd1_out_var
});
AddStatis
(
fusion_count
);
reshape2_1
->
LinksFrom
({
eltadd1_out_var
}).
LinksTo
({
reshape2_1_out_var
});
}
transpose2_1
->
LinksFrom
({
reshape2_1_out_var
}).
LinksTo
({
transpose2_1_out_var
});
// compute q*k
matmul_qk
->
LinksFrom
({
transpose2_0_out_var
,
transpose2_1_out_var
})
.
LinksTo
({
matmul_qk_out_var
});
eltadd_qk
->
LinksFrom
({
matmul_qk_out_var
,
eltadd_qk_b_var
})
.
LinksTo
({
eltadd_qk_out_var
});
softmax_qk
->
LinksFrom
({
eltadd_qk_out_var
}).
LinksTo
({
softmax_qk_out_var
});
// V path
mul2
->
LinksFrom
({
input0
,
mul2_w_var
}).
LinksTo
({
mul2_out_var
});
eltadd2
->
LinksFrom
({
mul2_out_var
,
eltadd2_b_var
}).
LinksTo
({
eltadd2_out_var
});
reshape2_2
->
LinksFrom
({
eltadd2_out_var
}).
LinksTo
({
reshape2_2_out_var
});
transpose2_2
->
LinksFrom
({
reshape2_2_out_var
}).
LinksTo
({
transpose2_2_out_var
});
// compute q*k*v
matmul_qkv
->
LinksFrom
({
softmax_qk_out_var
,
transpose2_2_out_var
})
.
LinksTo
({
matmul_qkv_out_var
});
transpose2_qkv
->
LinksFrom
({
matmul_qkv_out_var
})
.
LinksTo
({
transpose2_qkv_out_var
});
reshape2_qkv
->
LinksFrom
({
transpose2_qkv_out_var
})
.
LinksTo
({
reshape2_qkv_out_var
});
return
transpose2_2_out_var
;
MultiHeadMatmulV3FusePass
::
MultiHeadMatmulV3FusePass
()
{
AddOpCompat
(
OpCompat
(
"mul"
))
.
AddInput
(
"X"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// the shape shoule be (N*H, N*H)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
// the shape shoule be (B, S, N*H)
.
IsTensor
()
.
End
()
.
AddAttr
(
"x_num_col_dims"
)
.
IsNumEQ
(
2
)
.
End
()
.
AddAttr
(
"y_num_col_dims"
)
.
IsNumEQ
(
1
)
.
End
();
AddOpCompat
(
OpCompat
(
"elementwise_add"
))
.
AddInput
(
"X"
)
// in bias, shape is (B, S, N*H),
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
// in bias, shape is (N*H)
// in biasqk, shape is (B, H, S, S)
.
IsTensor
()
.
End
()
// in bias, shape is (B, S, N*H)
// in biasqk, shape is (B, H, S, S)
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
// in bias, it equal to 2
// in biasqk, it equal to -1 or 0
.
AddAttr
(
"axis"
)
.
IsIntIn
({
2
,
-
1
,
0
})
.
End
();
AddOpCompat
(
OpCompat
(
"reshape2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Shape"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddInput
(
"ShapeTensor"
)
.
IsTensor
()
.
IsOptional
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"shape"
)
// -->(B, S, H, N) <--(B, S, N*H)
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// -->: (B, S, H, N) -> (B, H, S, N)
// <--: (B, H, S, N) -> (B, S, H, N)
AddOpCompat
(
OpCompat
(
"transpose2"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"XShape"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
// {0, 2, 1, 3}
.
IsType
<
std
::
vector
<
int
>>
()
.
End
();
// QK (B, H, S, N)*(B, H, S, N) -> (B, H, S, S)
// QKV (B, H, S, S)*(B, H, S, N) -> (B, H, S, N)
AddOpCompat
(
OpCompat
(
"matmul"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddInput
(
"Y"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"alpha"
)
.
IsType
<
float
>
()
// QK(anyvalue, will copy to new op) QKV(1.0)
.
End
()
.
AddAttr
(
"transpose_X"
)
.
IsBoolEQ
(
false
)
.
End
()
.
AddAttr
(
"transpose_Y"
)
// QK(true) QKV(false)
.
IsType
<
bool
>
()
.
End
();
AddOpCompat
(
OpCompat
(
"softmax"
))
.
AddInput
(
"X"
)
.
IsTensor
()
.
End
()
.
AddOutput
(
"Out"
)
.
IsTensor
()
.
End
()
.
AddAttr
(
"axis"
)
.
IsIntIn
({
-
1
,
3
})
// shape is (B, H, S, S), so axis is -1 or 3
.
End
();
}
}
static
int
BuildFusionV3
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
int
MultiHeadMatmulV3FusePass
::
BuildFusionV3
(
Graph
*
graph
,
Scope
*
scope
)
{
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
{
GraphPatternDetector
gpd
;
GraphPatternDetector
gpd
;
auto
*
pattern
=
gpd
.
mutable_pattern
();
auto
*
pattern
=
gpd
.
mutable_pattern
();
// Create pattern.
// Create pattern.
MultiHeadMatmulV3Pattern
multihead_pattern
(
pattern
,
name_scope
);
patterns
::
MultiHeadMatmulV3Pattern
multihead_pattern
(
pattern
,
name_scope
);
multihead_pattern
();
multihead_pattern
();
// Create New OpDesc
// Create New OpDesc
...
@@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope,
...
@@ -1155,30 +1426,6 @@ static int BuildFusionV3(Graph* graph, const std::string& name_scope,
return
fusion_count
;
return
fusion_count
;
}
}
}
// namespace patterns
void
MultiHeadMatmulFusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
int
fusion_count
=
patterns
::
BuildFusion
(
graph
,
name_scope_
);
AddStatis
(
fusion_count
);
}
void
MultiHeadMatmulV2FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
PADDLE_ENFORCE_NOT_NULL
(
scope
,
platform
::
errors
::
Fatal
(
"During the multiheadMatmul pass, The scope should not be null."
));
int
fusion_count
=
patterns
::
BuildFusionV2
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
}
AddStatis
(
fusion_count
);
}
void
MultiHeadMatmulV3FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
void
MultiHeadMatmulV3FusePass
::
ApplyImpl
(
Graph
*
graph
)
const
{
FusePassBase
::
Init
(
name_scope_
,
graph
);
FusePassBase
::
Init
(
name_scope_
,
graph
);
auto
*
scope
=
param_scope
();
auto
*
scope
=
param_scope
();
...
@@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
...
@@ -1187,7 +1434,7 @@ void MultiHeadMatmulV3FusePass::ApplyImpl(Graph* graph) const {
platform
::
errors
::
Fatal
(
platform
::
errors
::
Fatal
(
"During the multiheadMatmul pass, The scope should not be null."
));
"During the multiheadMatmul pass, The scope should not be null."
));
int
fusion_count
=
patterns
::
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
int
fusion_count
=
BuildFusionV3
(
graph
,
name_scope_
,
scope
);
if
(
fusion_count
>
0
)
{
if
(
fusion_count
>
0
)
{
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
graph
->
Set
(
kMultiheadMatmulPass
,
new
bool
(
true
));
}
}
...
...
paddle/fluid/framework/ir/multihead_matmul_fuse_pass.h
浏览文件 @
0f59d4e6
...
@@ -18,16 +18,6 @@
...
@@ -18,16 +18,6 @@
#include <string>
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
class
Graph
;
}
// namespace ir
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
paddle
{
namespace
framework
{
namespace
framework
{
...
@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
...
@@ -158,22 +148,30 @@ class MultiHeadMatmulFusePass : public FusePassBase {
class
MultiHeadMatmulV2FusePass
:
public
FusePassBase
{
class
MultiHeadMatmulV2FusePass
:
public
FusePassBase
{
public:
public:
virtual
~
MultiHeadMatmulV2FusePass
()
{}
MultiHeadMatmulV2FusePass
();
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multihead_matmul_fuse_v2"
};
const
std
::
string
name_scope_
{
"multihead_matmul_fuse_v2"
};
private:
int
BuildFusionV2
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
};
class
MultiHeadMatmulV3FusePass
:
public
FusePassBase
{
class
MultiHeadMatmulV3FusePass
:
public
FusePassBase
{
public:
public:
virtual
~
MultiHeadMatmulV3FusePass
()
{}
MultiHeadMatmulV3FusePass
();
protected:
protected:
void
ApplyImpl
(
Graph
*
graph
)
const
;
void
ApplyImpl
(
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"multihead_matmul_fuse_v3"
};
const
std
::
string
name_scope_
{
"multihead_matmul_fuse_v3"
};
private:
int
BuildFusionV3
(
Graph
*
graph
,
const
std
::
string
&
name_scope
,
Scope
*
scope
)
const
;
};
};
}
// namespace ir
}
// namespace ir
...
...
paddle/fluid/framework/ir/multihead_matmul_fuse_pass_tester.cc
浏览文件 @
0f59d4e6
...
@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
...
@@ -64,7 +64,7 @@ TEST(MultiHeadMatmulFusePass, basic) {
// (transpose_qkv) reshape -> reshape_qkv
// (transpose_qkv) reshape -> reshape_qkv
// (reshape_qkv) mul -> mul_qkv
// (reshape_qkv) mul -> mul_qkv
Layers
layers
;
Layers
layers
;
auto
*
x
=
layers
.
data
(
"x"
,
{
128
,
768
});
auto
*
x
=
layers
.
data
(
"x"
,
{
1
,
1
28
,
768
});
auto
out
=
layers
.
layer_norm
(
x
);
auto
out
=
layers
.
layer_norm
(
x
);
auto
*
layer_out
=
out
[
0
];
auto
*
layer_out
=
out
[
0
];
...
@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
...
@@ -72,41 +72,41 @@ TEST(MultiHeadMatmulFusePass, basic) {
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
768
,
768
},
true
);
auto
*
weights_1
=
layers
.
data
(
"weights1"
,
{
768
,
768
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
768
,
768
},
true
);
auto
*
weights_2
=
layers
.
data
(
"weights2"
,
{
768
,
768
},
true
);
auto
*
mul_out_0
=
layers
.
mul
(
layer_out
,
weights_0
);
auto
*
mul_out_0
=
layers
.
mul
(
layer_out
,
weights_0
,
nullptr
,
2
);
auto
*
mul_out_1
=
layers
.
mul
(
layer_out
,
weights_1
);
auto
*
mul_out_1
=
layers
.
mul
(
layer_out
,
weights_1
,
nullptr
,
2
);
auto
*
mul_out_2
=
layers
.
mul
(
layer_out
,
weights_2
);
auto
*
mul_out_2
=
layers
.
mul
(
layer_out
,
weights_2
,
nullptr
,
2
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
768
},
true
);
auto
*
b0
=
layers
.
data
(
"bias_0"
,
{
768
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
768
},
true
);
auto
*
b1
=
layers
.
data
(
"bias_1"
,
{
768
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
768
},
true
);
auto
*
b2
=
layers
.
data
(
"bias_2"
,
{
768
},
true
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
mul_out_0
,
b0
);
auto
*
elementwise_out_0
=
layers
.
elementwise_add
(
mul_out_0
,
b0
,
nullptr
,
2
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
mul_out_1
,
b1
);
auto
*
elementwise_out_1
=
layers
.
elementwise_add
(
mul_out_1
,
b1
,
nullptr
,
2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
mul_out_2
,
b2
);
auto
*
elementwise_out_2
=
layers
.
elementwise_add
(
mul_out_2
,
b2
,
nullptr
,
2
);
std
::
vector
<
int
>
shape
=
{
128
,
12
,
64
};
std
::
vector
<
int
>
shape
=
{
1
,
1
28
,
12
,
64
};
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
);
auto
*
reshape_0
=
layers
.
reshape2
(
elementwise_out_0
,
shape
,
true
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
);
auto
*
reshape_1
=
layers
.
reshape2
(
elementwise_out_1
,
shape
,
true
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
);
auto
*
reshape_2
=
layers
.
reshape2
(
elementwise_out_2
,
shape
,
true
);
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
std
::
vector
<
int
>
axis
=
{
0
,
2
,
1
,
3
};
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
);
auto
*
transpose_0
=
layers
.
transpose2
(
reshape_0
,
axis
,
true
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
);
auto
*
transpose_1
=
layers
.
transpose2
(
reshape_1
,
axis
,
true
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
);
auto
*
transpose_2
=
layers
.
transpose2
(
reshape_2
,
axis
,
true
);
auto
*
scale_0
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
auto
*
scale_0
=
layers
.
scale
(
transpose_0
,
0.125
,
0
,
false
);
auto
*
matmul_qk
=
layers
.
matmul
(
scale_0
,
transpose_1
);
auto
*
matmul_qk
=
layers
.
matmul
(
scale_0
,
transpose_1
,
nullptr
,
false
,
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
76
8
},
true
);
auto
*
bqk
=
layers
.
data
(
"biasqk"
,
{
1
,
12
,
128
,
12
8
},
true
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
elementwise_qk
=
layers
.
elementwise_add
(
matmul_qk
,
bqk
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
softmax_qk
=
layers
.
softmax
(
elementwise_qk
,
-
1
);
auto
*
matmul_qkv
=
layers
.
matmul
(
softmax_qk
,
transpose_2
);
auto
*
matmul_qkv
=
layers
.
matmul
(
softmax_qk
,
transpose_2
);
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
});
auto
*
transpose_qkv
=
layers
.
transpose2
(
matmul_qkv
,
{
0
,
2
,
1
,
3
}
,
true
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
28
,
768
}
);
auto
*
reshape_qkv_out
=
layers
.
reshape2
(
transpose_qkv
,
{
1
,
128
,
768
},
true
);
auto
*
weights_l
=
layers
.
data
(
"weightsl"
,
{
768
,
768
},
true
);
auto
*
weights_l
=
layers
.
data
(
"weightsl"
,
{
768
,
768
},
true
);
layers
.
mul
(
reshape_qkv_out
,
weights_l
);
layers
.
mul
(
reshape_qkv_out
,
weights_l
,
nullptr
,
2
);
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
std
::
unique_ptr
<
ir
::
Graph
>
graph
(
new
ir
::
Graph
(
layers
.
main_program
()));
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
graph
->
Set
(
"__param_scope__"
,
CreateParamScope
());
...
...
paddle/fluid/framework/ir/pass_tester_helper.h
浏览文件 @
0f59d4e6
...
@@ -293,13 +293,17 @@ struct Layers {
...
@@ -293,13 +293,17 @@ struct Layers {
return
outs
;
return
outs
;
}
}
VarDesc
*
matmul
(
VarDesc
*
x
,
VarDesc
*
y
,
VarDesc
*
alpha
=
nullptr
)
{
VarDesc
*
matmul
(
VarDesc
*
x
,
VarDesc
*
y
,
VarDesc
*
alpha
=
nullptr
,
bool
transpose_x
=
false
,
bool
transpose_y
=
false
)
{
VarDesc
*
out
=
lod_tensor
(
unique_name
());
VarDesc
*
out
=
lod_tensor
(
unique_name
());
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
OpDesc
*
op
=
program_
.
MutableBlock
(
0
)
->
AppendOp
();
op
->
SetType
(
"matmul"
);
op
->
SetType
(
"matmul"
);
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetInput
(
"X"
,
{
x
->
Name
()});
op
->
SetInput
(
"Y"
,
{
y
->
Name
()});
op
->
SetInput
(
"Y"
,
{
y
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetOutput
(
"Out"
,
{
out
->
Name
()});
op
->
SetAttr
(
"transpose_X"
,
transpose_x
);
op
->
SetAttr
(
"transpose_Y"
,
transpose_y
);
op
->
SetAttr
(
"alpha"
,
1.0
f
);
return
out
;
return
out
;
}
}
...
...
paddle/fluid/operators/compat/matmul.pbtxt
浏览文件 @
0f59d4e6
...
@@ -23,6 +23,10 @@ def {
...
@@ -23,6 +23,10 @@ def {
}
}
}
}
extra {
extra {
attrs {
name: "head_number"
type: INT
}
attrs {
attrs {
name: "Scale_out"
name: "Scale_out"
type: FLOAT
type: FLOAT
...
...
paddle/fluid/operators/compat/softmax.pbtxt
浏览文件 @
0f59d4e6
...
@@ -10,12 +10,12 @@ def {
...
@@ -10,12 +10,12 @@ def {
name: "axis"
name: "axis"
type: INT
type: INT
}
}
}
extra {
attrs {
attrs {
name: "data_format"
name: "data_format"
type: STRING
type: STRING
}
}
}
extra {
attrs {
attrs {
name: "op_role"
name: "op_role"
type: INT
type: INT
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录