Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
07e788f1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
07e788f1
编写于
8月 01, 2023
作者:
H
hong19860320
提交者:
GitHub
8月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add fast_where fusion op and XPU micro kernel (#55628)
上级
744e1eaf
变更
14
隐藏空白更改
内联
并排
Showing
14 changed file
with
2139 addition
and
1 deletion
+2139
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-0
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
+658
-0
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
...e/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
+304
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+4
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+8
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+6
-0
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
+81
-0
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
+1
-1
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+7
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
...i/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
+191
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
+128
-0
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
+736
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
07e788f1
...
...
@@ -280,6 +280,7 @@ if(WITH_XPU)
pass_library
(
matmul_weight_trans_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fast_where_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
@@ -599,4 +600,8 @@ if(WITH_XPU)
test_reshape2_matmul_xpu_fuse_pass
SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc
DEPS reshape2_matmul_xpu_fuse_pass
)
cc_test
(
test_fast_where_xpu_fuse_pass
SRCS xpu/fast_where_xpu_fuse_pass_test.cc
DEPS fast_where_xpu_fuse_pass
)
endif
()
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <string>
#include "glog/logging.h"
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/xpu/pass_utils.h"
#include "paddle/fluid/framework/ir/xpu/quant_utils.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace
phi
{
class
DenseTensor
;
}
// namespace phi
namespace
paddle
{
namespace
framework
{
class
Scope
;
}
// namespace framework
}
// namespace paddle
namespace
paddle
{
namespace
framework
{
namespace
ir
{
namespace
patterns
{
/*
Fuse cast+scale+mul+mul+add ops to fast_where_xpu op reduce memory access.
Case 0: when mode = 0,
condition
|
cast
|
/ \
/ \
scale \
x / y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 1: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
x / y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
Case 2: when mode = 0,
condition
|
cast
|
/ \
scale \
/ \
/ x \ y
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 3: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
/ x \ y
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
Case 4: when mode = 0,
condition
|
cast
|
/ \
scale \
/ \
/ x y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition y x
\ | /
\ | /
\ | /
fast_where_xpu
Case 5: when mode = 1,
condition
|
cast
|
/ \
/ scale
/ \
/ x y \
\ / \ /
mul0 mul1
\ /
\ /
\ /
add
After the pass is applied,
condition x y
\ | /
\ | /
\ | /
fast_where_xpu
*/
struct
OneFastWhereXPUPattern
:
public
PatternBase
{
OneFastWhereXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
mode
);
// declare operator node's name
PATTERN_DECL_NODE
(
cast
);
PATTERN_DECL_NODE
(
scale
);
PATTERN_DECL_NODE
(
mul0
);
PATTERN_DECL_NODE
(
mul1
);
PATTERN_DECL_NODE
(
add
);
// declare variable node's name
// cast
PATTERN_DECL_NODE
(
condition
);
PATTERN_DECL_NODE
(
cast_out
);
// scale
PATTERN_DECL_NODE
(
scale_out
);
// mul0
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
mul0_out
);
// mul1
PATTERN_DECL_NODE
(
y
);
PATTERN_DECL_NODE
(
mul1_out
);
// add
PATTERN_DECL_NODE
(
add_out
);
private:
int
mode_
{
0
};
};
OneFastWhereXPUPattern
::
OneFastWhereXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
mode
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
mode_
(
mode
)
{
// cast
auto
condition
=
pattern
->
NewNode
(
condition_repr
())
->
assert_is_op_input
(
"cast"
,
"X"
);
auto
cast_out
=
pattern
->
NewNode
(
cast_out_repr
())
->
assert_is_op_output
(
"cast"
,
"Out"
)
->
assert_is_op_input
(
"scale"
,
"X"
)
->
assert_is_op_input
(
"elementwise_mul"
);
auto
cast
=
pattern
->
NewNode
(
cast_repr
())
->
assert_is_op
(
"cast"
)
->
assert_more
([](
Node
*
n
)
{
auto
in_dtype_val
=
PADDLE_GET_CONST
(
int
,
n
->
Op
()
->
GetAttr
(
"in_dtype"
));
auto
out_dtype_val
=
PADDLE_GET_CONST
(
int
,
n
->
Op
()
->
GetAttr
(
"out_dtype"
));
return
in_dtype_val
==
0
&&
(
out_dtype_val
==
4
||
out_dtype_val
==
5
);
});
// scale
auto
scale_out
=
pattern
->
NewNode
(
scale_out_repr
())
->
assert_is_op_output
(
"scale"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_mul"
);
auto
scale
=
pattern
->
NewNode
(
scale_repr
())
->
assert_is_op
(
"scale"
)
->
assert_more
([](
Node
*
n
)
{
auto
bias_val
=
PADDLE_GET_CONST
(
float
,
n
->
Op
()
->
GetAttr
(
"bias"
));
auto
scale_val
=
PADDLE_GET_CONST
(
float
,
n
->
Op
()
->
GetAttr
(
"scale"
));
return
fabs
(
bias_val
-
1.0
f
)
<=
1e-5
f
&&
fabs
(
scale_val
+
1.0
f
)
<=
1e-5
f
;
});
// mul0
auto
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"elementwise_mul"
);
auto
mul0_out
=
pattern
->
NewNode
(
mul0_out_repr
())
->
assert_is_op_output
(
"elementwise_mul"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
);
auto
mul0
=
pattern
->
NewNode
(
mul0_repr
())
->
assert_is_op
(
"elementwise_mul"
)
->
assert_more
([](
Node
*
node
)
{
auto
node1
=
node
->
inputs
[
0
];
auto
node2
=
node
->
inputs
[
1
];
auto
node1_shape
=
node1
->
Var
()
->
GetShape
();
auto
node2_shape
=
node2
->
Var
()
->
GetShape
();
if
(
node1_shape
.
size
()
!=
node2_shape
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
node1_shape
.
size
();
i
++
)
{
if
(
node1_shape
[
i
]
!=
node2_shape
[
i
]
&&
(
node1_shape
[
i
]
!=
1
&&
node2_shape
[
i
]
!=
1
))
{
return
false
;
}
}
return
true
;
});
// mul1
auto
y
=
pattern
->
NewNode
(
y_repr
())
->
assert_is_op_input
(
"elementwise_mul"
);
auto
mul1_out
=
pattern
->
NewNode
(
mul1_out_repr
())
->
assert_is_op_output
(
"elementwise_mul"
,
"Out"
)
->
assert_is_op_input
(
"elementwise_add"
);
auto
mul1
=
pattern
->
NewNode
(
mul1_repr
())
->
assert_is_op
(
"elementwise_mul"
)
->
assert_more
([](
Node
*
node
)
{
auto
node1
=
node
->
inputs
[
0
];
auto
node2
=
node
->
inputs
[
1
];
auto
node1_shape
=
node1
->
Var
()
->
GetShape
();
auto
node2_shape
=
node2
->
Var
()
->
GetShape
();
if
(
node1_shape
.
size
()
!=
node2_shape
.
size
())
return
false
;
for
(
size_t
i
=
0
;
i
<
node1_shape
.
size
();
i
++
)
{
if
(
node1_shape
[
i
]
!=
node2_shape
[
i
]
&&
(
node1_shape
[
i
]
!=
1
&&
node2_shape
[
i
]
!=
1
))
{
return
false
;
}
}
return
true
;
});
// add
auto
add_out
=
pattern
->
NewNode
(
add_out_repr
())
->
assert_is_op_output
(
"elementwise_add"
,
"Out"
);
auto
add
=
pattern
->
NewNode
(
add_repr
())
->
assert_is_op
(
"elementwise_add"
)
->
assert_more
([](
Node
*
node
)
{
auto
node_in1
=
node
->
inputs
[
0
];
auto
node_in2
=
node
->
inputs
[
1
];
if
(
node_in1
->
inputs
.
size
()
==
1
&&
node_in1
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"elementwise_mul"
&&
node_in2
->
inputs
.
size
()
==
1
&&
node_in2
->
inputs
[
0
]
->
Op
()
->
Type
()
==
"elementwise_mul"
)
{
auto
shape1
=
node_in1
->
Var
()
->
GetShape
();
auto
shape2
=
node_in2
->
Var
()
->
GetShape
();
return
shape1
==
shape2
;
}
return
false
;
});
cast
->
LinksFrom
({
condition
}).
LinksTo
({
cast_out
});
scale
->
LinksFrom
({
cast_out
}).
LinksTo
({
scale_out
});
PADDLE_ENFORCE_LE
(
mode
,
1
,
platform
::
errors
::
InvalidArgument
(
"one_fast_where_xpu_fuse_pass mode(%d) is not supported."
,
mode
));
if
(
mode
==
0
)
{
mul0
->
LinksFrom
({
x
,
scale_out
}).
LinksTo
({
mul0_out
});
mul1
->
LinksFrom
({
y
,
cast_out
}).
LinksTo
({
mul1_out
});
}
else
if
(
mode
==
1
)
{
mul0
->
LinksFrom
({
x
,
cast_out
}).
LinksTo
({
mul0_out
});
mul1
->
LinksFrom
({
y
,
scale_out
}).
LinksTo
({
mul1_out
});
}
add
->
LinksFrom
({
mul0_out
,
mul1_out
}).
LinksTo
({
add_out
});
}
/*
Fuse cascade fast_where_xpu ops to one fast_where_xpu op reduce memory access.
Case 0: when mode = 0,
x--------------
| |
| condition0 | y
| \ | /
| \ | /
| \ | /
condition1 | fast_where_xpu0
\ | /
\ | /
\ | /
fast_where_xpu1
After the pass is applied,
condition0 condition1
\ /
\ /
or
\ x y
\ | /
\ | /
fast_where_xpu
Case 1: when mode = 1,
condition0 x y
\ | / |
\ | / |
\ | / |
fast_where_xpu0 |
| |
condition1 | |
\ | /
\ | /
\ | /
fast_where_xpu1
After the pass is applied,
condition0 condition1
\ /
\ /
\ /
and
\ x y
\ | /
\ | /
fast_where_xpu
Other cases:
x ---------------------
| |
| condition0 y |
| \ | /
| \ | /
| \ | /
condition1 | fast_where_xpu0
\ | /
\ | /
\ | /
fast_where_xpu1
----------
| |
condition0 x y |
\ | / |
\ | / |
\ | / |
fast_where_xpu0 |
| |
condition1 | |
\ | /
\ | /
\ | /
fast_where_xpu1
*/
struct
CascadeFastWhereXPUPattern
:
public
PatternBase
{
CascadeFastWhereXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
mode
);
// declare operator node's name
PATTERN_DECL_NODE
(
fast_where_xpu0
);
PATTERN_DECL_NODE
(
fast_where_xpu1
);
// declare variable node's name
PATTERN_DECL_NODE
(
condition0
);
PATTERN_DECL_NODE
(
condition1
);
PATTERN_DECL_NODE
(
x
);
PATTERN_DECL_NODE
(
y
);
PATTERN_DECL_NODE
(
fast_where_xpu0_out
);
PATTERN_DECL_NODE
(
fast_where_xpu1_out
);
private:
int
mode_
{
0
};
};
CascadeFastWhereXPUPattern
::
CascadeFastWhereXPUPattern
(
PDPattern
*
pattern
,
const
std
::
string
&
name_scope
,
int
mode
)
:
PatternBase
(
pattern
,
name_scope
,
name_scope
),
mode_
(
mode
)
{
// declare operator nodes
auto
fast_where_xpu0
=
pattern
->
NewNode
(
fast_where_xpu0_repr
())
->
assert_is_op
(
"fast_where_xpu"
);
auto
fast_where_xpu1
=
pattern
->
NewNode
(
fast_where_xpu1_repr
())
->
assert_is_op
(
"fast_where_xpu"
);
// declare vairable nodes
auto
condition0
=
pattern
->
NewNode
(
condition0_repr
())
->
assert_is_op_input
(
"fast_where_xpu"
,
"condition"
);
auto
condition1
=
pattern
->
NewNode
(
condition1_repr
())
->
assert_is_op_input
(
"fast_where_xpu"
,
"condition"
);
auto
fast_where_xpu0_out
=
pattern
->
NewNode
(
fast_where_xpu0_out_repr
())
->
assert_is_op_output
(
"fast_where_xpu"
,
"out"
);
auto
fast_where_xpu1_out
=
pattern
->
NewNode
(
fast_where_xpu1_out_repr
())
->
assert_is_op_output
(
"fast_where_xpu"
,
"out"
);
auto
x
=
pattern
->
NewNode
(
x_repr
())
->
assert_is_op_input
(
"fast_where_xpu"
,
"x"
);
auto
y
=
pattern
->
NewNode
(
y_repr
())
->
assert_is_op_input
(
"fast_where_xpu"
,
"y"
);
fast_where_xpu0
->
LinksFrom
({
condition0
,
x
,
y
}).
LinksTo
({
fast_where_xpu0_out
});
PADDLE_ENFORCE_LE
(
mode
,
1
,
platform
::
errors
::
InvalidArgument
(
"cascade_fast_where_xpu_fuse_pass mode(%d) is not supported."
,
mode
));
if
(
mode
==
0
)
{
fast_where_xpu0_out
->
assert_is_op_input
(
"fast_where_xpu"
,
"y"
);
fast_where_xpu1
->
LinksFrom
({
condition1
,
x
,
fast_where_xpu0_out
})
.
LinksTo
({
fast_where_xpu1_out
});
}
else
if
(
mode
==
1
)
{
fast_where_xpu0_out
->
assert_is_op_input
(
"fast_where_xpu"
,
"x"
);
fast_where_xpu1
->
LinksFrom
({
condition1
,
fast_where_xpu0_out
,
y
})
.
LinksTo
({
fast_where_xpu1_out
});
}
}
}
// namespace patterns
class
OneFastWhereXPUFusePass
:
public
FusePassBase
{
public:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplySubgraph
(
ir
::
Graph
*
graph
,
int
mode
)
const
;
const
std
::
string
name_scope_
{
"one_fast_where_xpu_fuse_pass"
};
};
int
OneFastWhereXPUFusePass
::
ApplySubgraph
(
ir
::
Graph
*
graph
,
int
mode
)
const
{
GraphPatternDetector
gpd
;
patterns
::
OneFastWhereXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
mode
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FastWhereXPUFusePass fuse"
;
// declare operator node's name
GET_IR_NODE
(
cast
);
GET_IR_NODE
(
scale
);
GET_IR_NODE
(
mul0
);
GET_IR_NODE
(
mul1
);
GET_IR_NODE
(
add
);
// declare variable node's name
// scale
GET_IR_NODE
(
condition
);
GET_IR_NODE
(
cast_out
);
GET_IR_NODE
(
scale_out
);
// mul0
GET_IR_NODE
(
x
);
GET_IR_NODE
(
mul0_out
);
// mul1
GET_IR_NODE
(
y
);
GET_IR_NODE
(
mul1_out
);
// add
GET_IR_NODE
(
add_out
);
auto
*
block
=
add
->
Op
()
->
Block
();
framework
::
OpDesc
fast_where_xpu_op_desc
(
block
);
fast_where_xpu_op_desc
.
SetType
(
"fast_where_xpu"
);
fast_where_xpu_op_desc
.
SetInput
(
"condition"
,
{
condition
->
Name
()});
if
(
mode
==
0
)
{
fast_where_xpu_op_desc
.
SetInput
(
"x"
,
{
y
->
Name
()});
fast_where_xpu_op_desc
.
SetInput
(
"y"
,
{
x
->
Name
()});
}
else
if
(
mode
==
1
)
{
fast_where_xpu_op_desc
.
SetInput
(
"x"
,
{
x
->
Name
()});
fast_where_xpu_op_desc
.
SetInput
(
"y"
,
{
y
->
Name
()});
}
fast_where_xpu_op_desc
.
SetOutput
(
"out"
,
{
add_out
->
Name
()});
auto
fast_where_xpu_op_node
=
graph
->
CreateOpNode
(
&
fast_where_xpu_op_desc
);
IR_NODE_LINK_TO
(
x
,
fast_where_xpu_op_node
);
IR_NODE_LINK_TO
(
y
,
fast_where_xpu_op_node
);
IR_NODE_LINK_TO
(
condition
,
fast_where_xpu_op_node
);
IR_NODE_LINK_TO
(
fast_where_xpu_op_node
,
add_out
);
std
::
unordered_set
<
const
Node
*>
delete_nodes
=
{
cast
,
cast_out
,
scale
,
scale_out
,
mul0
,
mul0_out
,
mul1
,
mul1_out
,
add
};
GraphSafeRemoveNodes
(
graph
,
delete_nodes
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
OneFastWhereXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
int
found_subgraph_count
=
0
;
for
(
auto
mode
:
{
0
,
1
})
{
found_subgraph_count
+=
ApplySubgraph
(
graph
,
mode
);
}
AddStatis
(
found_subgraph_count
);
}
class
CascadeFastWhereXPUFusePass
:
public
FusePassBase
{
public:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
int
ApplySubgraph
(
ir
::
Graph
*
graph
,
int
mode
)
const
;
const
std
::
string
name_scope_
{
"cascade_fast_where_xpu_fuse_pass"
};
};
int
CascadeFastWhereXPUFusePass
::
ApplySubgraph
(
ir
::
Graph
*
graph
,
int
mode
)
const
{
GraphPatternDetector
gpd
;
patterns
::
CascadeFastWhereXPUPattern
pattern
(
gpd
.
mutable_pattern
(),
name_scope_
,
mode
);
int
found_subgraph_count
=
0
;
auto
handler
=
[
&
](
const
GraphPatternDetector
::
subgraph_t
&
subgraph
,
Graph
*
graph
)
{
VLOG
(
4
)
<<
"handle FastWhereXPUFusePass fuse"
;
// declare operator node's name
GET_IR_NODE
(
fast_where_xpu0
);
GET_IR_NODE
(
fast_where_xpu1
);
// declare variable node's name
GET_IR_NODE
(
condition0
);
GET_IR_NODE
(
condition1
);
GET_IR_NODE
(
x
);
GET_IR_NODE
(
y
);
GET_IR_NODE
(
fast_where_xpu0_out
);
GET_IR_NODE
(
fast_where_xpu1_out
);
// Reuse variables
fast_where_xpu0_out
->
Var
()
->
SetShape
(
condition0
->
Var
()
->
GetShape
());
fast_where_xpu0_out
->
Var
()
->
SetDataType
(
condition0
->
Var
()
->
GetDataType
());
// Change the first fast_where_xpu op to logical op
fast_where_xpu0
->
Op
()
->
RemoveInput
(
"condition"
);
fast_where_xpu0
->
Op
()
->
RemoveInput
(
"x"
);
fast_where_xpu0
->
Op
()
->
RemoveInput
(
"y"
);
fast_where_xpu0
->
Op
()
->
RemoveOutput
(
"out"
);
fast_where_xpu0
->
Op
()
->
SetInput
(
"X"
,
std
::
vector
<
std
::
string
>
({
condition0
->
Name
()}));
fast_where_xpu0
->
Op
()
->
SetInput
(
"Y"
,
std
::
vector
<
std
::
string
>
({
condition1
->
Name
()}));
fast_where_xpu0
->
Op
()
->
SetOutput
(
"Out"
,
std
::
vector
<
std
::
string
>
({
fast_where_xpu0_out
->
Name
()}));
// Reserve the second first_where_xpu but change its inputs
fast_where_xpu1
->
Op
()
->
SetInput
(
"condition"
,
std
::
vector
<
std
::
string
>
({
fast_where_xpu0_out
->
Name
()}));
fast_where_xpu1
->
Op
()
->
SetInput
(
"x"
,
std
::
vector
<
std
::
string
>
({
x
->
Name
()}));
fast_where_xpu1
->
Op
()
->
SetInput
(
"y"
,
std
::
vector
<
std
::
string
>
({
y
->
Name
()}));
if
(
mode
==
0
)
{
fast_where_xpu0
->
Op
()
->
SetType
(
"logical_or"
);
}
else
if
(
mode
==
1
)
{
fast_where_xpu0
->
Op
()
->
SetType
(
"logical_and"
);
}
IR_NODE_UNLINK
(
x
,
fast_where_xpu0
);
IR_NODE_UNLINK
(
y
,
fast_where_xpu0
);
IR_NODE_LINK_TO
(
condition1
,
fast_where_xpu0
);
IR_NODE_UNLINK
(
condition1
,
fast_where_xpu1
);
IR_NODE_LINK_TO
(
x
,
fast_where_xpu1
);
IR_NODE_LINK_TO
(
y
,
fast_where_xpu1
);
found_subgraph_count
++
;
};
gpd
(
graph
,
handler
);
return
found_subgraph_count
;
}
void
CascadeFastWhereXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
PADDLE_ENFORCE_NOT_NULL
(
graph
,
platform
::
errors
::
PreconditionNotMet
(
"graph should not be null."
));
Init
(
name_scope_
,
graph
);
int
total_found_subgraph_count
=
0
;
int
cur_found_subgraph_count
=
0
;
do
{
cur_found_subgraph_count
=
0
;
for
(
auto
mode
:
{
0
,
1
})
{
cur_found_subgraph_count
+=
ApplySubgraph
(
graph
,
mode
);
}
total_found_subgraph_count
+=
cur_found_subgraph_count
;
}
while
(
cur_found_subgraph_count
>
0
);
AddStatis
(
total_found_subgraph_count
);
}
class
FastWhereXPUFusePass
:
public
FusePassBase
{
protected:
void
ApplyImpl
(
ir
::
Graph
*
graph
)
const
override
;
private:
const
std
::
string
name_scope_
{
"fast_where_xpu_fuse_pass"
};
};
void
FastWhereXPUFusePass
::
ApplyImpl
(
ir
::
Graph
*
graph
)
const
{
VLOG
(
4
)
<<
"handle fast_where_xpu op fusion."
;
OneFastWhereXPUFusePass
one_fast_where_xpu_fuse_pass
;
one_fast_where_xpu_fuse_pass
.
ApplyImpl
(
graph
);
CascadeFastWhereXPUFusePass
cascade_fast_where_xpu_fuse_pass
;
cascade_fast_where_xpu_fuse_pass
.
ApplyImpl
(
graph
);
}
}
// namespace ir
}
// namespace framework
}
// namespace paddle
REGISTER_PASS
(
fast_where_xpu_fuse_pass
,
paddle
::
framework
::
ir
::
FastWhereXPUFusePass
);
REGISTER_PASS_CAPABILITY
(
fast_where_xpu_fuse_pass
)
.
AddCombination
(
paddle
::
framework
::
compatible
::
OpVersionComparatorCombination
().
EQ
(
"fast_where_xpu_fuse_pass"
,
0
));
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define APPLY_PASS \
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); \
auto pass = PassRegistry::Instance().Get("fast_where_xpu_fuse_pass"); \
pass->Apply(graph.get());
#define VERIFY_GRAPH(x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only one op node, but %d op nodes found.", \
num_op_nodes)); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST
(
FastWhereXPUFusePass
,
one_case0
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
x
,
scale_out
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
cast_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case1
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
x
,
cast_out
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
scale_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
one_case2
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
scale_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
cast_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case3
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
one_case4
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
scale_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
cast_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case5
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
scale_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
#undef VERIFY_GRAPH
#define VERIFY_GRAPH(logical_op, x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
2, \
platform::errors::PreconditionNotMet( \
"The graph contains only two op nodes, but %d op nodes found.", \
num_op_nodes)); \
auto logical_op_nodes = GetOpNodes(graph, #logical_op); \
PADDLE_ENFORCE_EQ( \
logical_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a '%s' op node, but %d op nodes found.", \
#logical_op, \
logical_op_nodes.size())); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST
(
FastWhereXPUFusePass
,
cascade_case0
)
{
Layers
layers
;
auto
*
condition0
=
layers
.
data
(
"condition0"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
condition1
=
layers
.
data
(
"condition1"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
// fast_where_xpu0
auto
*
cast0_out
=
layers
.
cast
(
condition0
,
0
,
5
);
cast0_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast0_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale0_out
=
layers
.
scale
(
cast0_out
,
-
1.0
f
,
1.0
f
,
true
);
scale0_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale0_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add0_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add0_out
->
SetShape
({
20
,
7
});
// fast_where_xpu1
auto
*
cast1_out
=
layers
.
cast
(
condition1
,
0
,
5
);
cast1_out
->
SetShape
({
20
,
1
});
auto
*
mul2_out
=
layers
.
elementwise_mul
(
cast1_out
,
x
);
mul2_out
->
SetShape
({
20
,
7
});
auto
*
scale1_out
=
layers
.
scale
(
cast1_out
,
-
1.0
f
,
1.0
f
,
true
);
scale1_out
->
SetShape
({
20
,
1
});
auto
*
mul3_out
=
layers
.
elementwise_mul
(
scale1_out
,
add0_out
);
mul3_out
->
SetShape
({
20
,
7
});
auto
*
add1_out
=
layers
.
elementwise_add
(
mul2_out
,
mul3_out
);
add1_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
logical_or
,
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
cascade_case1
)
{
Layers
layers
;
auto
*
condition0
=
layers
.
data
(
"condition0"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
condition1
=
layers
.
data
(
"condition1"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
// fast_where_xpu0
auto
*
cast0_out
=
layers
.
cast
(
condition0
,
0
,
5
);
cast0_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast0_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale0_out
=
layers
.
scale
(
cast0_out
,
-
1.0
f
,
1.0
f
,
true
);
scale0_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale0_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add0_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add0_out
->
SetShape
({
20
,
7
});
// fast_where_xpu1
auto
*
cast1_out
=
layers
.
cast
(
condition1
,
0
,
5
);
cast1_out
->
SetShape
({
20
,
1
});
auto
*
mul2_out
=
layers
.
elementwise_mul
(
cast1_out
,
add0_out
);
mul2_out
->
SetShape
({
20
,
7
});
auto
*
scale1_out
=
layers
.
scale
(
cast1_out
,
-
1.0
f
,
1.0
f
,
true
);
scale1_out
->
SetShape
({
20
,
1
});
auto
*
mul3_out
=
layers
.
elementwise_mul
(
scale1_out
,
y
);
mul3_out
->
SetShape
({
20
,
7
});
auto
*
add1_out
=
layers
.
elementwise_add
(
mul2_out
,
mul3_out
);
add1_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
logical_and
,
x
,
y
)
}
#undef APPLY_PASS
#undef VERIFY_GRAPH
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fast_where_xpu_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
07e788f1
...
...
@@ -545,6 +545,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_activation_xpu_fuse_pass"
,
"add_layernorm_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"fast_where_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"delete_isolated_node_pass"
,
// "auto_mixed_precision_pass",
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
07e788f1
...
...
@@ -53,6 +53,15 @@
data_type
:
tables
optional
:
mask, seq_lod, max_seq_len
-
op
:
fast_where_xpu
args
:
(Tensor condition, Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
func
:
FastWhereXPUInferMeta
kernel
:
func
:
fast_where_xpu
data_type
:
x
-
op
:
fc_xpu
args
:
(Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
07e788f1
...
...
@@ -295,6 +295,10 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
BOOL
,
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
FLOAT32
})},
{
"fast_where_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
INT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fc_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fill"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
07e788f1
...
...
@@ -721,4 +721,12 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
out_max
);
}
void
FastWhereXPUInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
x
.
dtype
());
}
}
// namespace phi
paddle/phi/infermeta/fusion.h
浏览文件 @
07e788f1
...
...
@@ -175,4 +175,10 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
const
std
::
string
&
act_type
,
MetaTensor
*
out
,
MetaTensor
*
out_max
);
void
FastWhereXPUInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
FastWhereXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
*
condition_data
=
condition
.
data
<
bool
>
();
auto
*
x_data
=
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
());
auto
*
y_data
=
reinterpret_cast
<
const
XPUType
*>
(
y
.
data
<
T
>
());
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
auto
condition_dims
=
phi
::
vectorize
<
int
>
(
condition
.
dims
());
auto
x_dims
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
auto
y_dims
=
phi
::
vectorize
<
int
>
(
y
.
dims
());
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
errors
::
PreconditionNotMet
(
"The dimensions of inputs should be equal, but x_dims=["
,
x
.
dims
(),
"] and y_dims=["
,
y
.
dims
(),
"]"
));
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG
(
WARNING
)
<<
"Add -DWITH_XPU_PLUGIN=ON to build xpu::plugin::fast_where(), or use "
"xpu::select() instead, which leads low performance."
;
int
r
=
xpu
::
select
<
XPUType
>
(
ctx
.
x_context
(),
condition_data
,
x_data
,
y_data
,
out_data
,
condition_dims
,
x_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"select"
);
#else
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
if
(
condition_dims
!=
x_dims
)
{
bool
*
temp_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
bool
>
(
x
.
numel
());
int
r
=
xpu
::
broadcast
<
bool
>
(
ctx
.
x_context
(),
condition_data
,
temp_data
,
condition_dims
,
x_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"broadcast"
);
condition_data
=
temp_data
;
}
int
r
=
xpu
::
plugin
::
fast_where
<
XPUType
>
(
ctx
.
x_context
(),
condition_data
,
x_data
,
y_data
,
out_data
,
x
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"fast_where"
);
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fast_where_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FastWhereXPUKernel
,
float
,
phi
::
dtype
::
float16
,
int
)
{}
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
浏览文件 @
07e788f1
...
...
@@ -154,7 +154,7 @@ macro(
${
kernel_path
}
-D
${
xpu_n_macro
}
--target=
${
TARGET_ARCH
}
${
HOST_XPU_FLAGS
}
--basename
${
kernel_name
}
-fno-builtin --xpu-arch=
${
xpu_n
}
-fPIC
-Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm
--xpu-inline-cost -mllvm --xpu-inline-hot-call
--xpu-inline-cost -mllvm --xpu-inline-hot-call
-I
${
XDNN_INC_DIR
}
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/include -I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/kernel
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/kernel/include
${
arg_rule
}
...
...
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
07e788f1
...
...
@@ -24,6 +24,13 @@ namespace api {
namespace
plugin
{
DLL_EXPORT
int
add2
(
Context
*
ctx
,
const
float
*
x
,
float
*
y
,
int
len
);
template
<
typename
T
>
DLL_EXPORT
int
fast_where
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
out
,
int64_t
len
);
}
// namespace plugin
}
// namespace api
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
#define CALC_MASK(offset) \
mask |= static_cast<int>(condition[i + offset]) << offset;
static __device__ inline void do_select_16(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
int len_rounddown32 = rounddown32(len);
for (int i = 0; i < len_rounddown32; i += 32) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
CALC_MASK(16)
CALC_MASK(17)
CALC_MASK(18)
CALC_MASK(19)
CALC_MASK(20)
CALC_MASK(21)
CALC_MASK(22)
CALC_MASK(23)
CALC_MASK(24)
CALC_MASK(25)
CALC_MASK(26)
CALC_MASK(27)
CALC_MASK(28)
CALC_MASK(29)
CALC_MASK(30)
CALC_MASK(31)
vstore_lm_int16x32_mh(y + i, vload_lm_int16x32(x + i), mask);
}
for (int i = len_rounddown32; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
static __device__ inline void do_select_32(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
int len_rounddown16 = rounddown16(len);
for (int i = 0; i < len_rounddown16; i += 16) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
vstore_lm_int32x16_mh(y + i, vload_lm_int32x16(x + i), mask);
}
for (int i = len_rounddown16; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
template <typename T>
static __device__ void do_select(const int8_t* condition,
const T* x,
T* y,
int len) {}
template <>
__device__ void do_select<float16>(const int8_t* condition,
const float16* x,
float16* y,
int len) {
do_select_16(condition,
reinterpret_cast<const int16_t*>(x),
reinterpret_cast<int16_t*>(y),
len);
}
template <>
__device__ void do_select<float>(const int8_t* condition,
const float* x,
float* y,
int len) {
do_select_32(condition,
reinterpret_cast<const int32_t*>(x),
reinterpret_cast<int32_t*>(y),
len);
}
template <>
__device__ void do_select<int16_t>(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
do_select_16(condition, x, y, len);
}
template <>
__device__ void do_select<int32_t>(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
do_select_32(condition, x, y, len);
}
template <typename T>
__global__ void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
#ifdef __XPU3__
const int buf_len = 1536 / sizeof(T);
#else
const int buf_len = 512 / sizeof(T);
#endif
__simd__ int8_t local_condition[buf_len];
__simd__ T local_x[buf_len];
__simd__ T local_y[buf_len];
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(condition + i, local_condition, read_len * sizeof(int8_t));
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
do_select<T>(local_condition, local_x, local_y, read_len);
LM2GM_ASYNC(local_y, z + i, read_len * sizeof(T));
mfence();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_WHERE_(DTYPE) \
template __global__ void fast_where<DTYPE>(const int8_t* condition, \
const DTYPE* x, \
const DTYPE* y, \
DTYPE* z, \
int64_t len);
_XPU_DEF__FAST_WHERE_(float16);
_XPU_DEF__FAST_WHERE_(float);
_XPU_DEF__FAST_WHERE_(int16_t);
_XPU_DEF__FAST_WHERE_(int32_t);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
>
__attribute__
((
global
))
void
fast_where
(
const
int8_t
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
);
}
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
i
++
)
{
z
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
}
return
SUCCESS
;
}
template
<
>
int
cpu_wrapper
<
float16
>
(
Context
*
ctx
,
const
bool
*
condition
,
const
float16
*
x
,
const
float16
*
y
,
float16
*
z
,
int64_t
len
)
{
std
::
vector
<
float
>
x_fp32
(
len
);
std
::
vector
<
float
>
y_fp32
(
len
);
std
::
vector
<
float
>
z_fp32
(
len
);
int
ret
=
cast
<
float16
,
float
>
(
ctx
,
x
,
x_fp32
.
data
(),
len
);
ret
=
cast
<
float16
,
float
>
(
ctx
,
y
,
y_fp32
.
data
(),
len
);
ret
=
cpu_wrapper
<
float
>
(
ctx
,
condition
,
x_fp32
.
data
(),
y_fp32
.
data
(),
z_fp32
.
data
(),
len
);
ret
=
cast
<
float
,
float16
>
(
ctx
,
z_fp32
.
data
(),
z
,
len
);
WRAPPER_ASSERT_SUCCESS
(
ctx
,
ret
);
return
ret
;
}
template
<
typename
T
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
xpu2
::
plugin
::
fast_where
<
T
><<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
reinterpret_cast
<
const
int8_t
*>
(
condition
),
x
,
y
,
z
,
len
);
return
SUCCESS
;
}
template
<
typename
T
>
int
fast_where
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T1
(
ctx
,
"fast_where"
,
float
);
WRAPPER_DUMP_PARAM5
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
WRAPPER_DUMP
(
ctx
);
WRAPPER_ASSERT_GT
(
ctx
,
len
,
0
);
WRAPPER_CHECK_2PTRS
(
ctx
,
T
,
len
,
x
,
y
);
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
>
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
>
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
float
*
,
const
float
*
,
float
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
float16
*
,
const
float16
*
,
float16
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
int16_t
*
,
const
int16_t
*
,
int16_t
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
int32_t
*
,
const
int32_t
*
,
int32_t
*
,
int64_t
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
0 → 100644
浏览文件 @
07e788f1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import
unittest
from
functools
import
partial
import
hypothesis.strategies
as
st
import
numpy
as
np
from
auto_scan_test
import
PassAutoScanTest
from
program_config
import
OpConfig
,
ProgramConfig
,
TensorConfig
class
TestFastWhereXPUFusePassOneCase0
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"x"
],
"Y"
:
[
"scale_out"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"y"
],
"Y"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
scale_op
,
mul0_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassOneCase1
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"x"
],
"Y"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"y"
],
"Y"
:
[
"scale_out"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
mul0_op
,
scale_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassOneCase2
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast_out"
],
"Y"
:
[
"y"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
scale_op
,
mul0_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassOneCase3
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale_out"
],
"Y"
:
[
"y"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
mul0_op
,
scale_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassOneCase4
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"y"
],
"Y"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
scale_op
,
mul0_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassOneCase5
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
cast_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition"
]},
outputs
=
{
"Out"
:
[
"cast_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
scale_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast_out"
]},
outputs
=
{
"Out"
:
[
"scale_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"y"
],
"Y"
:
[
"scale_out"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
ops
=
[
cast_op
,
mul0_op
,
scale_op
,
mul1_op
,
add_op
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassCascadeCase0
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"logical_or"
,
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
# fast_where_xpu0
cast0_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition0"
]},
outputs
=
{
"Out"
:
[
"cast0_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast0_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
scale0_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast0_out"
]},
outputs
=
{
"Out"
:
[
"scale0_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale0_out"
],
"Y"
:
[
"y"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add0_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
# fast_where_xpu1
cast1_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition1"
]},
outputs
=
{
"Out"
:
[
"cast1_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul2_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast1_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul2_out"
]},
axis
=-
1
,
)
scale1_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast1_out"
]},
outputs
=
{
"Out"
:
[
"scale1_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul3_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale1_out"
],
"Y"
:
[
"add0_out"
]},
outputs
=
{
"Out"
:
[
"mul3_out"
]},
axis
=-
1
,
)
add1_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul2_out"
],
"Y"
:
[
"mul3_out"
]},
outputs
=
{
"Out"
:
[
"add1_out"
]},
axis
=-
1
,
)
ops
=
[
cast0_op
,
mul0_op
,
scale0_op
,
mul1_op
,
add0_op
,
cast1_op
,
mul2_op
,
scale1_op
,
mul3_op
,
add1_op
,
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition0"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)
),
"condition1"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)
),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
class
TestFastWhereXPUFusePassCascadeCase1
(
PassAutoScanTest
):
def
sample_predictor_configs
(
self
,
program_config
):
config
=
self
.
create_inference_config
(
use_xpu
=
True
)
yield
config
,
[
"logical_and"
,
"fast_where_xpu"
],
(
1e-3
,
1e-3
)
def
sample_program_config
(
self
,
draw
):
value_shape
=
draw
(
st
.
lists
(
st
.
integers
(
min_value
=
1
,
max_value
=
4
),
min_size
=
2
,
max_size
=
4
)
)
condition_shape
=
value_shape
condition_shape
[
-
1
]
=
1
def
generate_condition
():
return
np
.
random
.
random
(
condition_shape
).
astype
(
bool
)
def
generate_value
():
return
np
.
random
.
random
(
value_shape
).
astype
(
np
.
float32
)
# fast_where_xpu0
cast0_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition0"
]},
outputs
=
{
"Out"
:
[
"cast0_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul0_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast0_out"
],
"Y"
:
[
"x"
]},
outputs
=
{
"Out"
:
[
"mul0_out"
]},
axis
=-
1
,
)
scale0_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast0_out"
]},
outputs
=
{
"Out"
:
[
"scale0_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul1_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale0_out"
],
"Y"
:
[
"y"
]},
outputs
=
{
"Out"
:
[
"mul1_out"
]},
axis
=-
1
,
)
add0_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul0_out"
],
"Y"
:
[
"mul1_out"
]},
outputs
=
{
"Out"
:
[
"add0_out"
]},
axis
=-
1
,
)
# fast_where_xpu1
cast1_op
=
OpConfig
(
"cast"
,
inputs
=
{
"X"
:
[
"condition1"
]},
outputs
=
{
"Out"
:
[
"cast1_out"
]},
in_dtype
=
0
,
out_dtype
=
5
,
)
mul2_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"cast1_out"
],
"Y"
:
[
"add0_out"
]},
outputs
=
{
"Out"
:
[
"mul2_out"
]},
axis
=-
1
,
)
scale1_op
=
OpConfig
(
"scale"
,
inputs
=
{
"X"
:
[
"cast1_out"
]},
outputs
=
{
"Out"
:
[
"scale1_out"
]},
scale
=-
1
,
bias
=
1
,
base_after_scale
=
True
,
)
mul3_op
=
OpConfig
(
"elementwise_mul"
,
inputs
=
{
"X"
:
[
"scale1_out"
],
"Y"
:
[
"y"
]},
outputs
=
{
"Out"
:
[
"mul3_out"
]},
axis
=-
1
,
)
add1_op
=
OpConfig
(
"elementwise_add"
,
inputs
=
{
"X"
:
[
"mul2_out"
],
"Y"
:
[
"mul3_out"
]},
outputs
=
{
"Out"
:
[
"add1_out"
]},
axis
=-
1
,
)
ops
=
[
cast0_op
,
mul0_op
,
scale0_op
,
mul1_op
,
add0_op
,
cast1_op
,
mul2_op
,
scale1_op
,
mul3_op
,
add1_op
,
]
program_config
=
ProgramConfig
(
ops
=
ops
,
weights
=
{},
inputs
=
{
"condition0"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)
),
"condition1"
:
TensorConfig
(
data_gen
=
partial
(
generate_condition
)
),
"x"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
"y"
:
TensorConfig
(
data_gen
=
partial
(
generate_value
)),
},
outputs
=
ops
[
-
1
].
outputs
[
"Out"
],
)
return
program_config
def
test
(
self
):
self
.
run_and_statis
(
quant
=
False
,
max_examples
=
25
,
passes
=
[
"fast_where_xpu_fuse_pass"
],
)
if
__name__
==
"__main__"
:
np
.
random
.
seed
(
200
)
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录