Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
3a5347ed
MegEngine
项目概览
MegEngine 天元
/
MegEngine
大约 1 年 前同步成功
通知
399
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
3a5347ed
编写于
3月 23, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
perf(imperative): speed up pooling
GitOrigin-RevId: 9f60b45eebf81fbb7f483328815d3744dc4d5811
上级
c0b267ff
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
162 addition
and
63 deletion
+162
-63
dnn/src/common/pooling.cpp
dnn/src/common/pooling.cpp
+57
-52
imperative/src/impl/ops/pooling.cpp
imperative/src/impl/ops/pooling.cpp
+105
-0
imperative/src/impl/ops/specializations.cpp
imperative/src/impl/ops/specializations.cpp
+0
-11
未找到文件。
dnn/src/common/pooling.cpp
浏览文件 @
3a5347ed
...
@@ -16,50 +16,55 @@
...
@@ -16,50 +16,55 @@
namespace
megdnn
{
namespace
megdnn
{
void
PoolingBase
::
deduce_layout_fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
)
{
void
PoolingBase
::
deduce_layout_fwd
(
const
TensorLayout
&
src
,
TensorLayout
&
dst
)
{
auto
errmsg
=
auto
&
p
=
param
();
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
auto
pformat
=
p
.
format
;
"pad_h="
+
std
::
to_string
(
param
().
pad_h
)
+
", "
+
"pad_w="
+
std
::
to_string
(
param
().
pad_w
)
+
", "
+
// the overhead of generating error message is about 18x of the other part of this
"stride_h="
+
std
::
to_string
(
param
().
stride_h
)
+
", "
+
// function so we use a function to wrap the error message and get it only when need.
"stride_w="
+
std
::
to_string
(
param
().
stride_w
)
+
", "
+
auto
get_errmsg
=
[
&
](
void
)
->
std
::
string
{
"window_h="
+
std
::
to_string
(
param
().
window_h
)
+
", "
+
std
::
string
errmsg
=
"window_w="
+
std
::
to_string
(
param
().
window_w
)
+
", "
+
megdnn_layout_msg
(
src
)
+
", "
+
megdnn_layout_msg
(
dst
)
+
", "
+
"is_max="
+
std
::
to_string
(
param
().
mode
==
Mode
::
MAX
)
+
", "
+
"pad_h="
+
std
::
to_string
(
param
().
pad_h
)
+
", "
+
"is_nhwc="
+
std
::
to_string
(
param
().
format
==
Param
::
Format
::
NHWC
)
+
", "
+
"pad_w="
+
std
::
to_string
(
param
().
pad_w
)
+
", "
+
"is_nhwcd4="
+
std
::
to_string
(
param
().
format
==
Param
::
Format
::
NHWCD4
);
"stride_h="
+
std
::
to_string
(
param
().
stride_h
)
+
", "
+
auto
errmsg_c
=
errmsg
.
c_str
();
"stride_w="
+
std
::
to_string
(
param
().
stride_w
)
+
", "
+
"window_h="
+
std
::
to_string
(
param
().
window_h
)
+
", "
+
MEGDNN_MARK_USED_VAR
(
errmsg_c
);
"window_w="
+
std
::
to_string
(
param
().
window_w
)
+
", "
+
"is_max="
+
std
::
to_string
(
param
().
mode
==
Mode
::
MAX
)
+
", "
+
"is_nhwc="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWC
)
+
", "
+
"is_nhwcd4="
+
std
::
to_string
(
pformat
==
Param
::
Format
::
NHWCD4
);
return
errmsg
;
};
MEGDNN_MARK_USED_VAR
(
get_errmsg
);
megdnn_assert_contiguous
(
src
);
megdnn_assert_contiguous
(
src
);
size_t
spatial_pos
,
c_pos
,
batch_pos
=
0
;
size_t
spatial_pos
,
c_pos
,
batch_pos
=
0
;
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW
)
{
if
(
pformat
==
Param
::
Format
::
NCHW
)
{
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
errmsg_c
);
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
get_errmsg
().
c_str
()
);
spatial_pos
=
2
;
spatial_pos
=
2
;
c_pos
=
1
;
c_pos
=
1
;
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
NHWC
)
{
}
else
if
(
pformat
==
Param
::
Format
::
NHWC
)
{
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
errmsg_c
);
megdnn_assert
(
src
.
ndim
==
4
_z
,
"%s"
,
get_errmsg
().
c_str
()
);
spatial_pos
=
1
;
spatial_pos
=
1
;
c_pos
=
3
;
c_pos
=
3
;
}
else
if
(
}
else
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
pformat
==
Param
::
Format
::
NCHW4
||
pformat
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44
||
pformat
==
Param
::
Format
::
NCHW88
||
pformat
==
Param
::
Format
::
NCHW32
||
param
().
format
==
Param
::
Format
::
NCHW88
||
pformat
==
Param
::
Format
::
NCHW64
)
{
param
().
format
==
Param
::
Format
::
NCHW32
||
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
get_errmsg
().
c_str
());
param
().
format
==
Param
::
Format
::
NCHW64
)
{
megdnn_assert
(
src
.
ndim
==
5
_z
,
"%s"
,
errmsg_c
);
spatial_pos
=
2
;
spatial_pos
=
2
;
c_pos
=
1
;
c_pos
=
1
;
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
CHWN4
)
{
}
else
if
(
pformat
==
Param
::
Format
::
CHWN4
)
{
spatial_pos
=
1
;
spatial_pos
=
1
;
c_pos
=
0
;
c_pos
=
0
;
batch_pos
=
3
;
batch_pos
=
3
;
}
else
{
}
else
{
megdnn_assert
(
megdnn_assert
(
p
aram
().
format
==
Param
::
Format
::
NHWCD4
&&
src
.
ndim
==
5
_z
,
"%s"
,
pformat
==
Param
::
Format
::
NHWCD4
&&
src
.
ndim
==
5
_z
,
"%s"
,
errmsg_c
);
get_errmsg
().
c_str
()
);
spatial_pos
=
1
;
spatial_pos
=
1
;
c_pos
=
2
;
c_pos
=
2
;
}
}
...
@@ -67,31 +72,34 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
...
@@ -67,31 +72,34 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
size_t
c
=
src
[
c_pos
];
size_t
c
=
src
[
c_pos
];
size_t
ih
=
src
[
spatial_pos
];
size_t
ih
=
src
[
spatial_pos
];
size_t
iw
=
src
[
spatial_pos
+
1
];
size_t
iw
=
src
[
spatial_pos
+
1
];
if
(
p
aram
().
format
==
Param
::
Format
::
NHWCD4
)
{
if
(
pformat
==
Param
::
Format
::
NHWCD4
)
{
c
*=
4
;
c
*=
4
;
iw
=
src
[
spatial_pos
+
2
];
iw
=
src
[
spatial_pos
+
2
];
}
}
if
(
param
().
format
==
Param
::
Format
::
NCHW4
||
if
(
pformat
==
Param
::
Format
::
NCHW4
||
pformat
==
Param
::
Format
::
NCHW44
||
param
().
format
==
Param
::
Format
::
NCHW44
||
pformat
==
Param
::
Format
::
CHWN4
)
{
param
().
format
==
Param
::
Format
::
CHWN4
)
{
c
*=
4
;
c
*=
4
;
}
}
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW88
)
{
if
(
pformat
==
Param
::
Format
::
NCHW88
)
{
c
*=
8
;
c
*=
8
;
}
}
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW32
)
{
if
(
pformat
==
Param
::
Format
::
NCHW32
)
{
c
*=
32
;
c
*=
32
;
}
}
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW64
)
{
if
(
pformat
==
Param
::
Format
::
NCHW64
)
{
c
*=
64
;
c
*=
64
;
}
}
size_t
oh
,
ow
;
size_t
oh
,
ow
;
size_t
fh
=
this
->
param
().
window_h
;
size_t
fh
=
p
.
window_h
;
size_t
fw
=
this
->
param
().
window_w
;
size_t
fw
=
p
.
window_w
;
size_t
sh
=
this
->
param
().
stride_h
;
size_t
sh
=
p
.
stride_h
;
size_t
sw
=
this
->
param
().
stride_w
;
size_t
sw
=
p
.
stride_w
;
size_t
ph
=
this
->
param
().
pad_h
;
size_t
ph
=
p
.
pad_h
;
size_t
pw
=
this
->
param
().
pad_w
;
size_t
pw
=
p
.
pad_w
;
// moving some python assert to here
// megdnn_assert()
if
(
ph
>=
fh
||
pw
>=
fw
)
{
if
(
ph
>=
fh
||
pw
>=
fw
)
{
megdnn_log_warn
(
megdnn_log_warn
(
"pooling padding size (%zu %zu) should not be bigger than "
"pooling padding size (%zu %zu) should not be bigger than "
...
@@ -99,26 +107,23 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
...
@@ -99,26 +107,23 @@ void PoolingBase::deduce_layout_fwd(const TensorLayout& src, TensorLayout& dst)
pw
,
ph
,
fw
,
fh
);
pw
,
ph
,
fw
,
fh
);
}
}
infer_conv_shape2d
(
ih
,
iw
,
fh
,
fw
,
sh
,
sw
,
ph
,
pw
,
oh
,
ow
);
infer_conv_shape2d
(
ih
,
iw
,
fh
,
fw
,
sh
,
sw
,
ph
,
pw
,
oh
,
ow
);
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW
)
{
if
(
pformat
==
Param
::
Format
::
NCHW
)
{
dst
=
TensorLayout
(
TensorShape
({
n
,
c
,
oh
,
ow
}),
src
.
dtype
);
dst
=
TensorLayout
(
TensorShape
({
n
,
c
,
oh
,
ow
}),
src
.
dtype
);
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
NHWC
)
{
}
else
if
(
pformat
==
Param
::
Format
::
NHWC
)
{
megdnn_assert
(
p
aram
().
format
==
Param
::
Format
::
NHWC
,
"invalid pooling format"
);
megdnn_assert
(
pformat
==
Param
::
Format
::
NHWC
,
"invalid pooling format"
);
dst
=
TensorLayout
({
n
,
oh
,
ow
,
c
},
src
.
dtype
,
src
.
format
);
dst
=
TensorLayout
({
n
,
oh
,
ow
,
c
},
src
.
dtype
,
src
.
format
);
}
else
if
(
}
else
if
(
pformat
==
Param
::
Format
::
NCHW4
||
pformat
==
Param
::
Format
::
NCHW44
)
{
param
().
format
==
Param
::
Format
::
NCHW4
||
param
().
format
==
Param
::
Format
::
NCHW44
)
{
dst
=
TensorLayout
{{
n
,
c
/
4
,
oh
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
4
,
oh
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW88
)
{
}
else
if
(
pformat
==
Param
::
Format
::
NCHW88
)
{
dst
=
TensorLayout
{{
n
,
c
/
8
,
oh
,
ow
,
8
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
8
,
oh
,
ow
,
8
},
src
.
dtype
,
src
.
format
};
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW32
)
{
}
else
if
(
pformat
==
Param
::
Format
::
NCHW32
)
{
dst
=
TensorLayout
{{
n
,
c
/
32
,
oh
,
ow
,
32
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
32
,
oh
,
ow
,
32
},
src
.
dtype
,
src
.
format
};
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
NCHW64
)
{
}
else
if
(
pformat
==
Param
::
Format
::
NCHW64
)
{
dst
=
TensorLayout
{{
n
,
c
/
64
,
oh
,
ow
,
64
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
c
/
64
,
oh
,
ow
,
64
},
src
.
dtype
,
src
.
format
};
}
else
if
(
p
aram
().
format
==
Param
::
Format
::
CHWN4
)
{
}
else
if
(
pformat
==
Param
::
Format
::
CHWN4
)
{
dst
=
TensorLayout
{{
c
/
4
,
oh
,
ow
,
n
,
4
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
c
/
4
,
oh
,
ow
,
n
,
4
},
src
.
dtype
,
src
.
format
};
}
else
{
}
else
{
megdnn_assert
(
megdnn_assert
(
pformat
==
Param
::
Format
::
NHWCD4
,
"invalid pooling format"
);
param
().
format
==
Param
::
Format
::
NHWCD4
,
"invalid pooling format"
);
dst
=
TensorLayout
{{
n
,
oh
,
c
/
4
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
dst
=
TensorLayout
{{
n
,
oh
,
c
/
4
,
ow
,
4
},
src
.
dtype
,
src
.
format
};
}
}
}
}
...
...
imperative/src/impl/ops/pooling.cpp
0 → 100644
浏览文件 @
3a5347ed
/**
* \file imperative/src/impl/ops/pooling.cpp
* MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
*
* Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
*/
#include "megbrain/opr/dnn/pooling.h"
#include "megbrain/imperative/ops/autogen.h"
#include "megbrain/opr/utility.h"
#include "megbrain/opr/internal/megdnn_opr_wrapper.h"
#include "../algo_chooser.h"
#include "../blob_manager_impl.h"
#include "../dnn_op_helper.h"
#include "../op_trait.h"
namespace
mgb
::
imperative
{
namespace
{
namespace
pooling
{
// using OprHandle = opr::intl::UniqPtrWithCN<megdnn::Pooling>;
// static ThinHashMap<CompNode, OprHandle> dnn_oprs;
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
pool
=
static_cast
<
const
Pooling
&>
(
def
);
OperatorNodeConfig
config
{
pool
.
make_name
()};
return
opr
::
Pooling
::
make
(
inputs
[
0
],
pool
.
param
(),
pool
.
policy
(),
config
);
}
std
::
tuple
<
SmallVector
<
LogicalTensorDesc
>
,
bool
>
infer_output_attrs_fallible
(
const
OpDef
&
def
,
const
SmallVector
<
LogicalTensorDesc
>&
inputs
)
{
mgb_assert
(
inputs
.
size
()
==
1
,
"num of inputs of pooling should be 1 but you give %zu"
,
inputs
.
size
());
auto
&&
op_def
=
def
.
cast_final_safe
<
Pooling
>
();
auto
&&
inp
=
inputs
[
0
];
auto
&
inp_cn
=
inp
.
comp_node
;
if
(
inp
.
layout
.
ndim
==
0
)
{
return
{{{
TensorLayout
{
inp
.
layout
.
dtype
},
inp_cn
,
{}}},
false
};
}
DnnOprCaller
<
megdnn
::
Pooling
>
caller
(
inp_cn
);
auto
&&
dnn_opr
=
caller
.
op
;
dnn_opr
->
param
()
=
op_def
.
param
();
TensorLayout
oup_layout
;
dnn_opr
->
deduce_layout
(
inp
.
layout
,
oup_layout
);
return
{{{
oup_layout
,
inp_cn
,
{}}},
true
};
}
SmallVector
<
TensorPtr
>
apply_on_physical_tensor
(
const
OpDef
&
def
,
const
SmallVector
<
TensorPtr
>&
inputs
,
SmallVector
<
LogicalTensorDesc
>&
output_descs
,
const
bool
&
validated
)
{
mgb_assert
(
inputs
.
size
()
==
1
,
"num of inputs of pooling should be 1 but you give %zu"
,
inputs
.
size
());
auto
&&
op_def
=
def
.
cast_final_safe
<
Pooling
>
();
auto
cn
=
inputs
[
0
]
->
comp_node
();
megdnn
::
TensorND
inp_tensornd
=
inputs
[
0
]
->
dnn_tensor
();
DnnOprCaller
<
megdnn
::
Pooling
>
caller
(
cn
);
auto
&&
dnn_opr
=
caller
.
op
;
dnn_opr
->
param
()
=
op_def
.
param
();
TensorLayout
&
oup_layout
=
output_descs
[
0
].
layout
;
if
(
!
validated
)
{
dnn_opr
->
deduce_layout
(
inp_tensornd
.
layout
,
oup_layout
);
}
DeviceTensorND
out_devtensor
=
BlobManager
::
inst
()
->
alloc_workspace_with_defrag
(
cn
,
oup_layout
);
size_t
wk_size
=
setup_algo
<
megdnn
::
Pooling
>
(
{
inp_tensornd
.
layout
,
oup_layout
},
dnn_opr
.
get
(),
0
,
false
,
false
,
cn
,
op_def
.
policy
(),
false
);
megdnn
::
Workspace
dnn_wk
;
if
(
wk_size
!=
0
)
{
auto
wk
=
Blob
::
make
(
cn
,
wk_size
);
dnn_wk
.
raw_ptr
=
wk
->
storage
().
get
();
dnn_wk
.
size
=
wk_size
;
}
dnn_opr
->
exec
(
inp_tensornd
,
out_devtensor
.
as_megdnn
(),
{});
return
{
Tensor
::
make
(
out_devtensor
)};
}
OP_TRAIT_REG
(
Pooling
,
Pooling
)
.
apply_on_var_node
(
apply_on_var_node
)
.
infer_output_attrs_fallible
(
infer_output_attrs_fallible
)
.
apply_on_physical_tensor
(
apply_on_physical_tensor
)
.
fallback
();
}
// namespace pooling
}
// namespace
}
// namespace mgb::imperative
imperative/src/impl/ops/specializations.cpp
浏览文件 @
3a5347ed
...
@@ -333,17 +333,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias)
...
@@ -333,17 +333,6 @@ OP_TRAIT_REG(BatchConvBias, BatchConvBias)
}
// namespace batch_conv_bias
}
// namespace batch_conv_bias
}
// namespace
}
// namespace
namespace
{
namespace
pooling
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
&&
pool
=
static_cast
<
const
Pooling
&>
(
def
);
OperatorNodeConfig
config
{
pool
.
make_name
()};
return
opr
::
Pooling
::
make
(
inputs
[
0
],
pool
.
param
(),
pool
.
policy
(),
config
);
}
OP_TRAIT_REG
(
Pooling
,
Pooling
).
apply_on_var_node
(
apply_on_var_node
).
fallback
();
}
// namespace pooling
}
// namespace
namespace
{
namespace
{
namespace
matrix_mul
{
namespace
matrix_mul
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
auto
apply_on_var_node
(
const
OpDef
&
def
,
const
VarNodeArray
&
inputs
)
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录