Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
wjd2002
Ncnn
提交
8c40a592
N
Ncnn
项目概览
wjd2002
/
Ncnn
10 个月 前同步成功
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
N
Ncnn
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
未验证
提交
8c40a592
编写于
6月 20, 2023
作者:
N
nihui
提交者:
GitHub
6月 20, 2023
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pnnx insert reshape for ncnn global pooling (#4812)
上级
9022b716
变更
5
隐藏空白更改
内联
并排
Showing
5 changed file
with
331 addition
and
1 deletion
+331
-1
tools/pnnx/src/CMakeLists.txt
tools/pnnx/src/CMakeLists.txt
+1
-0
tools/pnnx/src/pass_ncnn.cpp
tools/pnnx/src/pass_ncnn.cpp
+2
-0
tools/pnnx/src/pass_ncnn/Tensor_repeat.cpp
tools/pnnx/src/pass_ncnn/Tensor_repeat.cpp
+35
-1
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.cpp
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.cpp
+268
-0
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.h
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.h
+25
-0
未找到文件。
tools/pnnx/src/CMakeLists.txt
浏览文件 @
8c40a592
...
@@ -384,6 +384,7 @@ set(pnnx_pass_ncnn_SRCS
...
@@ -384,6 +384,7 @@ set(pnnx_pass_ncnn_SRCS
pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp
pass_ncnn/insert_reshape_numpy_binaryop_broadcast.cpp
pass_ncnn/insert_reshape_linear.cpp
pass_ncnn/insert_reshape_linear.cpp
pass_ncnn/insert_reshape_pooling.cpp
pass_ncnn/insert_reshape_pooling.cpp
pass_ncnn/insert_reshape_global_pooling.cpp
pass_ncnn/F_adaptive_avg_pool1d.cpp
pass_ncnn/F_adaptive_avg_pool1d.cpp
pass_ncnn/F_adaptive_avg_pool2d.cpp
pass_ncnn/F_adaptive_avg_pool2d.cpp
...
...
tools/pnnx/src/pass_ncnn.cpp
浏览文件 @
8c40a592
...
@@ -46,6 +46,7 @@
...
@@ -46,6 +46,7 @@
#include "pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h"
#include "pass_ncnn/insert_reshape_numpy_binaryop_broadcast.h"
#include "pass_ncnn/insert_reshape_linear.h"
#include "pass_ncnn/insert_reshape_linear.h"
#include "pass_ncnn/insert_reshape_pooling.h"
#include "pass_ncnn/insert_reshape_pooling.h"
#include "pass_ncnn/insert_reshape_global_pooling.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/dead_code_elimination.h"
#include "pass_level4/canonicalize.h"
#include "pass_level4/canonicalize.h"
...
@@ -89,6 +90,7 @@ void pass_ncnn(Graph& g)
...
@@ -89,6 +90,7 @@ void pass_ncnn(Graph& g)
ncnn
::
insert_reshape_numpy_binaryop_broadcast
(
g
);
ncnn
::
insert_reshape_numpy_binaryop_broadcast
(
g
);
ncnn
::
insert_reshape_pooling
(
g
);
ncnn
::
insert_reshape_pooling
(
g
);
ncnn
::
insert_reshape_linear
(
g
);
ncnn
::
insert_reshape_linear
(
g
);
ncnn
::
insert_reshape_global_pooling
(
g
);
ncnn
::
fuse_convert_shufflechannel_slice
(
g
);
ncnn
::
fuse_convert_shufflechannel_slice
(
g
);
...
...
tools/pnnx/src/pass_ncnn/Tensor_repeat.cpp
浏览文件 @
8c40a592
...
@@ -45,7 +45,41 @@ pnnx.Output output 1 0 out
...
@@ -45,7 +45,41 @@ pnnx.Output output 1 0 out
{
{
const
std
::
vector
<
int
>&
sizes
=
captured_params
.
at
(
"sizes"
).
ai
;
const
std
::
vector
<
int
>&
sizes
=
captured_params
.
at
(
"sizes"
).
ai
;
op
->
params
[
"2"
]
=
sizes
;
const
int
batch_index
=
op
->
outputs
[
0
]
->
params
[
"__batch_index"
].
i
;
if
(
batch_index
!=
0
&&
batch_index
!=
233
)
{
fprintf
(
stderr
,
"repeat tensor with batch index %d is not supported yet!
\n
"
,
batch_index
);
}
// drop sizes batch index
std
::
vector
<
int
>
new_sizes
;
for
(
int
i
=
0
;
i
<
(
int
)
sizes
.
size
();
i
++
)
{
if
(
i
==
batch_index
&&
sizes
[
i
]
==
1
)
continue
;
new_sizes
.
push_back
(
sizes
[
i
]);
}
if
(
new_sizes
.
size
()
==
5
&&
batch_index
==
233
)
{
if
(
new_sizes
[
0
]
==
1
)
{
fprintf
(
stderr
,
"assume repeat 5-rank tensor has batch_index 0
\n
"
);
new_sizes
.
erase
(
new_sizes
.
begin
());
}
}
const
int
sizes_rank
=
(
int
)
new_sizes
.
size
();
if
(
sizes_rank
>
5
)
{
fprintf
(
stderr
,
"repeat to %d-rank tensor is not supported yet!
\n
"
,
sizes_rank
);
return
;
}
op
->
params
[
"2"
]
=
new_sizes
;
}
}
};
};
...
...
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.cpp
0 → 100644
浏览文件 @
8c40a592
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "insert_reshape_global_pooling.h"
#include <algorithm>
#include <set>
namespace
pnnx
{
namespace
ncnn
{
static
bool
is_known_operator_handle_flatten_0
(
const
Operator
*
op
)
{
// opeartors that have similiar behavior for (1,c,1,1,1)/(1,c,1,1)/(1,c,1) and (1,c)
static
const
char
*
operator_handle_flatten_0
[]
=
{
"F.batch_norm"
,
"F.celu"
,
"F.conv1d"
,
"F.conv2d"
,
"F.conv3d"
,
"F.elu"
,
"F.gelu"
,
"F.glu"
,
"F.hardshrink"
,
"F.hardsigmoid"
,
"F.hardswish"
,
"F.hardtanh"
,
"F.leaky_relu"
,
"F.linear"
,
"F.log_softmax"
,
"F.logsigmoid"
,
"F.prelu"
,
"F.relu"
,
"F.relu6"
,
"F.rrelu"
,
"F.selu"
,
"F.sigmoid"
,
"F.silu"
,
"F.softmax"
,
"F.softmin"
,
"F.softplus"
,
"F.softshrink"
,
"F.softsign"
,
"F.tanh"
,
"F.tanhshrink"
,
"F.threshold"
,
"nn.BatchNorm1d"
,
"nn.BatchNorm2d"
,
"nn.BatchNorm3d"
,
"nn.CELU"
,
"nn.Conv1d"
,
"nn.Conv2d"
,
"nn.Conv3d"
,
"nn.ELU"
,
"nn.GELU"
,
"nn.GLU"
,
"nn.Hardshrink"
,
"nn.Hardsigmoid"
,
"nn.Hardswish"
,
"nn.Hardtanh"
,
"nn.LeakyReLU"
,
"nn.Linear"
,
"nn.LogSigmoid"
,
"nn.LogSoftmax"
,
"nn.PReLU"
,
"nn.ReLU"
,
"nn.ReLU6"
,
"nn.RReLU"
,
"nn.SELU"
,
"nn.Sigmoid"
,
"nn.SiLU"
,
"nn.Softmax"
,
"nn.Softmin"
,
"nn.Softplus"
,
"nn.Softshrink"
,
"nn.Softsign"
,
"nn.Tanh"
,
"nn.Tanhshrink"
,
"nn.Threshold"
,
"torch.abs"
,
"torch.acos"
,
"torch.acosh"
,
"torch.asin"
,
"torch.asinh"
,
"torch.atan"
,
"torch.atanh"
,
"torch.atan2"
,
"torch.ceil"
,
"torch.clamp"
,
"torch.cos"
,
"torch.cosh"
,
"torch.exp"
,
"torch.floor"
,
"torch.imag"
,
"torch.log"
,
"torch.log10"
,
"torch.neg"
,
"torch.pow"
,
"torch.real"
,
"torch.reciprocal"
,
"torch.rsqrt"
,
"torch.sign"
,
"torch.sin"
,
"torch.sinh"
,
"torch.sqrt"
,
"torch.square"
,
"torch.tan"
,
"torch.tanh"
,
"torch.trunc"
,
};
const
size_t
operator_handle_flatten_0_count
=
sizeof
(
operator_handle_flatten_0
)
/
sizeof
(
const
char
*
);
for
(
size_t
i
=
0
;
i
<
operator_handle_flatten_0_count
;
i
++
)
{
if
(
op
->
type
==
operator_handle_flatten_0
[
i
])
return
true
;
}
return
false
;
}
static
int
is_global_pooling
(
const
Operator
*
op
)
{
static
const
char
*
operator_with_flatten_state_0
[]
=
{
"F.adaptive_avg_pool2d"
,
"F.adaptive_avg_pool3d"
,
"F.adaptive_max_pool2d"
,
"F.adaptive_max_pool3d"
,
"nn.AdaptiveAvgPool2d"
,
"nn.AdaptiveAvgPool3d"
,
"nn.AdaptiveMaxPool2d"
,
"nn.AdaptiveMaxPool3d"
,
};
const
size_t
operator_with_flatten_state_0_count
=
sizeof
(
operator_with_flatten_state_0
)
/
sizeof
(
const
char
*
);
for
(
size_t
i
=
0
;
i
<
operator_with_flatten_state_0_count
;
i
++
)
{
if
(
op
->
type
==
operator_with_flatten_state_0
[
i
])
{
// output_size=(1,1)
// output_size=(1,1,1)
const
std
::
vector
<
int
>&
output_size
=
op
->
params
.
at
(
"output_size"
).
ai
;
if
(
output_size
==
std
::
vector
<
int
>
{
1
,
1
})
return
3
;
if
(
output_size
==
std
::
vector
<
int
>
{
1
,
1
,
1
})
return
4
;
}
}
return
0
;
}
static
int
insert_reshape_global_pooling_forward
(
Operand
*
operand
,
int
pooled_rank
,
Graph
&
graph
)
{
for
(
size_t
i
=
0
;
i
<
operand
->
consumers
.
size
();
i
++
)
{
Operator
*
op
=
operand
->
consumers
[
i
];
if
(
op
->
type
==
"Tensor.reshape"
||
op
->
type
==
"Tensor.view"
)
{
// reshape discard flatten state
break
;
}
if
(
is_known_operator_handle_flatten_0
(
op
))
{
for
(
Operand
*
r
:
op
->
outputs
)
{
int
ret
=
insert_reshape_global_pooling_forward
(
r
,
pooled_rank
,
graph
);
if
(
ret
)
return
ret
;
}
continue
;
}
if
(
op
->
type
==
"pnnx.Expression"
)
{
// if it can be auto-broadcast
// (1,c) with (1,c,d,h,w)/(1,c,h,w)/(1,c,w)/(1,c)
if
(
operand
->
shape
.
size
()
==
4
&&
op
->
outputs
[
0
]
->
shape
.
size
()
>=
2
)
{
if
(
operand
->
shape
[
1
]
==
op
->
outputs
[
0
]
->
shape
[
1
])
break
;
}
}
fprintf
(
stderr
,
"insert_reshape_global_pooling_forward %s %s
\n
"
,
op
->
name
.
c_str
(),
operand
->
name
.
c_str
());
// insert reshape (1,c,1,1) before op
Operator
*
reshape0
=
graph
.
new_operator_before
(
"Tensor.reshape"
,
op
->
name
+
"_ncnnreshape0"
,
op
);
Operand
*
reshape0_out
=
graph
.
new_operand
(
op
->
name
+
"_ncnnreshape0_out"
);
reshape0
->
inputs
.
push_back
(
operand
);
reshape0
->
outputs
.
push_back
(
reshape0_out
);
operand
->
consumers
[
i
]
=
reshape0
;
for
(
size_t
j
=
0
;
j
<
op
->
inputs
.
size
();
j
++
)
{
if
(
op
->
inputs
[
j
]
==
operand
)
{
op
->
inputs
[
j
]
=
reshape0_out
;
}
}
reshape0_out
->
producer
=
reshape0
;
reshape0_out
->
consumers
.
push_back
(
op
);
reshape0_out
->
params
[
"__batch_index"
]
=
0
;
if
(
pooled_rank
==
3
)
reshape0
->
params
[
"shape"
]
=
std
::
vector
<
int
>
{
1
,
-
1
,
1
,
1
};
if
(
pooled_rank
==
4
)
reshape0
->
params
[
"shape"
]
=
std
::
vector
<
int
>
{
1
,
-
1
,
1
,
1
,
1
};
return
1
;
}
return
0
;
}
void
insert_reshape_global_pooling
(
Graph
&
graph
)
{
int
inserted
=
0
;
while
(
1
)
{
inserted
=
0
;
for
(
Operator
*
op
:
graph
.
ops
)
{
int
pooled_rank
=
is_global_pooling
(
op
);
if
(
pooled_rank
==
0
)
continue
;
// look for all output consumers
// insert reshape (1,c,1,1) if it cannot handle flatten
inserted
=
insert_reshape_global_pooling_forward
(
op
->
outputs
[
0
],
pooled_rank
,
graph
);
if
(
inserted
)
{
break
;
}
}
if
(
inserted
==
0
)
break
;
}
}
}
// namespace ncnn
}
// namespace pnnx
tools/pnnx/src/pass_ncnn/insert_reshape_global_pooling.h
0 → 100644
浏览文件 @
8c40a592
// Tencent is pleased to support the open source community by making ncnn available.
//
// Copyright (C) 2023 THL A29 Limited, a Tencent company. All rights reserved.
//
// Licensed under the BSD 3-Clause License (the "License"); you may not use this file except
// in compliance with the License. You may obtain a copy of the License at
//
// https://opensource.org/licenses/BSD-3-Clause
//
// Unless required by applicable law or agreed to in writing, software distributed
// under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR
// CONDITIONS OF ANY KIND, either express or implied. See the License for the
// specific language governing permissions and limitations under the License.
#include "pass_ncnn.h"
namespace
pnnx
{
namespace
ncnn
{
void
insert_reshape_global_pooling
(
Graph
&
graph
);
}
// namespace ncnn
}
// namespace pnnx
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录