Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MegEngine 天元
MegEngine
提交
15c6da62
MegEngine
项目概览
MegEngine 天元
/
MegEngine
1 年多 前同步成功
通知
404
Star
4705
Fork
582
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
MegEngine
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
提交
15c6da62
编写于
4月 13, 2022
作者:
M
Megvii Engine Team
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
feat(imperative/amp): add nhwc support for adaptive pooling
GitOrigin-RevId: 7c5755308e4355f38fffd5634f0733836699a8c1
上级
c28a875f
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
62 addition
and
28 deletion
+62
-28
imperative/src/impl/ops/adaptive_pooling.cpp
imperative/src/impl/ops/adaptive_pooling.cpp
+47
-18
imperative/src/impl/transformations/format.cpp
imperative/src/impl/transformations/format.cpp
+15
-10
未找到文件。
imperative/src/impl/ops/adaptive_pooling.cpp
浏览文件 @
15c6da62
...
...
@@ -37,12 +37,10 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
return
{{{
TensorLayout
(
src
.
layout
.
dtype
),
src
.
comp_node
}},
false
};
}
const
dt_int32
*
oshp2d
=
nullptr
;
dst_layout
.
ndim
=
4u
;
if
(
nr_inp
==
1
)
{
dst_layout
[
0
]
=
src
.
layout
[
0
];
dst_layout
[
1
]
=
src
.
layout
[
1
];
dst_layout
[
2
]
=
pool
.
shape
[
0
];
dst_layout
[
3
]
=
pool
.
shape
[
1
];
oshp2d
=
pool
.
shape
.
data
();
}
else
{
auto
&&
tshp
=
inputs
[
1
];
if
(
tshp
.
value
.
empty
())
{
...
...
@@ -52,11 +50,21 @@ std::tuple<SmallVector<LogicalTensorDesc>, bool> infer_output_attrs_fallible(
tshp
.
layout
.
ndim
==
1
,
"target shape of AdaptivePooling expects ndim=1; got ndim=%lu actually"
,
tshp
.
layout
.
ndim
);
oshp2d
=
tshp
.
value
.
ptr
<
dt_int32
>
();
}
auto
param_format
=
pool
.
param
().
format
;
if
(
param_format
==
opr
::
AdaptivePooling
::
Param
::
Format
::
NCHW
)
{
dst_layout
[
0
]
=
src
.
layout
[
0
];
dst_layout
[
1
]
=
src
.
layout
[
1
];
auto
*
ptr
=
tshp
.
value
.
ptr
<
dt_int32
>
();
dst_layout
[
2
]
=
ptr
[
0
];
dst_layout
[
3
]
=
ptr
[
1
];
dst_layout
[
2
]
=
oshp2d
[
0
];
dst_layout
[
3
]
=
oshp2d
[
1
];
}
else
if
(
param_format
==
opr
::
AdaptivePooling
::
Param
::
Format
::
NHWC
)
{
dst_layout
[
0
]
=
src
.
layout
[
0
];
dst_layout
[
1
]
=
oshp2d
[
0
];
dst_layout
[
2
]
=
oshp2d
[
1
];
dst_layout
[
3
]
=
src
.
layout
[
3
];
}
else
{
mgb_throw
(
MegBrainError
,
"AdaptivePooling only support NCHW or NHWC format"
);
}
dst_layout
.
init_contiguous_stride
();
return
{{{
dst_layout
,
src
.
comp_node
}},
true
};
...
...
@@ -71,26 +79,47 @@ SmallVector<TensorPtr> apply_on_physical_tensor(
using
TensorND
=
megdnn
::
TensorND
;
auto
&&
src_layout
=
inputs
[
0
]
->
layout
();
TensorLayout
dst_layout
=
output_descs
[
0
].
layout
;
auto
param_format
=
pool
.
format
;
if
(
!
validated
)
{
TensorShape
tshp
;
dst_layout
.
ndim
=
src_layout
.
ndim
;
dst_layout
[
0
]
=
src_layout
[
0
];
dst_layout
[
1
]
=
src_layout
[
1
];
const
dt_int32
*
oshp2d
=
nullptr
;
if
(
inputs
.
size
()
==
2
)
{
auto
&&
tshp_nd
=
inputs
[
1
];
cg
::
copy_tensor_value_to_shape
(
tshp
,
tshp_nd
->
get_value
().
proxy_to_default_cpu
());
dst_layout
[
2
]
=
tshp
[
0
];
dst_layout
[
3
]
=
tshp
[
1
];
oshp2d
=
tshp_nd
->
get_value
().
proxy_to_default_cpu
().
ptr
<
dt_int32
>
();
}
else
{
dst_layout
[
2
]
=
pool
.
shape
[
0
];
dst_layout
[
3
]
=
pool
.
shape
[
1
];
oshp2d
=
pool
.
shape
.
data
();
}
if
(
param_format
==
opr
::
AdaptivePooling
::
Param
::
Format
::
NCHW
)
{
dst_layout
[
0
]
=
src_layout
[
0
];
dst_layout
[
1
]
=
src_layout
[
1
];
dst_layout
[
2
]
=
oshp2d
[
0
];
dst_layout
[
3
]
=
oshp2d
[
1
];
}
else
if
(
param_format
==
opr
::
AdaptivePooling
::
Param
::
Format
::
NHWC
)
{
dst_layout
[
0
]
=
src_layout
[
0
];
dst_layout
[
1
]
=
oshp2d
[
0
];
dst_layout
[
2
]
=
oshp2d
[
1
];
dst_layout
[
3
]
=
src_layout
[
3
];
}
else
{
mgb_throw
(
MegBrainError
,
"AdaptivePooling only support NCHW or NHWC format"
);
}
dst_layout
.
init_contiguous_stride
();
}
size_t
IH
=
src_layout
[
2
],
IW
=
src_layout
[
3
],
OH
=
dst_layout
[
2
],
size_t
IH
,
IW
,
OH
,
OW
;
if
(
param_format
==
param
::
AdaptivePooling
::
Format
::
NCHW
)
{
IH
=
src_layout
[
2
];
IW
=
src_layout
[
3
];
OH
=
dst_layout
[
2
];
OW
=
dst_layout
[
3
];
}
else
if
(
param_format
==
param
::
AdaptivePooling
::
Format
::
NHWC
)
{
IH
=
src_layout
[
1
];
IW
=
src_layout
[
2
];
OH
=
dst_layout
[
1
];
OW
=
dst_layout
[
2
];
}
else
{
mgb_throw
(
MegBrainError
,
"AdaptivePooling only support NCHW or NHWC format"
);
}
DnnOprCaller
<
megdnn
::
Pooling
>
dnn_opr
(
cn
);
auto
&&
param
=
dnn_opr
.
op
->
param
();
param
.
mode
=
pool
.
mode
;
...
...
imperative/src/impl/transformations/format.cpp
浏览文件 @
15c6da62
...
...
@@ -105,7 +105,7 @@ std::vector<int32_t> convert_nchw2nhwc_vector(const std::vector<int32_t>& shape)
}
else
{
mgb_throw
(
MegBrainError
,
"Unsupported shape ndim %u in convert NCHW shape to NHWC."
,
"Unsupported shape ndim %
l
u in convert NCHW shape to NHWC."
,
shape
.
size
());
}
}
...
...
@@ -184,7 +184,8 @@ ValueRefList reshape_rule(
// output is still NHWC format
auto
nhwc_shape
=
convert_nchw2nhwc_vector
(
op
.
shape
);
auto
outputs
=
imperative
::
apply
(
*
Reshape
::
make
(
op
.
axis
,
nhwc_shape
),
{
t
.
unwrap_input
(
inputs
[
0
])});
*
Reshape
::
make
(
op
.
axis
,
nhwc_shape
),
{
t
.
unwrap_input
(
inputs
[
0
])});
return
t
.
wrap_outputs
(
outputs
,
FT
::
NHWC
);
}
else
{
// will not maintain src's format
...
...
@@ -395,12 +396,17 @@ ValueRefList batchnorm_rule(
return
identity_rule_helper
(
op
,
inputs
,
t
);
}
ValueRefList
checknonfinite
_rule
(
const
CheckNonFinite
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
ValueRefList
adaptive_pooling
_rule
(
const
AdaptivePooling
&
op
,
Span
<
ValueRef
>&
inputs
,
const
bool
&
auto_convert
,
const
FormatTransformation
&
t
)
{
auto
&&
inputs_
=
t
.
unwrap_inputs
(
inputs
);
auto
&&
outputs_
=
imperative
::
apply
(
op
,
inputs_
);
return
t
.
wrap_outputs
(
outputs_
);
auto
&&
inp_format
=
inputs
[
0
].
cast
(
t
.
value_type
()).
format
();
if
(
inp_format
==
FT
::
NHWC
)
{
auto
&&
new_param
=
op
.
param
();
new_param
.
format
=
AdaptivePooling
::
Format
::
NHWC
;
auto
new_op
=
AdaptivePooling
::
make
(
new_param
,
op
.
shape
);
return
identity_rule_helper
(
*
new_op
,
inputs
,
t
);
}
return
identity_rule_helper
(
op
,
inputs
,
t
);
}
// clang-format off
...
...
@@ -417,7 +423,6 @@ ValueRefList checknonfinite_rule(
cb(Identity)
#define FOREACH_FORMAT_OP(cb) \
cb(AdaptivePooling) \
cb(WarpAffine) \
cb(Resize)
...
...
@@ -494,7 +499,7 @@ struct FormatRuleRegistry {
register_format_rule
(
setsubtensor_rule
<
IndexingSetMultiAxisVec
>
);
register_format_rule
(
concat_rule
);
register_format_rule
(
batchnorm_rule
);
register_format_rule
(
checknonfinite
_rule
);
register_format_rule
(
adaptive_pooling
_rule
);
FOREACH_MULTI_INPS_NO_PARAM_OP
(
REGISTER_OP_RULE
)
FOREACH_IDENTITY_OP
(
REGISTER_OP_RULE
)
FOREACH_FORMAT_OP
(
REGISTER_OP_RULE
)
...
...
@@ -506,7 +511,7 @@ struct FormatRuleRegistry {
ValueRefList
FormatTransformation
::
apply_transformation
(
const
Operator
&
op
,
Span
<
ValueRef
>
inputs
)
{
//mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
//
mgb_log_warn("Format::apply_transformation %s", op.to_string().c_str());
if
(
auto
*
apply_op
=
op
.
as
<
ApplyOp
>
())
{
// all inputs should be FormattedTensorValue
auto
iter
=
format_rules
.
find
(
apply_op
->
op
().
dyn_typeinfo
());
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录