Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
8d325d82
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2302
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
8d325d82
编写于
2月 23, 2023
作者:
C
csy0225
提交者:
GitHub
2月 23, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Migrate xpu_embedding_with_eltwise_add_fuse_pass (#50590)
上级
d7673e2f
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
655 addition
and
42 deletion
+655
-42
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/delete_dropout_op_pass.cc
paddle/fluid/framework/ir/delete_dropout_op_pass.cc
+39
-35
paddle/fluid/framework/ir/graph_pattern_detector.cc
paddle/fluid/framework/ir/graph_pattern_detector.cc
+9
-5
paddle/fluid/framework/ir/graph_pattern_detector.h
paddle/fluid/framework/ir/graph_pattern_detector.h
+1
-1
paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc
...mework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc
+313
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-1
paddle/phi/api/yaml/static_ops.yaml
paddle/phi/api/yaml/static_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu1_op_list.cc
paddle/phi/backends/xpu/xpu1_op_list.cc
+2
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+2
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+22
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+5
-0
paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc
...rnels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc
+84
-0
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py
...ence/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py
+167
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
8d325d82
...
...
@@ -221,6 +221,7 @@ if(WITH_XPU)
SRCS xpu/pass_utils.cc
DEPS pass
)
set
(
XPU_PASS_DEPS xpu_quant_utils xpu_pass_utils
)
pass_library
(
embedding_with_eltwise_add_xpu_fuse_pass inference DIR xpu
)
pass_library
(
fc_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
multi_encoder_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
...
...
paddle/fluid/framework/ir/delete_dropout_op_pass.cc
浏览文件 @
8d325d82
...
...
@@ -30,46 +30,50 @@ namespace ir {
void
DeleteDropoutOpPass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
const
std
::
string
pattern_name
=
"delete_dropout_op_pattern"
;
FusePassBase
::
Init
(
pattern_name
,
graph
);
int
found_subgraph_count
=
0
;
GraphPatternDetector
gpd
;
patterns
::
DeleteDropoutOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
();
for
(
auto
with_mask
:
{
true
,
false
})
{
GraphPatternDetector
gpd
;
patterns
::
DeleteDropoutOpPattern
pattern
(
gpd
.
mutable_pattern
(),
pattern_name
);
pattern
(
with_mask
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE
(
dropout_op_x
);
GET_IR_NODE
(
dropout_op
);
GET_IR_NODE
(
dropout_op_out
);
GET_IR_NODE
(
dropout_op_mask
);
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
g
)
{
GET_IR_NODE
(
dropout_op_x
);
GET_IR_NODE
(
dropout_op
);
GET_IR_NODE
(
dropout_op_out
);
// link dropout_op_out to pre_op
auto
dropout_op_x_name
=
dropout_op_x
->
Var
()
->
Name
();
auto
dropout_op_out_name
=
dropout_op_out
->
Var
()
->
Name
();
auto
pre_ops
=
dropout_op_x
->
inputs
;
if
(
pre_ops
.
empty
())
return
;
auto
pre_op_desc
=
pre_ops
[
0
]
->
Op
();
auto
pre_op_outs
=
pre_op_desc
->
Outputs
();
for
(
auto
&
out_var
:
pre_op_outs
)
{
auto
names
=
out_var
.
second
;
for
(
size_t
i
=
0
;
i
<
names
.
size
();
i
++
)
{
if
(
names
[
i
]
==
dropout_op_x_name
)
{
names
[
i
]
=
dropout_op_out_name
;
pre_op_desc
->
SetOutput
(
out_var
.
first
,
names
);
break
;
// link dropout_op_x to next_op
auto
dropout_op_x_name
=
dropout_op_x
->
Var
()
->
Name
();
auto
dropout_op_out_name
=
dropout_op_out
->
Var
()
->
Name
();
auto
next_op_nodes
=
dropout_op_out
->
outputs
;
for
(
auto
next_op_node
:
next_op_nodes
)
{
auto
next_op_desc
=
next_op_node
->
Op
();
auto
next_op_inputs
=
next_op_desc
->
Inputs
();
for
(
auto
&
input_var
:
next_op_inputs
)
{
auto
names
=
input_var
.
second
;
for
(
size_t
i
=
0
;
i
<
names
.
size
();
i
++
)
{
if
(
names
[
i
]
==
dropout_op_out_name
)
{
names
[
i
]
=
dropout_op_x_name
;
next_op_desc
->
SetInput
(
input_var
.
first
,
names
);
break
;
}
}
}
IR_NODE_LINK_TO
(
dropout_op_x
,
next_op_node
);
}
}
IR_NODE_LINK_TO
(
pre_ops
[
0
],
dropout_op_out
)
;
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
{
dropout_op_x
,
dropout_op
,
dropout_op_mask
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
{
dropout_op
,
dropout_op_out
}
;
if
(
with_mask
)
{
GET_IR_NODE
(
dropout_op_mask
);
delete_nodes
.
insert
(
dropout_op_mask
);
}
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
}
AddStatis
(
found_subgraph_count
);
}
...
...
paddle/fluid/framework/ir/graph_pattern_detector.cc
浏览文件 @
8d325d82
...
...
@@ -3032,7 +3032,7 @@ PDNode *patterns::TransposeFlattenConcat::operator()(
return
concat_out
;
}
void
patterns
::
DeleteDropoutOpPattern
::
operator
()()
{
void
patterns
::
DeleteDropoutOpPattern
::
operator
()(
bool
with_mask
)
{
auto
dropout_op_x
=
pattern
->
NewNode
(
dropout_op_x_repr
())
->
assert_is_op_input
(
"dropout"
,
"X"
)
->
AsInput
();
...
...
@@ -3042,10 +3042,14 @@ void patterns::DeleteDropoutOpPattern::operator()() {
std
::
string
(
"upscale_in_train"
));
auto
dropout_op_out
=
pattern
->
NewNode
(
dropout_op_out_repr
())
->
assert_is_op_output
(
"dropout"
,
"Out"
);
auto
dropout_op_mask
=
pattern
->
NewNode
(
dropout_op_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
dropout_op
->
LinksFrom
({
dropout_op_x
})
.
LinksTo
({
dropout_op_out
,
dropout_op_mask
});
if
(
with_mask
)
{
auto
dropout_op_mask
=
pattern
->
NewNode
(
dropout_op_mask_repr
())
->
assert_is_op_output
(
"dropout"
,
"Mask"
);
dropout_op
->
LinksFrom
({
dropout_op_x
})
.
LinksTo
({
dropout_op_out
,
dropout_op_mask
});
}
else
{
dropout_op
->
LinksFrom
({
dropout_op_x
}).
LinksTo
({
dropout_op_out
});
}
}
void
patterns
::
DeleteQuantOpFuse
::
operator
()(
PDNode
*
input_act_node
,
...
...
paddle/fluid/framework/ir/graph_pattern_detector.h
浏览文件 @
8d325d82
...
...
@@ -1759,7 +1759,7 @@ struct DeleteDropoutOpPattern : public PatternBase {
DeleteDropoutOpPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
"delete_dropout_op_pattern"
)
{}
void
operator
()();
void
operator
()(
bool
with_mask
);
PATTERN_DECL_NODE
(
dropout_op_x
);
PATTERN_DECL_NODE
(
dropout_op
);
...
...
paddle/fluid/framework/ir/xpu/embedding_with_eltwise_add_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
8d325d82
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
static
bool
GetBoolFromEnv
(
const
std
::
string
&
str
,
bool
def
=
false
)
{
char
*
variable
=
std
::
getenv
(
str
.
c_str
());
if
(
!
variable
)
{
return
def
;
}
if
(
strcmp
(
variable
,
"false"
)
==
0
||
strcmp
(
variable
,
"0"
)
==
0
)
{
return
false
;
}
else
{
return
true
;
}
}
namespace
patterns
{
struct
EmbeddingWithEltwiseAddXPUPattern
:
public
PatternBase
{
EmbeddingWithEltwiseAddXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
n_embedding_
,
const
std
::
string
&
op_type
,
const
std
::
string
&
pre_op_type
);
// declare operator node's name
PATTERN_DECL_NODE
(
embedding0
);
PATTERN_DECL_NODE
(
embedding1
);
PATTERN_DECL_NODE
(
ewadd01
);
// declare variable node's name
PATTERN_DECL_NODE
(
x0
);
PATTERN_DECL_NODE
(
x1
);
PATTERN_DECL_NODE
(
table0
);
PATTERN_DECL_NODE
(
table1
);
PATTERN_DECL_NODE
(
embedding_out0
);
PATTERN_DECL_NODE
(
embedding_out1
);
PATTERN_DECL_NODE
(
ewadd01_out
);
std
::
unordered_map
<
std
::
string
,
std
::
string
>
node_reprs
;
private:
int
n_embedding_
;
std
::
string
op_type_
;
std
::
string
pre_op_type_
;
};
EmbeddingWithEltwiseAddXPUPattern
::
EmbeddingWithEltwiseAddXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
n_embedding
,
const
std
::
string
&
op_type
,
const
std
::
string
&
pre_op_type
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
n_embedding_
(
n_embedding
),
op_type_
(
op_type
),
pre_op_type_
(
pre_op_type
)
{
for
(
int
i
=
0
;
i
<
n_embedding
;
i
++
)
{
node_reprs
[
"x"
+
std
::
to_string
(
i
)]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"x"
+
std
::
to_string
(
i
));
node_reprs
[
"table"
+
std
::
to_string
(
i
)]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"table"
+
std
::
to_string
(
i
));
node_reprs
[
"embedding"
+
std
::
to_string
(
i
)]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"embedding"
+
std
::
to_string
(
i
));
node_reprs
[
"embedding_out"
+
std
::
to_string
(
i
)]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
"embedding_out"
+
std
::
to_string
(
i
));
if
(
i
-
1
>=
0
)
{
auto
ewadd_name
=
string
::
Sprintf
(
"ewadd%d%d"
,
i
-
1
,
i
);
node_reprs
[
ewadd_name
]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
ewadd_name
);
auto
ewadd_out_name
=
string
::
Sprintf
(
"ewadd%d%d_out"
,
i
-
1
,
i
);
node_reprs
[
ewadd_out_name
]
=
PDNodeName
(
name_scope_
,
repr_
,
id_
,
ewadd_out_name
);
}
}
PDNode
*
x0
=
pattern
->
NewNode
(
x0_repr
())
->
assert_is_op_input
(
op_type_
,
"Ids"
)
->
assert_var_not_persistable
()
->
AsInput
();
PDNode
*
x1
=
pattern
->
NewNode
(
x1_repr
())
->
assert_is_op_input
(
op_type_
,
"Ids"
)
->
assert_var_not_persistable
()
->
AsInput
();
PDNode
*
embedding0
=
pattern
->
NewNode
(
embedding0_repr
())
->
assert_is_op
(
op_type_
);
auto
*
table0
=
pattern
->
NewNode
(
table0_repr
())
->
assert_is_op_input
(
op_type_
,
"W"
)
->
AsInput
();
auto
*
embedding_out0
=
pattern
->
NewNode
(
embedding_out0_repr
())
->
assert_is_op_output
(
op_type_
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
auto
*
table1
=
pattern
->
NewNode
(
table1_repr
())
->
assert_is_op_input
(
op_type_
,
"W"
)
->
AsInput
();
auto
*
embedding1
=
pattern
->
NewNode
(
embedding1_repr
())
->
assert_is_op
(
op_type_
);
auto
*
embedding_out1
=
pattern
->
NewNode
(
embedding_out1_repr
())
->
assert_is_op_output
(
op_type_
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
ewadd01
=
pattern
->
NewNode
(
ewadd01_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
*
ewadd01_out
=
pattern
->
NewNode
(
ewadd01_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
embedding0
->
LinksFrom
({
x0
,
table0
});
embedding1
->
LinksFrom
({
x1
,
table1
});
embedding0
->
LinksTo
({
embedding_out0
});
embedding1
->
LinksTo
({
embedding_out1
});
ewadd01
->
LinksFrom
({
embedding_out0
,
embedding_out1
});
ewadd01
->
LinksTo
({
ewadd01_out
});
auto
*
last_ewadd_out
=
ewadd01_out
;
for
(
int
i
=
2
;
i
<
n_embedding
;
++
i
)
{
auto
x_name
=
node_reprs
[
"x"
+
std
::
to_string
(
i
)];
auto
table_name
=
node_reprs
[
"table"
+
std
::
to_string
(
i
)];
auto
embedding_name
=
node_reprs
[
"embedding"
+
std
::
to_string
(
i
)];
auto
embedding_out_name
=
node_reprs
[
"embedding_out"
+
std
::
to_string
(
i
)];
auto
*
new_table
=
pattern
->
NewNode
(
table_name
)
->
assert_is_op_input
(
op_type_
,
"W"
)
->
AsInput
();
auto
*
new_embedding
=
pattern
->
NewNode
(
embedding_name
)
->
assert_is_op
(
op_type_
);
auto
*
new_embedding_out
=
pattern
->
NewNode
(
embedding_out_name
)
->
assert_is_op_output
(
op_type_
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
*
new_x
=
pattern
->
NewNode
(
x_name
)
->
assert_is_op_input
(
op_type_
,
"Ids"
)
->
AsInput
();
new_embedding
->
LinksFrom
({
new_x
,
new_table
});
new_embedding
->
LinksTo
({
new_embedding_out
});
auto
ewadd_name
=
node_reprs
[
"ewadd"
+
std
::
to_string
(
i
-
1
)
+
std
::
to_string
(
i
)];
auto
ewadd_out_name
=
node_reprs
[
"ewadd"
+
std
::
to_string
(
i
-
1
)
+
std
::
to_string
(
i
)
+
"_out"
];
auto
*
new_ewadd
=
pattern
->
NewNode
(
ewadd_name
)
->
assert_is_op
(
"elementwise_add"
);
auto
*
new_ewadd_out
=
pattern
->
NewNode
(
ewadd_out_name
)
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
new_ewadd
->
LinksFrom
({
last_ewadd_out
,
new_embedding_out
});
new_ewadd
->
LinksTo
({
new_ewadd_out
});
last_ewadd_out
=
new_ewadd_out
;
}
last_ewadd_out
->
AsOutput
();
}
}
// namespace patterns
class
EmbeddingWithEltwiseAddXPUFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
ApplyImpl
(
ir
::
Graph
*
graph
,
int
n_embedding
,
const
std
::
string
op_type
,
const
std
::
string
pre_op_type
)
const
;
const
std
::
string
name_scope_
{
"embedding_with_eltwise_add_xpu_fuse_pass"
};
};
void
EmbeddingWithEltwiseAddXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
FusePassBase
::
Init
(
name_scope_
,
graph
);
std
::
vector
<
std
::
string
>
pre_op_types
{
"reshape2"
,
"squeeze2"
,
""
};
std
::
vector
<
std
::
string
>
op_types
{
"lookup_table"
,
"lookup_table_v2"
};
for
(
auto
&
pre_op_type
:
pre_op_types
)
{
for
(
int
n_embedding
:
{
4
,
3
,
2
})
{
for
(
auto
&
op_type
:
op_types
)
{
ApplyImpl
(
graph
,
n_embedding
,
op_type
,
pre_op_type
);
}
}
}
}
void
EmbeddingWithEltwiseAddXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
,
int
n_embedding
,
const
std
::
string
op_type
,
const
std
::
string
pre_op_type
)
const
{
GraphPatternDetector
gpd
;
patterns
::
EmbeddingWithEltwiseAddXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
n_embedding
,
op_type
,
pre_op_type
);
int
found_subgraph_count
=
0
;
#define GET_IR_NODE_FROM_SUBGRAPH_BY_NAME(name, rt_node, pat) \
PADDLE_ENFORCE_NE( \
subgraph.count(pat.PatternBase::pattern->RetrieveNode(name)), \
0UL, \
platform::errors::NotFound("Node not found for PDNode %s", name)); \
Node* rt_node = subgraph.at(pat.PatternBase::pattern->RetrieveNode(name)); \
PADDLE_ENFORCE_NOT_NULL( \
rt_node, \
platform::errors::NotFound("node %s not exists in the sub-graph", \
#rt_node));
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
std
::
vector
<
std
::
string
>
x_names
;
std
::
vector
<
std
::
string
>
table_names
;
std
::
vector
<
Node
*>
x_nodes
;
std
::
vector
<
Node
*>
table_nodes
;
std
::
vector
<
Node
*>
embedding_nodes
;
auto
output_name
=
pattern
.
node_reprs
[
string
::
Sprintf
(
"ewadd%d%d_out"
,
n_embedding
-
2
,
n_embedding
-
1
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
output_name
,
output_node
,
pattern
)
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
for
(
int
i
=
0
;
i
<
n_embedding
;
++
i
)
{
// Ids
auto
x_name
=
pattern
.
node_reprs
[
"x"
+
std
::
to_string
(
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
x_name
,
x_node
,
pattern
)
x_nodes
.
push_back
(
x_node
);
x_names
.
push_back
(
x_node
->
Name
());
// Tables
auto
table_name
=
pattern
.
node_reprs
[
"table"
+
std
::
to_string
(
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
table_name
,
table_node
,
pattern
)
table_nodes
.
push_back
(
table_node
);
table_names
.
push_back
(
table_node
->
Name
());
// Embedding
auto
embedding_name
=
pattern
.
node_reprs
[
"embedding"
+
std
::
to_string
(
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
embedding_name
,
embedding_node
,
pattern
)
embedding_nodes
.
push_back
(
embedding_node
);
delete_nodes
.
insert
(
embedding_node
);
auto
embedding_out_name
=
pattern
.
node_reprs
[
"embedding_out"
+
std
::
to_string
(
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
embedding_out_name
,
embedding_out_node
,
pattern
)
delete_nodes
.
insert
(
embedding_out_node
);
if
(
i
-
1
>=
0
)
{
auto
ewadd_name
=
pattern
.
node_reprs
[
string
::
Sprintf
(
"ewadd%d%d"
,
i
-
1
,
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
ewadd_name
,
ewadd_node
,
pattern
)
delete_nodes
.
insert
(
ewadd_node
);
auto
ewadd_out_name
=
pattern
.
node_reprs
[
string
::
Sprintf
(
"ewadd%d%d_out"
,
i
-
1
,
i
)];
GET_IR_NODE_FROM_SUBGRAPH_BY_NAME
(
ewadd_out_name
,
ewadd_out_node
,
pattern
)
if
(
i
!=
n_embedding
-
1
)
{
delete_nodes
.
insert
(
ewadd_out_node
);
}
}
}
// Generate embedding_with_eltwise_add_xpu op
framework
::
OpDesc
embedding_with_eltwise_add_xpu_op_desc
;
embedding_with_eltwise_add_xpu_op_desc
.
SetType
(
"embedding_with_eltwise_add_xpu"
);
embedding_with_eltwise_add_xpu_op_desc
.
SetInput
(
"ids"
,
x_names
);
embedding_with_eltwise_add_xpu_op_desc
.
SetInput
(
"tables"
,
table_names
);
embedding_with_eltwise_add_xpu_op_desc
.
SetOutput
(
"out"
,
{
output_node
->
Name
()});
embedding_with_eltwise_add_xpu_op_desc
.
SetAttr
(
"n_embedding"
,
n_embedding
);
int64_t
padding_idx
=
PADDLE_GET_CONST
(
int64_t
,
embedding_nodes
[
0
]
->
Op
()
->
GetAttr
(
"padding_idx"
));
if
(
GetBoolFromEnv
(
"XPU_PADDING_IDX"
,
true
))
{
padding_idx
=
-
1
;
}
embedding_with_eltwise_add_xpu_op_desc
.
SetAttr
(
"padding_idx"
,
static_cast
<
int64_t
>
(
padding_idx
));
auto
*
embedding_with_eltwise_add_xpu_op
=
graph
->
CreateOpNode
(
&
embedding_with_eltwise_add_xpu_op_desc
);
for
(
size_t
i
=
0
;
i
<
x_nodes
.
size
();
i
++
)
{
SAFE_IR_NODE_LINK_TO
(
x_nodes
[
i
],
embedding_with_eltwise_add_xpu_op
);
SAFE_IR_NODE_LINK_TO
(
table_nodes
[
i
],
embedding_with_eltwise_add_xpu_op
);
}
SAFE_IR_NODE_LINK_TO
(
embedding_with_eltwise_add_xpu_op
,
output_node
);
// delete useless node
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
embedding_with_eltwise_add_xpu_fuse_pass
,
paddle
::
framework
::
ir
::
EmbeddingWithEltwiseAddXPUFusePass
);
REGISTER_PASS_CAPABILITY
(
embedding_with_eltwise_add_xpu_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"embedding_with_eltwise_add_xpu"
,
0
));
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
8d325d82
...
...
@@ -521,7 +521,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"generate_sequence_xpu_fuse_pass"
,
"multi_encoder_xpu_fuse_pass"
,
"multi_encoder_xpu_slice_fuse_pass"
,
//
"embedding_with_eltwise_add_xpu_fuse_pass",
"embedding_with_eltwise_add_xpu_fuse_pass"
,
"fc_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
});
...
...
paddle/phi/api/yaml/static_ops.yaml
浏览文件 @
8d325d82
-
op
:
embedding_with_eltwise_add_xpu
args
:
(Tensor[] ids, Tensor[] tables, int64_t padding_idx)
output
:
Tensor
infer_meta
:
func
:
EmbeddingWithEltwiseAddXPUInferMeta
kernel
:
func
:
embedding_with_eltwise_add_xpu
data_type
:
tables
-
op
:
fc_xpu
args
:
(Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu1_op_list.cc
浏览文件 @
8d325d82
...
...
@@ -80,6 +80,8 @@ XPUOpMap& get_kl1_ops() {
{
"elementwise_pow"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"elementwise_sub_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"elementwise_sub"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"embedding_with_eltwise_add_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"equal"
,
XPUKernelSet
({
phi
::
DataType
::
INT64
})},
{
"expand_as_v2"
,
XPUKernelSet
({
phi
::
DataType
::
INT32
,
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
8d325d82
...
...
@@ -212,6 +212,8 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
INT64
,
phi
::
DataType
::
INT32
})},
{
"embedding_with_eltwise_add_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"empty"
,
XPUKernelSet
({
phi
::
DataType
::
INT64
,
phi
::
DataType
::
INT32
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
8d325d82
...
...
@@ -21,6 +21,28 @@ limitations under the License. */
namespace
phi
{
void
EmbeddingWithEltwiseAddXPUInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
ids
,
const
std
::
vector
<
const
MetaTensor
*>&
tables
,
MetaTensor
*
out
)
{
PADDLE_ENFORCE_GT
(
ids
.
size
(),
0UL
,
phi
::
errors
::
InvalidArgument
(
"The input ids in EmbeddingWithEltwiseAddXPUInferMeta "
"can't be empty."
));
PADDLE_ENFORCE_GT
(
tables
.
size
(),
0UL
,
phi
::
errors
::
InvalidArgument
(
"The input tables in "
"EmbeddingWithEltwiseAddXPUInferMeta can't be empty."
));
auto
id_dims
=
ids
[
0
]
->
dims
();
auto
table_dims
=
tables
[
0
]
->
dims
();
out
->
set_dims
(
phi
::
make_ddim
({
id_dims
[
0
],
id_dims
[
1
],
table_dims
[
1
]}));
out
->
set_dtype
(
tables
[
0
]
->
dtype
());
out
->
set_layout
(
ids
[
0
]
->
layout
());
}
void
FcXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
w
,
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
8d325d82
...
...
@@ -22,6 +22,11 @@ namespace phi {
// Common InferMeta Functions for fusion operators.
// NOTE: The InferMeta Functions in this file are arranged in alphabetic order.
void
EmbeddingWithEltwiseAddXPUInferMeta
(
const
std
::
vector
<
const
MetaTensor
*>&
ids
,
const
std
::
vector
<
const
MetaTensor
*>&
tables
,
MetaTensor
*
out
);
void
FcXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
w
,
...
...
paddle/phi/kernels/fusion/xpu/embedding_with_eltwise_add_xpu_kernel.cc
0 → 100644
浏览文件 @
8d325d82
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
EmbeddingWithEltwiseAddXpuKernel
(
const
Context
&
ctx
,
const
std
::
vector
<
const
DenseTensor
*>&
ids
,
const
std
::
vector
<
const
DenseTensor
*>&
tables
,
int64_t
padding_idx
,
DenseTensor
*
out
)
{
auto
&
id_dims
=
ids
[
0
]
->
dims
();
int
idx_len
=
id_dims
[
0
]
*
id_dims
[
1
];
int
emb_layer_num
=
ids
.
size
();
int
embed_dim
=
tables
[
0
]
->
dims
()[
1
];
std
::
vector
<
int
>
table_lens_cpu
;
std
::
vector
<
const
float
*>
arg_tables
;
for
(
auto
*
table
:
tables
)
{
auto
&
table_dims
=
table
->
dims
();
PADDLE_ENFORCE_EQ
(
table_dims
.
size
(),
2
,
errors
::
InvalidArgument
(
"The table_dims size [%d] should be equal 2."
,
table_dims
.
size
()));
/* shape like [table_len, embed_dim] */
PADDLE_ENFORCE_EQ
(
table_dims
[
1
],
embed_dim
,
errors
::
InvalidArgument
(
"Every embed_dim [%d] should be equal the first one [%d]."
,
table_dims
[
1
],
embed_dim
));
table_lens_cpu
.
push_back
(
table_dims
[
0
]);
arg_tables
.
push_back
(
table
->
data
<
float
>
());
}
std
::
vector
<
std
::
vector
<
int
>>
int_idx
(
emb_layer_num
,
std
::
vector
<
int
>
(
idx_len
,
0
));
std
::
vector
<
xpu
::
VectorParam
<
int
>>
arg_ids
;
for
(
int
i
=
0
;
i
<
emb_layer_num
;
i
++
)
{
for
(
int
j
=
0
;
j
<
idx_len
;
j
++
)
{
int_idx
[
i
][
j
]
=
static_cast
<
int
>
(
ids
[
i
]
->
data
<
int64_t
>
()[
j
]);
}
arg_ids
.
push_back
(
xpu
::
VectorParam
<
int
>
{
int_idx
[
i
].
data
(),
idx_len
,
nullptr
});
}
ctx
.
template
Alloc
<
T
>(
out
);
int
r
=
xpu
::
multi_embedding_fusion
<
float
,
float
,
int
>
(
ctx
.
x_context
(),
arg_tables
,
/* tables */
out
->
data
<
T
>
(),
arg_ids
,
table_lens_cpu
,
embed_dim
,
std
::
vector
<
float
>
(
table_lens_cpu
.
size
(),
1.0
f
),
std
::
vector
<
int
>
(
table_lens_cpu
.
size
(),
padding_idx
));
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"embedding_with_eltwise_add_xpu"
);
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
embedding_with_eltwise_add_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
EmbeddingWithEltwiseAddXpuKernel
,
float
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
CPU
);
}
python/paddle/fluid/tests/unittests/ir/inference/test_xpu_embedding_with_eltwise_add_xpu_fuse_pass.py
0 → 100644
浏览文件 @
8d325d82
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestEmbeddingWithEltwiseAddXPUFusePass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"embedding_with_eltwise_add_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
# lookup_table_v2
lookup_table_num
=
draw
(
st
.
sampled_from
([
2
,
3
,
4
]))
print
(
"lookup_table_num: "
,
lookup_table_num
)
ids_shape
=
draw
(
st
.
sampled_from
([[
1
,
32
]]))
w_shape
=
draw
(
st
.
sampled_from
([[
1000
,
32
]]))
padding_idx
=
draw
(
st
.
sampled_from
([
-
1
]))
axis
=
draw
(
st
.
sampled_from
([
-
1
]))
def
gen_lookup_table_ops
():
lookup_table_op_config_list
=
[]
lookup_table_op_0
=
OpConfig
(
"lookup_table_v2"
,
inputs
=
{
"Ids"
:
[
"lookup_table_ids_0"
],
"W"
:
[
"lookup_table_w_0"
],
},
outputs
=
{
"Out"
:
[
"lookup_table_out_0"
]},
padding_idx
=
padding_idx
,
)
lookup_table_op_1
=
OpConfig
(
"lookup_table_v2"
,
inputs
=
{
"Ids"
:
[
"lookup_table_ids_1"
],
"W"
:
[
"lookup_table_w_1"
],
},
outputs
=
{
"Out"
:
[
"lookup_table_out_1"
]},
padding_idx
=
padding_idx
,
)
lookup_table_ops_list
=
[
lookup_table_op_0
,
lookup_table_op_1
]
if
lookup_table_num
>=
3
:
lookup_table_op_2
=
OpConfig
(
"lookup_table_v2"
,
inputs
=
{
"Ids"
:
[
"lookup_table_ids_2"
],
"W"
:
[
"lookup_table_w_2"
],
},
outputs
=
{
"Out"
:
[
"lookup_table_out_2"
]},
padding_idx
=
padding_idx
,
)
lookup_table_ops_list
.
append
(
lookup_table_op_2
)
if
lookup_table_num
>=
4
:
lookup_table_op_3
=
OpConfig
(
"lookup_table_v2"
,
inputs
=
{
"Ids"
:
[
"lookup_table_ids_3"
],
"W"
:
[
"lookup_table_w_3"
],
},
outputs
=
{
"Out"
:
[
"lookup_table_out_3"
]},
padding_idx
=
padding_idx
,
)
lookup_table_ops_list
.
append
(
lookup_table_op_3
)
return
lookup_table_ops_list
add_op_num
=
lookup_table_num
-
1
def
gen_eltwise_add_ops
():
add_op_0
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"lookup_table_out_0"
],
"Y"
:
[
"lookup_table_out_1"
],
},
outputs
=
{
"Out"
:
[
"add_op_0_out"
]},
axis
=
axis
,
)
add_op_list
=
[
add_op_0
]
if
add_op_num
>=
2
:
add_op_1
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"add_op_0_out"
],
"Y"
:
[
"lookup_table_out_2"
]},
outputs
=
{
"Out"
:
[
"add_op_1_out"
]},
axis
=
axis
,
)
add_op_list
.
append
(
add_op_1
)
if
add_op_num
>=
3
:
add_op_2
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"add_op_1_out"
],
"Y"
:
[
"lookup_table_out_3"
]},
outputs
=
{
"Out"
:
[
"add_op_2_out"
]},
axis
=
axis
,
)
add_op_list
.
append
(
add_op_2
)
return
add_op_list
lookup_table_op_list
=
gen_lookup_table_ops
()
add_op_list
=
gen_eltwise_add_ops
()
# ops
ops
=
[]
ops
.
extend
(
lookup_table_op_list
)
ops
.
extend
(
add_op_list
)
# inputs
def
generate_input
(
*
args
,
**
kwargs
):
return
np
.
random
.
randint
(
0
,
w_shape
[
0
],
ids_shape
).
astype
(
np
.
int64
)
def
gen_lookup_table_inputs_data
(
*
args
,
**
kwargs
):
inputs
=
{}
for
i
in
range
(
lookup_table_num
):
input_name
=
"lookup_table_ids_{}"
.
format
(
i
)
inputs
[
input_name
]
=
TensorConfig
(
data_gen
=
partial
(
generate_input
)
)
return
inputs
inputs
=
gen_lookup_table_inputs_data
()
# weights
def
gen_lookup_table_weights_data
():
weights
=
{}
for
i
in
range
(
lookup_table_num
):
w_name
=
"lookup_table_w_{}"
.
format
(
i
)
weights
[
w_name
]
=
TensorConfig
(
shape
=
w_shape
)
return
weights
weights
=
gen_lookup_table_weights_data
()
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
weights
,
inputs
=
inputs
,
outputs
=
add_op_list
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
3
,
min_success_num
=
3
,
passes
=
[
"embedding_with_eltwise_add_xpu_fuse_pass"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录