Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
项目经理老王
Mace
提交
b0d9a3aa
Mace
项目概览
项目经理老王
/
Mace
与 Fork 源项目一致
Fork自
Xiaomi / Mace
通知
1
Star
0
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
DevOps
流水线
流水线任务
计划
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
Mace
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
DevOps
DevOps
流水线
流水线任务
计划
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
流水线任务
提交
Issue看板
前往新版Gitcode,体验更适合开发者的 AI 搜索 >>
提交
b0d9a3aa
编写于
10月 27, 2020
作者:
L
like15
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix: Add framework_type arg in Net so that we can get it from Net besides Op
上级
7f2c41fe
变更
6
隐藏空白更改
内联
并排
Showing
6 changed file
with
25 addition
and
3 deletion
+25
-3
tools/python/transform/base_converter.py
tools/python/transform/base_converter.py
+14
-0
tools/python/transform/caffe_converter.py
tools/python/transform/caffe_converter.py
+2
-0
tools/python/transform/onnx_converter.py
tools/python/transform/onnx_converter.py
+2
-0
tools/python/transform/pytorch_converter.py
tools/python/transform/pytorch_converter.py
+2
-0
tools/python/transform/tensorflow_converter.py
tools/python/transform/tensorflow_converter.py
+2
-0
tools/python/transform/transformer.py
tools/python/transform/transformer.py
+3
-3
未找到文件。
tools/python/transform/base_converter.py
浏览文件 @
b0d9a3aa
...
...
@@ -714,3 +714,17 @@ class ConverterUtil(object):
return
DataFormat
.
OIHW
else
:
return
None
@
staticmethod
def
set_framework_type
(
net
,
framework_type
):
framework_type_arg
=
net
.
arg
.
add
()
framework_type_arg
.
name
=
MaceKeyword
.
mace_framework_type_str
framework_type_arg
.
i
=
framework_type
@
staticmethod
def
framework_type
(
net
):
framework_type_arg
=
ConverterUtil
.
get_arg
(
net
,
MaceKeyword
.
mace_framework_type_str
)
if
framework_type_arg
is
None
:
return
None
return
framework_type_arg
.
i
tools/python/transform/caffe_converter.py
浏览文件 @
b0d9a3aa
...
...
@@ -209,6 +209,8 @@ class CaffeConverter(base_converter.ConverterInterface):
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
DataFormat
.
OIHW
)
ConverterUtil
.
add_data_format_arg
(
self
.
_mace_net_def
,
DataFormat
.
NCHW
)
ConverterUtil
.
set_framework_type
(
self
.
_mace_net_def
,
FrameworkType
.
CAFFE
.
value
)
self
.
_caffe_net
=
CaffeNet
()
self
.
_caffe_layers
=
caffe_pb2
.
NetParameter
()
caffe_weights
=
caffe_pb2
.
NetParameter
()
...
...
tools/python/transform/onnx_converter.py
浏览文件 @
b0d9a3aa
...
...
@@ -415,6 +415,8 @@ class OnnxConverter(base_converter.ConverterInterface):
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
DataFormat
.
OIHW
)
ConverterUtil
.
add_data_format_arg
(
self
.
_mace_net_def
,
self
.
_data_format
)
ConverterUtil
.
set_framework_type
(
self
.
_mace_net_def
,
FrameworkType
.
ONNX
.
value
)
onnx_model
=
onnx
.
load
(
src_model_file
)
ir_version
=
onnx_model
.
ir_version
...
...
tools/python/transform/pytorch_converter.py
浏览文件 @
b0d9a3aa
...
...
@@ -204,6 +204,8 @@ class PytorchConverter(base_converter.ConverterInterface):
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
DataFormat
.
OIHW
)
ConverterUtil
.
add_data_format_arg
(
self
.
_mace_net_def
,
DataFormat
.
NCHW
)
ConverterUtil
.
set_framework_type
(
self
.
_mace_net_def
,
FrameworkType
.
PYTORCH
.
value
)
self
.
_op_converters
=
{
NodeKind
.
AdaptiveAvgPool2D
:
self
.
convert_pool
,
NodeKind
.
Add
:
self
.
convert_add
,
...
...
tools/python/transform/tensorflow_converter.py
浏览文件 @
b0d9a3aa
...
...
@@ -306,6 +306,8 @@ class TensorflowConverter(base_converter.ConverterInterface):
self
.
_mace_net_def
=
mace_pb2
.
NetDef
()
ConverterUtil
.
set_filter_format
(
self
.
_mace_net_def
,
DataFormat
.
HWIO
)
ConverterUtil
.
add_data_format_arg
(
self
.
_mace_net_def
,
DataFormat
.
NHWC
)
ConverterUtil
.
set_framework_type
(
self
.
_mace_net_def
,
FrameworkType
.
TENSORFLOW
.
value
)
# import tensorflow graph
tf_graph_def
=
tf
.
GraphDef
()
...
...
tools/python/transform/transformer.py
浏览文件 @
b0d9a3aa
...
...
@@ -1322,8 +1322,7 @@ class Transformer(base_converter.ConverterInterface):
# transform `input(4D) -> reshape(2D) -> matmul` to `fc(2D)`
# fc output is 2D in transformer, using as 4D in op kernel
# work for TensorFlow/PyTorch/ONNX
framework
=
ConverterUtil
.
get_arg
(
op
,
MaceKeyword
.
mace_framework_type_str
).
i
framework
=
ConverterUtil
.
framework_type
(
net
)
is_torch
=
framework
==
FrameworkType
.
PYTORCH
.
value
is_tf
=
framework
==
FrameworkType
.
TENSORFLOW
.
value
is_onnx
=
framework
==
FrameworkType
.
ONNX
.
value
...
...
@@ -1333,7 +1332,8 @@ class Transformer(base_converter.ConverterInterface):
op
.
input
[
1
]
in
self
.
_consts
and
\
len
(
op
.
output_shape
[
0
].
dims
)
==
2
and
\
(
is_tf
or
is_torch
or
is_onnx
)
and
\
op
.
input
[
0
]
in
self
.
_producer
:
op
.
input
[
0
]
in
self
.
_producer
and
\
op
.
output
[
0
]
in
self
.
_consumers
:
input_op
=
self
.
_producer
[
op
.
input
[
0
]]
input_shape
=
input_op
.
output_shape
[
0
].
dims
# check input op
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录