Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
9d147284
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
9d147284
编写于
7月 09, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
better type-shape inference support
上级
74feae9b
变更
10
展开全部
隐藏空白更改
内联
并排
Showing
10 changed file
with
483 addition
and
382 deletion
+483
-382
onnx2fluid/examples/convert_data_npz.py
onnx2fluid/examples/convert_data_npz.py
+2
-2
onnx2fluid/examples/convert_data_pb.py
onnx2fluid/examples/convert_data_pb.py
+2
-2
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+115
-84
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+6
-3
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+2
-2
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+28
-21
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+297
-235
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+1
-1
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+3
-2
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+27
-30
未找到文件。
onnx2fluid/examples/convert_data_npz.py
浏览文件 @
9d147284
...
...
@@ -17,8 +17,8 @@ def make_var_name(name):
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
assert
name
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:-'
:
#
...
...
onnx2fluid/examples/convert_data_pb.py
浏览文件 @
9d147284
...
...
@@ -20,8 +20,8 @@ def make_var_name(name):
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
assert
name
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:-'
:
#
...
...
onnx2fluid/examples/onnx_model_zoo.sh
浏览文件 @
9d147284
...
...
@@ -2,14 +2,17 @@
# setopt SH_WORD_SPLIT # if zsh
# alias python="python3" # if ...
# alias http_get="wget -c" # if no aria2
alias
http_get
=
"aria2c -c -s8 -x8"
base_url
=
"https://s3.amazonaws.com/download.onnx/models/opset_9/"
convert_cmd
=
"python -m onnx2fluid"
validate_cmd
=
"
$convert_cmd
.validation"
convert_flags
=
"-e -o /tmp/export/"
validate_flags1
=
"/tmp/export/model.py"
validate_flags2
=
"/tmp/export/__model__"
# alias http_get="wget -c" # if no aria2
alias
http_get
=
"aria2c -c -s8 -x8"
# alias python="python3" # if ...
validate_flags3
=
"/tmp/export/__model__ -i"
bvlc_alexnet
()
...
...
@@ -23,21 +26,23 @@ bvlc_alexnet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
/"
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
$validate_cmd
$validate_flags1
-t
"
$npz
"
$validate_cmd
$validate_flags2
-t
"
$npz
"
done
$validate_cmd
$validate_flags3
-t
"
$npz
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -53,14 +58,15 @@ bvlc_googlenet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -76,14 +82,15 @@ bvlc_reference_caffenet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 fc-rcnn_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -122,21 +130,23 @@ densenet121()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
/"
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 fc6_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
$validate_cmd
$validate_flags1
-t
"
$npz
"
$validate_cmd
$validate_flags2
-t
"
$npz
"
done
$validate_cmd
$validate_flags3
-t
"
$npz
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 fc6_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -152,14 +162,15 @@ emotion_ferplus()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
Input3 Plus692_Output_0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -175,21 +186,23 @@ inception_v1()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
/"
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
$validate_cmd
$validate_flags1
-t
"
$npz
"
$validate_cmd
$validate_flags2
-t
"
$npz
"
done
$validate_cmd
$validate_flags3
-t
"
$npz
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -205,21 +218,23 @@ inception_v2()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
/"
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
$validate_cmd
$validate_flags1
-t
"
$npz
"
$validate_cmd
$validate_flags2
-t
"
$npz
"
done
$validate_cmd
$validate_flags3
-t
"
$npz
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -235,14 +250,15 @@ mobilenet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data mobilenetv20_output_flatten0_reshape0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -258,14 +274,15 @@ resnet18()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv15_dense0_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -281,21 +298,23 @@ resnet50()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
/"
*
.npz
do
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
gpu_0/data_0 gpu_0/softmaxout_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
$validate_cmd
$validate_flags1
-t
"
$npz
"
$validate_cmd
$validate_flags2
-t
"
$npz
"
done
$validate_cmd
$validate_flags3
-t
"
$npz
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -311,14 +330,15 @@ resnet100_arcface()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data fc1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -334,14 +354,15 @@ resnet101_duc()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data seg_loss
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -357,14 +378,15 @@ resnet152()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv27_dense0_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -380,14 +402,15 @@ shufflenet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -403,14 +426,15 @@ squeezenet()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 softmaxout_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -426,14 +450,15 @@ squeezenet1v1()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data squeezenet0_flatten0_reshape0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -450,14 +475,15 @@ ssd()
mkdir
"
$bn_tar
"
tar
xf
"
$fn_tar
"
-C
"
$bn_tar
/"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
image bboxes,labels,scores
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -473,14 +499,15 @@ tiny_yolov2()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
image grid
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -496,14 +523,15 @@ vgg16bn()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-y
$convert_cm
d
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data vgg0_dense2_fwd
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -519,14 +547,15 @@ vgg19()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -535,21 +564,22 @@ yolov3()
{
bn_tar
=
"yolov3"
fn_tar
=
"
$bn_tar
.tar.gz"
fn_model
=
"
$bn_tar
/
model
.onnx"
fn_model
=
"
$bn_tar
/
yolov3
.onnx"
http_get
"https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/
$fn_tar
"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
-x
#
$convert_cm
d
$convert_flags
"
$fn_model
"
-x
#
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
@@ -565,14 +595,15 @@ zfnet512()
echo
"extracting ..."
tar
xf
"
$fn_tar
"
python
-m
onnx2flui
d
$convert_flags
"
$fn_model
"
$convert_cm
d
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
/"
*
/
do
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python
-m
onnx2fluid.validation
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2fluid.validation
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags1
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
$validate_cmd
$validate_flags2
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
$validate_cmd
$validate_flags3
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
rm
-rf
"
$bn_tar
/"
}
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
9d147284
...
...
@@ -61,11 +61,14 @@ def main(**kwargs):
passed
=
True
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
infer_inputs
=
kwargs
.
pop
(
'infer_inputs'
,
None
)
if
golden_data_filename
or
infer_inputs
is
not
None
:
save_inference_model
=
infer_inputs
is
not
None
if
golden_data_filename
or
save_inference_model
:
from
.validation
import
validate
save_inference_model
=
infer_inputs
is
not
None
inference_input_names
=
infer_inputs
and
infer_inputs
.
split
(
','
)
if
save_inference_model
:
inference_input_names
=
infer_inputs
.
split
(
','
)
else
:
inference_input_names
=
None
logger
.
info
(
'starting validation on desc ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
9d147284
...
...
@@ -23,7 +23,7 @@ def make_var_name(name):
"""
if
name
==
''
:
return
'
_
'
return
''
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:.-'
:
...
...
@@ -170,7 +170,7 @@ def convert(onnx_model_filename,
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
bad_vars
.
append
(
var_name
)
if
len
(
bad_vars
)
>
0
:
if
bad_vars
:
logger
.
warning
(
'type-shape not infered for var %s ...'
,
', '
.
join
(
bad_vars
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
9d147284
...
...
@@ -99,6 +99,9 @@ def get_attribute_value2(attr):
elif
attr
.
type
==
onnx
.
AttributeProto
.
STRING
:
value
=
attr
.
s
value
=
value
.
decode
()
if
isinstance
(
value
,
bytes
)
else
value
elif
attr
.
type
==
onnx
.
AttributeProto
.
STRINGS
:
value
=
attr
.
strings
value
=
[
s
.
decode
()
if
isinstance
(
s
,
bytes
)
else
s
for
s
in
value
]
else
:
value
=
get_attribute_value
(
attr
)
return
value
...
...
@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'):
for
node_idx
,
degree
in
enumerate
(
node_in_degrees
):
if
degree
==
0
:
queue
.
append
(
node_idx
)
while
len
(
queue
)
>
0
:
while
queue
:
node_idx
=
queue
.
pop
(
0
)
node_topo
.
append
(
node_idx
)
for
val_name
in
nodes
[
node_idx
].
output
:
output_refs
[
val_name
].
remove
(
node_idx
)
if
len
(
output_refs
[
val_name
])
>
0
:
if
output_refs
[
val_name
]
:
continue
output_refs
.
pop
(
val_name
)
if
val_name
not
in
input_refs
:
...
...
@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'):
for
node_idx
,
degree
in
enumerate
(
node_out_degrees
):
if
degree
==
0
:
queue
.
append
(
node_idx
)
while
len
(
queue
)
>
0
:
while
queue
:
node_idx
=
queue
.
pop
(
0
)
node_topo
.
append
(
node_idx
)
for
val_name
in
nodes
[
node_idx
].
input
:
input_refs
[
val_name
].
remove
(
node_idx
)
if
len
(
input_refs
[
val_name
])
>
0
:
if
input_refs
[
val_name
]
:
continue
input_refs
.
pop
(
val_name
)
if
val_name
not
in
output_refs
:
...
...
@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices
"""
for
index
in
indices
or
range
(
len
(
nodes
)):
if
indices
is
None
:
indices
=
range
(
len
(
nodes
))
for
index
in
indices
:
node
=
nodes
[
index
]
name
=
node
.
name
domain
=
node
.
domain
...
...
@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None):
if
name
==
''
:
name
=
'op_'
+
str
(
index
)
else
:
# make_op_name
for
s
in
'
\\
|/:-'
:
#
name
=
name
.
replace
(
s
,
'_'
)
# else: # make_op_name
# for s in ' \\|/:-': #
# name = name.replace(s, '_')
if
domain
==
''
:
domain
=
DEFAULT_OP_DOMAIN
...
...
@@ -356,10 +364,11 @@ def polish_and_save(model_filename,
run polish_model and save
"""
if
save_filename
is
None
:
save_filename
=
model_filename
.
replace
(
'.onnx'
,
suffix
+
'.onnx'
)
model
=
onnx
.
load
(
model_filename
)
model
=
polish_model
(
model
,
*
args
,
**
kwargs
)
save_filename
=
save_filename
or
model_filename
.
replace
(
'.onnx'
,
suffix
+
'.onnx'
)
onnx
.
save
(
model
,
save_filename
)
logger
.
info
(
'polished model saved to: %s'
,
save_filename
)
return
save_filename
...
...
@@ -495,7 +504,7 @@ def optimize_model_cast(model):
for
node_idx
,
node
in
enumerate
(
nodes
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
continue
if
no
t
(
node
.
op_type
==
'Cast'
)
:
if
no
de
.
op_type
!=
'Cast'
:
continue
attrs
=
node_attrs
(
node
)
output_dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
attrs
[
'to'
]]
...
...
@@ -551,7 +560,7 @@ def optimize_model_slice(model):
node
=
nodes
[
node_idx
]
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
return
chain
if
no
t
node
.
op_type
=
=
'Slice'
:
if
no
de
.
op_type
!
=
'Slice'
:
return
chain
chain
.
append
(
node_idx
)
output_name
=
node
.
output
[
0
]
...
...
@@ -585,10 +594,10 @@ def optimize_model_slice(model):
nodes_to_remove
=
[]
for
node_idx
in
range
(
len
(
nodes
)):
slice_chain
=
build_slice_node_chain
(
node_idx
)
if
len
(
slice_chain
)
==
0
:
if
not
slice_chain
:
continue
merged_slice
=
merge_slice
(
slice_chain
)
if
len
(
merged_slice
)
>
0
and
len
(
slice_chain
)
==
1
:
# no need to merge
if
merged_slice
and
len
(
slice_chain
)
==
1
:
# no need to merge
continue
attrs
=
{
'axes'
:
[],
'starts'
:
[],
'ends'
:
[]}
...
...
@@ -602,12 +611,11 @@ def optimize_model_slice(model):
output_name
=
last_node
.
output
[
0
]
processed
=
-
1
if
output_name
in
input_refs
:
# 0, [1...]
new_input_name
=
first_node
.
output
[
0
]
if
len
(
merged_slice
)
>
0
else
input_name
new_input_name
=
first_node
.
output
[
0
]
if
merged_slice
else
input_name
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
new_input_name
,
input_refs
)
if
processed
>
0
:
if
len
(
merged_slice
)
>
0
:
if
merged_slice
:
remain_idx
=
slice_chain
[
0
]
remove_chain
=
slice_chain
[
1
:]
slice_node
=
ret_nodes
[
remain_idx
]
...
...
@@ -621,12 +629,11 @@ def optimize_model_slice(model):
remove_chain
=
slice_chain
if
processed
<
0
and
input_name
in
output_refs
:
new_output_name
=
last_node
.
input
[
0
]
if
len
(
merged_slice
)
>
0
else
output_name
new_output_name
=
last_node
.
input
[
0
]
if
merged_slice
else
output_name
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
new_output_name
,
output_refs
)
if
processed
>
0
:
if
len
(
merged_slice
)
>
0
:
if
merged_slice
:
remain_idx
=
slice_chain
[
-
1
]
remove_chain
=
slice_chain
[:
-
1
]
slice_node
=
ret_nodes
[
remain_idx
]
...
...
@@ -641,7 +648,7 @@ def optimize_model_slice(model):
if
processed
>
0
:
nodes_to_remove
.
extend
(
remove_chain
)
if
len
(
merged_slice
)
==
0
:
if
not
merged_slice
:
logger
.
debug
(
'skip slice chain %s -> %s -> %s'
,
input_name
,
slice_chain
,
output_name
)
elif
processed
<
0
:
# NEVERFIX: not merge standalone slice chain
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
9d147284
此差异已折叠。
点击以展开。
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
9d147284
...
...
@@ -119,7 +119,7 @@ def export_onnx_with_validation(
return
list
(
map
(
tensors_to_arrays
,
tensors
))
def
zip_dict
(
keys
:
Union
[
Iterable
[
Any
],
None
],
keys
:
Optional
[
Iterable
[
Any
]
],
values
:
Sequence
[
Union
[
Any
,
Sequence
[
Any
]]],
)
->
MyDict
[
Text
,
Union
[
object
,
MyDict
[
Text
,
object
]]]:
keys
=
keys
or
range
(
len
(
values
))
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
9d147284
...
...
@@ -160,7 +160,8 @@ def validate(fluid_model_filename,
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
len
(
output_data
))
elif
save_inference_model
:
assert
inference_input_names
,
'input names required for type-shape inference'
assert
inference_input_names
is
not
None
,
(
'input names required for type-shape inference'
)
input_names
=
inference_input_names
logger
.
info
(
'using input names: %s'
,
', '
.
join
(
input_names
))
...
...
@@ -178,7 +179,7 @@ def validate(fluid_model_filename,
fluid
.
io
.
load_inference_model
(
fluid_model_dir
,
exe
)
logger
.
info
(
'model re-load passed'
)
if
not
golden_data_filename
:
if
golden_data_filename
==
''
:
return
True
# execute
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
9d147284
...
...
@@ -133,7 +133,7 @@ class Program(object):
od_attr
.
type
=
framework_pb2
.
STRING
od_attr
.
s
=
value
elif
isinstance
(
value
,
list
):
if
len
(
value
)
>
0
:
# TODO: test all items
if
value
:
# TODO: test all items
if
isinstance
(
value
[
0
],
bool
):
# bool.mro() = [bool, int, object]
od_attr
.
type
=
framework_pb2
.
BOOLEANS
...
...
@@ -183,23 +183,16 @@ class Program(object):
if
self
.
code_mutable
:
self
.
codes
.
append
(
code
)
def
OpDesc
(
self
,
op_type
,
input_key_vals
=
None
,
output_key_vals
=
None
,
attrs
=
None
):
def
OpDesc
(
self
,
op_type
,
input_key_vals
,
output_key_vals
,
attrs
):
"""
add OpDesc
"""
desc
=
framework_pb2
.
OpDesc
()
desc
.
type
=
op_type
if
input_key_vals
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_key_vals
))
if
output_key_vals
:
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_key_vals
))
if
attrs
:
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_key_vals
))
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_key_vals
))
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
self
.
op_descs
.
append
(
desc
)
return
desc
...
...
@@ -212,7 +205,7 @@ class Program(object):
add VarDesc,
"""
assert
name
not
in
self
.
var_descs
,
'var nam
ing conflicted'
assert
name
not
in
self
.
var_descs
,
'var nam
e {} conflicts'
.
format
(
name
)
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
.
name
=
name
...
...
@@ -220,10 +213,10 @@ class Program(object):
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
self
.
var_descs
[
name
]
=
var_desc
if
value_info
:
if
value_info
is
not
None
:
self
.
VarTypeShapeInfo
(
name
,
value_info
,
remove_batch
=
remove_batch
)
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
Op
(
self
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
**
kwargs
):
"""
convert an ONNX op and add it to program
"""
...
...
@@ -232,15 +225,17 @@ class Program(object):
raise
ValueError
(
'only default domain supported'
)
if
op_type
in
symbolic
.
DEFAULT_OP_MAPPING
:
symbolic
.
_default
(
self
,
op_type
,
*
args
,
**
kwargs
)
symbolic
.
_default
(
self
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
**
kwargs
)
elif
hasattr
(
symbolic
,
op_type
):
fn
=
getattr
(
symbolic
,
op_type
)
fn
(
self
,
*
args
,
**
kwargs
)
fn
(
self
,
inputs
,
outputs
,
attrs
,
*
args
,
**
kwargs
)
else
:
raise
ValueError
(
'conversion for {}::{} not supported'
.
format
(
domain
,
op_type
))
def
IntermediateOp
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
IntermediateOp
(
self
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
**
kwargs
):
"""
convert an intermediate ONNX op declaring in desc program only
"""
...
...
@@ -248,7 +243,7 @@ class Program(object):
code_mutable
=
self
.
code_mutable
self
.
code_mutable
=
False
try
:
self
.
Op
(
domain
,
op_type
,
*
args
,
**
kwargs
)
self
.
Op
(
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
**
kwargs
)
except
BaseException
as
e
:
self
.
code_mutable
=
code_mutable
raise
e
...
...
@@ -272,14 +267,15 @@ class Program(object):
tensor_desc
.
data_type
=
self
.
Dtype
(
dtype
)
# required
shape
=
value_info
.
get
(
'shape'
,
None
)
if
shape
is
not
None
:
tensor_desc
.
dims
.
extend
(
shape
)
if
len
(
shape
)
>
0
:
# skip scalars
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
False
)
#not persistable)
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
if
not
shape
:
# None or scalars
return
tensor_desc
.
dims
.
extend
(
shape
)
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
False
)
#not persistable)
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
class
Writer
(
object
):
...
...
@@ -337,8 +333,8 @@ class Writer(object):
emit an ONNX weight into program
"""
if
value_info
.
get
(
'embedded_as'
,
[]):
embedded_names
=
value_info
[
'embedded_as'
]
embedded_names
=
value_info
.
get
(
'embedded_as'
,
[])
if
embedded_names
:
prog
.
Code
(
'# parameter {} embedded as {}'
.
format
(
name
,
embedded_names
))
for
embedded_name
in
embedded_names
:
...
...
@@ -431,7 +427,8 @@ class Writer(object):
assert
lod
is
None
or
isinstance
(
lod
,
list
),
'lod should be None or list'
lod
=
lod
or
[
0
]
if
lod
is
None
:
lod
=
[
0
]
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录