Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
5efaaaa3
P
Paddle
项目概览
PaddlePaddle
/
Paddle
1 年多 前同步成功
通知
2305
Star
20932
Fork
5423
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5efaaaa3
编写于
9月 05, 2023
作者:
J
jiangfan06
提交者:
GitHub
9月 05, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add element_mul_add_fuse_pass and elementwise_madd_xpu kernel (#56629)
上级
6dd9a024
变更
13
隐藏空白更改
内联
并排
Showing
13 changed file
with
657 addition
and
1 deletion
+657
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+2
-0
paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc
...e/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc
+333
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+4
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+9
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+5
-0
paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc
+61
-0
paddle/phi/kernels/xpu/concat_kernel.cc
paddle/phi/kernels/xpu/concat_kernel.cc
+3
-1
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+3
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu
...kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu
+77
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp
+76
-0
test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py
test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py
+74
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
5efaaaa3
...
...
@@ -290,6 +290,8 @@ if(WITH_XPU)
pass_library
(
fast_where_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fast_layernorm_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
elementwise_mul_add_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
paddle/fluid/framework/ir/xpu/elementwise_mul_add_fuse_pass.cc
0 → 100644
浏览文件 @
5efaaaa3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
/*
fuse elementwise_mul + elementwise_add op to addcmul_xpu op
For example:
graph:
x y
\ /
\ /
elementwise_mul w
\ /
\ /
elementwise_add
|
|
output
------------------------------------------------------
After the pass is applied:
x y w
\ | /
\ | /
addcmul_xpu
|
|
output
*/
struct
ElementwiseMulAddFusePass
:
public
PatternBase
{
ElementwiseMulAddFusePass
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise_mul
);
PATTERN_DECL_NODE
(
elementwise_add
);
// declare variable node's name
PATTERN_DECL_NODE
(
mul_x
);
PATTERN_DECL_NODE
(
mul_y
);
PATTERN_DECL_NODE
(
mul_out
);
PATTERN_DECL_NODE
(
add_w
);
PATTERN_DECL_NODE
(
add_out
);
};
ElementwiseMulAddFusePass
::
ElementwiseMulAddFusePass
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
elementwise_mul
=
pattern
->
NewNode
(
elementwise_mul_repr
())
->
assert_is_op
(
"elementwise_mul"
);
auto
elementwise_add
=
pattern
->
NewNode
(
elementwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
mul_x
=
pattern
->
NewNode
(
mul_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_mul"
,
"X"
);
auto
mul_y
=
pattern
->
NewNode
(
mul_y_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_mul"
,
"Y"
);
auto
mul_out
=
pattern
->
NewNode
(
mul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_mul"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
)
->
assert_has_n_outputs
(
1
);
elementwise_mul
->
LinksFrom
({
mul_x
,
mul_y
}).
LinksTo
({
mul_out
});
auto
add_w
=
pattern
->
NewNode
(
add_w_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
add_out
=
pattern
->
NewNode
(
add_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
elementwise_add
->
LinksFrom
({
mul_out
,
add_w
}).
LinksTo
({
add_out
});
}
/*
special case for addcmul_xpu op:
graph:
x y
\ /
\ /
elementwise_mul x
\ /
\ /
elementwise_add
|
|
output
------------------------------------------------------
After the pass is applied:
x y
\ /
\ /
addcmul_xpu
|
|
output
*/
struct
ElementwiseMulAddFuseXYPattern
:
public
PatternBase
{
ElementwiseMulAddFuseXYPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
);
// declare operator node's name
PATTERN_DECL_NODE
(
elementwise_mul
);
PATTERN_DECL_NODE
(
elementwise_add
);
// declare variable node's name
PATTERN_DECL_NODE
(
mul_x
);
PATTERN_DECL_NODE
(
mul_y
);
PATTERN_DECL_NODE
(
mul_out
);
PATTERN_DECL_NODE
(
add_out
);
};
ElementwiseMulAddFuseXYPattern
::
ElementwiseMulAddFuseXYPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
)
{
auto
elementwise_mul
=
pattern
->
NewNode
(
elementwise_mul_repr
())
->
assert_is_op
(
"elementwise_mul"
);
auto
elementwise_add
=
pattern
->
NewNode
(
elementwise_add_repr
())
->
assert_is_op
(
"elementwise_add"
);
auto
mul_x
=
pattern
->
NewNode
(
mul_x_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_mul"
,
"X"
)
->
assert_is_op_input
(
"elementwise_add"
,
"Y"
);
auto
mul_y
=
pattern
->
NewNode
(
mul_y_repr
())
->
AsInput
()
->
assert_is_op_input
(
"elementwise_mul"
,
"Y"
);
auto
mul_out
=
pattern
->
NewNode
(
mul_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_mul"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
,
"X"
);
elementwise_mul
->
LinksFrom
({
mul_x
,
mul_y
}).
LinksTo
({
mul_out
});
auto
add_out
=
pattern
->
NewNode
(
add_out_repr
())
->
AsOutput
()
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
elementwise_add
->
LinksFrom
({
mul_out
,
mul_x
}).
LinksTo
({
add_out
});
}
}
// namespace patterns
class
ElementwiseMulAddFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
void
FuseElementwiseMulAdd
(
ir
::
Graph
*
graph
)
const
;
void
FuseElementwiseMulAddWithOnlyXY
(
ir
::
Graph
*
graph
)
const
;
const
std
::
string
name_scope_
{
"elementwise_mul_add_fuse_pass"
};
};
void
ElementwiseMulAddFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
FuseElementwiseMulAdd
(
graph
);
FuseElementwiseMulAddWithOnlyXY
(
graph
);
}
void
ElementwiseMulAddFusePass
::
FuseElementwiseMulAdd
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
ElementwiseMulAddFusePass
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle ElementwiseMulAddFusePass"
;
// declare operator node's name
GET_IR_NODE
(
elementwise_mul
);
GET_IR_NODE
(
elementwise_add
);
// declare variable node's name
GET_IR_NODE
(
mul_x
);
GET_IR_NODE
(
mul_y
);
GET_IR_NODE
(
mul_out
);
GET_IR_NODE
(
add_w
);
GET_IR_NODE
(
add_out
);
bool
flag
=
true
;
auto
var_type
=
mul_x
->
Var
()
->
GetDataType
();
if
(
var_type
!=
proto
::
VarType
::
FP16
&&
var_type
!=
proto
::
VarType
::
FP32
)
{
flag
=
false
;
}
auto
x_shape
=
mul_x
->
Var
()
->
GetShape
();
auto
y_shape
=
mul_y
->
Var
()
->
GetShape
();
auto
w_shape
=
add_w
->
Var
()
->
GetShape
();
if
(
x_shape
.
size
()
==
y_shape
.
size
()
&&
x_shape
.
size
()
==
w_shape
.
size
())
{
for
(
size_t
i
=
0
;
i
<
x_shape
.
size
();
++
i
)
{
if
(
x_shape
[
i
]
!=
y_shape
[
i
]
||
x_shape
[
i
]
!=
w_shape
[
i
]
||
x_shape
[
i
]
==
-
1
)
{
flag
=
false
;
}
}
}
else
{
flag
=
false
;
}
if
(
flag
)
{
auto
*
block
=
elementwise_mul
->
Op
()
->
Block
();
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
// Generate addcmul_xpu op
framework
::
OpDesc
fused_op_desc
(
block
);
fused_op_desc
.
SetType
(
"addcmul_xpu"
);
fused_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fused_op_desc
.
SetInput
(
"y"
,
{
mul_y
->
Name
()});
fused_op_desc
.
SetInput
(
"w"
,
{
add_w
->
Name
()});
fused_op_desc
.
SetOutput
(
"out"
,
{
add_out
->
Name
()});
auto
*
fused_op
=
graph
->
CreateOpNode
(
&
fused_op_desc
);
IR_NODE_LINK_TO
(
mul_x
,
fused_op
);
IR_NODE_LINK_TO
(
mul_y
,
fused_op
);
IR_NODE_LINK_TO
(
add_w
,
fused_op
);
IR_NODE_LINK_TO
(
fused_op
,
add_out
);
delete_nodes
.
insert
({
elementwise_mul
,
elementwise_add
,
mul_out
});
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
void
ElementwiseMulAddFusePass
::
FuseElementwiseMulAddWithOnlyXY
(
ir
::
Graph
*
graph
)
const
{
GraphPatternDetector
gpd
;
patterns
::
ElementwiseMulAddFuseXYPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle ElementwiseMulAddFusePass"
;
// declare operator node's name
GET_IR_NODE
(
elementwise_mul
);
GET_IR_NODE
(
elementwise_add
);
// declare variable node's name
GET_IR_NODE
(
mul_x
);
GET_IR_NODE
(
mul_y
);
GET_IR_NODE
(
mul_out
);
GET_IR_NODE
(
add_out
);
bool
flag
=
true
;
auto
var_type
=
mul_x
->
Var
()
->
GetDataType
();
if
(
var_type
!=
proto
::
VarType
::
FP16
&&
var_type
!=
proto
::
VarType
::
FP32
)
{
flag
=
false
;
}
auto
x_shape
=
mul_x
->
Var
()
->
GetShape
();
auto
y_shape
=
mul_y
->
Var
()
->
GetShape
();
if
(
x_shape
.
size
()
==
y_shape
.
size
())
{
for
(
size_t
i
=
0
;
i
<
x_shape
.
size
();
++
i
)
{
if
(
x_shape
[
i
]
!=
y_shape
[
i
]
||
x_shape
[
i
]
==
-
1
)
{
flag
=
false
;
}
}
}
else
{
flag
=
false
;
}
if
(
flag
)
{
auto
*
block
=
elementwise_mul
->
Op
()
->
Block
();
// delete useless node
std
::
unordered_set
<
const
Node
*>
delete_nodes
;
// Generate addcmul_xpu op
framework
::
OpDesc
fused_op_desc
(
block
);
fused_op_desc
.
SetType
(
"addcmul_xpu"
);
fused_op_desc
.
SetInput
(
"x"
,
{
mul_x
->
Name
()});
fused_op_desc
.
SetInput
(
"y"
,
{
mul_y
->
Name
()});
fused_op_desc
.
SetInput
(
"w"
,
{
mul_x
->
Name
()});
fused_op_desc
.
SetOutput
(
"out"
,
{
add_out
->
Name
()});
auto
*
fused_op
=
graph
->
CreateOpNode
(
&
fused_op_desc
);
IR_NODE_LINK_TO
(
mul_x
,
fused_op
);
IR_NODE_LINK_TO
(
mul_y
,
fused_op
);
IR_NODE_LINK_TO
(
fused_op
,
add_out
);
delete_nodes
.
insert
({
elementwise_mul
,
elementwise_add
,
mul_out
});
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
}
};
gpd
(
graph
,
handler
);
AddStatis
(
found_subgraph_count
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
elementwise_mul_add_fuse_pass
,
paddle
::
framework
::
ir
::
ElementwiseMulAddFusePass
);
REGISTER_PASS_CAPABILITY
(
elementwise_mul_add_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
()
.
GE
(
"elementwise_add"
,
0
)
.
LE
(
"elementwise_add"
,
1
)
.
GE
(
"elementwise_mul"
,
0
)
.
LE
(
"elementwise_mul"
,
1
));
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
5efaaaa3
...
...
@@ -552,6 +552,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"fast_layernorm_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"fast_where_xpu_fuse_pass"
,
"elementwise_mul_add_fuse_pass"
,
"link_xpu_op_max_pass"
,
"delete_isolated_node_pass"
,
// "auto_mixed_precision_pass",
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
5efaaaa3
...
...
@@ -23,6 +23,15 @@
func
:
add_layernorm_xpu
data_type
:
x
-
op
:
addcmul_xpu
args
:
(Tensor x, Tensor y, Tensor w)
output
:
Tensor(out)
infer_meta
:
func
:
AddCMulXPUInferMeta
kernel
:
func
:
addcmul_xpu
data_type
:
x
-
op
:
conv1d_xpu
args
:
(Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, str padding_algorithm, int dilations, int strides, int groups, int act_type, float act_param)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
5efaaaa3
...
...
@@ -36,6 +36,8 @@ XPUOpMap& get_kl2_ops() {
{
"adam_dense_param_sparse_grad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"adagrad"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
})},
{
"addcmul_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"arg_max"
,
XPUKernelSet
({
phi
::
DataType
::
INT32
,
phi
::
DataType
::
FLOAT32
,
...
...
@@ -161,6 +163,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
FLOAT64
,
phi
::
DataType
::
BOOL
,
phi
::
DataType
::
INT8
,
phi
::
DataType
::
INT64
,
phi
::
DataType
::
INT32
})},
{
"conv2d_grad"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
5efaaaa3
...
...
@@ -821,6 +821,15 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
out
->
set_layout
(
x
.
layout
());
}
void
AddCMulXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
MetaTensor
&
w
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
}
void
FusedScaleBiasReluConvBnstatsInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
w
,
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
5efaaaa3
...
...
@@ -201,6 +201,11 @@ void FastLayernormXPUInferMeta(const MetaTensor& x,
float
epsilon
,
MetaTensor
*
out
);
void
AddCMulXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
const
MetaTensor
&
w
,
MetaTensor
*
out
);
void
FusedScaleBiasReluConvBnstatsInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
w
,
...
...
paddle/phi/kernels/fusion/xpu/addcmul_xpu_kernel.cc
0 → 100644
浏览文件 @
5efaaaa3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
AddCMulXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
const
DenseTensor
&
w
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
const
auto
*
x_data
=
x
.
data
<
T
>
();
const
auto
*
y_data
=
y
.
data
<
T
>
();
const
auto
*
w_data
=
w
.
data
<
T
>
();
auto
*
out_data
=
ctx
.
template
Alloc
<
T
>(
out
);
#ifdef PADDLE_WITH_XPU_PLUGIN
int
r
=
xpu
::
plugin
::
fast_addcmul
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
w_data
),
reinterpret_cast
<
const
XPUType
*>
(
x_data
),
reinterpret_cast
<
const
XPUType
*>
(
y_data
),
reinterpret_cast
<
XPUType
*>
(
out_data
),
x
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"fast_addcmul"
);
#else
int
r
=
xpu
::
addcmul
(
ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
w_data
),
reinterpret_cast
<
const
XPUType
*>
(
x_data
),
reinterpret_cast
<
const
XPUType
*>
(
y_data
),
reinterpret_cast
<
XPUType
*>
(
out_data
),
1.0
f
,
x
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"addcmul"
);
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
addcmul_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
AddCMulXPUKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/xpu/concat_kernel.cc
浏览文件 @
5efaaaa3
...
...
@@ -119,4 +119,6 @@ PD_REGISTER_KERNEL(concat,
double
,
phi
::
dtype
::
float16
,
int64_t
,
int
)
{}
int
,
int8_t
,
bool
)
{}
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
5efaaaa3
...
...
@@ -114,6 +114,9 @@ DLL_EXPORT int fast_embedding(Context* ctx,
int64_t
ym
,
int64_t
padding_idx
,
TID
start_index
=
0
);
template
<
typename
T
>
DLL_EXPORT
int
fast_addcmul
(
Context
*
ctx
,
const
T
*
w
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
);
}
// namespace plugin
}
// namespace api
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_addcmul.xpu
0 → 100644
浏览文件 @
5efaaaa3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
template <typename T>
static inline __device__ void primitive_addcmul(T* x, const T* y, int len) {
float32x16_t vx0;
float32x16_t vy0;
float32x16_t vx1;
float32x16_t vy1;
for (int i = 0; i < len; i += 32) {
vload2_lm(x + i, vx0, vx1);
vload2_lm(y + i, vy0, vy1);
vx0 = vvmac_float32x16(vx0, vy0, vx0);
vx1 = vvmac_float32x16(vx1, vy1, vx1);
vstore2_lm(x + i, vx0, vx1);
}
mfence_lm();
}
template <typename T>
__global__ void fast_addcmul(const T* x, const T* y, T* z, int64_t len) {
int cid = core_id();
const int ncores = core_num();
int tid = cid * cluster_num() + cluster_id();
int nthreads = cluster_num() * ncores;
const int buf_len = 512 / sizeof(T);
__simd__ float local_x_after_cast[buf_len];
__simd__ float local_y_after_cast[buf_len];
T* local_x = (T*)(local_x_after_cast);
T* local_y = (T*)(local_y_after_cast);
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
primitive_addcmul<T>(local_x, local_y, read_len);
LM2GM_ASYNC(local_x, z + i, read_len * sizeof(T));
mfence_lm();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_ADDCMUL_(DTYPE) \
template __global__ void fast_addcmul<DTYPE>( \
const DTYPE* x, const DTYPE* y, DTYPE* z, int64_t len);
_XPU_DEF__FAST_ADDCMUL_(float);
_XPU_DEF__FAST_ADDCMUL_(float16);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_addcmul.cpp
0 → 100644
浏览文件 @
5efaaaa3
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
#include "xpu/refactor/util/vector_util.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
>
__attribute__
((
global
))
void
fast_addcmul
(
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
);
}
// namespace plugin
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
T
*
w
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
if
(
x
==
w
)
{
xpu2
::
plugin
::
fast_addcmul
<
T
>
<<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
x
,
y
,
z
,
len
);
}
else
{
return
addcmul
(
ctx
,
w
,
x
,
y
,
z
,
1.0
f
,
len
);
}
return
SUCCESS
;
}
template
<
typename
T
>
int
fast_addcmul
(
Context
*
ctx
,
const
T
*
w
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T1
(
ctx
,
"fast_mul_add"
,
T
);
WRAPPER_DUMP_PARAM4
(
ctx
,
w
,
x
,
y
,
z
);
WRAPPER_DUMP_PARAM2
(
ctx
,
len
,
ctx
->
_l3_mgr
.
get_size
());
WRAPPER_DUMP
(
ctx
);
WRAPPER_CHECK_4PTRS
(
ctx
,
T
,
len
,
w
,
x
,
y
,
z
);
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
>
(
ctx
,
w
,
x
,
y
,
z
,
len
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_addcmul
(
Context
*
,
const
float
*
,
const
float
*
,
const
float
*
,
float
*
,
int64_t
);
template
int
fast_addcmul
(
Context
*
,
const
float16
*
,
const
float16
*
,
const
float16
*
,
float16
*
,
int64_t
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
test/ir/inference/test_xpu_elementwise_mul_add_fuse_pass.py
0 → 100644
浏览文件 @
5efaaaa3
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestGatherAddTransposePass
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"addcmul_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
x_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
3
,
max_size
=
4
)
)
def
generate_data
(
shape
):
return
np
.
random
.
random
(
shape
).
astype
(
np
.
float32
)
mul_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"mul_x"
],
"Y"
:
[
"mul_y"
]},
outputs
=
{
"Out"
:
[
"mul_out"
]},
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul_out"
],
"Y"
:
[
"add_w"
]},
outputs
=
{
"Out"
:
[
"add_out"
]},
)
ops
=
[
mul_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
inputs
=
{
"mul_x"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)),
"mul_y"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)),
"add_w"
:
TensorConfig
(
data_gen
=
partial
(
generate_data
,
x_shape
)),
},
weights
=
{},
outputs
=
[
"add_out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"elementwise_mul_add_fuse_pass"
],
)
if
__name__
==
"__main__"
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录