Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
fcbdcb82
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
1 年多 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
fcbdcb82
编写于
7月 08, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add ops and update readme
上级
2a82fdeb
变更
11
展开全部
隐藏空白更改
内联
并排
Showing
11 changed file
with
706 addition
and
352 deletion
+706
-352
onnx2fluid/README.md
onnx2fluid/README.md
+4
-2
onnx2fluid/README_en.md
onnx2fluid/README_en.md
+10
-6
onnx2fluid/examples/convert_data_npz_0.py
onnx2fluid/examples/convert_data_npz_0.py
+0
-48
onnx2fluid/examples/convert_data_pb_0.py
onnx2fluid/examples/convert_data_pb_0.py
+0
-64
onnx2fluid/examples/gen_some_samples.py
onnx2fluid/examples/gen_some_samples.py
+60
-12
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+1
-1
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+10
-6
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+11
-7
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+587
-185
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+1
-1
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+22
-20
未找到文件。
onnx2fluid/README.md
浏览文件 @
fcbdcb82
...
...
@@ -54,7 +54,7 @@ onnx2fluid sample_1.onnx -t sample_1.npz
onnx2fluid:
```
shell
onnx2fluid
[
-dexy
]
[
-o
/path/to/export_dir/]
[
-z
archive.zip]
[
-t
test_data.npz] /path/to/onnx/model.onnx
onnx2fluid
[
-dexy
]
[
-o
/path/to/export_dir/]
[
-z
archive.zip]
[
-t
test_data.npz]
[
-i
[
input_name1,input_name2]]
/path/to/onnx/model.onnx
optional arguments:
--debug
,
-d
启用调试
...
...
@@ -65,6 +65,8 @@ optional arguments:
--output_dir
,
-o
指定输出目录
--archive
[
ARCHIVE],
-z
[
ARCHIVE]
如果验证通过,打包到指定的ZIP文件
--infer_inputs
,
-i
[
input_name1,input_name2]
调用PaddlePaddle fluid类形推导完善模型
```
转换工具onnx2fluid.conversion:
...
...
@@ -76,7 +78,7 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
验证工具onnx2fluid.validate:
```
shell
onnx2fluid.validate
[
-d
]
[
-t
test_data.npz]
[
-p
1e-3] /path/to/onnx/model.onnx
onnx2fluid.validate
[
-d
]
[
-t
test_data.npz]
[
-
i
[
input_name1,input_name2]]
[
-
p
1e-3] /path/to/onnx/model.onnx
```
## 参考
...
...
onnx2fluid/README_en.md
浏览文件 @
fcbdcb82
...
...
@@ -19,8 +19,8 @@ PyTorch to Paddlepaddle model conversion can be easily achieved with PyTorch ONN
## Environment and dependency
*
python 3.5+ (python 2 not fully supported yet)
*
onnx
== 1.4.0
*
paddlepaddle
=
= 1.3.0 (optional for validation)
*
onnx
>= 1.4
*
paddlepaddle
>
= 1.3.0 (optional for validation)
## Get started
...
...
@@ -47,10 +47,12 @@ onnx2fluid sample_unet.onnx -t sample_unet.npz
## Usage
**ONNX opset 9+**
is mainly supported, corresponded to PyTorch
**1.0/1.1(stable opset)**
,for more information:
[
ONNX doc
](
https://github.com/onnx/onnx/blob/master/docs/Operators.md
)
onnx2fluid (all in one):
```
shell
onnx2fluid
[
-dexy
]
[
-o
/path/to/export_dir/]
[
-z
archive.zip]
[
-t
test_data.npz] /path/to/onnx/model.onnx
onnx2fluid
[
-dexy
]
[
-o
/path/to/export_dir/]
[
-z
archive.zip]
[
-t
test_data.npz]
[
-i
[
input_name1,input_name2]]
/path/to/onnx/model.onnx
optional arguments:
--debug
,
-d
enable
debug logging and checking
...
...
@@ -61,6 +63,8 @@ optional arguments:
--output_dir
,
-o
output directory
--archive
[
ARCHIVE],
-z
[
ARCHIVE]
compress outputs to ZIP file
if
conversion successed
--infer_inputs
,
-i
[
input_name1,input_name2]
invoke PaddlePaddle fluid type-shape inference
```
onnx2fluid.conversion:
...
...
@@ -72,10 +76,10 @@ onnx2fluid.conversion [-dexy] [-o /path/to/export_dir/] /path/to/onnx/model.onnx
onnx2fluid.validate:
```
shell
onnx2fluid.validate
[
-d
]
[
-t
test_data.npz]
[
-p
1e-3] /path/to/onnx/model.onnx
onnx2fluid.validate
[
-d
]
[
-t
test_data.npz]
[
-
i
[
input_name1,input_name2]]
[
-
p
1e-3] /path/to/onnx/model.onnx
```
## Reference
*
[
PaddlePaddle fluid operators
](
http://www.paddlepaddle.org/documentation/docs/en/1.
4
/api/layers.html
)
*
load converted model via
[
load_inference_model
](
http://www.paddlepaddle.org/documentation/docs/en/1.
4
/api/io.html#permalink-1-load_inference_model
)
*
[
PaddlePaddle fluid operators
](
http://www.paddlepaddle.org/documentation/docs/en/1.
5
/api/layers.html
)
*
load converted model via
[
load_inference_model
](
http://www.paddlepaddle.org/documentation/docs/en/1.
5
/api/io.html#permalink-1-load_inference_model
)
onnx2fluid/examples/convert_data_npz_0.py
已删除
100644 → 0
浏览文件 @
2a82fdeb
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import
sys
import
numpy
as
np
from
collections
import
OrderedDict
as
Dict
def
_make_var_name
(
name
):
"""
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
' *?
\\
/-:'
:
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
fn
=
sys
.
argv
[
1
]
input_names
=
sys
.
argv
[
2
].
split
(
':'
)
output_name
=
sys
.
argv
[
3
].
split
(
':'
)
squeeze_data
=
len
(
sys
.
argv
)
>
4
data
=
np
.
load
(
fn
,
encoding
=
'bytes'
)
input_data
=
data
[
'inputs'
]
output_data
=
data
[
'outputs'
]
while
squeeze_data
and
input_data
.
ndim
>
4
and
input_data
.
shape
[
0
]
==
1
:
input_data
=
input_data
.
squeeze
(
0
)
while
squeeze_data
and
output_data
.
ndim
>
2
and
output_data
.
shape
[
0
]
==
1
:
output_data
=
output_data
.
squeeze
(
0
)
inputs
=
Dict
(
zip
(
map
(
_make_var_name
,
input_names
),
[
input_data
]))
outputs
=
Dict
(
zip
(
map
(
_make_var_name
,
output_name
),
[
output_data
]))
np
.
savez
(
fn
,
inputs
=
inputs
,
outputs
=
outputs
)
# overwrite
onnx2fluid/examples/convert_data_pb_0.py
已删除
100644 → 0
浏览文件 @
2a82fdeb
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Wed Mar 27 11:50:03 2019
@author: Macrobull
"""
import
os
,
sys
import
numpy
as
np
import
onnx
import
onnx.numpy_helper
as
numpy_helper
from
collections
import
OrderedDict
as
Dict
from
glob
import
glob
def
_make_var_name
(
name
):
"""
make a valid variable name in Python code
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
' *?
\\
/-:'
:
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
data_dir
=
os
.
path
.
dirname
(
sys
.
argv
[
1
])
input_names
=
sys
.
argv
[
2
].
split
(
':'
)
output_name
=
sys
.
argv
[
3
].
split
(
':'
)
squeeze_data
=
len
(
sys
.
argv
)
>
4
# Load inputs
inputs
=
[]
for
fn
in
glob
(
os
.
path
.
join
(
data_dir
,
'input_*.pb'
)):
tensor
=
onnx
.
TensorProto
()
with
open
(
fn
,
'rb'
)
as
f
:
tensor
.
ParseFromString
(
f
.
read
())
tensor
=
numpy_helper
.
to_array
(
tensor
)
while
squeeze_data
and
tensor
.
ndim
>
4
and
tensor
.
shape
[
0
]
==
1
:
tensor
=
tensor
.
squeeze
(
0
)
inputs
.
append
(
tensor
)
# Load outputs
outputs
=
[]
for
fn
in
glob
(
os
.
path
.
join
(
data_dir
,
'output_*.pb'
)):
tensor
=
onnx
.
TensorProto
()
with
open
(
fn
,
'rb'
)
as
f
:
tensor
.
ParseFromString
(
f
.
read
())
tensor
=
numpy_helper
.
to_array
(
tensor
)
while
squeeze_data
and
tensor
.
ndim
>
2
and
tensor
.
shape
[
0
]
==
1
:
tensor
=
tensor
.
squeeze
(
0
)
outputs
.
append
(
tensor
)
inputs
=
Dict
(
zip
(
map
(
_make_var_name
,
input_names
),
inputs
))
outputs
=
Dict
(
zip
(
map
(
_make_var_name
,
output_name
),
outputs
))
np
.
savez
(
data_dir
,
inputs
=
inputs
,
outputs
=
outputs
)
onnx2fluid/examples/gen_some_samples.py
浏览文件 @
fcbdcb82
...
...
@@ -20,34 +20,74 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix
=
'sample_'
idx
=
0
######## example: RNN cell ########
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
self
.
gru
=
nn
.
GRUCell
(
6
,
5
)
self
.
lstm
=
nn
.
LSTMCell
(
5
,
4
)
def
forward
(
self
,
x
,
h1
,
h2
,
c2
):
h
=
self
.
gru
(
x
,
h1
)
h
,
c
=
self
.
lstm
(
h
,
(
h2
,
c2
))
return
h
,
c
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
7
,
6
))
h1
=
torch
.
zeros
((
7
,
5
))
h2
=
torch
.
zeros
((
7
,
4
))
c2
=
torch
.
zeros
((
7
,
4
))
yp
=
model
(
xb
,
h1
,
h2
,
c2
)
idx
+=
1
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
[
xb
,
h1
,
h2
,
c2
],
prefix
+
str
(
idx
),
[
'x'
,
'h1'
,
'h2'
,
'c2'
],
[
'h'
,
'c'
],
verbose
=
True
,
training
=
False
)
######## example: RNN ########
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
self
.
gru
=
nn
.
GRU
(
4
,
5
,
3
)
self
.
lstm
=
nn
.
LSTM
(
5
,
6
,
2
)
self
.
gru
=
nn
.
GRU
(
6
,
5
,
3
)
self
.
lstm
=
nn
.
LSTM
(
5
,
4
,
2
)
def
forward
(
self
,
x
):
y
=
x
y
,
h
=
self
.
gru
(
y
)
y
,
h
=
self
.
lstm
(
y
)
def
forward
(
self
,
x
,
h1
,
h2
,
c2
):
y
,
h1
=
self
.
gru
(
x
,
h1
)
y
,
(
h2
,
c2
)
=
self
.
lstm
(
y
,
(
h2
,
c2
))
return
y
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
,
4
))
yp
=
model
(
xb
)
xb
=
torch
.
rand
((
8
,
1
,
6
))
h1
=
torch
.
zeros
((
3
,
1
,
5
))
h2
=
torch
.
zeros
((
2
,
1
,
4
))
c2
=
torch
.
zeros
((
2
,
1
,
4
))
yp
=
model
(
xb
,
h1
,
h2
,
c2
)
idx
+=
1
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
[
xb
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
export_onnx_with_validation
(
model
,
[
xb
,
h1
,
h2
,
c2
],
prefix
+
str
(
idx
),
[
'x'
,
'h1'
,
'h2'
,
'c2'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
######## example: random ########
"""
symbolic registration:
def rand(g, *shapes):
shapes_list = list(shapes)
shape = _maybe_get_const(shapes_list[0], "is")
return g.op('RandomUniform', shape_i=shape)
"""
class
Model
(
nn
.
Module
):
...
...
@@ -55,8 +95,9 @@ class Model(nn.Module):
super
(
Model
,
self
).
__init__
()
def
forward
(
self
,
x
):
y
=
torch
.
rand
((
2
,
3
))
# + torch.rand_like(xb)
y
=
y
+
torch
.
randn
((
2
,
3
))
# + torch.randn_like(xb)
y
=
torch
.
rand
((
2
,
3
))
# + torch.rand_like(x)
y
=
y
+
torch
.
randn
((
2
,
3
))
# + torch.randn_like(x)
y
=
y
+
x
return
y
...
...
@@ -124,6 +165,13 @@ export_onnx_with_validation(model, [xb0, xb1],
training
=
False
)
######## example: affine_grid ########
"""
symbolic registration:
@parse_args('v', 'is')
def affine_grid_generator(g, theta, size):
return g.op('AffineGrid', theta, size_i=size)
"""
class
Model
(
nn
.
Module
):
...
...
onnx2fluid/onnx2fluid/cmdline.py
浏览文件 @
fcbdcb82
...
...
@@ -61,7 +61,7 @@ def main(**kwargs):
passed
=
True
golden_data_filename
=
kwargs
.
pop
(
'test_data'
,
''
)
infer_inputs
=
kwargs
.
pop
(
'infer_inputs'
,
None
)
if
golden_data_filename
or
infer_inputs
:
if
golden_data_filename
or
infer_inputs
is
not
None
:
from
.validation
import
validate
save_inference_model
=
infer_inputs
is
not
None
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
fcbdcb82
...
...
@@ -91,7 +91,7 @@ def convert(onnx_model_filename,
# onnx model optimization
logger
.
info
(
'model has %d ops'
,
len
(
onnx_model
.
graph
.
node
))
logger
.
info
(
'optimizing model ...'
)
onnx_model
=
polish_model
(
onnx_model
)
onnx_model
=
polish_model
(
onnx_model
,
checking
=
onnx_opset_pedantic
)
# prepare filesystem
shutil
.
rmtree
(
save_dir
,
ignore_errors
=
True
)
...
...
@@ -123,6 +123,7 @@ def convert(onnx_model_filename,
for
name
,
weight
in
graph_weights
(
onnx_graph
):
var_name
=
make_var_name
(
name
)
value_info
=
value_infos
[
var_name
]
value_info
[
'lod'
]
=
[
0
]
value_info
[
'embedded_as'
]
=
[]
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
weight
)
# lazy getter
...
...
@@ -134,8 +135,8 @@ def convert(onnx_model_filename,
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
topo
=
topo
):
op_name
=
make_var_name
(
name
)
inputs
=
[
make_var_name
(
val
)
for
val
in
inputs
]
outputs
=
[
make_var_name
(
val
)
for
val
in
outputs
]
inputs
=
list
(
map
(
make_var_name
,
inputs
))
outputs
=
list
(
map
(
make_var_name
,
outputs
))
logger
.
debug
(
'translating op %s(%s) %s::%s ...'
,
name
,
op_name
,
domain
,
op_type
)
if
domain
==
DEFAULT_OP_DOMAIN
:
...
...
@@ -192,13 +193,16 @@ def convert(onnx_model_filename,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
embedded_names
)
for
embedded_name
in
embedded_names
:
# multiple references
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
embedded_name
))
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
embedded_name
),
lod
=
value_info
[
'lod'
])
else
:
logger
.
debug
(
'saving weight %s(%s[%d], %dB) to %s ...'
,
name
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
var_name
)
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
),
lod
=
value_info
[
'lod'
])
fluid_writer
.
emit_param
(
fluid_program
,
var_name
,
value_info
)
param_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
fcbdcb82
...
...
@@ -319,17 +319,20 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return
processed
def
polish_model
(
model
,
extras
=
True
):
def
polish_model
(
model
,
internals
=
True
,
extras
=
True
,
checking
=
True
):
"""
polish_model enhanced for inference
"""
check_model
(
model
)
if
checking
:
check_model
(
model
)
strip_doc_string
(
model
)
passes
=
optimizer
.
get_available_passes
()
passes
=
list
(
filter
(
lambda
name
:
not
name
.
startswith
(
'split_'
),
passes
))
#
logger
.
debug
(
'builtin optimizations to perform in ONNX:
\n\t
%s'
,
passes
)
model
=
optimizer
.
optimize
(
model
,
passes
=
passes
)
if
internals
:
passes
=
optimizer
.
get_available_passes
()
passes
=
list
(
filter
(
lambda
name
:
not
name
.
startswith
(
'split_'
),
passes
))
#
logger
.
debug
(
'builtin optimizations to perform in ONNX:
\n\t
%s'
,
passes
)
model
=
optimizer
.
optimize
(
model
,
passes
=
passes
)
if
extras
:
for
optimize
in
(
optimize_model_skip_op_for_inference
,
...
...
@@ -339,7 +342,8 @@ def polish_model(model, extras=True):
):
model
=
optimize
(
model
)
model
=
infer_shapes
(
model
)
check_model
(
model
)
if
checking
:
check_model
(
model
)
return
model
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
fcbdcb82
此差异已折叠。
点击以展开。
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
fcbdcb82
...
...
@@ -159,7 +159,7 @@ def validate(fluid_model_filename,
# output_names = output_data.keys()
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
len
(
output_data
))
el
se
:
el
if
save_inference_model
:
assert
inference_input_names
,
'input names required for type-shape inference'
input_names
=
inference_input_names
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
fcbdcb82
...
...
@@ -96,7 +96,7 @@ class Program(object):
return
Program
.
DTYPE_TO_FRAMEWORK_DTYPE
[
dtype
]
@
staticmethod
def
OpDescVars
(
vals
,
*
key
s
):
def
OpDescVars
(
keys
,
val
s
):
"""
make (OpDesc.Var)s
"""
...
...
@@ -150,13 +150,11 @@ class Program(object):
else
:
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
else
:
# WORKAROUND: shape of scalars is []
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
# od_attr.type = framework_pb2.INTS
# logger.warning('using attribute %s = %s as INTS', key, value)
else
:
# WORKAROUND: [] not inferred
# raise ValueError('unsupported attribute {} = {}'.format(key, value))
od_attr
.
type
=
framework_pb2
.
INTS
logger
.
warning
(
'using attribute %s = %s as INTS'
,
key
,
value
)
else
:
raise
ValueError
(
'unsupported attribute {} = {}'
.
format
(
key
,
value
))
...
...
@@ -187,8 +185,8 @@ class Program(object):
def
OpDesc
(
self
,
op_type
,
input_
val_key
s
=
None
,
output_
val_key
s
=
None
,
input_
key_val
s
=
None
,
output_
key_val
s
=
None
,
attrs
=
None
):
"""
add OpDesc
...
...
@@ -196,10 +194,10 @@ class Program(object):
desc
=
framework_pb2
.
OpDesc
()
desc
.
type
=
op_type
if
input_
val_key
s
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_
val_key
s
))
if
output_
val_key
s
:
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_
val_key
s
))
if
input_
key_val
s
:
desc
.
inputs
.
extend
(
self
.
OpDescVars
(
*
input_
key_val
s
))
if
output_
key_val
s
:
desc
.
outputs
.
extend
(
self
.
OpDescVars
(
*
output_
key_val
s
))
if
attrs
:
desc
.
attrs
.
extend
(
self
.
OpDescAttrs
(
attrs
))
self
.
op_descs
.
append
(
desc
)
...
...
@@ -388,8 +386,8 @@ class Writer(object):
))
prog
.
OpDesc
(
'feed'
,
([
'
feed'
],
'X'
),
([
name
],
'Out'
),
([
'
X'
],
[
'feed'
]
),
([
'Out'
],
[
name
]
),
{
'col'
:
idx
},
)
prog
.
VarDesc
(
name
,
value_info
=
value_info
,
remove_batch
=
remove_batch
)
...
...
@@ -406,8 +404,8 @@ class Writer(object):
prog
.
OpDesc
(
'fetch'
,
([
name
],
'X'
),
([
'
fetch'
],
'Out'
),
([
'X'
],
[
name
]
),
([
'
Out'
],
[
'fetch'
]
),
{
'col'
:
idx
},
)
# var is emitted over ops
...
...
@@ -424,12 +422,16 @@ class Writer(object):
return
codes
@
staticmethod
def
write_weight
(
weight
,
filename
):
def
write_weight
(
weight
,
filename
,
lod
=
None
):
"""
write single weight in fluid desc
"""
assert
isinstance
(
weight
,
np
.
ndarray
),
'weight is not an ndarray'
assert
lod
is
None
or
isinstance
(
lod
,
list
),
'lod should be None or list'
lod
=
lod
or
[
0
]
tensor_desc
=
framework_pb2
.
VarType
.
TensorDesc
()
tensor_desc
.
data_type
=
Program
.
Dtype
(
weight
.
dtype
)
...
...
@@ -437,7 +439,7 @@ class Writer(object):
fp
=
open
(
filename
,
'wb'
)
np
.
array
([
0
],
dtype
=
np
.
int32
).
tofile
(
fp
)
# version
np
.
array
(
[
0
]
,
dtype
=
np
.
int64
).
tofile
(
fp
)
# LOD level
np
.
array
(
lod
,
dtype
=
np
.
int64
).
tofile
(
fp
)
# LOD level
np
.
array
([
0
],
dtype
=
np
.
int32
).
tofile
(
fp
)
# tensor version
np
.
array
([
tensor_desc
.
ByteSize
()],
dtype
=
np
.
int32
).
tofile
(
fp
)
fp
.
write
(
tensor_desc
.
SerializeToString
())
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录