Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
22c7a6eb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
22c7a6eb
编写于
8月 02, 2023
作者:
W
wz1qqx
提交者:
GitHub
8月 02, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[XPU]Add conv1d fuse pass (#55719)
上级
63b7fc80
变更
12
展开全部
隐藏空白更改
内联
并排
Showing
12 changed file
with
973 addition
and
11 deletion
+973
-11
paddle/fluid/framework/ir/CMakeLists.txt
paddle/fluid/framework/ir/CMakeLists.txt
+1
-0
paddle/fluid/framework/ir/xpu/conv1d_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/conv1d_xpu_fuse_pass.cc
+714
-0
paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
+0
-6
paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
+3
-4
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
+1
-1
paddle/fluid/inference/api/paddle_pass_builder.cc
paddle/fluid/inference/api/paddle_pass_builder.cc
+1
-0
paddle/phi/api/yaml/fused_ops.yaml
paddle/phi/api/yaml/fused_ops.yaml
+10
-0
paddle/phi/backends/xpu/xpu2_op_list.cc
paddle/phi/backends/xpu/xpu2_op_list.cc
+3
-0
paddle/phi/infermeta/fusion.cc
paddle/phi/infermeta/fusion.cc
+93
-0
paddle/phi/infermeta/fusion.h
paddle/phi/infermeta/fusion.h
+17
-0
paddle/phi/kernels/fusion/xpu/conv1d_xpu_kernel.cc
paddle/phi/kernels/fusion/xpu/conv1d_xpu_kernel.cc
+111
-0
paddle/phi/kernels/xpu/activation_kernel.cc
paddle/phi/kernels/xpu/activation_kernel.cc
+19
-0
未找到文件。
paddle/fluid/framework/ir/CMakeLists.txt
浏览文件 @
22c7a6eb
...
...
@@ -238,6 +238,7 @@ if(WITH_XPU)
pass_library
(
cast_mixed_precision_op_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
yolo_box_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv1d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
conv2d_xpu_fuse_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
pass_library
(
redundant_onnx_ops_elimination_pass inference DIR xpu DEPS
${
XPU_PASS_DEPS
}
)
...
...
paddle/fluid/framework/ir/xpu/conv1d_xpu_fuse_pass.cc
0 → 100644
浏览文件 @
22c7a6eb
此差异已折叠。
点击以展开。
paddle/fluid/framework/ir/xpu/conv2d_xpu_fuse_pass.cc
浏览文件 @
22c7a6eb
...
...
@@ -562,12 +562,6 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
}
conv2d_xpu_op_desc
.
SetAttr
(
"act_type"
,
ConvertActivationType
(
act_type
));
conv2d_xpu_op_desc
.
SetAttr
(
"act_param"
,
act_param_
);
std
::
vector
<
int
>
conv_bias
;
if
(
has_bias
)
{
conv_bias
.
push_back
(
1
);
}
else
{
conv_bias
.
push_back
(
0
);
}
conv2d_xpu_op_desc
.
SetAttr
(
"padding_algorithm"
,
conv
->
Op
()
->
GetAttrIfExists
<
std
::
string
>
(
"padding_algorithm"
));
...
...
paddle/fluid/framework/ir/xpu/link_xpu_op_max_pass.cc
浏览文件 @
22c7a6eb
...
...
@@ -165,10 +165,9 @@ void LinkXPUOpMaxPass::LinkConv2dMax(ir::Graph* graph, bool with_branch) const {
GET_IR_NODE
(
x
);
GET_IR_NODE
(
branch
);
auto
*
fusion_op_desc
=
fusion_op
->
Op
();
if
(
fusion_op_desc
->
HasAttr
(
"has_branch"
))
{
bool
fusion_op_branch
=
PADDLE_GET_CONST
(
bool
,
fusion_op_desc
->
GetAttr
(
"has_branch"
));
if
(
fusion_op_branch
!=
with_branch
)
{
bool
fusion_op_has_branch
=
fusion_op_desc
->
HasInput
(
"branch"
);
if
(
fusion_op_has_branch
)
{
if
(
fusion_op_has_branch
!=
with_branch
)
{
return
;
}
}
...
...
paddle/fluid/framework/ir/xpu/reduce_ops_fuse_pass.cc
浏览文件 @
22c7a6eb
...
...
@@ -295,7 +295,7 @@ void ReduceOpsFusePass::FuseReduceMean(ir::Graph* graph) const {
framework
::
OpDesc
reduce_op_desc
(
block
);
reduce_op_desc
.
SetType
(
"reduce_mean"
);
reduce_op_desc
.
SetInput
(
"X"
,
{
x
->
Name
()});
reduce_op_desc
.
SetAttr
(
"dim"
,
std
::
vector
<
int
>
{
-
2
});
reduce_op_desc
.
SetAttr
(
"dim"
,
std
::
vector
<
int
>
{
-
1
});
reduce_op_desc
.
SetAttr
(
"reduce_all"
,
false
);
reduce_op_desc
.
SetAttr
(
"keep_dim"
,
true
);
reduce_op_desc
.
SetOutput
(
"Out"
,
{
squeeze2_out
->
Name
()});
...
...
paddle/fluid/inference/api/paddle_pass_builder.cc
浏览文件 @
22c7a6eb
...
...
@@ -526,6 +526,7 @@ XpuPassStrategy::XpuPassStrategy() : PassStrategy({}) {
"one_beam_size_fuse_pass"
,
"fold_interp_outsize_fuse_pass"
,
"fold_two_squeeze2_fuse_pass"
,
"conv1d_xpu_fuse_pass"
,
"redundant_onnx_ops_elimination_pass"
,
"reduce_ops_fuse_pass"
,
"delete_cast_op_pass"
,
...
...
paddle/phi/api/yaml/fused_ops.yaml
浏览文件 @
22c7a6eb
...
...
@@ -23,6 +23,16 @@
func
:
add_layernorm_xpu
data_type
:
x
-
op
:
conv1d_xpu
args
:
(Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, str padding_algorithm, int dilations, int strides, int groups, int act_type, float act_param)
output
:
Tensor(out), Tensor(out_max)
infer_meta
:
func
:
Conv1dXPUInferMeta
kernel
:
func
:
conv1d_xpu
data_type
:
x
optional
:
bias, branch, branch_max, x_max
-
op
:
conv2d_transpose_xpu
args
:
(Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, int[] strides, int[] paddings, int[] output_padding, IntArray output_size, str padding_algorithm, int groups, int[] dilations, str data_format, bool has_bias, bool with_act, str act_type)
output
:
Tensor(out), Tensor(out_max)
...
...
paddle/phi/backends/xpu/xpu2_op_list.cc
浏览文件 @
22c7a6eb
...
...
@@ -167,6 +167,8 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"conv2d"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"conv1d_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"conv2d_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"conv3d_grad"
,
...
...
@@ -261,6 +263,7 @@ XPUOpMap& get_kl2_ops() {
phi
::
DataType
::
FLOAT16
,
phi
::
DataType
::
INT64
,
phi
::
DataType
::
INT32
})},
{
"elu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"embedding_with_eltwise_add_xpu"
,
XPUKernelSet
({
phi
::
DataType
::
FLOAT32
,
phi
::
DataType
::
FLOAT16
})},
{
"empty"
,
...
...
paddle/phi/infermeta/fusion.cc
浏览文件 @
22c7a6eb
...
...
@@ -139,6 +139,99 @@ inline int ConvOutSize(int input_size,
return
output_size
;
}
void
Conv1dXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
filter
,
const
MetaTensor
&
filter_max
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
branch
,
const
MetaTensor
&
branch_max
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
string
&
padding_algorithm
,
int
dilations
,
int
strides
,
int
groups
,
int
act_type
,
float
act_param
,
MetaTensor
*
out
,
MetaTensor
*
out_max
)
{
auto
in_dims
=
x
.
dims
();
auto
filter_dims
=
filter
.
dims
();
// do some checks
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
3
,
phi
::
errors
::
InvalidArgument
(
"The input of Op(Conv_xpu) should be a 3-D Tensor. But "
"received: input's dimension is %u, input's shape is [%s]."
,
in_dims
.
size
(),
in_dims
));
PADDLE_ENFORCE_EQ
(
in_dims
.
size
(),
filter_dims
.
size
(),
phi
::
errors
::
InvalidArgument
(
"The input's dimension and filter's dimension of "
"Op(Conv_xpu) should be equal. But received: the input's shape is "
"[%s], "
"the input's dimension is %d; the filter's shape is [%s], "
"the filter's dimension is %d."
,
in_dims
,
in_dims
.
size
(),
filter_dims
,
filter_dims
.
size
()));
const
auto
input_channels
=
in_dims
[
1
];
PADDLE_ENFORCE_GT
(
dilations
,
0
,
phi
::
errors
::
InvalidArgument
(
"The dilation of Op(Conv) should be larget than 0, but received "
"dilation is %d."
,
dilations
));
PADDLE_ENFORCE_EQ
(
input_channels
,
filter_dims
[
1
]
*
groups
,
phi
::
errors
::
InvalidArgument
(
"The number of input's channels should be equal to filter's channels "
"* groups for Op(Conv_xpu). But received: the input's channels is "
"%d, "
"the input's shape is [%s]; the filter's channels is %d, the "
"filter's shape is [%s]; the groups is %d. "
,
input_channels
,
in_dims
,
filter_dims
[
1
],
filter_dims
,
groups
));
PADDLE_ENFORCE_EQ
(
filter_dims
[
0
]
%
groups
,
0
,
phi
::
errors
::
InvalidArgument
(
"The number of output's channels (filter's first dimension) of "
"Op(Conv) should be divided by groups. But received: "
"the output channels is %d, the filter's shape is [%s], "
"the groups is %d."
,
filter_dims
[
0
],
filter_dims
,
groups
));
std
::
vector
<
int64_t
>
out_shape
({
in_dims
[
0
],
filter_dims
[
0
]});
out_shape
.
push_back
(
ConvOutSize
(
in_dims
[
2
],
filter_dims
[
2
],
dilations
,
paddings
[
0
],
paddings
[
1
],
strides
));
// set output and output max dims
out
->
set_dims
(
DDim
(
out_shape
.
data
(),
out_shape
.
size
()));
out
->
set_dtype
(
x
.
dtype
());
out
->
set_layout
(
x
.
layout
());
out_max
->
set_dims
(
phi
::
make_ddim
({
6
}));
}
void
Conv2dXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
filter
,
...
...
paddle/phi/infermeta/fusion.h
浏览文件 @
22c7a6eb
...
...
@@ -42,6 +42,23 @@ void AddLayernormXPUInferMeta(const MetaTensor& x,
MetaTensor
*
variance
,
MetaTensor
*
z_add
);
void
Conv1dXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
filter
,
const
MetaTensor
&
filter_max
,
const
MetaTensor
&
bias
,
const
MetaTensor
&
branch
,
const
MetaTensor
&
branch_max
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
string
&
padding_algorithm
,
int
dilations
,
int
strides
,
int
groups
,
int
act_type
,
float
act_param
,
MetaTensor
*
out
,
MetaTensor
*
out_max
);
void
Conv2dXPUInferMeta
(
const
MetaTensor
&
x
,
const
MetaTensor
&
x_max
,
const
MetaTensor
&
filter
,
...
...
paddle/phi/kernels/fusion/xpu/conv1d_xpu_kernel.cc
0 → 100644
浏览文件 @
22c7a6eb
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "glog/logging.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/cpu/conv_util.h"
namespace
phi
{
namespace
fusion
{
template
<
typename
T
,
typename
Context
>
void
Conv1dXPUKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
paddle
::
optional
<
DenseTensor
>&
x_max
,
const
DenseTensor
&
filter
,
const
DenseTensor
&
filter_max
,
const
paddle
::
optional
<
DenseTensor
>&
bias
,
const
paddle
::
optional
<
DenseTensor
>&
branch
,
const
paddle
::
optional
<
DenseTensor
>&
branch_max
,
const
std
::
vector
<
int
>&
paddings
,
const
std
::
string
&
padding_algorithm
,
int
dilations
,
int
strides
,
int
groups
,
int
act_type
,
float
act_param
,
DenseTensor
*
out
,
DenseTensor
*
out_max
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
input_dims
=
x
.
dims
();
auto
filter_dims
=
filter
.
dims
();
int
batch
=
static_cast
<
int
>
(
input_dims
[
0
]);
int
in_c
=
static_cast
<
int
>
(
input_dims
[
1
]);
int
in_xw
=
static_cast
<
int
>
(
input_dims
[
2
]);
int
out_c
=
static_cast
<
int
>
(
filter_dims
[
0
]);
int
ksize_w
=
static_cast
<
int
>
(
filter_dims
[
2
]);
std
::
vector
<
int64_t
>
paddings_vec
(
std
::
begin
(
paddings
),
std
::
end
(
paddings
));
auto
*
input_data
=
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
());
const
float
*
input_max_data
=
x_max
.
get_ptr
()
==
nullptr
?
nullptr
:
x_max
.
get_ptr
()
->
data
<
float
>
();
auto
*
filter_data
=
filter
.
data
<
int16_t
>
();
auto
*
filter_max_data
=
filter_max
.
data
<
float
>
();
auto
*
branch_data
=
branch
.
get_ptr
()
==
nullptr
?
nullptr
:
reinterpret_cast
<
const
XPUType
*>
(
branch
.
get_ptr
()
->
data
<
T
>
());
const
float
*
branch_max_data
=
branch_max
.
get_ptr
()
==
nullptr
?
nullptr
:
branch_max
.
get_ptr
()
->
data
<
float
>
();
const
float
*
bias_data
=
bias
.
get_ptr
()
==
nullptr
?
nullptr
:
bias
.
get_ptr
()
->
data
<
float
>
();
auto
*
out_data
=
reinterpret_cast
<
XPUType
*>
(
ctx
.
template
Alloc
<
T
>(
out
));
auto
*
out_max_data
=
ctx
.
template
Alloc
<
float
>(
out_max
);
xpu
::
Activation_t
act
(
static_cast
<
xpu
::
Activation_t
::
act_enum
>
(
act_type
));
if
(
act_type
==
xpu
::
Activation_t
::
LEAKY_RELU
)
{
act
.
leaky_alpha
=
act_param
;
}
else
if
(
act_type
==
xpu
::
Activation_t
::
HARD_SIGMOID
)
{
act
.
hard_sigmoid_slope
=
act_param
;
}
int
r
=
xpu
::
conv1d_fusion
<
XPUType
,
int16_t
,
XPUType
,
int16_t
>
(
// TX/TW/TY/TGEMM
/* baidu::xpu::api::Context* ctx */
ctx
.
x_context
(),
/* const TX* x */
input_data
,
/* const TW* weight */
filter_data
,
/* TY* y */
out_data
,
/* int64_t n */
batch
,
/* int64_t c */
in_c
,
/* int64_t xw */
in_xw
,
/* int64_t f */
out_c
,
/* int64_t ksize_w */
ksize_w
,
/* int64_t stride_w */
strides
,
/* const std::vector<int64_t>& pad */
paddings_vec
,
/* int64_t dilation_w */
dilations
,
/* int64_t group */
groups
,
/* const float* x_maxptr */
input_max_data
,
/* const float* w_maxptr */
filter_max_data
,
/* float* y_maxptr */
out_max_data
,
/* bool is_nchw */
true
,
/* const float* bias */
bias_data
,
/* const TY* branch */
branch_data
,
/* const baidu::xpu::api::Activation_t& act */
act
,
/* const float* branch_maxptr */
branch_max_data
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"conv1d_xpu"
);
}
}
// namespace fusion
}
// namespace phi
PD_REGISTER_KERNEL
(
conv1d_xpu
,
XPU
,
ALL_LAYOUT
,
phi
::
fusion
::
Conv1dXPUKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/xpu/activation_kernel.cc
浏览文件 @
22c7a6eb
...
...
@@ -415,6 +415,23 @@ void SwishKernel(const Context& dev_ctx,
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"swish"
);
}
template
<
typename
T
,
typename
Context
>
void
EluKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
float
alpha
,
DenseTensor
*
out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
dev_ctx
.
template
Alloc
<
T
>(
out
);
// template<typename T> int elu(Context* ctx, const T* x, T* y, int64_t len,
// float alpha = 1.0f, const float* max_x = nullptr, float* max_y = nullptr)
int
r
=
xpu
::
elu
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
x
.
numel
(),
alpha
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"elu"
);
}
template
<
typename
T
,
typename
Context
>
void
Relu6Kernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
...
...
@@ -545,6 +562,8 @@ PD_REGISTER_KERNEL(
relu
,
XPU
,
ALL_LAYOUT
,
phi
::
ReluKernel
,
float
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
silu
,
XPU
,
ALL_LAYOUT
,
phi
::
SiluKernel
,
float
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
elu
,
XPU
,
ALL_LAYOUT
,
phi
::
EluKernel
,
float
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
sigmoid
,
XPU
,
ALL_LAYOUT
,
phi
::
SigmoidKernel
,
float
,
phi
::
dtype
::
float16
)
{}
PD_REGISTER_KERNEL
(
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录