Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
71dfbe45
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
71dfbe45
编写于
2月 13, 2019
作者:
刘
刘琦
浏览文件
操作
浏览文件
下载
差异文件
Merge branch 'master' into 'master'
Fix reshape format bugs See merge request !977
上级
a8b3c4f4
89832e0d
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
38 addition
and
24 deletion
+38
-24
mace/ops/reshape.cc
mace/ops/reshape.cc
+10
-1
mace/python/tools/converter_tool/shape_inference.py
mace/python/tools/converter_tool/shape_inference.py
+8
-7
mace/python/tools/converter_tool/transformer.py
mace/python/tools/converter_tool/transformer.py
+20
-16
未找到文件。
mace/ops/reshape.cc
浏览文件 @
71dfbe45
...
...
@@ -75,8 +75,17 @@ class ReshapeOp : public Operation {
<<
"Input size not match reshaped tensor size"
;
out_shape
[
unknown_idx
]
=
missing
;
}
Tensor
*
output
=
this
->
Output
(
OUTPUT
);
// NHWC -> NCHW
auto
df
=
static_cast
<
DataFormat
>
(
Operation
::
GetOptionalArg
<
int
>
(
"data_format"
,
DataFormat
::
DF_NONE
));
if
(
df
==
DataFormat
::
NHWC
&&
D
==
DeviceType
::
CPU
&&
out_shape
.
size
()
==
4
&&
shape
->
is_weight
())
{
std
::
vector
<
int
>
dst_dims
=
{
0
,
3
,
1
,
2
};
std
::
vector
<
index_t
>
out_shape_gpu
=
TransposeShape
<
index_t
,
index_t
>
(
out_shape
,
dst_dims
);
out_shape
=
out_shape_gpu
;
}
output
->
ReuseTensorBuffer
(
*
input
);
output
->
Reshape
(
out_shape
);
...
...
mace/python/tools/converter_tool/shape_inference.py
浏览文件 @
71dfbe45
...
...
@@ -276,18 +276,19 @@ class ShapeInference(object):
output_shape
[
idx
]
=
input_size
/
product
self
.
add_output_shape
(
op
,
[
output_shape
])
else
:
output_shape
=
list
(
self
.
_output_shape_cache
[
op
.
input
[
0
]])
output_shape
=
[]
axis
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
).
i
end_axis
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_end_axis_str
).
i
# noqa
if
end_axis
<
0
:
end_axis
=
len
(
output_shape
)
+
end_axis
end_axis
=
end_axis
if
end_axis
>
0
else
end_axis
+
len
(
list
(
self
.
_output_shape_cache
[
op
.
input
[
0
]]))
dim
=
1
for
i
in
range
(
0
,
axis
):
output_shape
[
i
]
=
self
.
_output_shape_cache
[
op
.
input
[
0
]][
i
]
output_shape
.
append
(
self
.
_output_shape_cache
[
op
.
input
[
0
]][
i
])
for
i
in
range
(
axis
,
end_axis
+
1
):
dim
*=
self
.
_output_shape_cache
[
op
.
input
[
0
]][
i
]
output_shape
[
i
]
=
1
for
i
in
range
(
end_axis
+
1
,
len
(
output_shape
)):
output_shape
[
i
]
=
self
.
_output_shape_cache
[
op
.
input
[
0
]][
i
]
output_shape
.
append
(
-
1
)
for
i
in
range
(
end_axis
+
1
,
len
(
list
(
self
.
_output_shape_cache
[
op
.
input
[
0
]]))):
output_shape
.
append
(
self
.
_output_shape_cache
[
op
.
input
[
0
]][
i
])
output_shape
[
axis
]
=
dim
self
.
add_output_shape
(
op
,
[
output_shape
])
mace/python/tools/converter_tool/transformer.py
浏览文件 @
71dfbe45
...
...
@@ -1790,31 +1790,35 @@ class Transformer(base_converter.ConverterInterface):
if
op
.
type
==
MaceOp
.
Reshape
.
name
and
\
len
(
op
.
input
)
==
1
:
print
(
"Transform Caffe Reshape"
)
if
op
.
arg
[
3
].
name
==
'dim'
:
dims
=
[]
dim_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_dim_str
)
axis_arg
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_axis_str
)
# transform caffe reshape op
if
dim_arg
:
dims
=
dim_arg
.
ints
shape_tensor
=
net
.
tensors
.
add
()
shape_tensor
.
name
=
op
.
name
+
'_shape'
shape_tensor
.
dims
.
append
(
len
(
op
.
output_shape
[
0
].
dims
))
shape_tensor
.
data_type
=
mace_pb2
.
DT_INT32
shape_tensor
.
int32_data
.
extend
(
op
.
arg
[
3
].
ints
)
op
.
input
.
append
(
shape_tensor
.
name
)
else
:
axis
=
op
.
arg
[
3
].
i
dims
=
[
1
]
*
len
(
op
.
output_shape
[
0
].
dims
)
end_axis
=
op
.
arg
[
4
].
i
end_axis
=
end_axis
if
end_axis
>=
0
else
end_axis
+
len
(
dims
)
# noqa
# transform caffe flatten op
elif
axis_arg
is
not
None
:
axis
=
axis_arg
.
i
for
i
in
range
(
0
,
axis
):
dims
[
i
]
=
0
for
i
in
range
(
axis
+
1
,
end_axis
+
1
):
dims
[
i
]
=
1
for
i
in
range
(
end_axis
+
1
,
len
(
dims
)):
dims
[
i
]
=
0
dims
[
axis
]
=
-
1
dims
.
append
(
0
)
dims
.
append
(
-
1
)
for
i
in
range
(
axis
+
1
,
len
(
op
.
output_shape
[
0
].
dims
)):
dims
.
append
(
0
)
shape_tensor
=
net
.
tensors
.
add
()
shape_tensor
.
name
=
op
.
name
+
'_shape'
shape_tensor
.
dims
.
append
(
len
(
dims
))
shape_tensor
.
data_type
=
mace_pb2
.
DT_INT32
shape_tensor
.
int32_data
.
extend
(
dims
)
op
.
input
.
append
(
shape_tensor
.
name
)
else
:
mace_check
(
False
,
"Only support reshape and flatten"
)
# NCHW -> NHWC
if
len
(
dims
)
==
4
:
self
.
transpose_shape
(
dims
,
[
0
,
2
,
3
,
1
])
shape_tensor
.
int32_data
.
extend
(
dims
)
op
.
input
.
append
(
shape_tensor
.
name
)
def
fold_fc_reshape
(
self
):
net
=
self
.
_model
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录