Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
07e788f1
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
07e788f1
编写于
8月 01, 2023
作者:
H
hong19860320
提交者:
GitHub
8月 01, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU] Add fast_where fusion op and XPU micro kernel (#55628)
上级
744e1eaf
变更
14
展开全部
显示空白变更内容
内联
并排
Showing
14 changed file
with
2139 addition
and
1 deletion
+2139
-1
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+5
-0
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
+658
-0
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
...e/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
+304
-0
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+9
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+4
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+8
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+6
-0
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
+81
-0
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
+1
-1
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
+7
-0
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
...i/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
+191
-0
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
+128
-0
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
+736
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
07e788f1
...
...
@@ -280,6 +280,7 @@ if(WITH_XPU)
pass_library
(
matmul_weight_trans_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
reshape2_matmul_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
fast_where_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
endif
()
cc_library
(
...
...
@@ -599,4 +600,8 @@ if(WITH_XPU)
test_reshape2_matmul_xpu_fuse_pass
SRCS xpu/reshape2_matmul_xpu_fuse_pass_test.cc
DEPS reshape2_matmul_xpu_fuse_pass
)
cc_test
(
test_fast_where_xpu_fuse_pass
SRCS xpu/fast_where_xpu_fuse_pass_test.cc
DEPS fast_where_xpu_fuse_pass
)
endif
()
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
07e788f1
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/xpu/fast_where_xpu_fuse_pass_test.cc
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace
paddle
{
namespace
framework
{
namespace
ir
{
#define APPLY_PASS \
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program())); \
auto pass = PassRegistry::Instance().Get("fast_where_xpu_fuse_pass"); \
pass->Apply(graph.get());
#define VERIFY_GRAPH(x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only one op node, but %d op nodes found.", \
num_op_nodes)); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST
(
FastWhereXPUFusePass
,
one_case0
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
x
,
scale_out
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
cast_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case1
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
x
,
cast_out
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
scale_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
one_case2
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
scale_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
cast_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case3
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
one_case4
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
scale_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
cast_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
y
,
x
)
}
TEST
(
FastWhereXPUFusePass
,
one_case5
)
{
Layers
layers
;
auto
*
condition
=
layers
.
data
(
"condition"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
auto
*
cast_out
=
layers
.
cast
(
condition
,
0
,
5
);
cast_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale_out
=
layers
.
scale
(
cast_out
,
-
1.0
f
,
1.0
f
,
true
);
scale_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
y
,
scale_out
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
x
,
y
)
}
#undef VERIFY_GRAPH
#define VERIFY_GRAPH(logical_op, x, y) \
auto num_op_nodes = GetNumOpNodes(graph); \
PADDLE_ENFORCE_EQ( \
num_op_nodes, \
2, \
platform::errors::PreconditionNotMet( \
"The graph contains only two op nodes, but %d op nodes found.", \
num_op_nodes)); \
auto logical_op_nodes = GetOpNodes(graph, #logical_op); \
PADDLE_ENFORCE_EQ( \
logical_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a '%s' op node, but %d op nodes found.", \
#logical_op, \
logical_op_nodes.size())); \
auto fast_where_xpu_op_nodes = GetOpNodes(graph, "fast_where_xpu"); \
PADDLE_ENFORCE_EQ(fast_where_xpu_op_nodes.size(), \
1, \
platform::errors::PreconditionNotMet( \
"The graph contains only a fast_where_xpu op node, " \
"but %d op nodes found.", \
fast_where_xpu_op_nodes.size())); \
const auto& x_name = fast_where_xpu_op_nodes[0]->Op()->Input("x")[0]; \
PADDLE_ENFORCE_EQ(x_name, \
#x, \
platform::errors::PreconditionNotMet( \
"The input 'x' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#x, \
x_name)); \
const auto& y_name = fast_where_xpu_op_nodes[0]->Op()->Input("y")[0]; \
PADDLE_ENFORCE_EQ(y_name, \
#y, \
platform::errors::PreconditionNotMet( \
"The input 'y' of fast_where_xpu op should be '%s', " \
"but receive '%s'.", \
#y, \
y_name));
TEST
(
FastWhereXPUFusePass
,
cascade_case0
)
{
Layers
layers
;
auto
*
condition0
=
layers
.
data
(
"condition0"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
condition1
=
layers
.
data
(
"condition1"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
// fast_where_xpu0
auto
*
cast0_out
=
layers
.
cast
(
condition0
,
0
,
5
);
cast0_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast0_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale0_out
=
layers
.
scale
(
cast0_out
,
-
1.0
f
,
1.0
f
,
true
);
scale0_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale0_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add0_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add0_out
->
SetShape
({
20
,
7
});
// fast_where_xpu1
auto
*
cast1_out
=
layers
.
cast
(
condition1
,
0
,
5
);
cast1_out
->
SetShape
({
20
,
1
});
auto
*
mul2_out
=
layers
.
elementwise_mul
(
cast1_out
,
x
);
mul2_out
->
SetShape
({
20
,
7
});
auto
*
scale1_out
=
layers
.
scale
(
cast1_out
,
-
1.0
f
,
1.0
f
,
true
);
scale1_out
->
SetShape
({
20
,
1
});
auto
*
mul3_out
=
layers
.
elementwise_mul
(
scale1_out
,
add0_out
);
mul3_out
->
SetShape
({
20
,
7
});
auto
*
add1_out
=
layers
.
elementwise_add
(
mul2_out
,
mul3_out
);
add1_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
logical_or
,
x
,
y
)
}
TEST
(
FastWhereXPUFusePass
,
cascade_case1
)
{
Layers
layers
;
auto
*
condition0
=
layers
.
data
(
"condition0"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
condition1
=
layers
.
data
(
"condition1"
,
{
20
,
1
},
false
,
proto
::
VarType
::
BOOL
);
auto
*
x
=
layers
.
data
(
"x"
,
{
20
,
7
});
auto
*
y
=
layers
.
data
(
"y"
,
{
20
,
7
});
// fast_where_xpu0
auto
*
cast0_out
=
layers
.
cast
(
condition0
,
0
,
5
);
cast0_out
->
SetShape
({
20
,
1
});
auto
*
mul0_out
=
layers
.
elementwise_mul
(
cast0_out
,
x
);
mul0_out
->
SetShape
({
20
,
7
});
auto
*
scale0_out
=
layers
.
scale
(
cast0_out
,
-
1.0
f
,
1.0
f
,
true
);
scale0_out
->
SetShape
({
20
,
1
});
auto
*
mul1_out
=
layers
.
elementwise_mul
(
scale0_out
,
y
);
mul1_out
->
SetShape
({
20
,
7
});
auto
*
add0_out
=
layers
.
elementwise_add
(
mul0_out
,
mul1_out
);
add0_out
->
SetShape
({
20
,
7
});
// fast_where_xpu1
auto
*
cast1_out
=
layers
.
cast
(
condition1
,
0
,
5
);
cast1_out
->
SetShape
({
20
,
1
});
auto
*
mul2_out
=
layers
.
elementwise_mul
(
cast1_out
,
add0_out
);
mul2_out
->
SetShape
({
20
,
7
});
auto
*
scale1_out
=
layers
.
scale
(
cast1_out
,
-
1.0
f
,
1.0
f
,
true
);
scale1_out
->
SetShape
({
20
,
1
});
auto
*
mul3_out
=
layers
.
elementwise_mul
(
scale1_out
,
y
);
mul3_out
->
SetShape
({
20
,
7
});
auto
*
add1_out
=
layers
.
elementwise_add
(
mul2_out
,
mul3_out
);
add1_out
->
SetShape
({
20
,
7
});
APPLY_PASS
VERIFY_GRAPH
(
logical_and
,
x
,
y
)
}
#undef APPLY_PASS
#undef VERIFY_GRAPH
}
// namespace ir
}
// namespace framework
}
// namespace paddle
USE_PASS
(
fast_where_xpu_fuse_pass
);
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
07e788f1
...
...
@@ -545,6 +545,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"add_activation_xpu_fuse_pass"
,
"add_layernorm_xpu_fuse_pass"
,
"yolo_box_xpu_fuse_pass"
,
"fast_where_xpu_fuse_pass"
,
"link_xpu_op_max_pass"
,
"delete_isolated_node_pass"
,
// "auto_mixed_precision_pass",
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
07e788f1
...
...
@@ -53,6 +53,15 @@
data_type
:
tables
optional
:
mask, seq_lod, max_seq_len
-
op
:
fast_where_xpu
args
:
(Tensor condition, Tensor x, Tensor y)
output
:
Tensor(out)
infer_meta
:
func
:
FastWhereXPUInferMeta
kernel
:
func
:
fast_where_xpu
data_type
:
x
-
op
:
fc_xpu
args
:
(Tensor x, Tensor x_max, Tensor w, Tensor w_max, Tensor bias, int in_num_col_dims, bool transpose_x, float alpha, float beta, int act_type, float act_alpha, DataType out_dtype)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
07e788f1
...
...
@@ -295,6 +295,10 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
BOOL
,
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
FLOAT32
})},
{
"fast_where_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
INT32
,
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fc_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"fill"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
07e788f1
...
...
@@ -721,4 +721,12 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
out_max
);
}
void
FastWhereXPUInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
)
{
out
->
set_dims
(
x
.
dims
());
out
->
set_dtype
(
x
.
dtype
());
}
}
// namespace phi
paddle/phi/infermeta/fusion.h
浏览文件 @
07e788f1
...
...
@@ -175,4 +175,10 @@ void Conv2dTransposeXPUInferMeta(const MetaTensor& x,
const
std
::
string
&
act_type
,
MetaTensor
*
out
,
MetaTensor
*
out_max
);
void
FastWhereXPUInferMeta
(
const
MetaTensor
&
condition
,
const
MetaTensor
&
x
,
const
MetaTensor
&
y
,
MetaTensor
*
out
);
}
// namespace phi
paddle/phi/kernels/fusion/xpu/fast_where_xpu_kernel.cc
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "glog/logging.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
FastWhereXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
condition
,
const
DenseTensor
&
x
,
const
DenseTensor
&
y
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
*
condition_data
=
condition
.
data
<
bool
>
();
auto
*
x_data
=
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
());
auto
*
y_data
=
reinterpret_cast
<
const
XPUType
*>
(
y
.
data
<
T
>
());
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
auto
condition_dims
=
phi
::
vectorize
<
int
>
(
condition
.
dims
());
auto
x_dims
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
auto
y_dims
=
phi
::
vectorize
<
int
>
(
y
.
dims
());
PADDLE_ENFORCE_EQ
(
x_dims
,
y_dims
,
errors
::
PreconditionNotMet
(
"The dimensions of inputs should be equal, but x_dims=["
,
x
.
dims
(),
"] and y_dims=["
,
y
.
dims
(),
"]"
));
#ifndef PADDLE_WITH_XPU_PLUGIN
LOG
(
WARNING
)
<<
"Add -DWITH_XPU_PLUGIN=ON to build xpu::plugin::fast_where(), or use "
"xpu::select() instead, which leads low performance."
;
int
r
=
xpu
::
select
<
XPUType
>
(
ctx
.
x_context
(),
condition_data
,
x_data
,
y_data
,
out_data
,
condition_dims
,
x_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"select"
);
#else
xpu
::
ctx_guard
RAII_GUARD
(
ctx
.
x_context
());
if
(
condition_dims
!=
x_dims
)
{
bool
*
temp_data
=
RAII_GUARD
.
alloc_l3_or_gm
<
bool
>
(
x
.
numel
());
int
r
=
xpu
::
broadcast
<
bool
>
(
ctx
.
x_context
(),
condition_data
,
temp_data
,
condition_dims
,
x_dims
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"broadcast"
);
condition_data
=
temp_data
;
}
int
r
=
xpu
::
plugin
::
fast_where
<
XPUType
>
(
ctx
.
x_context
(),
condition_data
,
x_data
,
y_data
,
out_data
,
x
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"fast_where"
);
#endif
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
fast_where_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
FastWhereXPUKernel
,
float
,
phi
::
dtype
::
float16
,
int
)
{}
paddle/phi/kernels/xpu/plugin/CMakeLists.txt
浏览文件 @
07e788f1
...
...
@@ -154,7 +154,7 @@ macro(
${
kernel_path
}
-D
${
xpu_n_macro
}
--target=
${
TARGET_ARCH
}
${
HOST_XPU_FLAGS
}
--basename
${
kernel_name
}
-fno-builtin --xpu-arch=
${
xpu_n
}
-fPIC
-Wno-int-to-void-pointer-cast -Wno-int-to-pointer-cast -Werror -mllvm
--xpu-inline-cost -mllvm --xpu-inline-hot-call
--xpu-inline-cost -mllvm --xpu-inline-hot-call
-I
${
XDNN_INC_DIR
}
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/include -I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/kernel
-I
${
CMAKE_CURRENT_SOURCE_DIR
}
/src/kernel/include
${
arg_rule
}
...
...
paddle/phi/kernels/xpu/plugin/include/xpu/plugin.h
浏览文件 @
07e788f1
...
...
@@ -24,6 +24,13 @@ namespace api {
namespace
plugin
{
DLL_EXPORT
int
add2
(
Context
*
ctx
,
const
float
*
x
,
float
*
y
,
int
len
);
template
<
typename
T
>
DLL_EXPORT
int
fast_where
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
out
,
int64_t
len
);
}
// namespace plugin
}
// namespace api
...
...
paddle/phi/kernels/xpu/plugin/src/kernel/kunlun2cpp/fast_where.xpu
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/kernel/cluster.h"
#include "xpu/kernel/cluster_partition.h"
#include "xpu/kernel/cluster_primitive.h"
namespace xpu2 {
namespace plugin {
#define CALC_MASK(offset) \
mask |= static_cast<int>(condition[i + offset]) << offset;
static __device__ inline void do_select_16(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
int len_rounddown32 = rounddown32(len);
for (int i = 0; i < len_rounddown32; i += 32) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
CALC_MASK(16)
CALC_MASK(17)
CALC_MASK(18)
CALC_MASK(19)
CALC_MASK(20)
CALC_MASK(21)
CALC_MASK(22)
CALC_MASK(23)
CALC_MASK(24)
CALC_MASK(25)
CALC_MASK(26)
CALC_MASK(27)
CALC_MASK(28)
CALC_MASK(29)
CALC_MASK(30)
CALC_MASK(31)
vstore_lm_int16x32_mh(y + i, vload_lm_int16x32(x + i), mask);
}
for (int i = len_rounddown32; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
static __device__ inline void do_select_32(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
int len_rounddown16 = rounddown16(len);
for (int i = 0; i < len_rounddown16; i += 16) {
int mask = condition[i];
CALC_MASK(1)
CALC_MASK(2)
CALC_MASK(3)
CALC_MASK(4)
CALC_MASK(5)
CALC_MASK(6)
CALC_MASK(7)
CALC_MASK(8)
CALC_MASK(9)
CALC_MASK(10)
CALC_MASK(11)
CALC_MASK(12)
CALC_MASK(13)
CALC_MASK(14)
CALC_MASK(15)
vstore_lm_int32x16_mh(y + i, vload_lm_int32x16(x + i), mask);
}
for (int i = len_rounddown16; i < len; i++) {
y[i] = condition[i] ? x[i] : y[i];
}
mfence_lm();
}
template <typename T>
static __device__ void do_select(const int8_t* condition,
const T* x,
T* y,
int len) {}
template <>
__device__ void do_select<float16>(const int8_t* condition,
const float16* x,
float16* y,
int len) {
do_select_16(condition,
reinterpret_cast<const int16_t*>(x),
reinterpret_cast<int16_t*>(y),
len);
}
template <>
__device__ void do_select<float>(const int8_t* condition,
const float* x,
float* y,
int len) {
do_select_32(condition,
reinterpret_cast<const int32_t*>(x),
reinterpret_cast<int32_t*>(y),
len);
}
template <>
__device__ void do_select<int16_t>(const int8_t* condition,
const int16_t* x,
int16_t* y,
int len) {
do_select_16(condition, x, y, len);
}
template <>
__device__ void do_select<int32_t>(const int8_t* condition,
const int32_t* x,
int32_t* y,
int len) {
do_select_32(condition, x, y, len);
}
template <typename T>
__global__ void fast_where(
const int8_t* condition, const T* x, const T* y, T* z, int64_t len) {
int tid = core_id() * cluster_num() + cluster_id();
int nthreads = core_num() * cluster_num();
#ifdef __XPU3__
const int buf_len = 1536 / sizeof(T);
#else
const int buf_len = 512 / sizeof(T);
#endif
__simd__ int8_t local_condition[buf_len];
__simd__ T local_x[buf_len];
__simd__ T local_y[buf_len];
int loop = 0;
for (int64_t i = tid * buf_len; i < len; i += nthreads * buf_len) {
int read_len = min(static_cast<int64_t>(buf_len), len - i);
GM2LM_ASYNC(condition + i, local_condition, read_len * sizeof(int8_t));
GM2LM_ASYNC(x + i, local_x, read_len * sizeof(T));
GM2LM(y + i, local_y, read_len * sizeof(T));
do_select<T>(local_condition, local_x, local_y, read_len);
LM2GM_ASYNC(local_y, z + i, read_len * sizeof(T));
mfence();
#ifndef __XPU3__
loop++;
if ((loop & 0xF) == 0) {
sync_all();
}
#endif
}
}
#define _XPU_DEF__FAST_WHERE_(DTYPE) \
template __global__ void fast_where<DTYPE>(const int8_t* condition, \
const DTYPE* x, \
const DTYPE* y, \
DTYPE* z, \
int64_t len);
_XPU_DEF__FAST_WHERE_(float16);
_XPU_DEF__FAST_WHERE_(float);
_XPU_DEF__FAST_WHERE_(int16_t);
_XPU_DEF__FAST_WHERE_(int32_t);
} // namespace plugin
} // namespace xpu2
paddle/phi/kernels/xpu/plugin/src/wrapper/fast_where.cpp
0 → 100644
浏览文件 @
07e788f1
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
/*
* copyright (C) 2022 KUNLUNXIN, Inc
*/
#include "xpu/plugin.h"
#include "xpu/refactor/impl_public/wrapper_check.h"
namespace
xpu2
{
namespace
plugin
{
template
<
typename
T
>
__attribute__
((
global
))
void
fast_where
(
const
int8_t
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
);
}
}
// namespace xpu2
namespace
baidu
{
namespace
xpu
{
namespace
api
{
namespace
plugin
{
template
<
typename
T
>
static
int
cpu_wrapper
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
for
(
int64_t
i
=
0
;
i
<
len
;
i
++
)
{
z
[
i
]
=
condition
[
i
]
?
x
[
i
]
:
y
[
i
];
}
return
SUCCESS
;
}
template
<
>
int
cpu_wrapper
<
float16
>
(
Context
*
ctx
,
const
bool
*
condition
,
const
float16
*
x
,
const
float16
*
y
,
float16
*
z
,
int64_t
len
)
{
std
::
vector
<
float
>
x_fp32
(
len
);
std
::
vector
<
float
>
y_fp32
(
len
);
std
::
vector
<
float
>
z_fp32
(
len
);
int
ret
=
cast
<
float16
,
float
>
(
ctx
,
x
,
x_fp32
.
data
(),
len
);
ret
=
cast
<
float16
,
float
>
(
ctx
,
y
,
y_fp32
.
data
(),
len
);
ret
=
cpu_wrapper
<
float
>
(
ctx
,
condition
,
x_fp32
.
data
(),
y_fp32
.
data
(),
z_fp32
.
data
(),
len
);
ret
=
cast
<
float
,
float16
>
(
ctx
,
z_fp32
.
data
(),
z
,
len
);
WRAPPER_ASSERT_SUCCESS
(
ctx
,
ret
);
return
ret
;
}
template
<
typename
T
>
static
int
xpu2_wrapper
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
xpu2
::
plugin
::
fast_where
<
T
><<<
ctx
->
ncluster
(),
64
,
ctx
->
xpu_stream
>>>
(
reinterpret_cast
<
const
int8_t
*>
(
condition
),
x
,
y
,
z
,
len
);
return
SUCCESS
;
}
template
<
typename
T
>
int
fast_where
(
Context
*
ctx
,
const
bool
*
condition
,
const
T
*
x
,
const
T
*
y
,
T
*
z
,
int64_t
len
)
{
WRAPPER_CHECK_CTX
(
ctx
);
WRAPPER_DUMP_FUNCTION_T1
(
ctx
,
"fast_where"
,
float
);
WRAPPER_DUMP_PARAM5
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
WRAPPER_DUMP
(
ctx
);
WRAPPER_ASSERT_GT
(
ctx
,
len
,
0
);
WRAPPER_CHECK_2PTRS
(
ctx
,
T
,
len
,
x
,
y
);
if
(
ctx
->
dev
().
type
()
==
api
::
kCPU
)
{
return
cpu_wrapper
<
T
>
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
}
if
(
ctx
->
dev
().
type
()
==
api
::
kXPU2
)
{
return
xpu2_wrapper
<
T
>
(
ctx
,
condition
,
x
,
y
,
z
,
len
);
}
WRAPPER_UNIMPLEMENTED
(
ctx
);
}
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
float
*
,
const
float
*
,
float
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
float16
*
,
const
float16
*
,
float16
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
int16_t
*
,
const
int16_t
*
,
int16_t
*
,
int64_t
);
template
int
fast_where
(
Context
*
,
const
bool
*
condition
,
const
int32_t
*
,
const
int32_t
*
,
int32_t
*
,
int64_t
);
}
// namespace plugin
}
// namespace api
}
// namespace xpu
}
// namespace baidu
test/ir/inference/test_xpu_fast_where_xpu_fuse_pass.py
0 → 100644
浏览文件 @
07e788f1
此差异已折叠。
点击以展开。
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录