Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
2228423e
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 2 年 前同步成功
通知
329
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2228423e
编写于
3月 28, 2019
作者:
J
Jason
提交者:
GitHub
3月 28, 2019
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #16 from MacroBull/master
Rename onnx2paddle to onnx2fluid
上级
a4796334
f0dede1f
变更
20
显示空白变更内容
内联
并排
Showing
20 changed file
with
4323 addition
and
1582 deletion
+4323
-1582
onnx2fluid/.gitignore
onnx2fluid/.gitignore
+1
-0
onnx2fluid/README.md
onnx2fluid/README.md
+7
-0
onnx2fluid/examples/convert_data_npz_0.py
onnx2fluid/examples/convert_data_npz_0.py
+1
-1
onnx2fluid/examples/convert_data_pb_0.py
onnx2fluid/examples/convert_data_pb_0.py
+0
-0
onnx2fluid/examples/gen_some_samples.py
onnx2fluid/examples/gen_some_samples.py
+26
-27
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+35
-35
onnx2fluid/onnx2fluid/__init__.py
onnx2fluid/onnx2fluid/__init__.py
+0
-0
onnx2fluid/onnx2fluid/__main__.py
onnx2fluid/onnx2fluid/__main__.py
+93
-0
onnx2fluid/onnx2fluid/cmdline.py
onnx2fluid/onnx2fluid/cmdline.py
+42
-37
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+134
-71
onnx2fluid/onnx2fluid/framework_pb2.py
onnx2fluid/onnx2fluid/framework_pb2.py
+1634
-0
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+107
-86
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+2041
-0
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+19
-13
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+72
-59
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+102
-78
onnx2fluid/requirements.txt
onnx2fluid/requirements.txt
+1
-1
onnx2fluid/setup.cfg
onnx2fluid/setup.cfg
+8
-8
onnx2fluid/setup.py
onnx2fluid/setup.py
+0
-1
onnx2paddle/onnx2paddle/framework_pb2.py
onnx2paddle/onnx2paddle/framework_pb2.py
+0
-1165
未找到文件。
onnx2
paddle
/.gitignore
→
onnx2
fluid
/.gitignore
浏览文件 @
2228423e
...
@@ -57,3 +57,4 @@ coverage.xml
...
@@ -57,3 +57,4 @@ coverage.xml
/examples/*.aria2
/examples/*.aria2
/examples/*.onnx
/examples/*.onnx
/examples/*.np?
/examples/*.np?
**/.*
onnx2
paddle
/README.md
→
onnx2
fluid
/README.md
浏览文件 @
2228423e
Onnx2
paddle
Onnx2
Fluid
===
===
Inference model conversion from ONNX/PyTorch to Paddle
Inference model conversion from ONNX/PyTorch to Paddle
fluid
快速开始
快速开始
---
---
...
...
onnx2
paddle
/examples/convert_data_npz_0.py
→
onnx2
fluid
/examples/convert_data_npz_0.py
浏览文件 @
2228423e
onnx2
paddle
/examples/convert_data_pb_0.py
→
onnx2
fluid
/examples/convert_data_pb_0.py
浏览文件 @
2228423e
文件已移动
onnx2
paddle
/examples/gen_some_samples.py
→
onnx2
fluid
/examples/gen_some_samples.py
浏览文件 @
2228423e
...
@@ -6,7 +6,7 @@ Created on Fri Mar 22 11:19:45 2019
...
@@ -6,7 +6,7 @@ Created on Fri Mar 22 11:19:45 2019
@author: Macrobull
@author: Macrobull
Not all ops in this file are supported by both Pytorch and ONNX
Not all ops in this file are supported by both Pytorch and ONNX
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle
This only demostrates the conversion/validation workflow from Pytorch to ONNX to Paddle
fluid
"""
"""
...
@@ -16,12 +16,10 @@ import torch
...
@@ -16,12 +16,10 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
onnx2paddle.torch_export_helper
import
export_onnx_with_validation
from
onnx2fluid.torch_export_helper
import
export_onnx_with_validation
idx
=
0
idx
=
0
######### example: RNN ########
######### example: RNN ########
#
#
#class Model(nn.Module):
#class Model(nn.Module):
...
@@ -44,7 +42,6 @@ idx = 0
...
@@ -44,7 +42,6 @@ idx = 0
# ['x'], ['y'],
# ['x'], ['y'],
# verbose=True, training=False)
# verbose=True, training=False)
######### example: random ########
######### example: random ########
#
#
#class Model(nn.Module):
#class Model(nn.Module):
...
@@ -66,9 +63,9 @@ idx = 0
...
@@ -66,9 +63,9 @@ idx = 0
# ['x'], ['y'],
# ['x'], ['y'],
# verbose=True, training=False)
# verbose=True, training=False)
######## example: fc ########
######## example: fc ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -85,13 +82,12 @@ xb = torch.rand((2, 3))
...
@@ -85,13 +82,12 @@ xb = torch.rand((2, 3))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
xb
,
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'x'
],
[
'y'
],
model
,
(
xb
,
),
't'
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
verbose
=
True
,
training
=
False
)
######## example: compare ########
######## example: compare ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -110,12 +106,15 @@ xb1 = torch.rand((2, 3))
...
@@ -110,12 +106,15 @@ xb1 = torch.rand((2, 3))
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
ya
,
yb
,
yc
=
model
(
xb0
,
xb1
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
xb0
,
xb1
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'x0'
,
'x1'
],
[
'ya'
,
'yb'
,
'yc'
],
model
,
(
xb0
,
xb1
),
verbose
=
True
,
training
=
False
)
't'
+
str
(
idx
),
[
'x0'
,
'x1'
],
[
'ya'
,
'yb'
,
'yc'
],
verbose
=
True
,
training
=
False
)
######## example: affine_grid ########
######## example: affine_grid ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -130,13 +129,15 @@ theta = torch.rand((2, 2, 3))
...
@@ -130,13 +129,15 @@ theta = torch.rand((2, 2, 3))
grid
=
model
(
theta
)
grid
=
model
(
theta
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
theta
,
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'theta'
],
[
'grid'
],
model
,
(
theta
,
),
verbose
=
True
,
training
=
False
)
't'
+
str
(
idx
),
[
'theta'
],
[
'grid'
],
verbose
=
True
,
training
=
False
)
######## example: conv2d_transpose ########
######## example: conv2d_transpose ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -155,12 +156,12 @@ xb = torch.rand((2, 3, 4, 5))
...
@@ -155,12 +156,12 @@ xb = torch.rand((2, 3, 4, 5))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
xb
,
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'x'
],
[
'y'
],
model
,
(
xb
,
),
't'
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
verbose
=
True
,
training
=
False
)
######## example: conv2d ########
######## example: conv2d ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -181,10 +182,8 @@ xb = torch.rand((2, 3, 4, 5))
...
@@ -181,10 +182,8 @@ xb = torch.rand((2, 3, 4, 5))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
xb
,
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'x'
],
[
'y'
],
model
,
(
xb
,
),
't'
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
verbose
=
True
,
training
=
False
)
######### example: conv1d ########
######### example: conv1d ########
#
#
...
@@ -210,6 +209,7 @@ export_onnx_with_validation(model, (xb, ), 't' + str(idx),
...
@@ -210,6 +209,7 @@ export_onnx_with_validation(model, (xb, ), 't' + str(idx),
######## example: empty ########
######## example: empty ########
class
Model
(
nn
.
Module
):
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
...
@@ -223,6 +223,5 @@ xb = torch.rand((2, 3))
...
@@ -223,6 +223,5 @@ xb = torch.rand((2, 3))
yp
=
model
(
xb
)
yp
=
model
(
xb
)
idx
+=
1
idx
+=
1
print
(
'index: '
,
idx
)
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
(
xb
,
),
't'
+
str
(
idx
),
export_onnx_with_validation
(
[
'y'
],
[
'y'
],
model
,
(
xb
,
),
't'
+
str
(
idx
),
[
'y'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
verbose
=
True
,
training
=
False
)
onnx2
paddle
/examples/onnx_model_zoo.sh
→
onnx2
fluid
/examples/onnx_model_zoo.sh
浏览文件 @
2228423e
#! /usr/bin/env sh
#! /usr/bin/env sh
get_url
=
"
proxychains4
aria2c -c -s8 -x8"
get_url
=
"aria2c -c -s8 -x8"
base_url
=
"https://s3.amazonaws.com/download.onnx/models/opset_9/"
base_url
=
"https://s3.amazonaws.com/download.onnx/models/opset_9/"
flags
=
"-
d
e -o /tmp/export/"
flags
=
"-e -o /tmp/export/"
bvlc_alexnet
()
bvlc_alexnet
()
{
{
...
@@ -18,13 +18,13 @@ bvlc_alexnet()
...
@@ -18,13 +18,13 @@ bvlc_alexnet()
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$npz
done
done
for
pb_dir
in
$bn_tar
/
*
/
for
pb_dir
in
$bn_tar
/
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
echo
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -42,7 +42,7 @@ bvlc_googlenet()
...
@@ -42,7 +42,7 @@ bvlc_googlenet()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -60,7 +60,7 @@ bvlc_reference_caffenet()
...
@@ -60,7 +60,7 @@ bvlc_reference_caffenet()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -77,8 +77,8 @@ bvlc_reference_rcnn_ilsvrc13()
...
@@ -77,8 +77,8 @@ bvlc_reference_rcnn_ilsvrc13()
for
pb_dir
in
$bn_tar
/
*
/
for
pb_dir
in
$bn_tar
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"
softmaxout
_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"
fc_rcnn
_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -96,14 +96,14 @@ inception_v1()
...
@@ -96,14 +96,14 @@ inception_v1()
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$npz
done
done
for
pb_dir
in
$bn_tar
/
*
/
for
pb_dir
in
$bn_tar
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -121,14 +121,14 @@ inception_v2()
...
@@ -121,14 +121,14 @@ inception_v2()
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python convert_data_npz_0.py
"
$npz
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$npz
done
done
for
pb_dir
in
$bn_tar
/
*
/
for
pb_dir
in
$bn_tar
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -146,14 +146,14 @@ resnet50()
...
@@ -146,14 +146,14 @@ resnet50()
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz_0.py
"
$npz
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python convert_data_npz_0.py
"
$npz
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$npz
done
done
for
pb_dir
in
$bn_tar
/
*
/
for
pb_dir
in
$bn_tar
/
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -171,7 +171,7 @@ shufflenet()
...
@@ -171,7 +171,7 @@ shufflenet()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmaxout_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -189,7 +189,7 @@ squeezenet()
...
@@ -189,7 +189,7 @@ squeezenet()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"softmaxout_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"softmaxout_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -207,7 +207,7 @@ tiny_yolov2()
...
@@ -207,7 +207,7 @@ tiny_yolov2()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"image"
"grid"
python convert_data_pb_0.py
"
$pb_dir
"
"image"
"grid"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
-x
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
-x
done
done
}
}
...
@@ -225,7 +225,7 @@ vgg19()
...
@@ -225,7 +225,7 @@ vgg19()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python convert_data_pb_0.py
"
$pb_dir
"
"data_0"
"prob_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
...
@@ -243,20 +243,20 @@ zfnet512()
...
@@ -243,20 +243,20 @@ zfnet512()
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmax_1"
python convert_data_pb_0.py
"
$pb_dir
"
"gpu_0/data_0"
"gpu_0/softmax_1"
python
-m
onnx2
paddle
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
python
-m
onnx2
fluid
$flags
"
$fn_model
"
-t
$(
dirname
"
$pb_dir
/x"
)
.npz
done
done
}
}
bvlc_alexnet
# data error
bvlc_alexnet
bvlc_googlenet
# desc error
bvlc_googlenet
bvlc_reference_caffenet
bvlc_reference_caffenet
bvlc_reference_rcnn_ilsvrc13
bvlc_reference_rcnn_ilsvrc13
inception_v1
###
inception_v1
inception_v2
###
inception_v2
resnet50
# data error
resnet50
shufflenet
###
shufflenet
squeezenet
squeezenet
tiny_yolov2
# not supported
tiny_yolov2
# not supported
vgg19
vgg19
zfnet512
# data error
zfnet512
onnx2
paddle/onnx2paddle
/__init__.py
→
onnx2
fluid/onnx2fluid
/__init__.py
浏览文件 @
2228423e
文件已移动
onnx2
paddle/onnx2paddle
/__main__.py
→
onnx2
fluid/onnx2fluid
/__main__.py
浏览文件 @
2228423e
...
@@ -5,7 +5,7 @@
...
@@ -5,7 +5,7 @@
#
#
################################################################################
################################################################################
"""
"""
本文件允许模块包以python -m onnx2
paddle
方式直接执行。
本文件允许模块包以python -m onnx2
fluid
方式直接执行。
Authors: Macrobull
Authors: Macrobull
Date: 2019/02/22 10:25:46
Date: 2019/02/22 10:25:46
...
@@ -21,43 +21,67 @@ import argparse
...
@@ -21,43 +21,67 @@ import argparse
import
logging
import
logging
import
sys
import
sys
parser
=
argparse
.
ArgumentParser
(
parser
=
argparse
.
ArgumentParser
(
description
=
'onnx2paddle
'
,
description
=
'onnx2fluid
'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
)
parser
.
add_argument
(
'model'
,
nargs
=
1
,
parser
.
add_argument
(
'model'
,
nargs
=
1
,
help
=
'path to model.onnx'
,
help
=
'path to model.onnx'
,
)
)
parser
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
parser
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
'enable debug logging and checking'
,
help
=
'enable debug logging and checking'
,
)
)
parser
.
add_argument
(
'--output-dir'
,
'-o'
,
type
=
str
,
default
=
''
,
parser
.
add_argument
(
'--output_dir'
,
'-o'
,
type
=
str
,
default
=
''
,
help
=
'output directory'
,
help
=
'output directory'
,
)
)
parser
.
add_argument
(
'--test_data'
,
'-t'
,
type
=
str
,
default
=
''
,
parser
.
add_argument
(
'--test_data'
,
'-t'
,
type
=
str
,
default
=
''
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
)
)
parser
.
add_argument
(
'--embed_params'
,
'-e'
,
action
=
'store_true'
,
parser
.
add_argument
(
help
=
'try to embed parameters for trainable Paddle layers'
,
'--embed_params'
,
)
'-e'
,
parser
.
add_argument
(
'--pedantic'
,
action
=
'store_true'
,
default
=
True
,
action
=
'store_true'
,
help
=
'try to embed parameters for trainable Paddle fluid layers'
,
)
parser
.
add_argument
(
'--pedantic'
,
action
=
'store_true'
,
default
=
True
,
help
=
'accept and convert only standard ONNX opset'
,
help
=
'accept and convert only standard ONNX opset'
,
)
)
parser
.
add_argument
(
'--no-pedantic'
,
'-x'
,
action
=
'store_false'
,
parser
.
add_argument
(
'--no-pedantic'
,
'-x'
,
action
=
'store_false'
,
dest
=
'pedantic'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
)
)
parser
.
add_argument
(
'--precision'
,
'-p'
,
type
=
int
,
default
=
4
,
parser
.
add_argument
(
'--precision'
,
'-p'
,
type
=
int
,
default
=
4
,
help
=
'assertion decimal for validation'
,
help
=
'assertion decimal for validation'
,
)
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
try
:
try
:
from
.
import
cmdline
from
.
import
cmdline
except
ImportError
:
except
ImportError
:
...
@@ -66,5 +90,4 @@ except ImportError:
...
@@ -66,5 +90,4 @@ except ImportError:
# imports
# imports
main
=
cmdline
.
main
main
=
cmdline
.
main
sys
.
exit
(
main
(
**
args
.
__dict__
))
sys
.
exit
(
main
(
**
args
.
__dict__
))
onnx2
paddle/onnx2paddle
/cmdline.py
→
onnx2
fluid/onnx2fluid
/cmdline.py
浏览文件 @
2228423e
...
@@ -21,7 +21,6 @@ import logging
...
@@ -21,7 +21,6 @@ import logging
import
shutil
import
shutil
import
zipfile
import
zipfile
__all__
=
[
__all__
=
[
'main'
,
'main'
,
]
]
...
@@ -42,7 +41,7 @@ def main(**kwargs):
...
@@ -42,7 +41,7 @@ def main(**kwargs):
# imports
# imports
convert
=
conversion
.
convert
convert
=
conversion
.
convert
logger
=
logging
.
getLogger
(
'onnx2
paddle
'
)
logger
=
logging
.
getLogger
(
'onnx2
fluid
'
)
debug
=
kwargs
.
get
(
'debug'
,
False
)
debug
=
kwargs
.
get
(
'debug'
,
False
)
# prepare arguments
# prepare arguments
...
@@ -58,7 +57,9 @@ def main(**kwargs):
...
@@ -58,7 +57,9 @@ def main(**kwargs):
onnx_opset_pedantic
=
kwargs
.
get
(
'pedantic'
,
True
)
onnx_opset_pedantic
=
kwargs
.
get
(
'pedantic'
,
True
)
# convert
# convert
convert
(
filename
,
save_dir
,
convert
(
filename
,
save_dir
,
model_basename
=
model_basename
,
model_basename
=
model_basename
,
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
...
@@ -80,16 +81,18 @@ def main(**kwargs):
...
@@ -80,16 +81,18 @@ def main(**kwargs):
# in fact fluid can not fully clear the context
# in fact fluid can not fully clear the context
# continuous validation may be inaccurate
# continuous validation may be inaccurate
precision
=
10
**
-
kwargs
.
get
(
'precision'
,
4
)
precision
=
10
**
-
kwargs
.
get
(
'precision'
,
4
)
logger
.
info
(
'starting validation on desc ...'
)
logger
.
info
(
'starting validation on desc ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
),
golden_data_filename
,
golden_data_filename
,
precision
=
precision
,
precision
=
precision
,
)
)
logger
.
info
(
'starting validation on code ...'
)
logger
.
info
(
'starting validation on code ...'
)
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
passed
&=
validate
(
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
),
golden_data_filename
,
golden_data_filename
,
model_func_name
=
model_func_name
,
model_func_name
=
model_func_name
,
precision
=
precision
,
precision
=
precision
,
...
@@ -112,20 +115,22 @@ def main(**kwargs):
...
@@ -112,20 +115,22 @@ def main(**kwargs):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
level
=
logging
.
DEBUG
,
level
=
logging
.
DEBUG
,
)
)
# main(model=['../examples/t5.onnx'],
# main(model=['../examples/t5.onnx'],
# output_dir='/tmp/export/',
# output_dir='/tmp/export/',
# embed_params=False,
# embed_params=False,
# pedantic=False,
# pedantic=False,
# test_data='../examples/t5.npz',
# test_data='../examples/t5.npz',
# debug=True)
# debug=True)
main
(
model
=
[
'../examples/shufflenet/model.onnx'
],
main
(
model
=
[
'../examples/inception_v2/model.onnx'
],
output_dir
=
'/tmp/export/'
,
output_dir
=
'/tmp/export/'
,
embed_params
=
True
,
embed_params
=
True
,
pedantic
=
False
,
pedantic
=
False
,
test_data
=
'../examples/shufflenet/test_data_set_0
.npz'
,
test_data
=
'../examples/inception_v2/test_data_set_2
.npz'
,
debug
=
True
)
debug
=
True
)
onnx2
paddle/onnx2paddle
/conversion.py
→
onnx2
fluid/onnx2fluid
/conversion.py
浏览文件 @
2228423e
...
@@ -12,19 +12,21 @@ from __future__ import division
...
@@ -12,19 +12,21 @@ from __future__ import division
import
logging
import
logging
import
shutil
import
shutil
__all__
=
[
__all__
=
[
'convert'
,
'convert'
,
]
]
def
convert
(
onnx_model_filename
,
save_dir
,
def
convert
(
onnx_model_filename
,
model_basename
=
'model.py'
,
model_func_name
=
'inference'
,
save_dir
,
model_basename
=
'model.py'
,
model_func_name
=
'inference'
,
embed_params
=
False
,
embed_params
=
False
,
onnx_opset_version
=
9
,
onnx_opset_pedantic
=
True
,
onnx_opset_version
=
9
,
onnx_opset_pedantic
=
True
,
debug
=
False
):
debug
=
False
):
"""
"""
convert an ONNX model to Paddle Python code and desc pb
convert an ONNX model to Paddle
fluid
Python code and desc pb
"""
"""
import
onnx
import
onnx
...
@@ -62,7 +64,8 @@ def convert(onnx_model_filename, save_dir,
...
@@ -62,7 +64,8 @@ def convert(onnx_model_filename, save_dir,
if
onnx_opset_pedantic
:
# WORKAROUND: RuntimeError: No Adapter For OP
if
onnx_opset_pedantic
:
# WORKAROUND: RuntimeError: No Adapter For OP
onnx_model
=
convert_version
(
onnx_model
,
onnx_opset_version
)
onnx_model
=
convert_version
(
onnx_model
,
onnx_opset_version
)
else
:
# TODO: add new argument for this option
else
:
# TODO: add new argument for this option
logger
.
warning
(
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
logger
.
warning
(
'opset conversion skipped for onnx_opset_pedantic is OFF'
)
onnx_model
=
polish_model
(
onnx_model
)
onnx_model
=
polish_model
(
onnx_model
)
except
ValidationError
as
e
:
except
ValidationError
as
e
:
if
onnx_opset_pedantic
:
if
onnx_opset_pedantic
:
...
@@ -90,13 +93,13 @@ def convert(onnx_model_filename, save_dir,
...
@@ -90,13 +93,13 @@ def convert(onnx_model_filename, save_dir,
onnx
.
save
(
model
,
debug_model_filename
+
'.optimized_and_inffered.onnx'
)
onnx
.
save
(
model
,
debug_model_filename
+
'.optimized_and_inffered.onnx'
)
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# onnx.save(model, '/tmp/export/optimized_and_inffered.onnx')
# I/O instances
# I/O instances
onnx_graph
=
onnx_model
.
graph
onnx_graph
=
onnx_model
.
graph
paddle
_program
=
Program
()
fluid
_program
=
Program
()
paddle
_writer
=
Writer
()
fluid
_writer
=
Writer
()
# model components
# model components
# graph_name = onnx_graph.name
# graph_name = onnx_graph.name
graph_inputs
=
[
value
.
name
for
value
in
onnx_graph
.
input
]
graph_inputs
=
[
value
.
name
for
value
in
onnx_graph
.
input
]
graph_outputs
=
[
value
.
name
for
value
in
onnx_graph
.
output
]
graph_outputs
=
[
value
.
name
for
value
in
onnx_graph
.
output
]
graph_params
=
[]
graph_params
=
[]
...
@@ -107,29 +110,37 @@ def convert(onnx_model_filename, save_dir,
...
@@ -107,29 +110,37 @@ def convert(onnx_model_filename, save_dir,
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
value_info
=
graph_value_infos
[
name
]
value_info
=
graph_value_infos
[
name
]
value_info
[
'embeded_as'
]
=
[]
value_info
[
'embeded_as'
]
=
[]
value_info
[
'get_weight'
]
=
lambda
:
weight
.
tolist
()
# lazy getter
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
weight
)
# lazy getter
logger
.
info
(
'conversion started'
)
logger
.
info
(
'conversion started'
)
# op set conversion
# op set conversion
# topo = 'backward' if embed_params else 'forward'
# topo = 'backward' if embed_params else 'forward'
topo
=
'forward'
topo
=
'forward'
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
topo
=
topo
):
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
topo
=
topo
):
logger
.
debug
(
'translating op %s %s::%s ...'
,
name
,
domain
,
op_type
)
logger
.
debug
(
'translating op %s %s::%s ...'
,
name
,
domain
,
op_type
)
if
domain
==
DEFAULT_OP_DOMAIN
:
if
domain
==
DEFAULT_OP_DOMAIN
:
domain
=
''
domain
=
''
try
:
try
:
paddle_writer
.
emit_op
(
paddle_program
,
name
,
domain
,
op_type
,
fluid_writer
.
emit_op
(
inputs
,
outputs
,
attrs
,
fluid_program
,
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
graph_value_infos
,
graph_value_infos
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
)
)
except
BaseException
as
e
:
except
BaseException
as
e
:
logger
.
fatal
(
'conversion failed for:
\n\t
%s -> %s::%s -> %s'
,
logger
.
fatal
(
'conversion failed for:
\n\t
%s -> %s::%s -> %s'
,
inputs
,
inputs
,
domain
,
op_type
,
outputs
)
domain
,
op_type
,
outputs
)
raise
e
raise
e
op_codes
=
paddle
_program
.
codes
op_codes
=
fluid
_program
.
codes
paddle
_program
.
codes
=
[]
fluid
_program
.
codes
=
[]
logger
.
info
(
'%d ops converted'
,
len
(
paddle
_program
.
op_descs
))
logger
.
info
(
'%d ops converted'
,
len
(
fluid
_program
.
op_descs
))
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
...
@@ -138,18 +149,24 @@ def convert(onnx_model_filename, save_dir,
...
@@ -138,18 +149,24 @@ def convert(onnx_model_filename, save_dir,
var_names
=
value_info
.
get
(
'embeded_as'
,
[])
var_names
=
value_info
.
get
(
'embeded_as'
,
[])
if
var_names
:
if
var_names
:
if
len
(
var_names
)
>
1
:
if
len
(
var_names
)
>
1
:
logger
.
info
(
'weight %s is shared between ops, more disk space will be consumed'
,
name
)
logger
.
info
(
logger
.
debug
(
'saving weight %s with size of %d, in %d bytes, as %s ...'
,
'weight %s is shared between ops, more disk space will be consumed'
,
name
)
logger
.
debug
(
'saving weight %s with size of %d, in %d bytes, as %s ...'
,
name
,
weight
.
size
,
weight
.
nbytes
,
var_names
)
name
,
weight
.
size
,
weight
.
nbytes
,
var_names
)
for
var_name
in
var_names
:
# multiple references
for
var_name
in
var_names
:
# multiple references
paddle_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
else
:
else
:
logger
.
debug
(
'saving weight %s with size of %d, in %d bytes, to %s ...'
,
logger
.
debug
(
'saving weight %s with size of %d, in %d bytes, to %s ...'
,
name
,
weight
.
size
,
weight
.
nbytes
,
make_var_name
(
name
))
name
,
weight
.
size
,
weight
.
nbytes
,
make_var_name
(
name
))
paddle_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
make_var_name
(
name
)))
fluid_writer
.
write_weight
(
paddle_writer
.
emit_param
(
paddle_program
,
name
,
value_info
)
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
make_var_name
(
name
)))
param_codes
=
paddle_program
.
codes
fluid_writer
.
emit_param
(
fluid_program
,
name
,
value_info
)
paddle_program
.
codes
=
[]
param_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
logger
.
info
(
'%d weights converted'
,
len
(
graph_params
))
logger
.
info
(
'%d weights converted'
,
len
(
graph_params
))
# input writer
# input writer
...
@@ -159,9 +176,11 @@ def convert(onnx_model_filename, save_dir,
...
@@ -159,9 +176,11 @@ def convert(onnx_model_filename, save_dir,
value_info
=
graph_value_infos
[
name
]
value_info
=
graph_value_infos
[
name
]
assert
value_info
[
'external'
]
assert
value_info
[
'external'
]
external_inputs
.
append
(
name
)
external_inputs
.
append
(
name
)
paddle_writer
.
emit_inputs
(
paddle_program
,
external_inputs
,
graph_value_infos
,
remove_batch
=
False
)
# TODO:
fluid_writer
.
emit_inputs
(
input_codes
=
paddle_program
.
codes
fluid_program
,
external_inputs
,
graph_value_infos
,
paddle_program
.
codes
=
[]
remove_batch
=
False
)
# TODO:
input_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
logger
.
info
(
'%d inputs converted'
,
len
(
external_inputs
))
logger
.
info
(
'%d inputs converted'
,
len
(
external_inputs
))
# output writer
# output writer
...
@@ -171,49 +190,93 @@ def convert(onnx_model_filename, save_dir,
...
@@ -171,49 +190,93 @@ def convert(onnx_model_filename, save_dir,
value_info
=
graph_value_infos
[
name
]
value_info
=
graph_value_infos
[
name
]
assert
value_info
[
'external'
]
assert
value_info
[
'external'
]
external_outputs
.
append
(
name
)
external_outputs
.
append
(
name
)
paddle_writer
.
emit_outputs
(
paddle
_program
,
external_outputs
)
fluid_writer
.
emit_outputs
(
fluid
_program
,
external_outputs
)
output_codes
=
[
''
]
+
paddle_program
.
codes
# add an empty line
output_codes
=
[
''
]
+
fluid_program
.
codes
# add an empty line
paddle
_program
.
codes
=
[]
fluid
_program
.
codes
=
[]
logger
.
info
(
'%d outputs converted'
,
len
(
external_outputs
))
logger
.
info
(
'%d outputs converted'
,
len
(
external_outputs
))
# code generation
# code generation
header_codes
=
fluid_writer
.
header_code
(
model_func_name
,
'From: {}'
.
format
(
onnx_model_filename
))
code_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
)
code_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
)
paddle_writer
.
write_code_file
(
code_filename
,
paddle_writer
.
header_code
(
model_func_name
),
fluid_writer
.
write_code_file
(
code_filename
,
header_codes
,
input_codes
,
input_codes
,
param_codes
,
op_codes
,
output_codes
)
param_codes
,
op_codes
,
output_codes
)
logger
.
info
(
'code saved to %s, factory function: %s'
,
code_filename
,
model_func_name
)
logger
.
info
(
'code saved to %s, factory function: %s'
,
code_filename
,
model_func_name
)
# desc generation
# desc generation
desc_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
)
desc_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
'__model__'
)
paddle_writer
.
write_desc_file
(
desc_filename
,
fluid_writer
.
write_desc_file
(
op_descs
=
paddle_program
.
op_descs
,
desc_filename
,
var_descs
=
paddle_program
.
var_descs
,
op_descs
=
fluid_program
.
op_descs
,
var_descs
=
fluid_program
.
var_descs
,
)
)
logger
.
info
(
'program saved to %s'
,
desc_filename
)
logger
.
info
(
'program saved to %s'
,
desc_filename
)
logger
.
info
(
'conversion finished'
)
logger
.
info
(
'conversion finished'
)
# globals().update(locals())
# globals().update(locals())
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
import
argparse
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
level
=
logging
.
DEBUG
,
parser
=
argparse
.
ArgumentParser
(
description
=
'onnx2fluid.convert'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
)
parser
.
add_argument
(
'model'
,
nargs
=
1
,
help
=
'path to model.onnx'
,
)
parser
.
add_argument
(
'--debug'
,
'-d'
,
action
=
'store_true'
,
help
=
'enable debug logging and checking'
,
)
)
parser
.
add_argument
(
'--output_dir'
,
'-o'
,
type
=
str
,
default
=
''
,
help
=
'output directory'
,
)
parser
.
add_argument
(
'--embed_params'
,
'-e'
,
action
=
'store_true'
,
help
=
'try to embed parameters for trainable Paddle fluid layers'
,
)
parser
.
add_argument
(
'--pedantic'
,
action
=
'store_true'
,
default
=
True
,
help
=
'accept and convert only standard ONNX opset'
,
)
parser
.
add_argument
(
'--no-pedantic'
,
'-x'
,
action
=
'store_false'
,
dest
=
'pedantic'
,
help
=
'process non-standard ONNX ops, this may lead to fails'
,
)
args
=
parser
.
parse_args
()
model_list
=
[
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
'../examples/t1.onnx'
,
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
'../examples/t2.onnx'
,
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
'../examples/t3.onnx'
,
'../examples/t4.onnx'
,
debug
=
args
.
debug
'../examples/t5.onnx'
,
model_filename
=
args
.
model
[
0
]
'../examples/t6.onnx'
,
save_dir
=
args
.
output_dir
# '../examples/t7.onnx',
embed_params
=
args
.
embed_params
# '../examples/t8.onnx',
pedantic
=
args
.
pedantic
]
convert
(
for
model
in
model_list
:
model_filename
,
pathname
,
_
=
shutil
.
os
.
path
.
splitext
(
model
)
save_dir
,
convert
(
model
,
pathname
,
embed_params
=
embed_params
,
onnx_opset_pedantic
=
False
,
debug
=
True
)
onnx_opset_pedantic
=
pedantic
,
convert
(
model
,
pathname
+
'.embeded'
,
debug
=
debug
)
embed_params
=
True
,
onnx_opset_pedantic
=
False
,
debug
=
True
)
onnx2fluid/onnx2fluid/framework_pb2.py
0 → 100644
浏览文件 @
2228423e
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import
sys
_b
=
sys
.
version_info
[
0
]
<
3
and
(
lambda
x
:
x
)
or
(
lambda
x
:
x
.
encode
(
'latin1'
))
from
google.protobuf.internal
import
enum_type_wrapper
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
message
as
_message
from
google.protobuf
import
reflection
as
_reflection
from
google.protobuf
import
symbol_database
as
_symbol_database
from
google.protobuf
import
descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor
.
FileDescriptor
(
name
=
'framework.proto'
,
package
=
'paddle.framework.proto'
,
syntax
=
'proto2'
,
serialized_pb
=
_b
(
'
\n\x0f\x66
ramework.proto
\x12\x16
paddle.framework.proto
\"\x1d\n\x07
Version
\x12\x12\n\x07
version
\x18\x01
\x01
(
\x03
:
\x01\x30\"\xec\x03\n\x06
OpDesc
\x12\x0c\n\x04
type
\x18\x03
\x02
(
\t\x12\x32\n\x06
inputs
\x18\x01
\x03
(
\x0b\x32\"
.paddle.framework.proto.OpDesc.Var
\x12\x33\n\x07
outputs
\x18\x02
\x03
(
\x0b\x32\"
.paddle.framework.proto.OpDesc.Var
\x12\x32\n\x05\x61
ttrs
\x18\x04
\x03
(
\x0b\x32
#.paddle.framework.proto.OpDesc.Attr
\x12\x18\n\t
is_target
\x18\x05
\x01
(
\x08
:
\x05\x66\x61
lse
\x1a\xef\x01\n\x04\x41
ttr
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
.
\n\x04
type
\x18\x02
\x02
(
\x0e\x32
.paddle.framework.proto.AttrType
\x12\t\n\x01
i
\x18\x03
\x01
(
\x05\x12\t\n\x01\x66\x18\x04
\x01
(
\x02\x12\t\n\x01
s
\x18\x05
\x01
(
\t\x12\x0c\n\x04
ints
\x18\x06
\x03
(
\x05\x12\x0e\n\x06\x66
loats
\x18\x07
\x03
(
\x02\x12\x0f\n\x07
strings
\x18\x08
\x03
(
\t\x12\t\n\x01\x62\x18\n
\x01
(
\x08\x12\r\n\x05\x62
ools
\x18\x0b
\x03
(
\x08\x12\x11\n\t
block_idx
\x18\x0c
\x01
(
\x05\x12\t\n\x01
l
\x18\r
\x01
(
\x03\x12\x12\n\n
blocks_idx
\x18\x0e
\x03
(
\x05\x12\r\n\x05
longs
\x18\x0f
\x03
(
\x03\x1a
+
\n\x03
Var
\x12\x11\n\t
parameter
\x18\x01
\x02
(
\t\x12\x11\n\t
arguments
\x18\x02
\x03
(
\t\"\xb3\x03\n\x07
OpProto
\x12\x0c\n\x04
type
\x18\x01
\x02
(
\t\x12\x33\n\x06
inputs
\x18\x02
\x03
(
\x0b\x32
#.paddle.framework.proto.OpProto.Var
\x12\x34\n\x07
outputs
\x18\x03
\x03
(
\x0b\x32
#.paddle.framework.proto.OpProto.Var
\x12\x33\n\x05\x61
ttrs
\x18\x04
\x03
(
\x0b\x32
$.paddle.framework.proto.OpProto.Attr
\x12\x0f\n\x07\x63
omment
\x18\x05
\x02
(
\t\x1a
x
\n\x03
Var
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12\x0f\n\x07\x63
omment
\x18\x02
\x02
(
\t\x12\x19\n\n
duplicable
\x18\x03
\x01
(
\x08
:
\x05\x66\x61
lse
\x12\x1b\n\x0c
intermediate
\x18\x04
\x01
(
\x08
:
\x05\x66\x61
lse
\x12\x1a\n\x0b\x64
ispensable
\x18\x05
\x01
(
\x08
:
\x05\x66\x61
lse
\x1a
o
\n\x04\x41
ttr
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
.
\n\x04
type
\x18\x02
\x02
(
\x0e\x32
.paddle.framework.proto.AttrType
\x12\x0f\n\x07\x63
omment
\x18\x03
\x02
(
\t\x12\x18\n\t
generated
\x18\x04
\x01
(
\x08
:
\x05\x66\x61
lse
\"\xda\x08\n\x07
VarType
\x12\x32\n\x04
type
\x18\x01
\x02
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\x12\x41\n\r
selected_rows
\x18\x02
\x01
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x41\n\n
lod_tensor
\x18\x03
\x01
(
\x0b\x32
-.paddle.framework.proto.VarType.LoDTensorDesc
\x12
H
\n\x0c
tensor_array
\x18\x04
\x01
(
\x0b\x32\x32
.paddle.framework.proto.VarType.LoDTensorArrayDesc
\x12
:
\n\x06
reader
\x18\x05
\x01
(
\x0b\x32
*.paddle.framework.proto.VarType.ReaderDesc
\x12\x34\n\x05
tuple
\x18\x07
\x01
(
\x0b\x32
%.paddle.framework.proto.VarType.Tuple
\x1a
S
\n\n
TensorDesc
\x12\x37\n\t
data_type
\x18\x01
\x02
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\x12\x0c\n\x04\x64
ims
\x18\x02
\x03
(
\x03\x1a\x61\n\r
LoDTensorDesc
\x12
:
\n\x06
tensor
\x18\x01
\x02
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x14\n\t
lod_level
\x18\x02
\x01
(
\x05
:
\x01\x30\x1a\x66\n\x12
LoDTensorArrayDesc
\x12
:
\n\x06
tensor
\x18\x01
\x02
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x14\n\t
lod_level
\x18\x02
\x01
(
\x05
:
\x01\x30\x1a
O
\n\n
ReaderDesc
\x12\x41\n\n
lod_tensor
\x18\x01
\x03
(
\x0b\x32
-.paddle.framework.proto.VarType.LoDTensorDesc
\x1a\x43\n\x05
Tuple
\x12
:
\n\x0c\x65
lement_type
\x18\x01
\x03
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\"\xa2\x02\n\x04
Type
\x12\x08\n\x04\x42
OOL
\x10\x00\x12\t\n\x05
INT16
\x10\x01\x12\t\n\x05
INT32
\x10\x02\x12\t\n\x05
INT64
\x10\x03\x12\x08\n\x04\x46
P16
\x10\x04\x12\x08\n\x04\x46
P32
\x10\x05\x12\x08\n\x04\x46
P64
\x10\x06\x12\n\n\x06
SIZE_T
\x10\x13\x12\t\n\x05
UINT8
\x10\x14\x12\x08\n\x04
INT8
\x10\x15\x12\x0e\n\n
LOD_TENSOR
\x10\x07\x12\x11\n\r
SELECTED_ROWS
\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44
_MINIBATCH
\x10\t\x12\x0e\n\n
FETCH_LIST
\x10\n\x12\x0f\n\x0b
STEP_SCOPES
\x10\x0b\x12\x12\n\x0e
LOD_RANK_TABLE
\x10\x0c\x12\x14\n\x10
LOD_TENSOR_ARRAY
\x10\r\x12\x0e\n\n
PLACE_LIST
\x10\x0e\x12\n\n\x06
READER
\x10\x0f\x12\x07\n\x03
RAW
\x10\x11\x12\t\n\x05
TUPLE
\x10\x12\"
b
\n\x07
VarDesc
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
-
\n\x04
type
\x18\x02
\x02
(
\x0b\x32\x1f
.paddle.framework.proto.VarType
\x12\x1a\n\x0b
persistable
\x18\x03
\x01
(
\x08
:
\x05\x66\x61
lse
\"\xa7\x01\n\t
BlockDesc
\x12\x0b\n\x03
idx
\x18\x01
\x02
(
\x05\x12\x12\n\n
parent_idx
\x18\x02
\x02
(
\x05\x12
-
\n\x04
vars
\x18\x03
\x03
(
\x0b\x32\x1f
.paddle.framework.proto.VarDesc
\x12
+
\n\x03
ops
\x18\x04
\x03
(
\x0b\x32\x1e
.paddle.framework.proto.OpDesc
\x12\x1d\n\x11\x66
orward_block_idx
\x18\x05
\x01
(
\x05
:
\x02
-1
\"
r
\n\x0b
ProgramDesc
\x12\x31\n\x06\x62
locks
\x18\x01
\x03
(
\x0b\x32
!.paddle.framework.proto.BlockDesc
\x12\x30\n\x07
version
\x18\x02
\x01
(
\x0b\x32\x1f
.paddle.framework.proto.Version*
\x94\x01\n\x08\x41
ttrType
\x12\x07\n\x03
INT
\x10\x00\x12\t\n\x05\x46
LOAT
\x10\x01\x12\n\n\x06
STRING
\x10\x02\x12\x08\n\x04
INTS
\x10\x03\x12\n\n\x06\x46
LOATS
\x10\x04\x12\x0b\n\x07
STRINGS
\x10\x05\x12\x0b\n\x07\x42
OOLEAN
\x10\x06\x12\x0c\n\x08\x42
OOLEANS
\x10\x07\x12\t\n\x05\x42
LOCK
\x10\x08\x12\x08\n\x04
LONG
\x10\t\x12\n\n\x06\x42
LOCKS
\x10\n\x12\t\n\x05
LONGS
\x10\x0b\x42\x02
H
\x03
'
))
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_ATTRTYPE
=
_descriptor
.
EnumDescriptor
(
name
=
'AttrType'
,
full_name
=
'paddle.framework.proto.AttrType'
,
filename
=
None
,
file
=
DESCRIPTOR
,
values
=
[
_descriptor
.
EnumValueDescriptor
(
name
=
'INT'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOAT'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STRING'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INTS'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOATS'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STRINGS'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEAN'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEANS'
,
index
=
7
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCK'
,
index
=
8
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONG'
,
index
=
9
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCKS'
,
index
=
10
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONGS'
,
index
=
11
,
number
=
11
,
options
=
None
,
type
=
None
),
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
2511
,
serialized_end
=
2659
,
)
_sym_db
.
RegisterEnumDescriptor
(
_ATTRTYPE
)
AttrType
=
enum_type_wrapper
.
EnumTypeWrapper
(
_ATTRTYPE
)
INT
=
0
FLOAT
=
1
STRING
=
2
INTS
=
3
FLOATS
=
4
STRINGS
=
5
BOOLEAN
=
6
BOOLEANS
=
7
BLOCK
=
8
LONG
=
9
BLOCKS
=
10
LONGS
=
11
_VARTYPE_TYPE
=
_descriptor
.
EnumDescriptor
(
name
=
'Type'
,
full_name
=
'paddle.framework.proto.VarType.Type'
,
filename
=
None
,
file
=
DESCRIPTOR
,
values
=
[
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOL'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT16'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT32'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT64'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP16'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP32'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP64'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'SIZE_T'
,
index
=
7
,
number
=
19
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'UINT8'
,
index
=
8
,
number
=
20
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT8'
,
index
=
9
,
number
=
21
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR'
,
index
=
10
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'SELECTED_ROWS'
,
index
=
11
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FEED_MINIBATCH'
,
index
=
12
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FETCH_LIST'
,
index
=
13
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STEP_SCOPES'
,
index
=
14
,
number
=
11
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_RANK_TABLE'
,
index
=
15
,
number
=
12
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR_ARRAY'
,
index
=
16
,
number
=
13
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'PLACE_LIST'
,
index
=
17
,
number
=
14
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'READER'
,
index
=
18
,
number
=
15
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'RAW'
,
index
=
19
,
number
=
17
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'TUPLE'
,
index
=
20
,
number
=
18
,
options
=
None
,
type
=
None
),
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
1832
,
serialized_end
=
2122
,
)
_sym_db
.
RegisterEnumDescriptor
(
_VARTYPE_TYPE
)
_VERSION
=
_descriptor
.
Descriptor
(
name
=
'Version'
,
full_name
=
'paddle.framework.proto.Version'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle.framework.proto.Version.version'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
43
,
serialized_end
=
72
,
)
_OPDESC_ATTR
=
_descriptor
.
Descriptor
(
name
=
'Attr'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.type'
,
index
=
1
,
number
=
2
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'i'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.i'
,
index
=
2
,
number
=
3
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'f'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.f'
,
index
=
3
,
number
=
4
,
type
=
2
,
cpp_type
=
6
,
label
=
1
,
has_default_value
=
False
,
default_value
=
float
(
0
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
's'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.s'
,
index
=
4
,
number
=
5
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'ints'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.ints'
,
index
=
5
,
number
=
6
,
type
=
5
,
cpp_type
=
1
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'floats'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.floats'
,
index
=
6
,
number
=
7
,
type
=
2
,
cpp_type
=
6
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'strings'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.strings'
,
index
=
7
,
number
=
8
,
type
=
9
,
cpp_type
=
9
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'b'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.b'
,
index
=
8
,
number
=
10
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
False
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'bools'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.bools'
,
index
=
9
,
number
=
11
,
type
=
8
,
cpp_type
=
7
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'block_idx'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.block_idx'
,
index
=
10
,
number
=
12
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'l'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.l'
,
index
=
11
,
number
=
13
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'blocks_idx'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.blocks_idx'
,
index
=
12
,
number
=
14
,
type
=
5
,
cpp_type
=
1
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'longs'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.longs'
,
index
=
13
,
number
=
15
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
283
,
serialized_end
=
522
,
)
_OPDESC_VAR
=
_descriptor
.
Descriptor
(
name
=
'Var'
,
full_name
=
'paddle.framework.proto.OpDesc.Var'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'parameter'
,
full_name
=
'paddle.framework.proto.OpDesc.Var.parameter'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'arguments'
,
full_name
=
'paddle.framework.proto.OpDesc.Var.arguments'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
524
,
serialized_end
=
567
,
)
_OPDESC
=
_descriptor
.
Descriptor
(
name
=
'OpDesc'
,
full_name
=
'paddle.framework.proto.OpDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpDesc.type'
,
index
=
0
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'inputs'
,
full_name
=
'paddle.framework.proto.OpDesc.inputs'
,
index
=
1
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'outputs'
,
full_name
=
'paddle.framework.proto.OpDesc.outputs'
,
index
=
2
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'attrs'
,
full_name
=
'paddle.framework.proto.OpDesc.attrs'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'is_target'
,
full_name
=
'paddle.framework.proto.OpDesc.is_target'
,
index
=
4
,
number
=
5
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[
_OPDESC_ATTR
,
_OPDESC_VAR
,
],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
75
,
serialized_end
=
567
,
)
_OPPROTO_VAR
=
_descriptor
.
Descriptor
(
name
=
'Var'
,
full_name
=
'paddle.framework.proto.OpProto.Var'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpProto.Var.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.Var.comment'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'duplicable'
,
full_name
=
'paddle.framework.proto.OpProto.Var.duplicable'
,
index
=
2
,
number
=
3
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'intermediate'
,
full_name
=
'paddle.framework.proto.OpProto.Var.intermediate'
,
index
=
3
,
number
=
4
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'dispensable'
,
full_name
=
'paddle.framework.proto.OpProto.Var.dispensable'
,
index
=
4
,
number
=
5
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
772
,
serialized_end
=
892
,
)
_OPPROTO_ATTR
=
_descriptor
.
Descriptor
(
name
=
'Attr'
,
full_name
=
'paddle.framework.proto.OpProto.Attr'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.type'
,
index
=
1
,
number
=
2
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.comment'
,
index
=
2
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'generated'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.generated'
,
index
=
3
,
number
=
4
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
894
,
serialized_end
=
1005
,
)
_OPPROTO
=
_descriptor
.
Descriptor
(
name
=
'OpProto'
,
full_name
=
'paddle.framework.proto.OpProto'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpProto.type'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'inputs'
,
full_name
=
'paddle.framework.proto.OpProto.inputs'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'outputs'
,
full_name
=
'paddle.framework.proto.OpProto.outputs'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'attrs'
,
full_name
=
'paddle.framework.proto.OpProto.attrs'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.comment'
,
index
=
4
,
number
=
5
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[
_OPPROTO_VAR
,
_OPPROTO_ATTR
,
],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
570
,
serialized_end
=
1005
,
)
_VARTYPE_TENSORDESC
=
_descriptor
.
Descriptor
(
name
=
'TensorDesc'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'data_type'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc.data_type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'dims'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc.dims'
,
index
=
1
,
number
=
2
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1393
,
serialized_end
=
1476
,
)
_VARTYPE_LODTENSORDESC
=
_descriptor
.
Descriptor
(
name
=
'LoDTensorDesc'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'tensor'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc.tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_level'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc.lod_level'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1478
,
serialized_end
=
1575
,
)
_VARTYPE_LODTENSORARRAYDESC
=
_descriptor
.
Descriptor
(
name
=
'LoDTensorArrayDesc'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'tensor'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_level'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1577
,
serialized_end
=
1679
,
)
_VARTYPE_READERDESC
=
_descriptor
.
Descriptor
(
name
=
'ReaderDesc'
,
full_name
=
'paddle.framework.proto.VarType.ReaderDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'lod_tensor'
,
full_name
=
'paddle.framework.proto.VarType.ReaderDesc.lod_tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1681
,
serialized_end
=
1760
,
)
_VARTYPE_TUPLE
=
_descriptor
.
Descriptor
(
name
=
'Tuple'
,
full_name
=
'paddle.framework.proto.VarType.Tuple'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'element_type'
,
full_name
=
'paddle.framework.proto.VarType.Tuple.element_type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1762
,
serialized_end
=
1829
,
)
_VARTYPE
=
_descriptor
.
Descriptor
(
name
=
'VarType'
,
full_name
=
'paddle.framework.proto.VarType'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.VarType.type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'selected_rows'
,
full_name
=
'paddle.framework.proto.VarType.selected_rows'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_tensor'
,
full_name
=
'paddle.framework.proto.VarType.lod_tensor'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'tensor_array'
,
full_name
=
'paddle.framework.proto.VarType.tensor_array'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'reader'
,
full_name
=
'paddle.framework.proto.VarType.reader'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'tuple'
,
full_name
=
'paddle.framework.proto.VarType.tuple'
,
index
=
5
,
number
=
7
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[
_VARTYPE_TENSORDESC
,
_VARTYPE_LODTENSORDESC
,
_VARTYPE_LODTENSORARRAYDESC
,
_VARTYPE_READERDESC
,
_VARTYPE_TUPLE
,
],
enum_types
=
[
_VARTYPE_TYPE
,
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
1008
,
serialized_end
=
2122
,
)
_VARDESC
=
_descriptor
.
Descriptor
(
name
=
'VarDesc'
,
full_name
=
'paddle.framework.proto.VarDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.VarDesc.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.VarDesc.type'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'persistable'
,
full_name
=
'paddle.framework.proto.VarDesc.persistable'
,
index
=
2
,
number
=
3
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
2124
,
serialized_end
=
2222
,
)
_BLOCKDESC
=
_descriptor
.
Descriptor
(
name
=
'BlockDesc'
,
full_name
=
'paddle.framework.proto.BlockDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.idx'
,
index
=
0
,
number
=
1
,
type
=
5
,
cpp_type
=
1
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'parent_idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.parent_idx'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'vars'
,
full_name
=
'paddle.framework.proto.BlockDesc.vars'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'ops'
,
full_name
=
'paddle.framework.proto.BlockDesc.ops'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'forward_block_idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.forward_block_idx'
,
index
=
4
,
number
=
5
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=-
1
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
2225
,
serialized_end
=
2392
,
)
_PROGRAMDESC
=
_descriptor
.
Descriptor
(
name
=
'ProgramDesc'
,
full_name
=
'paddle.framework.proto.ProgramDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'blocks'
,
full_name
=
'paddle.framework.proto.ProgramDesc.blocks'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle.framework.proto.ProgramDesc.version'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[],
nested_types
=
[],
enum_types
=
[],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[],
serialized_start
=
2394
,
serialized_end
=
2508
,
)
_OPDESC_ATTR
.
fields_by_name
[
'type'
].
enum_type
=
_ATTRTYPE
_OPDESC_ATTR
.
containing_type
=
_OPDESC
_OPDESC_VAR
.
containing_type
=
_OPDESC
_OPDESC
.
fields_by_name
[
'inputs'
].
message_type
=
_OPDESC_VAR
_OPDESC
.
fields_by_name
[
'outputs'
].
message_type
=
_OPDESC_VAR
_OPDESC
.
fields_by_name
[
'attrs'
].
message_type
=
_OPDESC_ATTR
_OPPROTO_VAR
.
containing_type
=
_OPPROTO
_OPPROTO_ATTR
.
fields_by_name
[
'type'
].
enum_type
=
_ATTRTYPE
_OPPROTO_ATTR
.
containing_type
=
_OPPROTO
_OPPROTO
.
fields_by_name
[
'inputs'
].
message_type
=
_OPPROTO_VAR
_OPPROTO
.
fields_by_name
[
'outputs'
].
message_type
=
_OPPROTO_VAR
_OPPROTO
.
fields_by_name
[
'attrs'
].
message_type
=
_OPPROTO_ATTR
_VARTYPE_TENSORDESC
.
fields_by_name
[
'data_type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE_TENSORDESC
.
containing_type
=
_VARTYPE
_VARTYPE_LODTENSORDESC
.
fields_by_name
[
'tensor'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC
.
containing_type
=
_VARTYPE
_VARTYPE_LODTENSORARRAYDESC
.
fields_by_name
[
'tensor'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC
.
containing_type
=
_VARTYPE
_VARTYPE_READERDESC
.
fields_by_name
[
'lod_tensor'
].
message_type
=
_VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC
.
containing_type
=
_VARTYPE
_VARTYPE_TUPLE
.
fields_by_name
[
'element_type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE_TUPLE
.
containing_type
=
_VARTYPE
_VARTYPE
.
fields_by_name
[
'type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE
.
fields_by_name
[
'selected_rows'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE
.
fields_by_name
[
'lod_tensor'
].
message_type
=
_VARTYPE_LODTENSORDESC
_VARTYPE
.
fields_by_name
[
'tensor_array'
].
message_type
=
_VARTYPE_LODTENSORARRAYDESC
_VARTYPE
.
fields_by_name
[
'reader'
].
message_type
=
_VARTYPE_READERDESC
_VARTYPE
.
fields_by_name
[
'tuple'
].
message_type
=
_VARTYPE_TUPLE
_VARTYPE_TYPE
.
containing_type
=
_VARTYPE
_VARDESC
.
fields_by_name
[
'type'
].
message_type
=
_VARTYPE
_BLOCKDESC
.
fields_by_name
[
'vars'
].
message_type
=
_VARDESC
_BLOCKDESC
.
fields_by_name
[
'ops'
].
message_type
=
_OPDESC
_PROGRAMDESC
.
fields_by_name
[
'blocks'
].
message_type
=
_BLOCKDESC
_PROGRAMDESC
.
fields_by_name
[
'version'
].
message_type
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'Version'
]
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'OpDesc'
]
=
_OPDESC
DESCRIPTOR
.
message_types_by_name
[
'OpProto'
]
=
_OPPROTO
DESCRIPTOR
.
message_types_by_name
[
'VarType'
]
=
_VARTYPE
DESCRIPTOR
.
message_types_by_name
[
'VarDesc'
]
=
_VARDESC
DESCRIPTOR
.
message_types_by_name
[
'BlockDesc'
]
=
_BLOCKDESC
DESCRIPTOR
.
message_types_by_name
[
'ProgramDesc'
]
=
_PROGRAMDESC
DESCRIPTOR
.
enum_types_by_name
[
'AttrType'
]
=
_ATTRTYPE
Version
=
_reflection
.
GeneratedProtocolMessageType
(
'Version'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VERSION
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db
.
RegisterMessage
(
Version
)
OpDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'OpDesc'
,
(
_message
.
Message
,
),
dict
(
Attr
=
_reflection
.
GeneratedProtocolMessageType
(
'Attr'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_OPDESC_ATTR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
)),
Var
=
_reflection
.
GeneratedProtocolMessageType
(
'Var'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_OPDESC_VAR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
)),
DESCRIPTOR
=
_OPDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db
.
RegisterMessage
(
OpDesc
)
_sym_db
.
RegisterMessage
(
OpDesc
.
Attr
)
_sym_db
.
RegisterMessage
(
OpDesc
.
Var
)
OpProto
=
_reflection
.
GeneratedProtocolMessageType
(
'OpProto'
,
(
_message
.
Message
,
),
dict
(
Var
=
_reflection
.
GeneratedProtocolMessageType
(
'Var'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_OPPROTO_VAR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
)),
Attr
=
_reflection
.
GeneratedProtocolMessageType
(
'Attr'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_OPPROTO_ATTR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
)),
DESCRIPTOR
=
_OPPROTO
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db
.
RegisterMessage
(
OpProto
)
_sym_db
.
RegisterMessage
(
OpProto
.
Var
)
_sym_db
.
RegisterMessage
(
OpProto
.
Attr
)
VarType
=
_reflection
.
GeneratedProtocolMessageType
(
'VarType'
,
(
_message
.
Message
,
),
dict
(
TensorDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'TensorDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARTYPE_TENSORDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
)),
LoDTensorDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'LoDTensorDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARTYPE_LODTENSORDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
)),
LoDTensorArrayDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'LoDTensorArrayDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARTYPE_LODTENSORARRAYDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
)),
ReaderDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'ReaderDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARTYPE_READERDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
)),
Tuple
=
_reflection
.
GeneratedProtocolMessageType
(
'Tuple'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARTYPE_TUPLE
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
)),
DESCRIPTOR
=
_VARTYPE
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db
.
RegisterMessage
(
VarType
)
_sym_db
.
RegisterMessage
(
VarType
.
TensorDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
LoDTensorDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
LoDTensorArrayDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
ReaderDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
Tuple
)
VarDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'VarDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_VARDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db
.
RegisterMessage
(
VarDesc
)
BlockDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'BlockDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_BLOCKDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db
.
RegisterMessage
(
BlockDesc
)
ProgramDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'ProgramDesc'
,
(
_message
.
Message
,
),
dict
(
DESCRIPTOR
=
_PROGRAMDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db
.
RegisterMessage
(
ProgramDesc
)
DESCRIPTOR
.
has_options
=
True
DESCRIPTOR
.
_options
=
_descriptor
.
_ParseOptions
(
descriptor_pb2
.
FileOptions
(),
_b
(
'H
\003
'
))
# @@protoc_insertion_point(module_scope)
onnx2
paddle/onnx2paddle
/onnx_utils.py
→
onnx2
fluid/onnx2fluid
/onnx_utils.py
浏览文件 @
2228423e
...
@@ -18,28 +18,30 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
...
@@ -18,28 +18,30 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
from
onnx.numpy_helper
import
to_array
from
onnx.numpy_helper
import
to_array
from
onnx.shape_inference
import
infer_shapes
from
onnx.shape_inference
import
infer_shapes
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
__all__
=
[
'print_pb_structure'
,
'print_pb_structure'
,
'build_value_refs'
,
'build_value_refs'
,
'node_attrs'
,
'node_topo'
,
'node_iter'
,
'node_attrs'
,
'node_topo'
,
'node_iter'
,
'tensor_shape'
,
'tensor_shape'
,
'graph_ops'
,
'graph_weights'
,
'graph_ops'
,
'graph_weights'
,
'inferred_model_value_info'
,
'inferred_model_value_info'
,
'optimize_model_skip_op_for_inference'
,
'optimize_model_skip_op_for_inference'
,
'optimize_model_strip_initializer'
,
'optimize_model_strip_initializer'
,
'optimize_model_cast'
,
'optimize_model_slice'
,
'optimize_model_cast'
,
'optimize_model_slice'
,
]
]
ONNX_INT_MAX
=
2
**
63
-
1
ONNX_INT_MAX
=
2
**
63
-
1
DEFAULT_OP_DOMAIN
=
'ai.onnx'
DEFAULT_OP_DOMAIN
=
'ai.onnx'
def
print_pb_structure
(
message
,
def
print_pb_structure
(
message
,
loop_iterative
=
False
,
depth
=
0
):
loop_iterative
=
False
,
depth
=
0
):
"""
"""
print pb fields in its structure
print pb fields in its structure
"""
"""
...
@@ -47,14 +49,17 @@ def print_pb_structure(message,
...
@@ -47,14 +49,17 @@ def print_pb_structure(message,
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
if
hasattr
(
message
,
'DESCRIPTOR'
)
and
hasattr
(
message
.
DESCRIPTOR
,
'fields'
):
for
field
in
message
.
DESCRIPTOR
.
fields
:
for
field
in
message
.
DESCRIPTOR
.
fields
:
print
(
'
\t
'
*
depth
+
'-'
,
field
.
name
)
print
(
'
\t
'
*
depth
+
'-'
,
field
.
name
)
print_pb_structure
(
getattr
(
message
,
field
.
name
),
print_pb_structure
(
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
getattr
(
message
,
field
.
name
),
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
if
loop_iterative
and
hasattr
(
message
,
'MergeFrom'
)
and
hasattr
(
message
,
'__len__'
):
if
loop_iterative
and
hasattr
(
message
,
'MergeFrom'
)
and
hasattr
(
message
,
'__len__'
):
for
idx
,
item
in
enumerate
(
message
):
for
idx
,
item
in
enumerate
(
message
):
print
(
'
\t
'
*
depth
+
'-'
,
idx
)
print
(
'
\t
'
*
depth
+
'-'
,
idx
)
print_pb_structure
(
item
,
print_pb_structure
(
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
item
,
loop_iterative
=
loop_iterative
,
depth
=
(
depth
+
1
))
def
build_value_refs
(
nodes
):
def
build_value_refs
(
nodes
):
...
@@ -80,7 +85,8 @@ def get_attribute_value2(attr):
...
@@ -80,7 +85,8 @@ def get_attribute_value2(attr):
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
data
=
attr
.
t
.
raw_data
data
=
attr
.
t
.
raw_data
value
=
np
.
frombuffer
(
data
,
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
value
=
np
.
frombuffer
(
data
,
dtype
=
dtype
,
count
=
(
len
(
data
)
//
dtype
.
itemsize
))
else
:
else
:
value
=
get_attribute_value
(
attr
)
value
=
get_attribute_value
(
attr
)
return
value
return
value
...
@@ -91,7 +97,8 @@ def node_attrs(node):
...
@@ -91,7 +97,8 @@ def node_attrs(node):
convert ONNX node attributes to dict
convert ONNX node attributes to dict
"""
"""
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
for
attr
in
node
.
attribute
}
# dict
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
for
attr
in
node
.
attribute
}
# dict
def
tensor_shape
(
tensor
):
def
tensor_shape
(
tensor
):
...
@@ -168,8 +175,7 @@ def node_topo(nodes, topo='default'):
...
@@ -168,8 +175,7 @@ def node_topo(nodes, topo='default'):
raise
ValueError
(
'unkown given topo: {}'
.
format
(
topo
))
raise
ValueError
(
'unkown given topo: {}'
.
format
(
topo
))
def
node_iter
(
nodes
,
def
node_iter
(
nodes
,
indices
=
None
):
indices
=
None
):
"""
"""
generator for ONNX node graph with given indices
generator for ONNX node graph with given indices
"""
"""
...
@@ -194,8 +200,7 @@ def node_iter(nodes,
...
@@ -194,8 +200,7 @@ def node_iter(nodes,
yield
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
yield
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
def
graph_ops
(
graph
,
def
graph_ops
(
graph
,
topo
=
'default'
):
topo
=
'default'
):
"""
"""
generator for ONNX node graph with given topology
generator for ONNX node graph with given topology
"""
"""
...
@@ -244,7 +249,7 @@ def inferred_model_value_info(model):
...
@@ -244,7 +249,7 @@ def inferred_model_value_info(model):
external
=
True
,
external
=
True
,
)
)
for
item
in
graph
.
output
:
for
item
in
graph
.
output
:
# assert item.name not in value_info, 'bypass-model not supported'
# assert item.name not in value_info, 'bypass-model not supported'
value_info
[
item
.
name
]
=
dict
(
value_info
[
item
.
name
]
=
dict
(
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
shape
=
tensor_shape
(
item
),
shape
=
tensor_shape
(
item
),
...
@@ -283,9 +288,7 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
...
@@ -283,9 +288,7 @@ def skip_node_backward(nodes, src_input_name, dst_output_name, output_refs):
return
processed
return
processed
def
optimize_model_skip_op_for_inference
(
def
optimize_model_skip_op_for_inference
(
model
,
op_list
=
None
):
model
,
op_list
=
None
):
"""
"""
skip ops can be bypassed for inference
skip ops can be bypassed for inference
"""
"""
...
@@ -297,21 +300,23 @@ def optimize_model_skip_op_for_inference(
...
@@ -297,21 +300,23 @@ def optimize_model_skip_op_for_inference(
ret
=
type
(
model
)()
ret
=
type
(
model
)()
ret
.
CopyFrom
(
model
)
ret
.
CopyFrom
(
model
)
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret_nodes
=
ret
.
graph
.
node
ret_nodes
=
ret
.
graph
.
node
nodes_to_remove
=
[]
nodes_to_remove
=
[]
for
node_idx
,
node
in
enumerate
(
nodes
):
for
node_idx
,
node
in
enumerate
(
nodes
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
continue
continue
op_type
=
node
.
op_type
op_type
=
node
.
op_type
if
not
(
op_type
in
op_list
):
if
not
(
op_type
in
op_list
):
continue
continue
if
op_type
in
[
'Dropout'
]:
if
op_type
in
[
'Dropout'
]:
input_name
=
node
.
input
[
0
]
input_name
=
node
.
input
[
0
]
output_name
=
node
.
output
[
0
]
output_name
=
node
.
output
[
0
]
elif
not
(
len
(
node
.
input
)
==
1
and
len
(
node
.
output
)
==
1
):
elif
not
(
len
(
node
.
input
)
==
1
and
len
(
node
.
output
)
==
1
):
logger
.
warning
(
'currently only 1-input-1-output op supported, skip required %d: %s'
,
logger
.
warning
(
'currently only 1-input-1-output op supported, skip required %d: %s'
,
node_idx
,
node
.
op_type
)
node_idx
,
node
.
op_type
)
continue
continue
else
:
else
:
...
@@ -319,16 +324,18 @@ def optimize_model_skip_op_for_inference(
...
@@ -319,16 +324,18 @@ def optimize_model_skip_op_for_inference(
output_name
=
node
.
output
[
0
]
output_name
=
node
.
output
[
0
]
if
output_name
in
input_refs
:
if
output_name
in
input_refs
:
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
input_name
,
input_refs
)
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
input_name
,
input_refs
)
elif
input_name
in
output_refs
:
elif
input_name
in
output_refs
:
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
output_name
,
output_refs
)
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
output_name
,
output_refs
)
else
:
else
:
processed
=
-
1
processed
=
-
1
if
processed
>
0
:
if
processed
>
0
:
nodes_to_remove
.
append
(
node_idx
)
nodes_to_remove
.
append
(
node_idx
)
logger
.
debug
(
'skip op %d: %s -> %s -> %s'
,
logger
.
debug
(
'skip op %d: %s -> %s -> %s'
,
node_idx
,
input_name
,
node
_idx
,
input_name
,
node
.
op_type
,
output_name
)
node
.
op_type
,
output_name
)
elif
processed
==
0
:
elif
processed
==
0
:
logger
.
warning
(
'weird, no node processed'
)
logger
.
warning
(
'weird, no node processed'
)
else
:
else
:
...
@@ -342,8 +349,7 @@ def optimize_model_skip_op_for_inference(
...
@@ -342,8 +349,7 @@ def optimize_model_skip_op_for_inference(
return
ret
return
ret
def
optimize_model_strip_initializer
(
model
,
def
optimize_model_strip_initializer
(
model
,
keep_input_only
=
True
):
keep_input_only
=
True
):
"""
"""
strip weights for inference
strip weights for inference
"""
"""
...
@@ -354,7 +360,8 @@ def optimize_model_strip_initializer(model,
...
@@ -354,7 +360,8 @@ def optimize_model_strip_initializer(model,
ret
=
type
(
model
)()
ret
=
type
(
model
)()
ret
.
CopyFrom
(
model
)
ret
.
CopyFrom
(
model
)
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
# strip initializers
# strip initializers
ret
.
graph
.
ClearField
(
'initializer'
)
ret
.
graph
.
ClearField
(
'initializer'
)
...
@@ -366,8 +373,7 @@ def optimize_model_strip_initializer(model,
...
@@ -366,8 +373,7 @@ def optimize_model_strip_initializer(model,
elif
not
keep_input_only
and
name
in
output_refs
:
elif
not
keep_input_only
and
name
in
output_refs
:
ret_initializers
.
add
().
CopyFrom
(
initializer
)
ret_initializers
.
add
().
CopyFrom
(
initializer
)
else
:
else
:
logger
.
debug
(
'initializer %s(%s[%d]) stripped'
,
logger
.
debug
(
'initializer %s(%s[%d]) stripped'
,
name
,
name
,
TENSOR_TYPE_TO_NP_TYPE
[
initializer
.
data_type
],
TENSOR_TYPE_TO_NP_TYPE
[
initializer
.
data_type
],
len
(
initializer
.
raw_data
))
len
(
initializer
.
raw_data
))
...
@@ -379,8 +385,8 @@ def optimize_model_strip_initializer(model,
...
@@ -379,8 +385,8 @@ def optimize_model_strip_initializer(model,
if
name
in
input_refs
or
name
in
out_names
:
if
name
in
input_refs
or
name
in
out_names
:
ret_inputs
.
add
().
CopyFrom
(
item
)
ret_inputs
.
add
().
CopyFrom
(
item
)
else
:
else
:
logger
.
debug
(
'input %s(%s%s) stripped'
,
logger
.
debug
(
name
,
'input %s(%s%s) stripped'
,
name
,
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
TENSOR_TYPE_TO_NP_TYPE
[
item
.
type
.
tensor_type
.
elem_type
],
tensor_shape
(
item
))
tensor_shape
(
item
))
return
ret
return
ret
...
@@ -397,13 +403,14 @@ def optimize_model_cast(model):
...
@@ -397,13 +403,14 @@ def optimize_model_cast(model):
ret
=
type
(
model
)()
ret
=
type
(
model
)()
ret
.
CopyFrom
(
model
)
ret
.
CopyFrom
(
model
)
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret_nodes
=
ret
.
graph
.
node
ret_nodes
=
ret
.
graph
.
node
nodes_to_remove
=
[]
nodes_to_remove
=
[]
for
node_idx
,
node
in
enumerate
(
nodes
):
for
node_idx
,
node
in
enumerate
(
nodes
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
continue
continue
if
not
(
node
.
op_type
==
'Cast'
):
if
not
(
node
.
op_type
==
'Cast'
):
continue
continue
attrs
=
node_attrs
(
node
)
attrs
=
node_attrs
(
node
)
output_dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
attrs
[
'to'
]]
output_dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
attrs
[
'to'
]]
...
@@ -417,21 +424,23 @@ def optimize_model_cast(model):
...
@@ -417,21 +424,23 @@ def optimize_model_cast(model):
output_name
=
node
.
output
[
0
]
output_name
=
node
.
output
[
0
]
if
output_name
in
input_refs
:
if
output_name
in
input_refs
:
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
input_name
,
input_refs
)
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
input_name
,
input_refs
)
elif
input_name
in
output_refs
:
elif
input_name
in
output_refs
:
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
output_name
,
output_refs
)
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
output_name
,
output_refs
)
else
:
else
:
processed
=
-
1
processed
=
-
1
if
processed
>
0
:
if
processed
>
0
:
nodes_to_remove
.
append
(
node_idx
)
nodes_to_remove
.
append
(
node_idx
)
logger
.
debug
(
'skip %s: %s -> %s Cast op'
,
logger
.
debug
(
'skip %s: %s -> %s Cast op'
,
node
.
name
,
input_dtype
,
node
.
name
,
input_dtype
,
output_dtype
)
output_dtype
)
elif
processed
==
0
:
elif
processed
==
0
:
logger
.
warning
(
'weird, no node processed'
)
logger
.
warning
(
'weird, no node processed'
)
else
:
else
:
logger
.
debug
(
'keep standalone %s: %s -> %s Cast op'
,
logger
.
debug
(
'keep standalone %s: %s -> %s Cast op'
,
node
.
name
,
node
.
name
,
input_dtype
,
output_dtype
)
input_dtype
,
output_dtype
)
nodes_to_remove
.
sort
(
reverse
=
True
)
nodes_to_remove
.
sort
(
reverse
=
True
)
for
node_idx
in
nodes_to_remove
:
for
node_idx
in
nodes_to_remove
:
...
@@ -452,13 +461,14 @@ def optimize_model_slice(model):
...
@@ -452,13 +461,14 @@ def optimize_model_slice(model):
chain
=
[]
chain
=
[]
while
True
:
while
True
:
node
=
nodes
[
node_idx
]
node
=
nodes
[
node_idx
]
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
if
not
(
node
.
domain
==
DEFAULT_OP_DOMAIN
or
node
.
domain
==
''
):
return
chain
return
chain
if
not
node
.
op_type
==
'Slice'
:
if
not
node
.
op_type
==
'Slice'
:
return
chain
return
chain
chain
.
append
(
node_idx
)
chain
.
append
(
node_idx
)
output_name
=
node
.
output
[
0
]
output_name
=
node
.
output
[
0
]
if
output_name
not
in
input_refs
or
len
(
input_refs
[
output_name
])
!=
1
:
if
output_name
not
in
input_refs
or
len
(
input_refs
[
output_name
])
!=
1
:
return
chain
return
chain
node_idx
=
list
(
input_refs
[
output_name
])[
0
]
node_idx
=
list
(
input_refs
[
output_name
])[
0
]
...
@@ -468,7 +478,8 @@ def optimize_model_slice(model):
...
@@ -468,7 +478,8 @@ def optimize_model_slice(model):
for
slice_node_idx
in
slice_chain
:
for
slice_node_idx
in
slice_chain
:
node
=
nodes
[
slice_node_idx
]
node
=
nodes
[
slice_node_idx
]
attrs
=
node_attrs
(
node
)
attrs
=
node_attrs
(
node
)
for
axis
,
start
,
end
in
zip
(
attrs
[
'axes'
],
attrs
[
'starts'
],
attrs
[
'ends'
]):
for
axis
,
start
,
end
in
zip
(
attrs
[
'axes'
],
attrs
[
'starts'
],
attrs
[
'ends'
]):
if
start
==
0
and
end
==
ONNX_INT_MAX
:
if
start
==
0
and
end
==
ONNX_INT_MAX
:
continue
continue
if
axis
in
merged_slice
:
if
axis
in
merged_slice
:
...
@@ -480,7 +491,8 @@ def optimize_model_slice(model):
...
@@ -480,7 +491,8 @@ def optimize_model_slice(model):
ret
=
type
(
model
)()
ret
=
type
(
model
)()
ret
.
CopyFrom
(
model
)
ret
.
CopyFrom
(
model
)
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret
.
graph
.
ClearField
(
'value_info'
)
# WORKAROUND: onnx do not drop old value_info
ret_nodes
=
ret
.
graph
.
node
ret_nodes
=
ret
.
graph
.
node
nodes_to_remove
=
[]
nodes_to_remove
=
[]
for
node_idx
in
range
(
len
(
nodes
)):
for
node_idx
in
range
(
len
(
nodes
)):
...
@@ -502,40 +514,48 @@ def optimize_model_slice(model):
...
@@ -502,40 +514,48 @@ def optimize_model_slice(model):
output_name
=
last_node
.
output
[
0
]
output_name
=
last_node
.
output
[
0
]
processed
=
-
1
processed
=
-
1
if
output_name
in
input_refs
:
# 0, [1...]
if
output_name
in
input_refs
:
# 0, [1...]
new_input_name
=
first_node
.
output
[
0
]
if
len
(
merged_slice
)
>
0
else
input_name
new_input_name
=
first_node
.
output
[
0
]
if
len
(
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
new_input_name
,
input_refs
)
merged_slice
)
>
0
else
input_name
processed
=
skip_node_forward
(
ret_nodes
,
output_name
,
new_input_name
,
input_refs
)
if
processed
>
0
:
if
processed
>
0
:
if
len
(
merged_slice
)
>
0
:
if
len
(
merged_slice
)
>
0
:
remain_idx
=
slice_chain
[
0
]
remain_idx
=
slice_chain
[
0
]
remove_chain
=
slice_chain
[
1
:]
remove_chain
=
slice_chain
[
1
:]
slice_node
=
ret_nodes
[
remain_idx
]
slice_node
=
ret_nodes
[
remain_idx
]
for
attr
in
slice_node
.
attribute
:
for
attr
in
slice_node
.
attribute
:
attr
.
CopyFrom
(
make_attribute
(
attr
.
name
,
attrs
[
attr
.
name
]))
attr
.
CopyFrom
(
make_attribute
(
attr
.
name
,
attrs
[
attr
.
name
]))
logger
.
debug
(
'merged slice chain %s -> %s%s -> %s'
,
logger
.
debug
(
'merged slice chain %s -> %s%s -> %s'
,
input_name
,
remain_idx
,
remove_chain
,
output_name
)
input_name
,
remain_idx
,
remove_chain
,
output_name
)
else
:
else
:
remove_chain
=
slice_chain
remove_chain
=
slice_chain
if
processed
<
0
and
input_name
in
output_refs
:
if
processed
<
0
and
input_name
in
output_refs
:
new_output_name
=
last_node
.
input
[
0
]
if
len
(
merged_slice
)
>
0
else
output_name
new_output_name
=
last_node
.
input
[
0
]
if
len
(
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
new_output_name
,
output_refs
)
merged_slice
)
>
0
else
output_name
processed
=
skip_node_backward
(
ret_nodes
,
input_name
,
new_output_name
,
output_refs
)
if
processed
>
0
:
if
processed
>
0
:
if
len
(
merged_slice
)
>
0
:
if
len
(
merged_slice
)
>
0
:
remain_idx
=
slice_chain
[
-
1
]
remain_idx
=
slice_chain
[
-
1
]
remove_chain
=
slice_chain
[:
-
1
]
remove_chain
=
slice_chain
[:
-
1
]
slice_node
=
ret_nodes
[
remain_idx
]
slice_node
=
ret_nodes
[
remain_idx
]
for
attr
in
slice_node
.
attribute
:
for
attr
in
slice_node
.
attribute
:
attr
.
CopyFrom
(
make_attribute
(
attr
.
name
,
attrs
[
attr
.
name
]))
attr
.
CopyFrom
(
make_attribute
(
attr
.
name
,
attrs
[
attr
.
name
]))
logger
.
debug
(
'merged slice chain %s -> %s%s -> %s'
,
logger
.
debug
(
'merged slice chain %s -> %s%s -> %s'
,
input_name
,
remove_chain
,
remain_idx
,
output_name
)
input_name
,
remove_chain
,
remain_idx
,
output_name
)
else
:
else
:
remove_chain
=
slice_chain
remove_chain
=
slice_chain
if
processed
>
0
:
if
processed
>
0
:
nodes_to_remove
.
extend
(
remove_chain
)
nodes_to_remove
.
extend
(
remove_chain
)
if
len
(
merged_slice
)
==
0
:
if
len
(
merged_slice
)
==
0
:
logger
.
debug
(
'skip slice chain %s -> %s -> %s'
,
logger
.
debug
(
'skip slice chain %s -> %s -> %s'
,
input_name
,
input_name
,
slice_chain
,
output_name
)
slice_chain
,
output_name
)
elif
processed
<
0
:
# NEVERFIX: not merge standalone slice chain
elif
processed
<
0
:
# NEVERFIX: not merge standalone slice chain
logger
.
debug
(
'keep standalone slice chain %s -> %s -> %s'
,
logger
.
debug
(
'keep standalone slice chain %s -> %s -> %s'
,
input_name
,
slice_chain
,
output_name
)
input_name
,
slice_chain
,
output_name
)
...
@@ -549,7 +569,8 @@ def optimize_model_slice(model):
...
@@ -549,7 +569,8 @@ def optimize_model_slice(model):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
logging
.
basicConfig
(
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
level
=
logging
.
DEBUG
,
level
=
logging
.
DEBUG
,
)
)
...
...
onnx2
paddle/onnx2paddle
/symbolic.py
→
onnx2
fluid/onnx2fluid
/symbolic.py
浏览文件 @
2228423e
#!/usr/bin/env python3
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# -*- coding: utf-8 -*-
"""
"""
ONNX to Paddle symbolic translation
ONNX to Paddle
fluid
symbolic translation
Created on Mon Feb 25 09:33:43 2019
Created on Mon Feb 25 09:33:43 2019
...
@@ -18,20 +18,23 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
...
@@ -18,20 +18,23 @@ from onnx.mapping import TENSOR_TYPE_TO_NP_TYPE
_logger
=
_logging
.
getLogger
(
__name__
)
_logger
=
_logging
.
getLogger
(
__name__
)
ONNX_INT_MAX
=
2
**
63
-
1
ONNX_INT_MAX
=
2
**
63
-
1
FLUID_INT_MAX
=
2
**
31
-
1
DEFAULT_ONNX_OP_DOMAIN
=
''
DEFAULT_ONNX_OP_DOMAIN
=
''
DEFAULT_
PADDLE
_OP_NAMESCOPE
=
'/'
DEFAULT_
FLUID
_OP_NAMESCOPE
=
'/'
DEFAULT_OP_MAPPING_FIELD_VALUES
=
_dict
()
DEFAULT_OP_MAPPING_FIELD_VALUES
=
_dict
()
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'PADDLE_OP'
]
=
''
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FLUID_OP'
]
=
''
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'PADDLE_INPUT_ARGS'
]
=
None
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FLUID_INPUT_ARGS'
]
=
None
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'PADDLE_OUTPUT_ARGS'
]
=
None
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FLUID_OUTPUT_ARGS'
]
=
None
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'ATTR_MAPPING'
]
=
dict
()
# dict(onnx_attr_from=paddle_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'ATTR_MAPPING'
]
=
dict
(
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'DEFAULTS'
]
=
dict
()
# dict(paddle_attr=default)
)
# dict(onnx_attr_from=fluid_attr_to)
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'INPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'DEFAULTS'
]
=
dict
()
# dict(fluid_attr=default)
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'OUTPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'INPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'OUTPUT_PERM'
]
=
None
# sampler: [idx_onnx_arg...]
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FILL_NAME_FIELD'
]
=
True
DEFAULT_OP_MAPPING_FIELD_VALUES
[
'FILL_NAME_FIELD'
]
=
True
DEFAULT_OP_MAPPING
=
{
DEFAULT_OP_MAPPING
=
{
...
@@ -60,7 +63,7 @@ DEFAULT_OP_MAPPING = {
...
@@ -60,7 +63,7 @@ DEFAULT_OP_MAPPING = {
'Reciprocal'
:
[
'reciprocal'
,
[
'X'
],
[
'Out'
]],
'Reciprocal'
:
[
'reciprocal'
,
[
'X'
],
[
'Out'
]],
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'Relu'
:
[
'relu'
,
[
'X'
],
[
'Out'
]],
'Selu'
:
[
'selu'
,
[
'X'
],
[
'Out'
],
dict
(
gamma
=
'scale'
)],
'Selu'
:
[
'selu'
,
[
'X'
],
[
'Out'
],
dict
(
gamma
=
'scale'
)],
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
# FIXME: out is int64
-
int32
'Shape'
:
[
'shape'
,
[
'X'
],
[
'Out'
]],
# FIXME: out is int64
vs
int32
'Shrink'
:
[
'softshrink'
,
[
'X'
],
[
'Out'
],
dict
(
bias
=
''
,
labmd
=
''
)],
'Shrink'
:
[
'softshrink'
,
[
'X'
],
[
'Out'
],
dict
(
bias
=
''
,
labmd
=
''
)],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Sigmoid'
:
[
'sigmoid'
,
[
'X'
],
[
'Out'
]],
'Sin'
:
[
'sin'
,
[
'X'
],
[
'Out'
]],
'Sin'
:
[
'sin'
,
[
'X'
],
[
'Out'
]],
...
@@ -74,25 +77,24 @@ DEFAULT_OP_MAPPING = {
...
@@ -74,25 +77,24 @@ DEFAULT_OP_MAPPING = {
'Transpose'
:
[
'transpose'
,
[
'X'
],
[
'Out'
]],
# FIXME: emit transpose2
'Transpose'
:
[
'transpose'
,
[
'X'
],
[
'Out'
]],
# FIXME: emit transpose2
'Unsqueeze'
:
[
'unsqueeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit unsqueeze2
'Unsqueeze'
:
[
'unsqueeze'
,
[
'X'
],
[
'Out'
]],
# attrs bypassed, FIXME: emit unsqueeze2
## binary ops ##
## binary ops ##
# FIXME: axis=-1 in Paddle is broken, refer it in specialization
'Add'
:
[
'elementwise_add'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=-
1
)],
'Add'
:
[
'elementwise_add'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
# 'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
'And'
:
[
'logical_and'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'And'
:
[
'logical_and'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Div'
:
[
'elementwise_div'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
'Div'
:
[
'elementwise_div'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
'Equal'
:
[
'equal'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Equal'
:
[
'equal'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Greater'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Greater'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Less'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'Less'
:
[
'less_than'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(),
None
,
None
,
False
],
'MatMul'
:
[
'matmul'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# defaults excluded for transpose_x
-
transpose_X
'MatMul'
:
[
'matmul'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# defaults excluded for transpose_x
vs
transpose_X
'Max'
:
[
'elementwise_max'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
'Max'
:
[
'elementwise_max'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
'Min'
:
[
'elementwise_min'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
'Min'
:
[
'elementwise_min'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
'Mul'
:
[
'elementwise_mul'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
'Not'
:
[
'logical_not'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Not'
:
[
'logical_not'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'OneHot'
:
# assuming values=[0, 1], axis=-1 and drop them
'OneHot'
:
# assuming values=[0, 1], axis=-1 and drop them
[
'one_hot'
,
[
'Input'
,
'Depth'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(),
[
'one_hot'
,
[
'Input'
,
'Depth'
],
[
'Out'
],
dict
(
axis
=
''
),
dict
(),
[
0
,
1
],
None
,
False
],
[
0
,
1
],
None
,
False
],
'Or'
:
[
'logical_or'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Or'
:
[
'logical_or'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
# TODO: pow for scalar exponent
'Pow'
:
[
'elementwise_pow'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
# TODO: pow for scalar exponent
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
0
)],
'Sub'
:
[
'elementwise_sub'
,
[
'X'
,
'Y'
],
[
'Out'
],
dict
(),
dict
(
axis
=
-
1
)],
'Xor'
:
[
'logical_xor'
,
[
'X'
,
'Y'
],
[
'Out'
]],
'Xor'
:
[
'logical_xor'
,
[
'X'
,
'Y'
],
[
'Out'
]],
# reduce ops
# reduce ops
'ReduceMax'
:
[
'reduce_max'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
'ReduceMax'
:
[
'reduce_max'
,
[
'X'
],
[
'Out'
],
dict
(
axes
=
'dim'
,
keepdims
=
'keep_dim'
)],
...
@@ -106,29 +108,34 @@ DEFAULT_OP_MAPPING = {
...
@@ -106,29 +108,34 @@ DEFAULT_OP_MAPPING = {
}
}
DEFAULT_IOA_CONSTRAINT
=
{
DEFAULT_IOA_CONSTRAINT
=
{
'ArgMax'
:
'ArgMax'
:
[
[(
lambda
i
,
o
,
a
:
a
.
get
(
'keepdims'
,
1
)
==
1
,
'only keepdims = 0 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'keepdims'
,
1
)
==
1
,
'only keepdims = 0 is supported'
),
],
],
'ArgMin'
:
'ArgMin'
:
[
[(
lambda
i
,
o
,
a
:
a
.
get
(
'keepdims'
,
1
)
==
1
,
'only keepdims = 0 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'keepdims'
,
1
)
==
1
,
'only keepdims = 0 is supported'
),
],
],
'Gather'
:
'Gather'
:
[
[
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
0
)
==
0
,
'only axis = 0 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
0
)
==
0
,
'only axis = 0 is supported'
),
],
],
'Shrink'
:
'Shrink'
:
[
[(
lambda
i
,
o
,
a
:
a
.
get
(
'bias'
,
0
)
==
a
.
get
(
'lambd'
,
0.5
),
'only SoftShrink with bias = lambd is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'bias'
,
0
)
==
a
.
get
(
'lambd'
,
0.5
),
'only SoftShrink with bias = lambd is supported'
),
],
],
# 'Softmax':
# 'Softmax':
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle Softmax works on dim -2 only'),
# [(lambda i, o, a: a.get('axis', 1) == -2, 'Paddle fluid Softmax works on dim -2 only'),
# ],
# ],
'OneHot'
:
'OneHot'
:
[
[(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
-
1
)
==
-
1
,
'only axis = -1 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
-
1
)
==
-
1
,
'only axis = -1 is supported'
),
],
],
'Scatter'
:
'Scatter'
:
[
[
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
0
)
==
0
,
'only axis = 0 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
0
)
==
0
,
'only axis = 0 is supported'
),
],
],
'TopK'
:
'TopK'
:
[
[(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
-
1
)
==
-
1
,
'only axis = -1 is supported'
),
(
lambda
i
,
o
,
a
:
a
.
get
(
'axis'
,
-
1
)
==
-
1
,
'only axis = -1 is supported'
),
],
],
}
}
...
@@ -142,7 +149,7 @@ def _make_var_name(name):
...
@@ -142,7 +149,7 @@ def _make_var_name(name):
return
'_'
return
'_'
if
name
[
0
].
isdigit
():
if
name
[
0
].
isdigit
():
return
'var_'
+
name
return
'var_'
+
name
for
s
in
' *?\/-:'
:
for
s
in
' *?
\
\
/-:'
:
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
name
=
'var'
+
name
...
@@ -188,82 +195,91 @@ def _shape_or_none(value_infos, val_name):
...
@@ -188,82 +195,91 @@ def _shape_or_none(value_infos, val_name):
# return value_info.get('const_value', var_name)
# return value_info.get('const_value', var_name)
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
def
_default
(
prog
,
op_type
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
*
args
,
name
=
''
,
**
kwargs
):
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
=
DEFAULT_OP_MAPPING
[
op_type
]
info
.
extend
(
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())[
len
(
info
):])
info
.
extend
(
list
(
DEFAULT_OP_MAPPING_FIELD_VALUES
.
values
())[
len
(
info
):])
(
paddle_op
,
(
paddle_input_args
,
paddle_output_args
,
fluid_op
,
attr_mapping
,
default_attrs
,
fluid_input_args
,
input_perm
,
output_perm
,
fluid_output_args
,
attr_mapping
,
default_attrs
,
input_perm
,
output_perm
,
fill_name_field
,
fill_name_field
,
)
=
info
)
=
info
if
paddle
_op
in
DEFAULT_IOA_CONSTRAINT
:
if
fluid
_op
in
DEFAULT_IOA_CONSTRAINT
:
for
predicate
,
message
in
DEFAULT_IOA_CONSTRAINT
[
paddle
_op
]:
for
predicate
,
message
in
DEFAULT_IOA_CONSTRAINT
[
fluid
_op
]:
assert
predicate
(
inputs
,
outputs
,
attrs
),
message
assert
predicate
(
inputs
,
outputs
,
attrs
),
message
# bypass if key absent, drop if mapped key is '' or '_'
# bypass if key absent, drop if mapped key is '' or '_'
mapped_attrs
=
{
attr_mapping
.
get
(
key
,
key
):
value
for
key
,
value
in
attrs
.
items
()}
mapped_attrs
=
{
attr_mapping
.
get
(
key
,
key
):
value
for
key
,
value
in
attrs
.
items
()
}
if
''
in
mapped_attrs
:
if
''
in
mapped_attrs
:
mapped_attrs
.
pop
(
''
)
mapped_attrs
.
pop
(
''
)
if
'_'
in
mapped_attrs
:
if
'_'
in
mapped_attrs
:
mapped_attrs
.
pop
(
'_'
)
mapped_attrs
.
pop
(
'_'
)
paddle
_attrs
=
default_attrs
.
copy
()
fluid
_attrs
=
default_attrs
.
copy
()
paddle_attrs
.
update
(
mapped_attrs
)
# as new attrs
fluid_attrs
.
update
(
mapped_attrs
)
# as new attrs
val_inps
=
inputs
if
input_perm
is
None
else
map
(
lambda
i
:
inputs
[
i
],
input_perm
)
val_inps
=
inputs
if
input_perm
is
None
else
map
(
lambda
i
:
inputs
[
i
],
val_outs
=
outputs
if
output_perm
is
None
else
map
(
lambda
i
:
outputs
[
i
],
output_perm
)
input_perm
)
val_outs
=
outputs
if
output_perm
is
None
else
map
(
lambda
i
:
outputs
[
i
],
output_perm
)
var_inps
=
[
_make_var_name
(
val
)
for
val
in
val_inps
]
var_inps
=
[
_make_var_name
(
val
)
for
val
in
val_inps
]
var_outs
=
[
_make_var_name
(
val
)
for
val
in
val_outs
]
var_outs
=
[
_make_var_name
(
val
)
for
val
in
val_outs
]
arg_name
=
', name={}'
.
format
(
repr
(
name
))
if
fill_name_field
and
name
else
''
arg_name
=
', name={}'
.
format
(
arg_attrs
=
[
', {}={}'
.
format
(
key
,
value
)
for
key
,
value
in
paddle_attrs
.
items
()]
repr
(
name
))
if
fill_name_field
and
name
else
''
arg_attrs
=
[
prog
.
Code
(
'{} = layers.{}({}{}{})'
', {}={}'
.
format
(
key
,
value
)
for
key
,
value
in
fluid_attrs
.
items
()
.
format
(
', '
.
join
(
var_outs
),
]
paddle_op
,
prog
.
Code
(
'{} = layers.{}({}{}{})'
.
format
(
', '
.
join
(
var_outs
),
fluid_op
,
', '
.
join
(
var_inps
),
', '
.
join
(
var_inps
),
''
.
join
(
arg_attrs
),
''
.
join
(
arg_attrs
),
arg_name
,
arg_name
,
))
))
for
va
l_out
,
var_out
in
zip
(
val_outs
,
var_outs
)
:
for
va
r_out
in
var_outs
:
prog
.
VarDesc
(
var_out
)
prog
.
VarDesc
(
var_out
)
# dummy var_out
# dummy var_out
num_vars
=
len
(
var_outs
)
num_vars
=
len
(
var_outs
)
num_args
=
len
(
paddle
_output_args
)
num_args
=
len
(
fluid
_output_args
)
if
num_vars
<
num_args
:
if
num_vars
<
num_args
:
assert
fill_name_field
,
'name required to naming dummy output variable'
assert
fill_name_field
,
'name required to naming dummy output variable'
for
idx_out
in
range
(
num_vars
,
num_args
):
for
idx_out
in
range
(
num_vars
,
num_args
):
var_out
=
_make_var_name
(
name
+
'.'
+
paddle_output_args
[
idx_out
].
lower
())
var_out
=
_make_var_name
(
name
+
'.'
+
fluid_output_args
[
idx_out
].
lower
())
var_outs
.
append
(
var_out
)
var_outs
.
append
(
var_out
)
prog
.
VarDesc
(
var_out
)
prog
.
VarDesc
(
var_out
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
(
var_inps
,
*
fluid_input_args
),
(
var_inps
,
*
paddle_input_args
),
(
var_outs
,
*
fluid_output_args
),
fluid_attrs
)
(
var_outs
,
*
paddle_output_args
),
paddle_attrs
)
def
_assign
(
prog
,
attrs
):
def
_assign
(
prog
,
attrs
):
mapping
=
attrs
[
'mapping'
]
# additional
mapping
=
attrs
[
'mapping'
]
# additional
paddle
_op
=
'assign'
fluid
_op
=
'assign'
for
val_dst
,
val_src
in
mapping
.
items
():
for
val_dst
,
val_src
in
mapping
.
items
():
var_dst
=
_make_var_name
(
val_dst
)
var_dst
=
_make_var_name
(
val_dst
)
var_src
=
_make_var_name
(
val_src
)
var_src
=
_make_var_name
(
val_src
)
prog
.
Code
(
'{} = {}'
.
format
(
var_dst
,
var_src
))
prog
.
Code
(
'{} = {}'
.
format
(
var_dst
,
var_src
))
# prog.Code('{} = layers.{}({})'
# prog.Code('{} = layers.{}({})'
# .format(var_dst,
# .format(var_dst,
# paddle
_op,
# fluid
_op,
# var_src,
# var_src,
# ))
# ))
prog
.
VarDesc
(
var_dst
)
prog
.
VarDesc
(
var_dst
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_src
],
'X'
),
([
var_src
],
'X'
),
([
var_dst
],
'Out'
),
([
var_dst
],
'Out'
),
dict
(),
dict
(),
...
@@ -283,10 +299,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -283,10 +299,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
return
pads
[:
ndims
],
None
return
pads
[:
ndims
],
None
val_padded
=
val_name
+
'_padded'
val_padded
=
val_name
+
'_padded'
prog
.
Op
(
''
,
'Pad'
,
prog
.
Op
(
''
,
'Pad'
,
[
val_name
],
[
val_name
],
[
val_padded
],
# val
[
val_padded
],
# val
dict
(
mode
=
'constant'
,
dict
(
mode
=
'constant'
,
value
=
0.
,
value
=
0.
,
pads
=
pads
,
pads
=
pads
,
),
),
...
@@ -295,7 +314,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -295,7 +314,13 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
)
)
return
[
0
]
*
ndims
,
val_padded
return
[
0
]
*
ndims
,
val_padded
def
_adaptive_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
def
_adaptive_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
val_x
,
=
inputs
...
@@ -312,11 +337,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -312,11 +337,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
pool_size
=
attrs
[
'output_size'
]
# required
pool_size
=
attrs
[
'output_size'
]
# required
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
output_shape
=
_shape_or_none
(
value_infos
,
val_y
)
if
output_shape
is
not
None
:
if
output_shape
is
not
None
:
assert
pool_size
==
output_shape
[
2
:],
'pool_size unmatches shape of Y'
# NC...
assert
pool_size
==
output_shape
[
2
:],
'pool_size unmatches shape of Y'
# NC...
poolnd
=
len
(
pool_size
)
poolnd
=
len
(
pool_size
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
paddle
_op
=
'adaptive_pool{}d'
.
format
(
poolnd
)
fluid
_op
=
'adaptive_pool{}d'
.
format
(
poolnd
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
...
@@ -324,9 +350,10 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -324,9 +350,10 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', require_index={}'
', require_index={}'
', pool_size={}'
', pool_size={}'
', pool_type={}'
', pool_type={}'
'{})'
'{})'
.
format
(
.
format
(
var_y
,
', {}'
.
format
(
var_indices
)
if
has_indices
else
''
,
var_y
,
paddle_op
,
', {}'
.
format
(
var_indices
)
if
has_indices
else
''
,
fluid_op
,
var_x
,
var_x
,
# attrs
# attrs
has_indices
,
has_indices
,
...
@@ -334,14 +361,16 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -334,14 +361,16 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
repr
(
pool_type
),
repr
(
pool_type
),
name_attr
,
name_attr
,
))
))
paddle
_op
=
'pool{}d'
.
format
(
poolnd
)
fluid
_op
=
'pool{}d'
.
format
(
poolnd
)
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
if
has_indices
:
if
has_indices
:
prog
.
VarDesc
(
var_indices
)
prog
.
VarDesc
(
var_indices
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
dict
(
global_pooling
=
False
,
dict
(
global_pooling
=
False
,
adaptive
=
True
,
adaptive
=
True
,
exclusive
=
True
,
exclusive
=
True
,
require_index
=
has_indices
,
require_index
=
has_indices
,
...
@@ -351,8 +380,7 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -351,8 +380,7 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
)
)
def
_global_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
def
_global_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
val_x
,
=
inputs
val_y
,
=
outputs
val_y
,
=
outputs
...
@@ -369,33 +397,34 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -369,33 +397,34 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos,
poolnd
=
len
(
output_shape
)
-
2
# NC...
poolnd
=
len
(
output_shape
)
-
2
# NC...
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
paddle
_op
=
'pool{}d'
.
format
(
poolnd
)
fluid
_op
=
'pool{}d'
.
format
(
poolnd
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}, global_pooling=True'
prog
.
Code
(
'{} = layers.{}({}, global_pooling=True'
', pool_type={}'
', pool_type={}'
'{})'
'{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle
_op
,
fluid
_op
,
var_x
,
var_x
,
# attrs
# attrs
repr
(
pool_type
),
repr
(
pool_type
),
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
],
'Out'
),
([
var_y
],
'Out'
),
dict
(
global_pooling
=
True
,
dict
(
global_pooling
=
True
,
adaptive
=
False
,
adaptive
=
False
,
pooling_type
=
pool_type
,
pooling_type
=
pool_type
,
),
),
)
)
def
_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
def
_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
val_x
,
=
inputs
val_y
,
=
outputs
[:
1
]
val_y
,
=
outputs
[:
1
]
...
@@ -407,12 +436,14 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -407,12 +436,14 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
var_indices
=
_make_var_name
(
val_indices
)
var_indices
=
_make_var_name
(
val_indices
)
# interpretation
# interpretation
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad = NOTSET supported'
# optional
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad = NOTSET supported'
# optional
pool_size
=
attrs
[
'kernel_shape'
]
# required
pool_size
=
attrs
[
'kernel_shape'
]
# required
poolnd
=
len
(
pool_size
)
poolnd
=
len
(
pool_size
)
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
assert
2
<=
poolnd
<=
3
,
'only pool2d and pool3d supported'
paddle
_op
=
'pool{}d'
.
format
(
poolnd
)
fluid
_op
=
'pool{}d'
.
format
(
poolnd
)
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
len
(
pool_size
*
2
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
len
(
pool_size
*
2
))
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
...
@@ -429,9 +460,10 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -429,9 +460,10 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
', pool_stride={}'
', pool_stride={}'
', pool_padding={}'
', pool_padding={}'
', ceil_mode={}'
', ceil_mode={}'
'{})'
'{})'
.
format
(
.
format
(
var_y
,
', {}'
.
format
(
var_indices
)
if
has_indices
else
''
,
var_y
,
paddle_op
,
', {}'
.
format
(
var_indices
)
if
has_indices
else
''
,
fluid_op
,
var_x
,
var_x
,
# attrs
# attrs
pool_size
,
pool_size
,
...
@@ -444,23 +476,25 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
...
@@ -444,23 +476,25 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos,
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
if
has_indices
:
if
has_indices
:
prog
.
VarDesc
(
var_indices
)
prog
.
VarDesc
(
var_indices
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
has_indices
else
[]),
'Out'
,
'Indices'
),
dict
(
global_pooling
=
False
,
dict
(
global_pooling
=
False
,
adaptive
=
False
,
adaptive
=
False
,
exclusive
=
True
,
exclusive
=
True
,
require_index
=
has_indices
,
require_index
=
has_indices
,
pooling_type
=
pool_type
,
pooling_type
=
pool_type
,
ksize
=
pool_size
,
ksize
=
pool_size
,
strides
=
strides
,
strides
=
strides
,
pool_padding
=
paddings
,
paddings
=
paddings
,
ceil_mode
=
ceil_mode
,
ceil_mode
=
ceil_mode
,
),
),
)
)
def
_roi_pool
(
prog
,
paddle
_op
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
):
def
_roi_pool
(
prog
,
fluid
_op
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
):
# I/O
# I/O
val_x
,
val_rois
=
inputs
val_x
,
val_rois
=
inputs
val_y
,
=
outputs
val_y
,
=
outputs
...
@@ -469,7 +503,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
...
@@ -469,7 +503,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
var_y
=
_make_var_name
(
val_y
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
pooled_height
,
pooled_width
=
attrs
[
'pooled_shape'
]
# required
od_attrs
=
dict
(
od_attrs
=
dict
(
spatial_scale
=
spatial_scale
,
spatial_scale
=
spatial_scale
,
...
@@ -477,7 +511,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
...
@@ -477,7 +511,7 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
pooled_width
=
pooled_width
,
pooled_width
=
pooled_width
,
)
)
feature_attr
=
''
feature_attr
=
''
is_max_pool
=
paddle
_op
==
'roi_pool'
is_max_pool
=
fluid
_op
==
'roi_pool'
if
'sampling_ratio'
in
attrs
:
if
'sampling_ratio'
in
attrs
:
sampling_ratio
=
attrs
[
'sampling_ratio'
]
sampling_ratio
=
attrs
[
'sampling_ratio'
]
od_attrs
[
'sampling_ratio'
]
=
sampling_ratio
od_attrs
[
'sampling_ratio'
]
=
sampling_ratio
...
@@ -492,10 +526,11 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
...
@@ -492,10 +526,11 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
', spatial_scale={}'
', spatial_scale={}'
', pooled_height={}'
', pooled_height={}'
', pooled_width={}'
', pooled_width={}'
'{})'
'{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle_op
,
fluid_op
,
val_x
,
var_rois
,
val_x
,
var_rois
,
# attrs
# attrs
spatial_scale
,
spatial_scale
,
pooled_height
,
pooled_height
,
...
@@ -506,7 +541,8 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
...
@@ -506,7 +541,8 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
if
is_max_pool
:
if
is_max_pool
:
var_argmax
=
_make_var_name
(
name
+
'.argmax'
)
# implicit variable
var_argmax
=
_make_var_name
(
name
+
'.argmax'
)
# implicit variable
prog
.
VarDesc
(
var_argmax
)
prog
.
VarDesc
(
var_argmax
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
,
var_rois
],
'X'
,
'Rois'
),
([
var_x
,
var_rois
],
'X'
,
'Rois'
),
([
var_y
]
+
([
var_argmax
]
if
is_max_pool
else
[]),
'Out'
,
'Argmax'
),
([
var_y
]
+
([
var_argmax
]
if
is_max_pool
else
[]),
'Out'
,
'Argmax'
),
od_attrs
,
od_attrs
,
...
@@ -514,7 +550,9 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
...
@@ -514,7 +550,9 @@ def _roi_pool(prog, paddle_op, inputs, outputs, attrs, value_infos, name):
def
_zeros_like
(
prog
,
val_ref
,
val_out
,
value_infos
):
def
_zeros_like
(
prog
,
val_ref
,
val_out
,
value_infos
):
prog
.
Op
(
''
,
'Sub'
,
prog
.
Op
(
''
,
'Sub'
,
[
val_ref
,
val_ref
],
[
val_ref
,
val_ref
],
[
val_out
],
# val
[
val_out
],
# val
dict
(
axis
=
0
),
dict
(
axis
=
0
),
...
@@ -522,47 +560,54 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
...
@@ -522,47 +560,54 @@ def _zeros_like(prog, val_ref, val_out, value_infos):
)
)
def
AdaptiveAveragePool
(
def
AdaptiveAveragePool
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
name
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
aten::adaptive_avg_poolnd
aten::adaptive_avg_poolnd
"""
"""
return
_adaptive_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_adaptive_pool
(
name
=
name
)
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
AdaptiveMaxPool
(
def
AdaptiveMaxPool
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
name
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
aten::adaptive_max_poolnd
aten::adaptive_max_poolnd
"""
"""
return
_adaptive_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_adaptive_pool
(
name
=
name
)
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
AveragePool
(
def
AveragePool
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
name
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
onnx::AveragePool-10:
onnx::AveragePool-10:
"""
"""
return
_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
name
=
name
)
def
AffineGrid
(
def
AffineGrid
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
aten::affine_grid
aten::affine_grid
"""
"""
...
@@ -574,33 +619,39 @@ def AffineGrid(
...
@@ -574,33 +619,39 @@ def AffineGrid(
var_grid
=
_make_var_name
(
val_grid
)
var_grid
=
_make_var_name
(
val_grid
)
# interpretation
# interpretation
paddle
_op
=
'affine_grid'
fluid
_op
=
'affine_grid'
size
=
attrs
[
'size'
]
# required
size
=
attrs
[
'size'
]
# required
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', out_shape={}'
', out_shape={}'
'{})'
'{})'
.
format
(
.
format
(
var_grid
,
var_grid
,
paddle
_op
,
fluid
_op
,
var_theta
,
var_theta
,
# attrs
# attrs
size
,
size
,
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_grid
)
prog
.
VarDesc
(
var_grid
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_theta
],
'Theta'
),
([
var_theta
],
'Theta'
),
([
var_grid
],
'Output'
),
([
var_grid
],
'Output'
),
dict
(
output_shape
=
size
),
# f**k you API
dict
(
output_shape
=
size
),
# f**k you API
)
)
def
BatchNormalization
(
def
BatchNormalization
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
name
=
''
,
embed_params
=
False
,
outputs
,
*
args
,
**
kwargs
):
attrs
,
value_infos
,
name
=
''
,
embed_params
=
False
,
*
args
,
**
kwargs
):
"""
"""
onnx::BatchNormalization-9:
onnx::BatchNormalization-9:
"""
"""
...
@@ -612,7 +663,7 @@ def BatchNormalization(
...
@@ -612,7 +663,7 @@ def BatchNormalization(
var_y
=
_make_var_name
(
val_y
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
paddle
_op
=
'batch_norm'
fluid
_op
=
'batch_norm'
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
momentum
=
attrs
.
get
(
'momentum'
,
.
9
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -633,8 +684,9 @@ def BatchNormalization(
...
@@ -633,8 +684,9 @@ def BatchNormalization(
var_mean
=
_make_var_name
(
val_mean
)
var_mean
=
_make_var_name
(
val_mean
)
var_var
=
_make_var_name
(
val_var
)
var_var
=
_make_var_name
(
val_var
)
param_attr
=
(
', param_attr={}, bias_attr={}'
param_attr
=
(
', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}'
', moving_mean_name={}, moving_variance_name={}'
).
format
(
).
format
(
repr
(
var_scale
),
repr
(
var_b
),
repr
(
var_mean
),
repr
(
var_var
))
repr
(
var_scale
),
repr
(
var_b
),
repr
(
var_mean
),
repr
(
var_var
))
var_saved_mean
=
'{}.saved_mean'
.
format
(
name
)
# dropped var
var_saved_mean
=
'{}.saved_mean'
.
format
(
name
)
# dropped var
var_saved_variance
=
'{}.saved_variance'
.
format
(
name
)
# dropped var
var_saved_variance
=
'{}.saved_variance'
.
format
(
name
)
# dropped var
...
@@ -642,24 +694,27 @@ def BatchNormalization(
...
@@ -642,24 +694,27 @@ def BatchNormalization(
prog
.
Code
(
'{} = layers.{}({}, is_test=True, data_layout="NCHW"'
prog
.
Code
(
'{} = layers.{}({}, is_test=True, data_layout="NCHW"'
', momentum={}'
', momentum={}'
', epsilon={}'
', epsilon={}'
'{}{})'
'{}{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle
_op
,
fluid
_op
,
var_x
,
var_x
,
# attrs
# attrs
momentum
,
momentum
,
epsilon
,
epsilon
,
param_attr
,
name_attr
,
param_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_saved_mean
)
prog
.
VarDesc
(
var_saved_mean
)
prog
.
VarDesc
(
var_saved_variance
)
prog
.
VarDesc
(
var_saved_variance
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
([
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
],
fluid_op
,
'X'
,
'Scale'
,
'Bias'
,
'Mean'
,
'Variance'
),
([
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
],
'X'
,
'Scale'
,
'Bias'
,
([
var_y
,
var_mean
,
var_saved_mean
,
var_saved_variance
,
var_var
],
'Mean'
,
'Variance'
),
'Y'
,
'MeanOut'
,
'SavedMean'
,
'SavedVariance'
,
'VarianceOut'
),
([
var_y
,
var_mean
,
var_saved_mean
,
var_saved_variance
,
var_var
],
'Y'
,
dict
(
is_test
=
1
,
'MeanOut'
,
'SavedMean'
,
'SavedVariance'
,
'VarianceOut'
),
dict
(
is_test
=
1
,
data_layout
=
'NCHW'
,
data_layout
=
'NCHW'
,
use_global_stats
=
False
,
use_global_stats
=
False
,
momentum
=
momentum
,
momentum
=
momentum
,
...
@@ -667,9 +722,7 @@ def BatchNormalization(
...
@@ -667,9 +722,7 @@ def BatchNormalization(
)
)
def
Cast
(
def
Cast
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Cast-9:
onnx::Cast-9:
"""
"""
...
@@ -688,33 +741,31 @@ def Cast(
...
@@ -688,33 +741,31 @@ def Cast(
if
output_dtype
:
if
output_dtype
:
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
paddle
_op
=
'cast'
fluid
_op
=
'cast'
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', dtype={}'
', dtype={}'
')'
')'
.
format
(
.
format
(
var_output
,
var_output
,
paddle
_op
,
fluid
_op
,
var_input
,
var_input
,
# attrs
# attrs
repr
(
dtype
.
name
),
repr
(
dtype
.
name
),
))
))
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_input
],
'X'
),
([
var_input
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
in_dtype
=
prog
.
Dtype
(
_dtype
(
value_infos
,
val_input
)),
# holy, required
dict
(
in_dtype
=
prog
.
Dtype
(
_dtype
(
value_infos
,
val_input
)),
# holy, required
out_dtype
=
prog
.
Dtype
(
dtype
),
out_dtype
=
prog
.
Dtype
(
dtype
),
)
))
)
def
Concat
(
def
Concat
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
onnx::Concat-4:
onnx::Concat-4:
"""
"""
...
@@ -725,32 +776,31 @@ def Concat(
...
@@ -725,32 +776,31 @@ def Concat(
var_concat_result
=
_make_var_name
(
val_concat_result
)
var_concat_result
=
_make_var_name
(
val_concat_result
)
# interpretation
# interpretation
paddle
_op
=
'concat'
fluid
_op
=
'concat'
axis
=
attrs
[
'axis'
]
# required
axis
=
attrs
[
'axis'
]
# required
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', axis={}'
', axis={}'
'{})'
'{})'
.
format
(
.
format
(
var_concat_result
,
var_concat_result
,
paddle
_op
,
fluid
_op
,
'['
+
', '
.
join
(
var_inps
)
+
']'
,
'['
+
', '
.
join
(
var_inps
)
+
']'
,
# attrs
# attrs
axis
,
axis
,
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_concat_result
)
prog
.
VarDesc
(
var_concat_result
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
([
var_concat_result
],
'Out'
),
([
var_concat_result
],
'Out'
),
dict
(
axis
=
axis
),
dict
(
axis
=
axis
),
)
)
def
Constant
(
def
Constant
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Constant-9:
onnx::Constant-9:
"""
"""
...
@@ -766,36 +816,40 @@ def Constant(
...
@@ -766,36 +816,40 @@ def Constant(
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
val_output
)
if
output_dtype
:
if
output_dtype
:
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
# dtype = np.dtype('float32') # force to float32
# dtype = np.dtype('float32') # force to float32
shape
=
attrs
.
get
(
'shape'
,
None
)
# additional, maybe var_name
shape
=
attrs
.
get
(
'shape'
,
None
)
# additional, maybe var_name
if
shape
is
None
:
if
shape
is
None
:
shape
=
_shape_or_none
(
value_infos
,
val_output
)
shape
=
_shape_or_none
(
value_infos
,
val_output
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
list
(
value
.
shape
)
shape
=
list
(
value
.
shape
)
_logger
.
warning
(
'shape of %s not inferred, using value as 1-D tensor may lead to fails'
,
val_output
)
_logger
.
warning
(
'shape of %s not inferred, using value as 1-D tensor may lead to fails'
,
val_output
)
# generation
# generation
if
value
.
size
==
1
:
# scalar
if
value
.
size
==
1
:
# scalar
paddle
_op
=
'fill_constant'
fluid
_op
=
'fill_constant'
prog
.
Code
(
'{} = layers.{}(shape={}, dtype={}, value={})'
prog
.
Code
(
'{} = layers.{}(shape={}, dtype={}, value={})'
.
format
(
.
format
(
var_output
,
var_output
,
paddle
_op
,
fluid
_op
,
# attrs
# attrs
shape
,
repr
(
dtype
.
name
),
value
[
0
],
# shape can be list or var_name
shape
,
repr
(
dtype
.
name
),
value
[
0
],
# shape can be list or var_name
))
))
value_infos
[
val_output
][
'const_value'
]
=
value
[
0
]
value_infos
[
val_output
][
'const_value'
]
=
value
[
0
]
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
else
:
# list parameter -> const_value
else
:
# list parameter -> const_value
prog
.
Code
(
'{} = {}'
prog
.
Code
(
'{} = {}'
.
format
(
.
format
(
var_output
,
var_output
,
value
.
tolist
(),
value
.
tolist
(),
))
))
value_infos
[
val_output
][
'const_value'
]
=
value
.
tolist
()
value_infos
[
val_output
][
'const_value'
]
=
value
.
tolist
()
def
ConstantOfShape
(
def
ConstantOfShape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::ConstantOfShape-9:
onnx::ConstantOfShape-9:
"""
"""
...
@@ -815,10 +869,15 @@ def ConstantOfShape(
...
@@ -815,10 +869,15 @@ def ConstantOfShape(
Constant
(
prog
,
[],
outputs
,
attrs
,
value_infos
)
Constant
(
prog
,
[],
outputs
,
attrs
,
value_infos
)
def
Conv
(
def
Conv
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
name
=
''
,
embed_params
=
False
,
outputs
,
*
args
,
**
kwargs
):
attrs
,
value_infos
,
name
=
''
,
embed_params
=
False
,
*
args
,
**
kwargs
):
"""
"""
onnx::ConstantOfShape-1:
onnx::ConstantOfShape-1:
"""
"""
...
@@ -833,14 +892,17 @@ def Conv(
...
@@ -833,14 +892,17 @@ def Conv(
val_b
,
=
inputs
[
2
:]
val_b
,
=
inputs
[
2
:]
# interpretation
# interpretation
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad == NOTSET supported'
# optional
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad == NOTSET supported'
# optional
kernel_shape
=
_shape
(
value_infos
,
val_w
)[
2
:]
# OI...
kernel_shape
=
_shape
(
value_infos
,
val_w
)[
2
:]
# OI...
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
convnd
=
len
(
kernel_shape
)
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d and conv3d supported'
assert
2
<=
convnd
<=
3
,
'only conv2d and conv3d supported'
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
0
]
# OI...
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
0
]
# OI...
paddle
_op
=
'conv{}d'
.
format
(
convnd
)
fluid
_op
=
'conv{}d'
.
format
(
convnd
)
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
...
@@ -864,7 +926,8 @@ def Conv(
...
@@ -864,7 +926,8 @@ def Conv(
var_w
=
_make_var_name
(
val_w
)
var_w
=
_make_var_name
(
val_w
)
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
...
@@ -874,9 +937,9 @@ def Conv(
...
@@ -874,9 +937,9 @@ def Conv(
', padding={}'
', padding={}'
', dilation={}'
', dilation={}'
', groups={}'
', groups={}'
'{}{})'
'{}{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle
_op
,
fluid
_op
,
var_x
,
var_x
,
# attrs
# attrs
num_out_channels
,
num_out_channels
,
...
@@ -885,13 +948,16 @@ def Conv(
...
@@ -885,13 +948,16 @@ def Conv(
paddings
,
paddings
,
dilations
,
dilations
,
num_groups
,
num_groups
,
param_attr
,
name_attr
,
param_attr
,
name_attr
,
))
))
var_conv
=
_make_var_name
(
name
+
'.conv'
)
# hidden variable
var_conv
=
_make_var_name
(
name
+
'.conv'
)
# hidden variable
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
dict
(
strides
=
strides
,
dict
(
strides
=
strides
,
paddings
=
paddings
,
paddings
=
paddings
,
dilations
=
dilations
,
dilations
=
dilations
,
groups
=
num_groups
,
groups
=
num_groups
,
...
@@ -899,7 +965,8 @@ def Conv(
...
@@ -899,7 +965,8 @@ def Conv(
if
has_bias
:
if
has_bias
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
''
,
'Add'
,
''
,
'Add'
,
[
var_conv
,
var_b
],
[
var_conv
,
var_b
],
[
var_y
],
# var
[
var_y
],
# var
dict
(
axis
=
1
),
dict
(
axis
=
1
),
...
@@ -910,10 +977,15 @@ def Conv(
...
@@ -910,10 +977,15 @@ def Conv(
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
def
ConvTranspose
(
def
ConvTranspose
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
name
=
''
,
embed_params
=
False
,
outputs
,
*
args
,
**
kwargs
):
attrs
,
value_infos
,
name
=
''
,
embed_params
=
False
,
*
args
,
**
kwargs
):
"""
"""
onnx::ConvTranspose-1:
onnx::ConvTranspose-1:
"""
"""
...
@@ -928,15 +1000,20 @@ def ConvTranspose(
...
@@ -928,15 +1000,20 @@ def ConvTranspose(
val_b
,
=
inputs
[
2
:]
val_b
,
=
inputs
[
2
:]
# interpretation
# interpretation
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad == NOTSET supported'
# optional
assert
attrs
.
get
(
assert
sum
(
attrs
.
get
(
'output_padding'
,
[]))
==
0
,
'only zero output_padding supported'
# optional ?
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad == NOTSET supported'
# optional
assert
sum
(
attrs
.
get
(
'output_padding'
,
[]))
==
0
,
'only zero output_padding supported'
# optional ?
kernel_shape
=
_shape
(
value_infos
,
val_w
)[
2
:]
# IO...
kernel_shape
=
_shape
(
value_infos
,
val_w
)[
2
:]
# IO...
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
convnd
=
len
(
kernel_shape
)
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose supported'
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose supported'
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
1
]
# IO...
num_out_channels
=
_shape
(
value_infos
,
val_w
)[
1
]
# IO...
paddle
_op
=
'conv{}d_transpose'
.
format
(
convnd
)
fluid
_op
=
'conv{}d_transpose'
.
format
(
convnd
)
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
convnd
*
2
)
# optional
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
val_x_padded
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
...
@@ -960,20 +1037,21 @@ def ConvTranspose(
...
@@ -960,20 +1037,21 @@ def ConvTranspose(
var_w
=
_make_var_name
(
val_w
)
var_w
=
_make_var_name
(
val_w
)
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', num_filters={}'
', num_filters={}'
# ', output_size={}'
# ', output_size={}'
', filter_size={}'
', filter_size={}'
', padding={}'
', padding={}'
', stride={}'
', stride={}'
', dilation={}'
', dilation={}'
', groups={}'
', groups={}'
'{}{})'
'{}{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle
_op
,
fluid
_op
,
var_x
,
var_x
,
# attrs
# attrs
num_out_channels
,
num_out_channels
,
...
@@ -982,13 +1060,16 @@ def ConvTranspose(
...
@@ -982,13 +1060,16 @@ def ConvTranspose(
strides
,
strides
,
dilations
,
dilations
,
num_groups
,
num_groups
,
param_attr
,
name_attr
,
param_attr
,
name_attr
,
))
))
var_conv
=
_make_var_name
(
name
+
'.conv'
)
# hidden variable
var_conv
=
_make_var_name
(
name
+
'.conv'
)
# hidden variable
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
dict
(
strides
=
strides
,
dict
(
strides
=
strides
,
paddings
=
paddings
,
paddings
=
paddings
,
dilations
=
dilations
,
dilations
=
dilations
,
# output_size=output_size,
# output_size=output_size,
...
@@ -997,7 +1078,8 @@ def ConvTranspose(
...
@@ -997,7 +1078,8 @@ def ConvTranspose(
if
has_bias
:
if
has_bias
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
''
,
'Add'
,
''
,
'Add'
,
[
var_conv
,
var_b
],
[
var_conv
,
var_b
],
[
var_y
],
# var
[
var_y
],
# var
dict
(
axis
=
1
),
dict
(
axis
=
1
),
...
@@ -1025,14 +1107,12 @@ def ConvTranspose(
...
@@ -1025,14 +1107,12 @@ def ConvTranspose(
# )
# )
def
Gemm
(
def
Gemm
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
onnx::Gemm-9:
onnx::Gemm-9:
"""
"""
# due to
paddle
fc don't support transposed weight, we use matmul + ew_add
# due to
fluid
fc don't support transposed weight, we use matmul + ew_add
val_a
,
val_b
,
val_c
=
inputs
val_a
,
val_b
,
val_c
=
inputs
val_y
,
=
outputs
val_y
,
=
outputs
...
@@ -1042,23 +1122,29 @@ def Gemm(
...
@@ -1042,23 +1122,29 @@ def Gemm(
trans_b
=
bool
(
attrs
.
get
(
'transB'
,
0
))
# optional
trans_b
=
bool
(
attrs
.
get
(
'transB'
,
0
))
# optional
val_mm
=
name
+
'_mm'
# explicit variable
val_mm
=
name
+
'_mm'
# explicit variable
prog
.
Op
(
''
,
'MatMul'
,
prog
.
Op
(
''
,
'MatMul'
,
[
val_a
,
val_b
],
[
val_a
,
val_b
],
[
val_mm
],
# val
[
val_mm
],
# val
dict
(
transpose_x
=
trans_a
,
dict
(
transpose_x
=
trans_a
,
transpose_y
=
trans_b
,
transpose_y
=
trans_b
,
alpha
=
alpha
,
alpha
=
alpha
,
),
),
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
val_mm
,
name
=
val_mm
,
)
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
(
dict
(
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
(
dict
(
transpose_X
=
trans_a
,
transpose_X
=
trans_a
,
transpose_Y
=
trans_b
,
transpose_Y
=
trans_b
,
)))
# f**k you API
)))
# f**k you API
if
beta
!=
0
:
if
beta
!=
0
:
if
beta
==
1.
:
# exactly
if
beta
==
1.
:
# exactly
prog
.
Op
(
''
,
'Add'
,
prog
.
Op
(
''
,
'Add'
,
[
val_mm
,
val_c
],
[
val_mm
,
val_c
],
[
val_y
],
# val
[
val_y
],
# val
dict
(
axis
=
1
),
dict
(
axis
=
1
),
...
@@ -1072,21 +1158,27 @@ def Gemm(
...
@@ -1072,21 +1158,27 @@ def Gemm(
if
vm_dtype
is
None
:
if
vm_dtype
is
None
:
vm_dtype
=
np
.
dtype
(
'float32'
)
vm_dtype
=
np
.
dtype
(
'float32'
)
beta
=
np
.
dtype
(
vm_dtype
).
type
(
beta
)
beta
=
np
.
dtype
(
vm_dtype
).
type
(
beta
)
prog
.
Op
(
''
,
'Constant'
,
prog
.
Op
(
''
,
'Constant'
,
[],
[],
[
val_beta
],
# val
[
val_beta
],
# val
dict
(
value
=
beta
),
dict
(
value
=
beta
),
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
val_beta
,
name
=
val_beta
,
)
)
prog
.
Op
(
''
,
'Mul'
,
prog
.
Op
(
''
,
'Mul'
,
[
val_c
,
val_beta
],
[
val_c
,
val_beta
],
[
val_vm
],
# val
[
val_vm
],
# val
dict
(),
dict
(),
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_scale'
),
name
=
(
name
+
'_scale'
),
)
)
prog
.
Op
(
''
,
'Add'
,
prog
.
Op
(
''
,
'Add'
,
[
val_mm
,
val_vm
],
[
val_mm
,
val_vm
],
[
val_y
],
# val
[
val_y
],
# val
dict
(
axis
=
1
),
dict
(
axis
=
1
),
...
@@ -1094,28 +1186,36 @@ def Gemm(
...
@@ -1094,28 +1186,36 @@ def Gemm(
)
)
def
GlobalAveragePool
(
def
GlobalAveragePool
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
name
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
onnx::GlobalAveragePool-1:
onnx::GlobalAveragePool-1:
"""
"""
return
_global_pool
(
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_global_pool
(
name
=
name
)
prog
,
'avg'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
def
GlobalMaxPool
(
def
GlobalMaxPool
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
name
=
''
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
onnx::GlobalMaxPool-1:
onnx::GlobalMaxPool-1:
"""
"""
return
_global_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_global_pool
(
name
=
name
)
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
#def LRN(
#def LRN(
...
@@ -1132,7 +1232,7 @@ def GlobalMaxPool(
...
@@ -1132,7 +1232,7 @@ def GlobalMaxPool(
# var_y = _make_var_name(val_y)
# var_y = _make_var_name(val_y)
#
#
# # interpretation
# # interpretation
#
paddle
_op = 'lrn'
#
fluid
_op = 'lrn'
# size = attrs['size'] # required
# size = attrs['size'] # required
# alpha = attrs.get('alpha', 0.0001) # optional
# alpha = attrs.get('alpha', 0.0001) # optional
# beta = attrs.get('beta', 0.75) # optional
# beta = attrs.get('beta', 0.75) # optional
...
@@ -1147,7 +1247,7 @@ def GlobalMaxPool(
...
@@ -1147,7 +1247,7 @@ def GlobalMaxPool(
# ', beta={}'
# ', beta={}'
# '{})'
# '{})'
# .format(var_y,
# .format(var_y,
#
paddle
_op,
#
fluid
_op,
# var_x,
# var_x,
# # attrs
# # attrs
# size,
# size,
...
@@ -1159,7 +1259,7 @@ def GlobalMaxPool(
...
@@ -1159,7 +1259,7 @@ def GlobalMaxPool(
# var_mid = name + '.mid' # hidden variable
# var_mid = name + '.mid' # hidden variable
# prog.VarDesc(var_y)
# prog.VarDesc(var_y)
# prog.VarDesc(var_mid)
# prog.VarDesc(var_mid)
# prog.OpDesc(
paddle
_op,
# prog.OpDesc(
fluid
_op,
# ([var_x], 'X'),
# ([var_x], 'X'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# ([var_y, var_mid], 'Out', 'MidOut'),
# dict(n=size,
# dict(n=size,
...
@@ -1170,21 +1270,17 @@ def GlobalMaxPool(
...
@@ -1170,21 +1270,17 @@ def GlobalMaxPool(
# )
# )
def
MaxPool
(
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
**
kwargs
):
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::MaxPool-10:
onnx::MaxPool-10:
"""
"""
return
_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
return
_pool
(
prog
,
'max'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
name
)
name
=
name
)
def
MaxRoiPool
(
def
MaxRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
onnx::MaxRoiPool-1:
onnx::MaxRoiPool-1:
"""
"""
...
@@ -1192,9 +1288,7 @@ def MaxRoiPool(
...
@@ -1192,9 +1288,7 @@ def MaxRoiPool(
_roi_pool
(
prog
,
'roi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'roi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
def
RoiAlign
(
def
RoiAlign
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
caffe2::RoiAlign
caffe2::RoiAlign
"""
"""
...
@@ -1202,10 +1296,7 @@ def RoiAlign(
...
@@ -1202,10 +1296,7 @@ def RoiAlign(
_roi_pool
(
prog
,
'roi_align'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'roi_align'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
def
Pad
(
def
Pad
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::Pad-2:
onnx::Pad-2:
"""
"""
...
@@ -1231,14 +1322,17 @@ def Pad(
...
@@ -1231,14 +1322,17 @@ def Pad(
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
assume_pad2d
|=
output_shape
and
len
(
output_shape
)
==
4
# NCHW
od_attrs
=
dict
(
pad_value
=
value
)
od_attrs
=
dict
(
pad_value
=
value
)
if
assume_pad2d
:
if
assume_pad2d
:
paddle
_op
=
'pad2d'
fluid
_op
=
'pad2d'
pad2d_attr
=
', mode={}, data_format="NCHW"'
.
format
(
repr
(
mode
))
pad2d_attr
=
', mode={}, data_format="NCHW"'
.
format
(
repr
(
mode
))
od_attrs
[
'mode'
]
=
mode
od_attrs
[
'mode'
]
=
mode
od_attrs
[
'data_format'
]
=
"NCHW"
else
:
else
:
assert
mode
==
'constant'
,
'mode {} is supported only in pad2d'
.
format
(
mode
)
assert
mode
==
'constant'
,
'mode {} is supported only in pad2d'
.
format
(
paddle_op
=
'pad'
mode
)
fluid_op
=
'pad'
pad2d_attr
=
''
pad2d_attr
=
''
paddings
=
np
.
array
(
pads
).
reshape
((
-
1
,
2
)).
transpose
().
flatten
().
tolist
()
# SSEE -> SESE
paddings
=
np
.
array
(
pads
).
reshape
(
(
-
1
,
2
)).
transpose
().
flatten
().
tolist
()
# SSEE -> SESE
od_attrs
[
'paddings'
]
=
paddings
od_attrs
[
'paddings'
]
=
paddings
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -1246,27 +1340,34 @@ def Pad(
...
@@ -1246,27 +1340,34 @@ def Pad(
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', paddings={}'
', paddings={}'
', pad_value={}'
', pad_value={}'
'{}{})'
'{}{})'
.
format
(
.
format
(
var_output
,
var_output
,
paddle
_op
,
fluid
_op
,
var_data
,
var_data
,
# attrs
# attrs
paddings
,
paddings
,
value
,
value
,
pad2d_attr
,
name_attr
,
pad2d_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_data
],
'X'
),
([
var_data
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
od_attrs
,
od_attrs
,
)
)
def
PRelu
(
def
PRelu
(
prog
,
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
inputs
,
name
=
''
,
embed_params
=
False
,
outputs
,
*
args
,
**
kwargs
):
attrs
,
value_infos
,
name
=
''
,
embed_params
=
False
,
*
args
,
**
kwargs
):
"""
"""
onnx::PRelu-9:
onnx::PRelu-9:
"""
"""
...
@@ -1278,7 +1379,7 @@ def PRelu(
...
@@ -1278,7 +1379,7 @@ def PRelu(
var_y
=
_make_var_name
(
val_y
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
paddle
_op
=
'prelu'
fluid
_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
...
@@ -1291,24 +1392,24 @@ def PRelu(
...
@@ -1291,24 +1392,24 @@ def PRelu(
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}, mode="all"'
prog
.
Code
(
'{} = layers.{}({}, mode="all"'
'{}{})'
'{}{})'
.
format
(
.
format
(
var_y
,
var_y
,
paddle
_op
,
fluid
_op
,
var_x
,
var_x
,
# attrs
# attrs
param_attr
,
name_attr
,
param_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
],
'Out'
),
([
var_y
],
'Out'
),
dict
(
mode
=
'all'
),
dict
(
mode
=
'all'
),
)
)
def
PsRoiPool
(
def
PsRoiPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
caffe2::PsRoiPool
caffe2::PsRoiPool
"""
"""
...
@@ -1316,9 +1417,7 @@ def PsRoiPool(
...
@@ -1316,9 +1417,7 @@ def PsRoiPool(
_roi_pool
(
prog
,
'psroi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
_roi_pool
(
prog
,
'psroi_pool'
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
)
def
Reshape
(
def
Reshape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
,
*
args
,
**
kwargs
):
"""
"""
onnx::Reshape-5:
onnx::Reshape-5:
"""
"""
...
@@ -1330,7 +1429,7 @@ def Reshape(
...
@@ -1330,7 +1429,7 @@ def Reshape(
var_reshaped
=
_make_var_name
(
val_reshaped
)
var_reshaped
=
_make_var_name
(
val_reshaped
)
# interpretation
# interpretation
paddle
_op
=
'reshape'
fluid
_op
=
'reshape'
is_const_shape
=
'const_value'
in
value_infos
[
val_shape
]
is_const_shape
=
'const_value'
in
value_infos
[
val_shape
]
var_shape
=
_make_var_name
(
val_shape
)
# for code
var_shape
=
_make_var_name
(
val_shape
)
# for code
if
is_const_shape
:
if
is_const_shape
:
...
@@ -1343,9 +1442,9 @@ def Reshape(
...
@@ -1343,9 +1442,9 @@ def Reshape(
if
is_const_shape
:
if
is_const_shape
:
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', shape={}'
', shape={}'
'{})'
'{})'
.
format
(
.
format
(
var_reshaped
,
var_reshaped
,
paddle
_op
,
fluid
_op
,
var_data
,
var_data
,
# attrs
# attrs
var_shape
,
var_shape
,
...
@@ -1353,7 +1452,9 @@ def Reshape(
...
@@ -1353,7 +1452,9 @@ def Reshape(
))
))
else
:
else
:
var_shape_int32
=
var_shape
+
'_int32'
var_shape_int32
=
var_shape
+
'_int32'
prog
.
Op
(
''
,
'Cast'
,
prog
.
Op
(
''
,
'Cast'
,
[
var_shape
],
[
var_shape
],
[
var_shape_int32
],
# var
[
var_shape_int32
],
# var
dict
(
to
=
np
.
dtype
(
'int32'
)),
dict
(
to
=
np
.
dtype
(
'int32'
)),
...
@@ -1363,36 +1464,36 @@ def Reshape(
...
@@ -1363,36 +1464,36 @@ def Reshape(
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', shape={}'
', shape={}'
', actual_shape={}'
', actual_shape={}'
'{})'
'{})'
.
format
(
.
format
(
var_reshaped
,
var_reshaped
,
paddle
_op
,
fluid
_op
,
var_data
,
var_data
,
# attrs
# attrs
shape
,
shape
,
var_shape_int32
,
var_shape_int32
,
name_attr
,
name_attr
,
))
))
paddle
_op
=
'reshape2'
fluid
_op
=
'reshape2'
var_xshape
=
_make_var_name
(
name
+
'.xshape'
)
var_xshape
=
_make_var_name
(
name
+
'.xshape'
)
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_reshaped
)
prog
.
VarDesc
(
var_xshape
)
prog
.
VarDesc
(
var_xshape
)
if
is_const_shape
:
if
is_const_shape
:
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_data
],
'X'
),
([
var_data
],
'X'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
dict
(
shape
=
shape
),
dict
(
shape
=
shape
),
)
)
else
:
else
:
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_data
,
var_shape_int32
],
'X'
,
'Shape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
([
var_reshaped
,
var_xshape
],
'Out'
,
'XShape'
),
dict
(
shape
=
shape
),
dict
(
shape
=
shape
),
)
)
def
Slice
(
def
Slice
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
onnx::Slice-1:9
onnx::Slice-1:9
"""
"""
...
@@ -1404,17 +1505,17 @@ def Slice(
...
@@ -1404,17 +1505,17 @@ def Slice(
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
paddle
_op
=
'slice'
fluid
_op
=
'slice'
axes
=
attrs
[
'axes'
]
# required
axes
=
attrs
[
'axes'
]
# required
starts
=
attrs
[
'starts'
]
# required
starts
=
attrs
[
'starts'
]
# required
ends
=
attrs
[
'ends'
]
# required
ends
=
attrs
[
'ends'
]
# required
shape
=
_shape_or_none
(
value_infos
,
val_data
)
shape
=
_shape_or_none
(
value_infos
,
val_data
)
if
shape
:
if
shape
:
ndims
=
len
(
shape
)
#
ndims = len(shape)
for
idx
,
value
in
enumerate
(
axes
):
#
for idx, value in enumerate(axes):
if
value
>
ONNX_INT_MAX
//
2
:
#
if value > ONNX_INT_MAX // 2:
axes
[
idx
]
=
ndims
+
value
-
ONNX_INT_MAX
-
1
#
axes[idx] = ndims + value - ONNX_INT_MAX - 1
#
HINT
: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
#
FIXME
: Paddle 1.3 Doc: '对于未知大小维度的末尾进行切片,则建议传入 INT_MAX' not works ?
for
idx
,
value
in
enumerate
(
starts
):
for
idx
,
value
in
enumerate
(
starts
):
if
value
>
ONNX_INT_MAX
//
2
:
if
value
>
ONNX_INT_MAX
//
2
:
value
=
value
-
ONNX_INT_MAX
-
1
value
=
value
-
ONNX_INT_MAX
-
1
...
@@ -1429,9 +1530,9 @@ def Slice(
...
@@ -1429,9 +1530,9 @@ def Slice(
', axes={}'
', axes={}'
', starts={}'
', starts={}'
', ends={}'
', ends={}'
')'
')'
.
format
(
.
format
(
var_output
,
var_output
,
paddle
_op
,
fluid
_op
,
var_data
,
var_data
,
# attrs
# attrs
axes
,
axes
,
...
@@ -1439,19 +1540,19 @@ def Slice(
...
@@ -1439,19 +1540,19 @@ def Slice(
ends
,
ends
,
))
))
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_data
],
'X'
),
([
var_data
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
axes
=
axes
,
dict
(
axes
=
axes
,
starts
=
starts
,
starts
=
starts
,
ends
=
ends
,
ends
=
ends
,
),
),
)
)
def
Sum
(
def
Sum
(
prog
,
inputs
,
outputs
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
*
args
,
**
kwargs
):
"""
"""
onnx::Sum-8:
onnx::Sum-8:
"""
"""
...
@@ -1462,27 +1563,25 @@ def Sum(
...
@@ -1462,27 +1563,25 @@ def Sum(
var_sum
=
_make_var_name
(
val_sum
)
var_sum
=
_make_var_name
(
val_sum
)
# interpretation
# interpretation
paddle
_op
=
'sums'
fluid
_op
=
'sums'
# generation
# generation
prog
.
Code
(
'{} = layers.{}({})'
prog
.
Code
(
'{} = layers.{}({})'
.
format
(
.
format
(
var_sum
,
var_sum
,
paddle
_op
,
fluid
_op
,
'['
+
', '
.
join
(
var_inps
)
+
']'
,
'['
+
', '
.
join
(
var_inps
)
+
']'
,
# attrs
# attrs
))
))
prog
.
VarDesc
(
var_sum
)
prog
.
VarDesc
(
var_sum
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
(
var_inps
,
*
([
'X'
]
*
len
(
var_inps
))),
([
var_sum
],
'Out'
),
([
var_sum
],
'Out'
),
dict
(),
dict
(),
)
)
def
Tile
(
def
Tile
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
"""
onnx::ConstantOfShape-6:
onnx::ConstantOfShape-6:
"""
"""
...
@@ -1494,7 +1593,7 @@ def Tile(
...
@@ -1494,7 +1593,7 @@ def Tile(
var_output
=
_make_var_name
(
val_output
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
paddle
_op
=
'expand'
fluid
_op
=
'expand'
is_const_repeats
=
'const_value'
in
value_infos
[
val_repeats
]
is_const_repeats
=
'const_value'
in
value_infos
[
val_repeats
]
if
is_const_repeats
:
if
is_const_repeats
:
code_repeats
=
_make_var_name
(
val_repeats
)
# for code
code_repeats
=
_make_var_name
(
val_repeats
)
# for code
...
@@ -1507,16 +1606,17 @@ def Tile(
...
@@ -1507,16 +1606,17 @@ def Tile(
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', expand_times={}'
', expand_times={}'
'{})'
'{})'
.
format
(
.
format
(
var_output
,
var_output
,
paddle
_op
,
fluid
_op
,
var_input
,
var_input
,
# attrs
# attrs
code_repeats
,
code_repeats
,
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_output
)
prog
.
VarDesc
(
var_output
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
([
var_input
],
'X'
),
([
var_input
],
'X'
),
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
dict
(
expand_times
=
repeats
),
dict
(
expand_times
=
repeats
),
...
@@ -1537,29 +1637,25 @@ def Tile(
...
@@ -1537,29 +1637,25 @@ def Tile(
# var_shape = _make_var_name(val_shape)
# var_shape = _make_var_name(val_shape)
#
#
# # interpretation
# # interpretation
#
paddle
_op = 'shape'
#
fluid
_op = 'shape'
## value_infos[val_shape]['remove_batch'] = False
## value_infos[val_shape]['remove_batch'] = False
#
#
# # generation
# # generation
# prog.Code('{} = layers.{}({})'
# prog.Code('{} = layers.{}({})'
# .format(var_shape,
# .format(var_shape,
#
paddle
_op,
#
fluid
_op,
# var_data,
# var_data,
# # attrs
# # attrs
# ))
# ))
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape))
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape))
# prog.OpDesc(
paddle
_op,
# prog.OpDesc(
fluid
_op,
# ([var_data], 'X'),
# ([var_data], 'X'),
# ([var_shape], 'Out'),
# ([var_shape], 'Out'),
# dict(),
# dict(),
# )
# )
def
Split
(
def
Split
(
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
prog
,
inputs
,
outputs
,
attrs
,
*
args
,
name
=
''
,
**
kwargs
):
"""
"""
onnx::Split-2:
onnx::Split-2:
"""
"""
...
@@ -1570,7 +1666,7 @@ def Split(
...
@@ -1570,7 +1666,7 @@ def Split(
var_input
=
_make_var_name
(
val_input
)
var_input
=
_make_var_name
(
val_input
)
# interpretation
# interpretation
paddle
_op
=
'split'
fluid
_op
=
'split'
split
=
attrs
[
'split'
]
# required
split
=
attrs
[
'split'
]
# required
axis
=
attrs
.
get
(
'axis'
,
0
)
# optional
axis
=
attrs
.
get
(
'axis'
,
0
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -1578,21 +1674,23 @@ def Split(
...
@@ -1578,21 +1674,23 @@ def Split(
# generation
# generation
prog
.
Code
(
'{} = layers.{}({}, {}'
prog
.
Code
(
'{} = layers.{}({}, {}'
', dim={}'
', dim={}'
'{})'
'{})'
.
format
(
.
format
(
', '
.
join
(
var_outs
),
', '
.
join
(
var_outs
),
paddle
_op
,
fluid
_op
,
var_input
,
var_input
,
split
,
split
,
# attrs
# attrs
axis
,
axis
,
name_attr
,
name_attr
,
))
))
for
va
l_out
,
var_out
in
zip
(
outputs
,
var_outs
)
:
for
va
r_out
in
var_outs
:
prog
.
VarDesc
(
var_out
)
prog
.
VarDesc
(
var_out
)
prog
.
OpDesc
(
paddle_op
,
prog
.
OpDesc
(
fluid_op
,
(
var_input
,
'X'
),
(
var_input
,
'X'
),
([
var_outs
],
*
([
'Out'
]
*
len
(
var_outs
))),
([
var_outs
],
*
([
'Out'
]
*
len
(
var_outs
))),
dict
(
axis
=
axis
,
dict
(
axis
=
axis
,
sections
=
split
,
sections
=
split
,
),
),
)
)
...
@@ -1600,7 +1698,8 @@ def Split(
...
@@ -1600,7 +1698,8 @@ def Split(
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
_logging
.
basicConfig
(
_logging
.
basicConfig
(
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
level
=
_logging
.
DEBUG
,
level
=
_logging
.
DEBUG
,
)
)
logger
=
_logging
.
getLogger
(
'symbolic_test'
)
logger
=
_logging
.
getLogger
(
'symbolic_test'
)
...
@@ -1608,7 +1707,10 @@ if __name__ == '__main__':
...
@@ -1608,7 +1707,10 @@ if __name__ == '__main__':
from
writer
import
Program
from
writer
import
Program
prog
=
Program
()
prog
=
Program
()
AdaptiveAveragePool
(
prog
,
[
'X'
],
[
'Y'
],
AdaptiveAveragePool
(
prog
,
[
'X'
],
[
'Y'
],
dict
(
output_size
=
[
3
,
3
]),
dict
(
output_size
=
[
3
,
3
]),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
)),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
)),
name
=
'AdaptiveAveragePool2d'
,
name
=
'AdaptiveAveragePool2d'
,
...
@@ -1616,7 +1718,10 @@ if __name__ == '__main__':
...
@@ -1616,7 +1718,10 @@ if __name__ == '__main__':
logger
.
info
(
'AdaptiveAveragePool2d program:
\n
%s'
,
prog
)
logger
.
info
(
'AdaptiveAveragePool2d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
AdaptiveAveragePool
(
prog
,
[
'X'
],
[
'Y'
],
AdaptiveAveragePool
(
prog
,
[
'X'
],
[
'Y'
],
dict
(
output_size
=
[
3
,
3
,
3
]),
dict
(
output_size
=
[
3
,
3
,
3
]),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
)),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
)),
name
=
'AdaptiveAveragePool3d'
,
name
=
'AdaptiveAveragePool3d'
,
...
@@ -1624,18 +1729,26 @@ if __name__ == '__main__':
...
@@ -1624,18 +1729,26 @@ if __name__ == '__main__':
logger
.
info
(
'AdaptiveAveragePool3d program:
\n
%s'
,
prog
)
logger
.
info
(
'AdaptiveAveragePool3d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
AffineGrid
(
prog
,
[
'Theta'
],
[
'Grid'
],
AffineGrid
(
prog
,
[
'Theta'
],
[
'Grid'
],
dict
(
size
=
[
2
,
2
,
8
,
8
]),
dict
(
size
=
[
2
,
2
,
8
,
8
]),
dict
(
Grid
=
dict
(
shape
=
(
2
,
8
,
8
,
2
),
dtype
=
np
.
float32
)),
dict
(
Grid
=
dict
(
shape
=
(
2
,
8
,
8
,
2
),
dtype
=
np
.
float32
)),
)
)
logger
.
info
(
'AffineGrid program:
\n
%s'
,
prog
)
logger
.
info
(
'AffineGrid program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
BatchNormalization
(
prog
,
[
'X'
,
'scale'
,
'B'
,
'mean'
,
'var'
],
[
'Y'
],
BatchNormalization
(
dict
(
epsilon
=
1e-5
,
prog
,
[
'X'
,
'scale'
,
'B'
,
'mean'
,
'var'
],
[
'Y'
],
dict
(
epsilon
=
1e-5
,
momentum
=
.
9
,
momentum
=
.
9
,
),
),
dict
(
scale
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
dict
(
scale
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
mean
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
mean
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
var
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
var
=
dict
(
shape
=
(
3
,
),
dtype
=
np
.
float32
),
...
@@ -1647,30 +1760,43 @@ if __name__ == '__main__':
...
@@ -1647,30 +1760,43 @@ if __name__ == '__main__':
logger
.
info
(
'BatchNormalization program:
\n
%s'
,
prog
)
logger
.
info
(
'BatchNormalization program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Cast
(
prog
,
[
'input'
],
[
'output'
],
Cast
(
prog
,
[
'input'
],
[
'output'
],
dict
(
to
=
2
),
# TensorProto.UINT8
dict
(
to
=
2
),
# TensorProto.UINT8
dict
(
input
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
),
dict
(
input
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
uint8
)),
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
uint8
)),
)
)
logger
.
info
(
'Cast program:
\n
%s'
,
prog
)
logger
.
info
(
'Cast program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
prog
,
'Clip'
,
[
'input'
],
[
'output'
],
_default
(
prog
,
'Clip'
,
[
'input'
],
[
'output'
],
dict
(
min
=-
1.
,
max
=
1.
),
dict
(
min
=-
1.
,
max
=
1.
),
dict
(
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
)),
dict
(
output
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
)),
)
)
logger
.
info
(
'Clip program:
\n
%s'
,
prog
)
logger
.
info
(
'Clip program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Conv
(
prog
,
[
'X'
,
'W'
],
[
'Y'
],
Conv
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
],
dilations
=
[
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
],
kernel_shape
=
[
3
,
3
],
pads
=
[
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
],
strides
=
[
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
),
dtype
=
np
.
float32
),
),
),
name
=
'ConvNoBias2d'
,
name
=
'ConvNoBias2d'
,
...
@@ -1679,15 +1805,20 @@ if __name__ == '__main__':
...
@@ -1679,15 +1805,20 @@ if __name__ == '__main__':
logger
.
info
(
'ConvNoBias2d program:
\n
%s'
,
prog
)
logger
.
info
(
'ConvNoBias2d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Conv
(
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
Conv
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
],
dilations
=
[
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
],
kernel_shape
=
[
3
,
3
],
pads
=
[
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
],
strides
=
[
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
),
dtype
=
np
.
float32
),
),
),
...
@@ -1697,17 +1828,22 @@ if __name__ == '__main__':
...
@@ -1697,17 +1828,22 @@ if __name__ == '__main__':
logger
.
info
(
'Conv2d program:
\n
%s'
,
prog
)
logger
.
info
(
'Conv2d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
ConvTranspose
(
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
ConvTranspose
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
],
dilations
=
[
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
],
kernel_shape
=
[
3
,
3
],
# output_padding=[1, 1, 1, 1],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
# output_shape=[6, 8],
pads
=
[
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
],
strides
=
[
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
6
,
8
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
6
,
8
),
dtype
=
np
.
float32
),
),
),
...
@@ -1717,15 +1853,20 @@ if __name__ == '__main__':
...
@@ -1717,15 +1853,20 @@ if __name__ == '__main__':
logger
.
info
(
'ConvTransposed2d program:
\n
%s'
,
prog
)
logger
.
info
(
'ConvTransposed2d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Conv
(
prog
,
[
'X'
,
'W'
],
[
'Y'
],
Conv
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
,
1
],
dilations
=
[
1
,
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
,
3
],
kernel_shape
=
[
3
,
3
,
3
],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
,
8
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
,
8
),
dtype
=
np
.
float32
),
),
),
name
=
'ConvNoBias3d'
,
name
=
'ConvNoBias3d'
,
...
@@ -1734,15 +1875,20 @@ if __name__ == '__main__':
...
@@ -1734,15 +1875,20 @@ if __name__ == '__main__':
logger
.
info
(
'ConvNoBias3d program:
\n
%s'
,
prog
)
logger
.
info
(
'ConvNoBias3d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Conv
(
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
Conv
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
,
1
],
dilations
=
[
1
,
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
,
3
],
kernel_shape
=
[
3
,
3
,
3
],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
,
8
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
4
,
6
,
8
),
dtype
=
np
.
float32
),
),
),
...
@@ -1752,17 +1898,22 @@ if __name__ == '__main__':
...
@@ -1752,17 +1898,22 @@ if __name__ == '__main__':
logger
.
info
(
'Conv3d program:
\n
%s'
,
prog
)
logger
.
info
(
'Conv3d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
ConvTranspose
(
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
ConvTranspose
(
dict
(
auto_pad
=
'NOTSET'
,
prog
,
[
'X'
,
'W'
,
'B'
],
[
'Y'
],
dict
(
auto_pad
=
'NOTSET'
,
dilations
=
[
1
,
1
,
1
],
dilations
=
[
1
,
1
,
1
],
group
=
1
,
group
=
1
,
kernel_shape
=
[
3
,
3
,
3
],
kernel_shape
=
[
3
,
3
,
3
],
# output_padding=[1, 1, 1, 1],
# output_padding=[1, 1, 1, 1],
# output_shape=[6, 8],
# output_shape=[6, 8],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
pads
=
[
1
,
1
,
1
,
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
strides
=
[
1
,
1
,
1
],
),
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
W
=
dict
(
shape
=
(
2
,
3
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
B
=
dict
(
shape
=
(
2
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
6
,
8
,
9
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
2
,
6
,
8
,
9
),
dtype
=
np
.
float32
),
),
),
...
@@ -1772,20 +1923,29 @@ if __name__ == '__main__':
...
@@ -1772,20 +1923,29 @@ if __name__ == '__main__':
logger
.
info
(
'ConvTransposed3d program:
\n
%s'
,
prog
)
logger
.
info
(
'ConvTransposed3d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
prog
,
'Equal'
,
[
'A'
,
'B'
],
[
'C'
],
_default
(
prog
,
'Equal'
,
[
'A'
,
'B'
],
[
'C'
],
dict
(),
dict
(),
dict
(
C
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
bool
)),
dict
(
C
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
bool
)),
)
)
logger
.
info
(
'Equal program:
\n
%s'
,
prog
)
logger
.
info
(
'Equal program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Gemm
(
prog
,
[
'A'
,
'B'
,
'C'
],
[
'Y'
],
Gemm
(
dict
(
alpha
=
1.
,
prog
,
[
'A'
,
'B'
,
'C'
],
[
'Y'
],
dict
(
alpha
=
1.
,
beta
=
1.
,
beta
=
1.
,
transA
=
0
,
transA
=
0
,
transB
=
1
,
transB
=
1
,
),
),
dict
(
B
=
dict
(
shape
=
(
8
,
3
),
dtype
=
np
.
float32
),
dict
(
B
=
dict
(
shape
=
(
8
,
3
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
),
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
),
),
),
name
=
'Gemm'
,
name
=
'Gemm'
,
...
@@ -1793,34 +1953,48 @@ if __name__ == '__main__':
...
@@ -1793,34 +1953,48 @@ if __name__ == '__main__':
logger
.
info
(
'Gemm program:
\n
%s'
,
prog
)
logger
.
info
(
'Gemm program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
prog
,
'Less'
,
[
'A'
,
'B'
],
[
'C'
],
_default
(
prog
,
'Less'
,
[
'A'
,
'B'
],
[
'C'
],
dict
(),
dict
(),
dict
(
C
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
bool
)),
dict
(
C
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
bool
)),
)
)
logger
.
info
(
'Less program:
\n
%s'
,
prog
)
logger
.
info
(
'Less program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
prog
,
'MatMul'
,
[
'A'
,
'B'
],
[
'Y'
],
_default
(
prog
,
'MatMul'
,
[
'A'
,
'B'
],
[
'Y'
],
dict
(),
dict
(),
dict
(
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
dict
(
Y
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
name
=
'MatMul'
name
=
'MatMul'
)
)
logger
.
info
(
'MatMul program:
\n
%s'
,
prog
)
logger
.
info
(
'MatMul program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
_default
(
prog
,
'OneHot'
,
[
'indices'
,
'depth'
,
'values'
],
[
'output'
],
_default
(
prog
,
'OneHot'
,
[
'indices'
,
'depth'
,
'values'
],
[
'output'
],
dict
(
axis
=-
1
),
dict
(
axis
=-
1
),
dict
(
output
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
dict
(
output
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
)),
)
)
logger
.
info
(
'OneHot program:
\n
%s'
,
prog
)
logger
.
info
(
'OneHot program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Pad
(
prog
,
[
'data'
],
[
'output'
],
Pad
(
dict
(
mode
=
'constant'
,
prog
,
[
'data'
],
[
'output'
],
dict
(
mode
=
'constant'
,
pads
=
[
0
,
1
],
pads
=
[
0
,
1
],
value
=
0.
,
value
=
0.
,
),
),
dict
(
data
=
dict
(
shape
=
(
2
,
7
),
dtype
=
np
.
float32
),
dict
(
data
=
dict
(
shape
=
(
2
,
7
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
8
),
dtype
=
np
.
float32
),
),
),
name
=
'Pad'
,
name
=
'Pad'
,
...
@@ -1828,12 +2002,17 @@ if __name__ == '__main__':
...
@@ -1828,12 +2002,17 @@ if __name__ == '__main__':
logger
.
info
(
'Pad program:
\n
%s'
,
prog
)
logger
.
info
(
'Pad program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Pad
(
prog
,
[
'data'
],
[
'output'
],
Pad
(
dict
(
mode
=
'reflect'
,
prog
,
[
'data'
],
[
'output'
],
dict
(
mode
=
'reflect'
,
pads
=
[
0
,
1
,
2
,
3
],
pads
=
[
0
,
1
,
2
,
3
],
value
=
0.
,
value
=
0.
,
),
),
dict
(
data
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
dict
(
data
=
dict
(
shape
=
(
2
,
3
,
3
,
3
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
3
,
5
,
7
),
dtype
=
np
.
float32
),
output
=
dict
(
shape
=
(
2
,
3
,
5
,
7
),
dtype
=
np
.
float32
),
),
),
name
=
'Pad2d'
,
name
=
'Pad2d'
,
...
@@ -1841,7 +2020,10 @@ if __name__ == '__main__':
...
@@ -1841,7 +2020,10 @@ if __name__ == '__main__':
logger
.
info
(
'Pad2d program:
\n
%s'
,
prog
)
logger
.
info
(
'Pad2d program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
PRelu
(
prog
,
[
'X'
,
'slope'
],
[
'Y'
],
PRelu
(
prog
,
[
'X'
,
'slope'
],
[
'Y'
],
dict
(),
dict
(),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
)),
dict
(
Y
=
dict
(
shape
=
(
2
,
3
),
dtype
=
np
.
float32
)),
name
=
'PRelu'
,
name
=
'PRelu'
,
...
@@ -1849,11 +2031,11 @@ if __name__ == '__main__':
...
@@ -1849,11 +2031,11 @@ if __name__ == '__main__':
logger
.
info
(
'PRelu program:
\n
%s'
,
prog
)
logger
.
info
(
'PRelu program:
\n
%s'
,
prog
)
prog
=
Program
()
prog
=
Program
()
Tile
(
prog
,
[
'input'
,
'repeats'
],
[
'output'
],
Tile
(
prog
,
[
'input'
,
'repeats'
],
[
'output'
],
dict
(),
dict
(),
dict
(
repeats
=
dict
(
const_value
=
[
1
,
2
]),
dict
(
output
=
dict
(
shape
=
(
2
,
2
,
4
),
dtype
=
np
.
float32
)
repeats
=
dict
(
const_value
=
[
1
,
2
]),
),
output
=
dict
(
shape
=
(
2
,
2
,
4
),
dtype
=
np
.
float32
)),
name
=
'Tile'
name
=
'Tile'
)
)
logger
.
info
(
'Tile program:
\n
%s'
,
prog
)
logger
.
info
(
'Tile program:
\n
%s'
,
prog
)
onnx2
paddle/onnx2paddle
/torch_export_helper.py
→
onnx2
fluid/onnx2fluid
/torch_export_helper.py
浏览文件 @
2228423e
...
@@ -24,8 +24,7 @@ def _ensure_tuple(obj):
...
@@ -24,8 +24,7 @@ def _ensure_tuple(obj):
return
(
obj
,
)
return
(
obj
,
)
def
_flatten_list
(
obj
,
def
_flatten_list
(
obj
,
out
=
None
):
out
=
None
):
assert
isinstance
(
obj
,
list
)
assert
isinstance
(
obj
,
list
)
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
...
@@ -37,8 +36,7 @@ def _flatten_list(obj,
...
@@ -37,8 +36,7 @@ def _flatten_list(obj,
return
out
return
out
def
export_data
(
state_dict
,
def
export_data
(
state_dict
,
prefix
=
''
):
prefix
=
''
):
"""
"""
export binary data with meta text for raw C++ inference engines
export binary data with meta text for raw C++ inference engines
"""
"""
...
@@ -65,10 +63,14 @@ def export_data(state_dict,
...
@@ -65,10 +63,14 @@ def export_data(state_dict,
fp
.
close
()
fp
.
close
()
def
export_onnx_with_validation
(
model
,
inputs
,
export_basepath
,
def
export_onnx_with_validation
(
model
,
input_names
=
None
,
output_names
=
None
,
inputs
,
export_basepath
,
input_names
=
None
,
output_names
=
None
,
use_npz
=
True
,
use_npz
=
True
,
*
args
,
**
kwargs
):
*
args
,
**
kwargs
):
"""
"""
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
export PyTorch model to ONNX model and export sample inputs and outputs in a Numpy file
"""
"""
...
@@ -96,10 +98,14 @@ def export_onnx_with_validation(model, inputs, export_basepath,
...
@@ -96,10 +98,14 @@ def export_onnx_with_validation(model, inputs, export_basepath,
return
ret
return
ret
torch_inputs
=
_ensure_tuple
(
inputs
)
# WORKAROUND: for torch.onnx
torch_inputs
=
_ensure_tuple
(
inputs
)
# WORKAROUND: for torch.onnx
outputs
=
torch
.
onnx
.
export
(
model
,
torch_inputs
,
export_basepath
+
'.onnx'
,
outputs
=
torch
.
onnx
.
export
(
model
,
torch_inputs
,
export_basepath
+
'.onnx'
,
input_names
=
_flatten_list
(
input_names
),
input_names
=
_flatten_list
(
input_names
),
output_names
=
_flatten_list
(
output_names
),
output_names
=
_flatten_list
(
output_names
),
*
args
,
**
kwargs
)
*
args
,
**
kwargs
)
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
if
outputs
is
None
:
# WORKAROUND: for torch.onnx
outputs
=
model
(
*
inputs
)
outputs
=
model
(
*
inputs
)
torch_outputs
=
_ensure_tuple
(
outputs
)
torch_outputs
=
_ensure_tuple
(
outputs
)
...
...
onnx2
paddle/onnx2paddle
/validation.py
→
onnx2
fluid/onnx2fluid
/validation.py
浏览文件 @
2228423e
...
@@ -13,8 +13,7 @@ import os
...
@@ -13,8 +13,7 @@ import os
import
sys
import
sys
def
_flatten_dict
(
obj
,
def
_flatten_dict
(
obj
,
out
=
None
):
out
=
None
):
assert
isinstance
(
obj
,
dict
)
assert
isinstance
(
obj
,
dict
)
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
...
@@ -34,12 +33,13 @@ def _ensure_list(obj):
...
@@ -34,12 +33,13 @@ def _ensure_list(obj):
return
[
obj
]
return
[
obj
]
def
validate
(
paddle_model_filename
,
golden_data_filename
,
def
validate
(
fluid_model_filename
,
golden_data_filename
,
model_func_name
=
'inference'
,
model_func_name
=
'inference'
,
precision
=
1e-4
,
precision
=
1e-4
,
save_inference_model
=
False
):
save_inference_model
=
False
):
"""
"""
inferece the converted Paddle model, validate with given golden data
inferece the converted Paddle
fluid
model, validate with given golden data
"""
"""
import
numpy
as
np
import
numpy
as
np
...
@@ -52,17 +52,17 @@ def validate(paddle_model_filename, golden_data_filename,
...
@@ -52,17 +52,17 @@ def validate(paddle_model_filename, golden_data_filename,
exe
.
run
(
fluid
.
default_startup_program
())
exe
.
run
(
fluid
.
default_startup_program
())
# load model
# load model
paddle_model_dir
,
basename
=
os
.
path
.
split
(
paddle
_model_filename
)
fluid_model_dir
,
basename
=
os
.
path
.
split
(
fluid
_model_filename
)
if
basename
==
'__model__'
:
# is desc model
if
basename
==
'__model__'
:
# is desc model
logger
.
debug
(
'using desc file %s'
,
basename
)
logger
.
debug
(
'using desc file %s'
,
basename
)
prog
,
in_names
,
var_outs
=
fluid
.
io
.
load_inference_model
(
paddle
_model_dir
,
exe
)
prog
,
_
,
var_outs
=
fluid
.
io
.
load_inference_model
(
fluid
_model_dir
,
exe
)
out_names
=
var_outs
# HINT: pass var if fetch ops already created
out_names
=
var_outs
# HINT: pass var if fetch ops already created
logger
.
info
(
'model load passed'
)
logger
.
info
(
'model load passed'
)
elif
basename
.
endswith
(
'.py'
):
# is python code
elif
basename
.
endswith
(
'.py'
):
# is python code
logger
.
debug
(
'using python code file %s'
,
basename
)
logger
.
debug
(
'using python code file %s'
,
basename
)
module_name
,
_
=
os
.
path
.
splitext
(
basename
)
module_name
,
_
=
os
.
path
.
splitext
(
basename
)
sys_path
=
sys
.
path
.
copy
()
sys_path
=
sys
.
path
.
copy
()
sys
.
path
.
append
(
paddle
_model_dir
)
sys
.
path
.
append
(
fluid
_model_dir
)
try
:
try
:
module
=
importlib
.
import_module
(
module_name
)
module
=
importlib
.
import_module
(
module_name
)
func
=
getattr
(
module
,
model_func_name
)
func
=
getattr
(
module
,
model_func_name
)
...
@@ -71,18 +71,21 @@ def validate(paddle_model_filename, golden_data_filename,
...
@@ -71,18 +71,21 @@ def validate(paddle_model_filename, golden_data_filename,
module
=
importlib
.
import_module
(
module_name
)
module
=
importlib
.
import_module
(
module_name
)
func
=
getattr
(
module
,
model_func_name
)
func
=
getattr
(
module
,
model_func_name
)
sys
.
path
=
sys_path
sys
.
path
=
sys_path
logger
.
debug
(
'from %s imported %s: %s'
,
module_name
,
model_func_name
,
func
)
logger
.
debug
(
'from %s imported %s: %s'
,
module_name
,
model_func_name
,
func
)
var_outs
=
func
()
var_outs
=
func
()
var_outs
=
_ensure_list
(
var_outs
)
var_outs
=
_ensure_list
(
var_outs
)
out_names
=
[
var
.
name
for
var
in
var_outs
]
# HINT: pass string to create fetch ops
out_names
=
[
var
.
name
for
var
in
var_outs
]
# HINT: pass string to create fetch ops
logger
.
info
(
'import passed'
)
logger
.
info
(
'import passed'
)
prog
=
fluid
.
default_main_program
()
prog
=
fluid
.
default_main_program
()
fluid
.
io
.
load_persistables
(
executor
=
exe
,
dirname
=
paddle_model_dir
,
main_program
=
prog
)
fluid
.
io
.
load_persistables
(
executor
=
exe
,
dirname
=
fluid_model_dir
,
main_program
=
prog
)
logger
.
info
(
'weight load passed'
)
logger
.
info
(
'weight load passed'
)
else
:
else
:
raise
ValueError
(
'unsupported Paddle model'
)
raise
ValueError
(
'unsupported Paddle
fluid
model'
)
# load data
# load data
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
logger
.
info
(
'using golden data %s'
,
golden_data_filename
)
...
@@ -100,10 +103,15 @@ def validate(paddle_model_filename, golden_data_filename,
...
@@ -100,10 +103,15 @@ def validate(paddle_model_filename, golden_data_filename,
# DEBUG: reload test for python code
# DEBUG: reload test for python code
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
if
basename
.
endswith
(
'.py'
)
and
save_inference_model
:
fluid
.
io
.
save_inference_model
(
paddle_model_dir
,
input_data
.
keys
(),
var_outs
,
exe
,
fluid
.
io
.
save_inference_model
(
main_program
=
prog
,
export_for_deployment
=
True
)
fluid_model_dir
,
input_data
.
keys
(),
var_outs
,
exe
,
main_program
=
prog
,
export_for_deployment
=
True
)
logger
.
info
(
'model re-save passed'
)
logger
.
info
(
'model re-save passed'
)
fluid
.
io
.
load_inference_model
(
paddle
_model_dir
,
exe
)
fluid
.
io
.
load_inference_model
(
fluid
_model_dir
,
exe
)
logger
.
info
(
'model re-load passed'
)
logger
.
info
(
'model re-load passed'
)
# execute
# execute
...
@@ -124,49 +132,54 @@ def validate(paddle_model_filename, golden_data_filename,
...
@@ -124,49 +132,54 @@ def validate(paddle_model_filename, golden_data_filename,
else
:
else
:
logger
.
info
(
'accuracy not passed'
)
logger
.
info
(
'accuracy not passed'
)
# globals().update(locals())
# globals().update(locals())
return
passed
return
passed
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
logging
.
basicConfig
(
import
argparse
format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
,
level
=
logging
.
DEBUG
,
)
logger
=
logging
.
getLogger
(
'validation_test'
)
model_rc_list
=
[
'../examples/t{}/model.py'
,
'../examples/t{}/__model__'
,
'../examples/t{}.embeded/model.py'
,
'../examples/t{}.embeded/__model__'
,
]
import
numpy
as
np
idx_model
=
np
.
random
.
randint
(
1
,
7
)
model
=
np
.
random
.
choice
(
model_rc_list
).
format
(
idx_model
)
precision
=
10
**
(
np
.
random
.
rand
()
*
-
4
-
2
)
debug
=
False
model
=
'/tmp/export/model.py'
# model = '../examples/t1/__model__'
# model = '../examples/t1.embeded/model.py'
# model = '../examples/t1.embeded/__model__'
debug
=
True
logger
.
info
(
'args: %s %.6f'
,
model
,
precision
)
data_dir
,
dir_name
=
os
.
path
.
split
(
os
.
path
.
split
(
model
)[
0
])
parser
=
argparse
.
ArgumentParser
(
data_pathname
=
os
.
path
.
splitext
(
dir_name
)[
0
]
description
=
'onnx2fluid.validate'
,
formatter_class
=
argparse
.
ArgumentDefaultsHelpFormatter
,
# proto debug test
)
from
framework_pb2
import
ProgramDesc
parser
.
add_argument
(
pd
=
ProgramDesc
()
'model'
,
pd
.
ParseFromString
(
open
(
os
.
path
.
join
(
data_dir
,
dir_name
,
'__model__'
),
'rb'
).
read
())
nargs
=
1
,
help
=
'path to model.py or __model__'
,
# validate
)
# validate(model, os.path.join(data_dir, data_pathname + '.npz'),
parser
.
add_argument
(
# precision=precision, save_inference_model=debug)
'--debug'
,
validate
(
model
,
'../examples/bvlc_alexnet/test_data_0.npz'
,
'-d'
,
precision
=
precision
,
save_inference_model
=
debug
)
action
=
'store_true'
,
help
=
'enable debug logging and checking'
,
)
parser
.
add_argument
(
'--test_data'
,
'-t'
,
type
=
str
,
help
=
'I/O golden data for validation, e.g. test.npy, test.npz'
,
)
parser
.
add_argument
(
'--precision'
,
'-p'
,
type
=
int
,
default
=
4
,
help
=
'assertion decimal for validation'
,
)
args
=
parser
.
parse_args
()
logging_format
=
'[%(levelname)8s]%(name)s::%(funcName)s:%(lineno)04d: %(message)s'
logging_level
=
logging
.
DEBUG
if
args
.
debug
else
logging
.
INFO
logging
.
basicConfig
(
format
=
logging_format
,
level
=
logging_level
)
debug
=
args
.
debug
fluid_model_filename
=
args
.
model
[
0
]
golden_data_filename
=
args
.
test_data
precision
=
args
.
precision
validate
(
fluid_model_filename
,
golden_data_filename
,
precision
=
precision
,
save_inference_model
=
debug
)
onnx2
paddle/onnx2paddle
/writer.py
→
onnx2
fluid/onnx2fluid
/writer.py
浏览文件 @
2228423e
...
@@ -34,15 +34,13 @@ except ImportError:
...
@@ -34,15 +34,13 @@ except ImportError:
logger
.
warning
(
'importing paddle.fluid.proto.framework_pb2d failed,'
logger
.
warning
(
'importing paddle.fluid.proto.framework_pb2d failed,'
'using fallback framework_pb2'
)
'using fallback framework_pb2'
)
__all__
=
[
__all__
=
[
'Program'
,
'Program'
,
'Writer'
,
'Writer'
,
]
]
def
_irepr
(
obj
,
def
_irepr
(
obj
,
to
=
'_'
):
to
=
'_'
):
"""inline repr"""
"""inline repr"""
s
=
repr
(
obj
)
s
=
repr
(
obj
)
...
@@ -53,8 +51,7 @@ def _irepr(obj,
...
@@ -53,8 +51,7 @@ def _irepr(obj,
return
s
return
s
def
_flatten_list
(
obj
,
def
_flatten_list
(
obj
,
out
=
None
):
out
=
None
):
if
out
is
None
:
if
out
is
None
:
out
=
type
(
obj
)()
out
=
type
(
obj
)()
for
item
in
obj
:
for
item
in
obj
:
...
@@ -72,7 +69,7 @@ def make_attr_name(name):
...
@@ -72,7 +69,7 @@ def make_attr_name(name):
if
name
==
''
:
if
name
==
''
:
raise
ValueError
(
'name should not be empty'
)
raise
ValueError
(
'name should not be empty'
)
for
s
in
' *?\
/-:'
:
#
for
s
in
' *?
\
\
/-:'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
not
name
.
startswith
(
'_'
):
if
not
name
.
startswith
(
'_'
):
name
=
'_'
+
name
name
=
'_'
+
name
...
@@ -168,11 +165,8 @@ class Program(object):
...
@@ -168,11 +165,8 @@ class Program(object):
return
(
'Program(code mutable: {}) with:
\n
'
return
(
'Program(code mutable: {}) with:
\n
'
'codes: {}
\n
'
'codes: {}
\n
'
'op_descs: {}
\n
'
'op_descs: {}
\n
'
'var_descs: {}
\n
'
).
format
(
'var_descs: {}
\n
'
).
format
(
self
.
code_mutable
,
self
.
codes
,
self
.
code_mutable
,
self
.
op_descs
,
self
.
var_descs
)
self
.
codes
,
self
.
op_descs
,
self
.
var_descs
)
def
__repr__
(
self
):
def
__repr__
(
self
):
return
self
.
__str__
()
return
self
.
__str__
()
...
@@ -185,8 +179,11 @@ class Program(object):
...
@@ -185,8 +179,11 @@ class Program(object):
if
self
.
code_mutable
:
if
self
.
code_mutable
:
self
.
codes
.
append
(
code
)
self
.
codes
.
append
(
code
)
def
OpDesc
(
self
,
name
,
def
OpDesc
(
self
,
input_val_keys
=
None
,
output_val_keys
=
None
,
attrs
=
None
):
name
,
input_val_keys
=
None
,
output_val_keys
=
None
,
attrs
=
None
):
"""
"""
add OpDesc
add OpDesc
"""
"""
...
@@ -202,10 +199,15 @@ class Program(object):
...
@@ -202,10 +199,15 @@ class Program(object):
self
.
op_descs
.
append
(
desc
)
self
.
op_descs
.
append
(
desc
)
return
desc
return
desc
def
VarDesc
(
self
,
name
,
def
VarDesc
(
self
,
persistable
=
False
,
value_info
=
None
,
remove_batch
=
None
):
name
,
persistable
=
False
,
value_info
=
None
,
remove_batch
=
None
,
dummy_dtype
=
'float32'
):
"""
"""
add VarDesc
add VarDesc,
dummy_dtype: WORKAROUND for Netron viewer
"""
"""
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
=
framework_pb2
.
VarDesc
()
...
@@ -213,6 +215,10 @@ class Program(object):
...
@@ -213,6 +215,10 @@ class Program(object):
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
# REMOVEIT: WORKAROUND: Netron: null.tensor error
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
dummy_dtype
)
# required
if
value_info
and
'dtype'
in
value_info
:
if
value_info
and
'dtype'
in
value_info
:
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
value_info
[
'dtype'
])
# required
tensor_desc
.
data_type
=
self
.
Dtype
(
value_info
[
'dtype'
])
# required
...
@@ -220,7 +226,8 @@ class Program(object):
...
@@ -220,7 +226,8 @@ class Program(object):
tensor_desc
.
dims
.
extend
(
value_info
[
'shape'
])
tensor_desc
.
dims
.
extend
(
value_info
[
'shape'
])
if
len
(
value_info
[
'shape'
])
>
0
:
# skip scalars
if
len
(
value_info
[
'shape'
])
>
0
:
# skip scalars
if
remove_batch
is
None
:
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
not
persistable
)
remove_batch
=
value_info
.
get
(
'remove_batch'
,
not
persistable
)
if
remove_batch
:
if
remove_batch
:
tensor_desc
.
dims
[
0
]
=
-
1
tensor_desc
.
dims
[
0
]
=
-
1
...
@@ -240,8 +247,8 @@ class Program(object):
...
@@ -240,8 +247,8 @@ class Program(object):
fn
=
getattr
(
symbolic
,
op_type
)
fn
=
getattr
(
symbolic
,
op_type
)
fn
(
self
,
*
args
,
**
kwargs
)
fn
(
self
,
*
args
,
**
kwargs
)
else
:
else
:
raise
ValueError
(
'conversion for {}::{} not supported'
raise
ValueError
(
'conversion for {}::{} not supported'
.
format
(
.
format
(
domain
,
op_type
))
domain
,
op_type
))
def
IntermediateOp
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
IntermediateOp
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -267,14 +274,15 @@ class Writer(object):
...
@@ -267,14 +274,15 @@ class Writer(object):
CODE_INDENT
=
' '
*
4
CODE_INDENT
=
' '
*
4
@
staticmethod
@
staticmethod
def
header_code
(
func_name
):
def
header_code
(
func_name
,
info
=
''
):
"""
"""
Python header codes
Python header codes
"""
"""
codes
=
list
()
codes
=
list
()
codes
.
append
(
'"""'
)
codes
.
append
(
'"""'
)
codes
.
append
(
'This code is generated by onnx2paddle.'
)
codes
.
append
(
'This code is generated by onnx2fluid.'
)
codes
.
append
(
'{}'
.
format
(
info
))
codes
.
append
(
'"""'
)
codes
.
append
(
'"""'
)
codes
.
append
(
''
)
codes
.
append
(
''
)
codes
.
append
(
'from __future__ import division'
)
codes
.
append
(
'from __future__ import division'
)
...
@@ -287,16 +295,25 @@ class Writer(object):
...
@@ -287,16 +295,25 @@ class Writer(object):
return
codes
return
codes
@
staticmethod
@
staticmethod
def
emit_op
(
prog
,
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
emit_op
(
prog
,
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
"""
emit an ONNX op into program
emit an ONNX op into program
"""
"""
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
prog
.
Code
(
'# {}, {}::{}: {} -> {}, {}'
.
format
(
name
,
domain
,
op_type
,
.
format
(
name
,
domain
,
op_type
,
inputs
,
outputs
,
_irepr
(
attrs
,
to
=
', '
)))
inputs
,
outputs
,
prog
.
Op
(
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
_irepr
(
attrs
,
to
=
', '
)))
value_infos
=
value_infos
,
name
=
name
,
prog
.
Op
(
*
args
,
**
kwargs
)
domain
,
op_type
,
inputs
,
outputs
,
attrs
,
value_infos
=
value_infos
,
name
=
name
,
*
args
,
**
kwargs
)
@
staticmethod
@
staticmethod
def
emit_param
(
prog
,
name
,
value_info
):
def
emit_param
(
prog
,
name
,
value_info
):
...
@@ -315,16 +332,16 @@ class Writer(object):
...
@@ -315,16 +332,16 @@ class Writer(object):
prog
.
Code
(
'# parameter: {}'
.
format
(
name
))
prog
.
Code
(
'# parameter: {}'
.
format
(
name
))
prog
.
Code
(
'{} = ParamAttr(name={})'
# , trainable=True
prog
.
Code
(
'{} = ParamAttr(name={})'
# , trainable=True
.
format
(
attr_name
,
repr
(
var_name
)))
.
format
(
attr_name
,
repr
(
var_name
)))
prog
.
Code
(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
prog
.
Code
(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))'
#, is_bias={}
', default_initializer=initializer.Constant(0))'
#, is_bias={}
.
format
(
var_name
,
.
format
(
var_name
,
value_info
[
'shape'
]
,
value_info
[
'shape'
],
repr
(
value_info
[
'dtype'
].
name
),
repr
(
value_info
[
'dtype'
].
name
),
repr
(
name
),
repr
(
name
),
attr_name
))
#, value_info.get('is_bias', False)))
attr_name
))
#, value_info.get('is_bias', False)))
prog
.
VarDesc
(
var_name
,
persistable
=
True
,
value_info
=
value_info
)
prog
.
VarDesc
(
var_name
,
persistable
=
True
,
value_info
=
value_info
)
@
staticmethod
@
staticmethod
def
emit_inputs
(
prog
,
names
,
value_infos
,
def
emit_inputs
(
prog
,
names
,
value_infos
,
remove_batch
=
None
):
remove_batch
=
None
):
"""
"""
emit ONNX inputs into program
emit ONNX inputs into program
"""
"""
...
@@ -334,24 +351,30 @@ class Writer(object):
...
@@ -334,24 +351,30 @@ class Writer(object):
value_info
=
value_infos
[
name
]
value_info
=
value_infos
[
name
]
shape
=
value_info
[
'shape'
]
shape
=
value_info
[
'shape'
]
if
remove_batch
is
None
:
if
remove_batch
is
None
:
remove_batch
=
value_info
.
get
(
'remove_batch'
,
True
)
# HINT: True by default ?
remove_batch
=
value_info
.
get
(
'remove_batch'
,
True
)
# HINT: True by default ?
if
remove_batch
:
if
remove_batch
:
shape
=
shape
[
1
:]
shape
=
shape
[
1
:]
prog
.
Code
(
'# input: {}'
.
format
(
name
))
prog
.
Code
(
'# input: {}'
.
format
(
name
))
prog
.
Code
((
'{} = layers.data(name={}, shape={}, dtype={}, '
prog
.
Code
((
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})'
# , stop_gradient=True
'append_batch_size={})'
# , stop_gradient=True
).
format
(
var_name
,
repr
(
name
),
).
format
(
var_name
,
repr
(
name
),
shape
,
shape
,
repr
(
value_info
[
'dtype'
].
name
),
repr
(
value_info
[
'dtype'
].
name
),
remove_batch
,
remove_batch
,
))
))
prog
.
OpDesc
(
'feed'
,
prog
.
OpDesc
(
'feed'
,
([
'feed'
],
'X'
),
([
'feed'
],
'X'
),
([
var_name
],
'Out'
),
([
var_name
],
'Out'
),
dict
(
col
=
idx
),
dict
(
col
=
idx
),
)
)
prog
.
VarDesc
(
var_name
,
value_info
=
value_info
,
remove_batch
=
remove_batch
)
prog
.
VarDesc
(
var_name
,
value_info
=
value_info
,
remove_batch
=
remove_batch
)
@
staticmethod
@
staticmethod
def
emit_outputs
(
prog
,
names
):
#, value_infos
def
emit_outputs
(
prog
,
names
):
#, value_infos
...
@@ -364,7 +387,8 @@ class Writer(object):
...
@@ -364,7 +387,8 @@ class Writer(object):
var_name
=
make_var_name
(
name
)
var_name
=
make_var_name
(
name
)
code
+=
var_name
+
', '
code
+=
var_name
+
', '
prog
.
OpDesc
(
'fetch'
,
prog
.
OpDesc
(
'fetch'
,
([
var_name
],
'X'
),
([
var_name
],
'X'
),
([
'fetch'
],
'Out'
),
([
'fetch'
],
'Out'
),
dict
(
col
=
idx
),
dict
(
col
=
idx
),
...
...
onnx2
paddle
/requirements.txt
→
onnx2
fluid
/requirements.txt
浏览文件 @
2228423e
onnx2
paddle
/setup.cfg
→
onnx2
fluid
/setup.cfg
浏览文件 @
2228423e
...
@@ -2,14 +2,14 @@
...
@@ -2,14 +2,14 @@
# https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
# https://setuptools.readthedocs.io/en/latest/setuptools.html#configuring-setup-using-setup-cfg-files
[metadata]
[metadata]
# 项目名称,发布、安装时以此作为包名
# 项目名称,发布、安装时以此作为包名
name = onnx2
paddle
name = onnx2
fluid
# 作者姓名和邮箱地址
# 作者姓名和邮箱地址
author = Macrobull
author = Macrobull
# author_email = .Github@github.com
# author_email = .Github@github.com
# 项目版本号,1.0以上版本才视为正式版
# 项目版本号,1.0以上版本才视为正式版
version = 0.1.0
version = 0.1.0
# 项目概要描述信息,一句话让用户明白项目概要,不支持中文
# 项目概要描述信息,一句话让用户明白项目概要,不支持中文
description = Inference model conversion from ONNX/PyTorch to Paddle
description = Inference model conversion from ONNX/PyTorch to Paddle
fluid
# 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式
# 项目的详细描述内容和格式,包括readme和changelog等,通常使用md或rst等格式
long_description = file: README.md, CHANGELOG.md
long_description = file: README.md, CHANGELOG.md
long_description_content_type = text/markdown
long_description_content_type = text/markdown
...
@@ -25,7 +25,7 @@ classifier =
...
@@ -25,7 +25,7 @@ classifier =
Programming Language :: Python :: 3.5
Programming Language :: Python :: 3.5
# 关键字,用于检索,方便用户搜索到你的项目
# 关键字,用于检索,方便用户搜索到你的项目
keywords =
keywords =
onnx paddle
onnx paddle
paddle
[options]
[options]
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
# 包名称,find:表示自动寻找,可在options.packages.find中进行详细配置
...
@@ -44,21 +44,21 @@ install_requires =
...
@@ -44,21 +44,21 @@ install_requires =
# mock
# mock
# 单测代码目录
# 单测代码目录
#test_suite = onnx2
paddle
.tests
#test_suite = onnx2
fluid
.tests
# 自动添加被版本控制的数据文件
# 自动添加被版本控制的数据文件
include_package_data = True
include_package_data = True
# 项目是纯py项目,可以直接执行zip源码包
# 项目是纯py项目,可以直接执行zip源码包
zip_safe = False
zip_safe = False
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
# 可以通过以下配置将指定的函数变成命令行工具,允许用户直接执行
#
[options.entry_points]
[options.entry_points]
#
console_scripts =
console_scripts =
# onnx2paddle = onnx2paddle
.cmdline:main
onnx2fluid = onnx2fluid
.cmdline:main
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 可以通过以下配置向包中添加conf或data等非py文件,安装时会一同安装到site-packages目录下
# 仅支持文件,不支持目录,但可以使用通配
# 仅支持文件,不支持目录,但可以使用通配
#[options.package_data]
#[options.package_data]
#onnx2
paddle
=
#onnx2
fluid
=
# conf/*
# conf/*
# data/*
# data/*
...
...
onnx2
paddle
/setup.py
→
onnx2
fluid
/setup.py
浏览文件 @
2228423e
...
@@ -15,4 +15,3 @@ Date: 2019/02/22 10:25:46
...
@@ -15,4 +15,3 @@ Date: 2019/02/22 10:25:46
import
setuptools
import
setuptools
setuptools
.
setup
()
setuptools
.
setup
()
onnx2paddle/onnx2paddle/framework_pb2.py
已删除
100644 → 0
浏览文件 @
a4796334
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: framework.proto
import
sys
_b
=
sys
.
version_info
[
0
]
<
3
and
(
lambda
x
:
x
)
or
(
lambda
x
:
x
.
encode
(
'latin1'
))
from
google.protobuf.internal
import
enum_type_wrapper
from
google.protobuf
import
descriptor
as
_descriptor
from
google.protobuf
import
message
as
_message
from
google.protobuf
import
reflection
as
_reflection
from
google.protobuf
import
symbol_database
as
_symbol_database
from
google.protobuf
import
descriptor_pb2
# @@protoc_insertion_point(imports)
_sym_db
=
_symbol_database
.
Default
()
DESCRIPTOR
=
_descriptor
.
FileDescriptor
(
name
=
'framework.proto'
,
package
=
'paddle.framework.proto'
,
syntax
=
'proto2'
,
serialized_pb
=
_b
(
'
\n\x0f\x66
ramework.proto
\x12\x16
paddle.framework.proto
\"\x1d\n\x07
Version
\x12\x12\n\x07
version
\x18\x01
\x01
(
\x03
:
\x01\x30\"\xec\x03\n\x06
OpDesc
\x12\x0c\n\x04
type
\x18\x03
\x02
(
\t\x12\x32\n\x06
inputs
\x18\x01
\x03
(
\x0b\x32\"
.paddle.framework.proto.OpDesc.Var
\x12\x33\n\x07
outputs
\x18\x02
\x03
(
\x0b\x32\"
.paddle.framework.proto.OpDesc.Var
\x12\x32\n\x05\x61
ttrs
\x18\x04
\x03
(
\x0b\x32
#.paddle.framework.proto.OpDesc.Attr
\x12\x18\n\t
is_target
\x18\x05
\x01
(
\x08
:
\x05\x66\x61
lse
\x1a\xef\x01\n\x04\x41
ttr
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
.
\n\x04
type
\x18\x02
\x02
(
\x0e\x32
.paddle.framework.proto.AttrType
\x12\t\n\x01
i
\x18\x03
\x01
(
\x05\x12\t\n\x01\x66\x18\x04
\x01
(
\x02\x12\t\n\x01
s
\x18\x05
\x01
(
\t\x12\x0c\n\x04
ints
\x18\x06
\x03
(
\x05\x12\x0e\n\x06\x66
loats
\x18\x07
\x03
(
\x02\x12\x0f\n\x07
strings
\x18\x08
\x03
(
\t\x12\t\n\x01\x62\x18\n
\x01
(
\x08\x12\r\n\x05\x62
ools
\x18\x0b
\x03
(
\x08\x12\x11\n\t
block_idx
\x18\x0c
\x01
(
\x05\x12\t\n\x01
l
\x18\r
\x01
(
\x03\x12\x12\n\n
blocks_idx
\x18\x0e
\x03
(
\x05\x12\r\n\x05
longs
\x18\x0f
\x03
(
\x03\x1a
+
\n\x03
Var
\x12\x11\n\t
parameter
\x18\x01
\x02
(
\t\x12\x11\n\t
arguments
\x18\x02
\x03
(
\t\"\xb3\x03\n\x07
OpProto
\x12\x0c\n\x04
type
\x18\x01
\x02
(
\t\x12\x33\n\x06
inputs
\x18\x02
\x03
(
\x0b\x32
#.paddle.framework.proto.OpProto.Var
\x12\x34\n\x07
outputs
\x18\x03
\x03
(
\x0b\x32
#.paddle.framework.proto.OpProto.Var
\x12\x33\n\x05\x61
ttrs
\x18\x04
\x03
(
\x0b\x32
$.paddle.framework.proto.OpProto.Attr
\x12\x0f\n\x07\x63
omment
\x18\x05
\x02
(
\t\x1a
x
\n\x03
Var
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12\x0f\n\x07\x63
omment
\x18\x02
\x02
(
\t\x12\x19\n\n
duplicable
\x18\x03
\x01
(
\x08
:
\x05\x66\x61
lse
\x12\x1b\n\x0c
intermediate
\x18\x04
\x01
(
\x08
:
\x05\x66\x61
lse
\x12\x1a\n\x0b\x64
ispensable
\x18\x05
\x01
(
\x08
:
\x05\x66\x61
lse
\x1a
o
\n\x04\x41
ttr
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
.
\n\x04
type
\x18\x02
\x02
(
\x0e\x32
.paddle.framework.proto.AttrType
\x12\x0f\n\x07\x63
omment
\x18\x03
\x02
(
\t\x12\x18\n\t
generated
\x18\x04
\x01
(
\x08
:
\x05\x66\x61
lse
\"\xda\x08\n\x07
VarType
\x12\x32\n\x04
type
\x18\x01
\x02
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\x12\x41\n\r
selected_rows
\x18\x02
\x01
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x41\n\n
lod_tensor
\x18\x03
\x01
(
\x0b\x32
-.paddle.framework.proto.VarType.LoDTensorDesc
\x12
H
\n\x0c
tensor_array
\x18\x04
\x01
(
\x0b\x32\x32
.paddle.framework.proto.VarType.LoDTensorArrayDesc
\x12
:
\n\x06
reader
\x18\x05
\x01
(
\x0b\x32
*.paddle.framework.proto.VarType.ReaderDesc
\x12\x34\n\x05
tuple
\x18\x07
\x01
(
\x0b\x32
%.paddle.framework.proto.VarType.Tuple
\x1a
S
\n\n
TensorDesc
\x12\x37\n\t
data_type
\x18\x01
\x02
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\x12\x0c\n\x04\x64
ims
\x18\x02
\x03
(
\x03\x1a\x61\n\r
LoDTensorDesc
\x12
:
\n\x06
tensor
\x18\x01
\x02
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x14\n\t
lod_level
\x18\x02
\x01
(
\x05
:
\x01\x30\x1a\x66\n\x12
LoDTensorArrayDesc
\x12
:
\n\x06
tensor
\x18\x01
\x02
(
\x0b\x32
*.paddle.framework.proto.VarType.TensorDesc
\x12\x14\n\t
lod_level
\x18\x02
\x01
(
\x05
:
\x01\x30\x1a
O
\n\n
ReaderDesc
\x12\x41\n\n
lod_tensor
\x18\x01
\x03
(
\x0b\x32
-.paddle.framework.proto.VarType.LoDTensorDesc
\x1a\x43\n\x05
Tuple
\x12
:
\n\x0c\x65
lement_type
\x18\x01
\x03
(
\x0e\x32
$.paddle.framework.proto.VarType.Type
\"\xa2\x02\n\x04
Type
\x12\x08\n\x04\x42
OOL
\x10\x00\x12\t\n\x05
INT16
\x10\x01\x12\t\n\x05
INT32
\x10\x02\x12\t\n\x05
INT64
\x10\x03\x12\x08\n\x04\x46
P16
\x10\x04\x12\x08\n\x04\x46
P32
\x10\x05\x12\x08\n\x04\x46
P64
\x10\x06\x12\n\n\x06
SIZE_T
\x10\x13\x12\t\n\x05
UINT8
\x10\x14\x12\x08\n\x04
INT8
\x10\x15\x12\x0e\n\n
LOD_TENSOR
\x10\x07\x12\x11\n\r
SELECTED_ROWS
\x10\x08\x12\x12\n\x0e\x46\x45\x45\x44
_MINIBATCH
\x10\t\x12\x0e\n\n
FETCH_LIST
\x10\n\x12\x0f\n\x0b
STEP_SCOPES
\x10\x0b\x12\x12\n\x0e
LOD_RANK_TABLE
\x10\x0c\x12\x14\n\x10
LOD_TENSOR_ARRAY
\x10\r\x12\x0e\n\n
PLACE_LIST
\x10\x0e\x12\n\n\x06
READER
\x10\x0f\x12\x07\n\x03
RAW
\x10\x11\x12\t\n\x05
TUPLE
\x10\x12\"
b
\n\x07
VarDesc
\x12\x0c\n\x04
name
\x18\x01
\x02
(
\t\x12
-
\n\x04
type
\x18\x02
\x02
(
\x0b\x32\x1f
.paddle.framework.proto.VarType
\x12\x1a\n\x0b
persistable
\x18\x03
\x01
(
\x08
:
\x05\x66\x61
lse
\"\xa7\x01\n\t
BlockDesc
\x12\x0b\n\x03
idx
\x18\x01
\x02
(
\x05\x12\x12\n\n
parent_idx
\x18\x02
\x02
(
\x05\x12
-
\n\x04
vars
\x18\x03
\x03
(
\x0b\x32\x1f
.paddle.framework.proto.VarDesc
\x12
+
\n\x03
ops
\x18\x04
\x03
(
\x0b\x32\x1e
.paddle.framework.proto.OpDesc
\x12\x1d\n\x11\x66
orward_block_idx
\x18\x05
\x01
(
\x05
:
\x02
-1
\"
r
\n\x0b
ProgramDesc
\x12\x31\n\x06\x62
locks
\x18\x01
\x03
(
\x0b\x32
!.paddle.framework.proto.BlockDesc
\x12\x30\n\x07
version
\x18\x02
\x01
(
\x0b\x32\x1f
.paddle.framework.proto.Version*
\x94\x01\n\x08\x41
ttrType
\x12\x07\n\x03
INT
\x10\x00\x12\t\n\x05\x46
LOAT
\x10\x01\x12\n\n\x06
STRING
\x10\x02\x12\x08\n\x04
INTS
\x10\x03\x12\n\n\x06\x46
LOATS
\x10\x04\x12\x0b\n\x07
STRINGS
\x10\x05\x12\x0b\n\x07\x42
OOLEAN
\x10\x06\x12\x0c\n\x08\x42
OOLEANS
\x10\x07\x12\t\n\x05\x42
LOCK
\x10\x08\x12\x08\n\x04
LONG
\x10\t\x12\n\n\x06\x42
LOCKS
\x10\n\x12\t\n\x05
LONGS
\x10\x0b\x42\x02
H
\x03
'
)
)
_sym_db
.
RegisterFileDescriptor
(
DESCRIPTOR
)
_ATTRTYPE
=
_descriptor
.
EnumDescriptor
(
name
=
'AttrType'
,
full_name
=
'paddle.framework.proto.AttrType'
,
filename
=
None
,
file
=
DESCRIPTOR
,
values
=
[
_descriptor
.
EnumValueDescriptor
(
name
=
'INT'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOAT'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STRING'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INTS'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FLOATS'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STRINGS'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEAN'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOLEANS'
,
index
=
7
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCK'
,
index
=
8
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONG'
,
index
=
9
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'BLOCKS'
,
index
=
10
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LONGS'
,
index
=
11
,
number
=
11
,
options
=
None
,
type
=
None
),
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
2511
,
serialized_end
=
2659
,
)
_sym_db
.
RegisterEnumDescriptor
(
_ATTRTYPE
)
AttrType
=
enum_type_wrapper
.
EnumTypeWrapper
(
_ATTRTYPE
)
INT
=
0
FLOAT
=
1
STRING
=
2
INTS
=
3
FLOATS
=
4
STRINGS
=
5
BOOLEAN
=
6
BOOLEANS
=
7
BLOCK
=
8
LONG
=
9
BLOCKS
=
10
LONGS
=
11
_VARTYPE_TYPE
=
_descriptor
.
EnumDescriptor
(
name
=
'Type'
,
full_name
=
'paddle.framework.proto.VarType.Type'
,
filename
=
None
,
file
=
DESCRIPTOR
,
values
=
[
_descriptor
.
EnumValueDescriptor
(
name
=
'BOOL'
,
index
=
0
,
number
=
0
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT16'
,
index
=
1
,
number
=
1
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT32'
,
index
=
2
,
number
=
2
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT64'
,
index
=
3
,
number
=
3
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP16'
,
index
=
4
,
number
=
4
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP32'
,
index
=
5
,
number
=
5
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FP64'
,
index
=
6
,
number
=
6
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'SIZE_T'
,
index
=
7
,
number
=
19
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'UINT8'
,
index
=
8
,
number
=
20
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'INT8'
,
index
=
9
,
number
=
21
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR'
,
index
=
10
,
number
=
7
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'SELECTED_ROWS'
,
index
=
11
,
number
=
8
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FEED_MINIBATCH'
,
index
=
12
,
number
=
9
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'FETCH_LIST'
,
index
=
13
,
number
=
10
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'STEP_SCOPES'
,
index
=
14
,
number
=
11
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_RANK_TABLE'
,
index
=
15
,
number
=
12
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'LOD_TENSOR_ARRAY'
,
index
=
16
,
number
=
13
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'PLACE_LIST'
,
index
=
17
,
number
=
14
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'READER'
,
index
=
18
,
number
=
15
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'RAW'
,
index
=
19
,
number
=
17
,
options
=
None
,
type
=
None
),
_descriptor
.
EnumValueDescriptor
(
name
=
'TUPLE'
,
index
=
20
,
number
=
18
,
options
=
None
,
type
=
None
),
],
containing_type
=
None
,
options
=
None
,
serialized_start
=
1832
,
serialized_end
=
2122
,
)
_sym_db
.
RegisterEnumDescriptor
(
_VARTYPE_TYPE
)
_VERSION
=
_descriptor
.
Descriptor
(
name
=
'Version'
,
full_name
=
'paddle.framework.proto.Version'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle.framework.proto.Version.version'
,
index
=
0
,
number
=
1
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
43
,
serialized_end
=
72
,
)
_OPDESC_ATTR
=
_descriptor
.
Descriptor
(
name
=
'Attr'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.type'
,
index
=
1
,
number
=
2
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'i'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.i'
,
index
=
2
,
number
=
3
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'f'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.f'
,
index
=
3
,
number
=
4
,
type
=
2
,
cpp_type
=
6
,
label
=
1
,
has_default_value
=
False
,
default_value
=
float
(
0
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
's'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.s'
,
index
=
4
,
number
=
5
,
type
=
9
,
cpp_type
=
9
,
label
=
1
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'ints'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.ints'
,
index
=
5
,
number
=
6
,
type
=
5
,
cpp_type
=
1
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'floats'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.floats'
,
index
=
6
,
number
=
7
,
type
=
2
,
cpp_type
=
6
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'strings'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.strings'
,
index
=
7
,
number
=
8
,
type
=
9
,
cpp_type
=
9
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'b'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.b'
,
index
=
8
,
number
=
10
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
False
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'bools'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.bools'
,
index
=
9
,
number
=
11
,
type
=
8
,
cpp_type
=
7
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'block_idx'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.block_idx'
,
index
=
10
,
number
=
12
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'l'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.l'
,
index
=
11
,
number
=
13
,
type
=
3
,
cpp_type
=
2
,
label
=
1
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'blocks_idx'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.blocks_idx'
,
index
=
12
,
number
=
14
,
type
=
5
,
cpp_type
=
1
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'longs'
,
full_name
=
'paddle.framework.proto.OpDesc.Attr.longs'
,
index
=
13
,
number
=
15
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
283
,
serialized_end
=
522
,
)
_OPDESC_VAR
=
_descriptor
.
Descriptor
(
name
=
'Var'
,
full_name
=
'paddle.framework.proto.OpDesc.Var'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'parameter'
,
full_name
=
'paddle.framework.proto.OpDesc.Var.parameter'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'arguments'
,
full_name
=
'paddle.framework.proto.OpDesc.Var.arguments'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
524
,
serialized_end
=
567
,
)
_OPDESC
=
_descriptor
.
Descriptor
(
name
=
'OpDesc'
,
full_name
=
'paddle.framework.proto.OpDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpDesc.type'
,
index
=
0
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'inputs'
,
full_name
=
'paddle.framework.proto.OpDesc.inputs'
,
index
=
1
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'outputs'
,
full_name
=
'paddle.framework.proto.OpDesc.outputs'
,
index
=
2
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'attrs'
,
full_name
=
'paddle.framework.proto.OpDesc.attrs'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'is_target'
,
full_name
=
'paddle.framework.proto.OpDesc.is_target'
,
index
=
4
,
number
=
5
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[
_OPDESC_ATTR
,
_OPDESC_VAR
,
],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
75
,
serialized_end
=
567
,
)
_OPPROTO_VAR
=
_descriptor
.
Descriptor
(
name
=
'Var'
,
full_name
=
'paddle.framework.proto.OpProto.Var'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpProto.Var.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.Var.comment'
,
index
=
1
,
number
=
2
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'duplicable'
,
full_name
=
'paddle.framework.proto.OpProto.Var.duplicable'
,
index
=
2
,
number
=
3
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'intermediate'
,
full_name
=
'paddle.framework.proto.OpProto.Var.intermediate'
,
index
=
3
,
number
=
4
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'dispensable'
,
full_name
=
'paddle.framework.proto.OpProto.Var.dispensable'
,
index
=
4
,
number
=
5
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
772
,
serialized_end
=
892
,
)
_OPPROTO_ATTR
=
_descriptor
.
Descriptor
(
name
=
'Attr'
,
full_name
=
'paddle.framework.proto.OpProto.Attr'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.type'
,
index
=
1
,
number
=
2
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.comment'
,
index
=
2
,
number
=
3
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'generated'
,
full_name
=
'paddle.framework.proto.OpProto.Attr.generated'
,
index
=
3
,
number
=
4
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
894
,
serialized_end
=
1005
,
)
_OPPROTO
=
_descriptor
.
Descriptor
(
name
=
'OpProto'
,
full_name
=
'paddle.framework.proto.OpProto'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.OpProto.type'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'inputs'
,
full_name
=
'paddle.framework.proto.OpProto.inputs'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'outputs'
,
full_name
=
'paddle.framework.proto.OpProto.outputs'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'attrs'
,
full_name
=
'paddle.framework.proto.OpProto.attrs'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'comment'
,
full_name
=
'paddle.framework.proto.OpProto.comment'
,
index
=
4
,
number
=
5
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[
_OPPROTO_VAR
,
_OPPROTO_ATTR
,
],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
570
,
serialized_end
=
1005
,
)
_VARTYPE_TENSORDESC
=
_descriptor
.
Descriptor
(
name
=
'TensorDesc'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'data_type'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc.data_type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'dims'
,
full_name
=
'paddle.framework.proto.VarType.TensorDesc.dims'
,
index
=
1
,
number
=
2
,
type
=
3
,
cpp_type
=
2
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1393
,
serialized_end
=
1476
,
)
_VARTYPE_LODTENSORDESC
=
_descriptor
.
Descriptor
(
name
=
'LoDTensorDesc'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'tensor'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc.tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_level'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorDesc.lod_level'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1478
,
serialized_end
=
1575
,
)
_VARTYPE_LODTENSORARRAYDESC
=
_descriptor
.
Descriptor
(
name
=
'LoDTensorArrayDesc'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'tensor'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_level'
,
full_name
=
'paddle.framework.proto.VarType.LoDTensorArrayDesc.lod_level'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1577
,
serialized_end
=
1679
,
)
_VARTYPE_READERDESC
=
_descriptor
.
Descriptor
(
name
=
'ReaderDesc'
,
full_name
=
'paddle.framework.proto.VarType.ReaderDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'lod_tensor'
,
full_name
=
'paddle.framework.proto.VarType.ReaderDesc.lod_tensor'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1681
,
serialized_end
=
1760
,
)
_VARTYPE_TUPLE
=
_descriptor
.
Descriptor
(
name
=
'Tuple'
,
full_name
=
'paddle.framework.proto.VarType.Tuple'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'element_type'
,
full_name
=
'paddle.framework.proto.VarType.Tuple.element_type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1762
,
serialized_end
=
1829
,
)
_VARTYPE
=
_descriptor
.
Descriptor
(
name
=
'VarType'
,
full_name
=
'paddle.framework.proto.VarType'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.VarType.type'
,
index
=
0
,
number
=
1
,
type
=
14
,
cpp_type
=
8
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'selected_rows'
,
full_name
=
'paddle.framework.proto.VarType.selected_rows'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'lod_tensor'
,
full_name
=
'paddle.framework.proto.VarType.lod_tensor'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'tensor_array'
,
full_name
=
'paddle.framework.proto.VarType.tensor_array'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'reader'
,
full_name
=
'paddle.framework.proto.VarType.reader'
,
index
=
4
,
number
=
5
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'tuple'
,
full_name
=
'paddle.framework.proto.VarType.tuple'
,
index
=
5
,
number
=
7
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[
_VARTYPE_TENSORDESC
,
_VARTYPE_LODTENSORDESC
,
_VARTYPE_LODTENSORARRAYDESC
,
_VARTYPE_READERDESC
,
_VARTYPE_TUPLE
,
],
enum_types
=
[
_VARTYPE_TYPE
,
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
1008
,
serialized_end
=
2122
,
)
_VARDESC
=
_descriptor
.
Descriptor
(
name
=
'VarDesc'
,
full_name
=
'paddle.framework.proto.VarDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'name'
,
full_name
=
'paddle.framework.proto.VarDesc.name'
,
index
=
0
,
number
=
1
,
type
=
9
,
cpp_type
=
9
,
label
=
2
,
has_default_value
=
False
,
default_value
=
_b
(
""
).
decode
(
'utf-8'
),
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'type'
,
full_name
=
'paddle.framework.proto.VarDesc.type'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
2
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'persistable'
,
full_name
=
'paddle.framework.proto.VarDesc.persistable'
,
index
=
2
,
number
=
3
,
type
=
8
,
cpp_type
=
7
,
label
=
1
,
has_default_value
=
True
,
default_value
=
False
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
2124
,
serialized_end
=
2222
,
)
_BLOCKDESC
=
_descriptor
.
Descriptor
(
name
=
'BlockDesc'
,
full_name
=
'paddle.framework.proto.BlockDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.idx'
,
index
=
0
,
number
=
1
,
type
=
5
,
cpp_type
=
1
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'parent_idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.parent_idx'
,
index
=
1
,
number
=
2
,
type
=
5
,
cpp_type
=
1
,
label
=
2
,
has_default_value
=
False
,
default_value
=
0
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'vars'
,
full_name
=
'paddle.framework.proto.BlockDesc.vars'
,
index
=
2
,
number
=
3
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'ops'
,
full_name
=
'paddle.framework.proto.BlockDesc.ops'
,
index
=
3
,
number
=
4
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'forward_block_idx'
,
full_name
=
'paddle.framework.proto.BlockDesc.forward_block_idx'
,
index
=
4
,
number
=
5
,
type
=
5
,
cpp_type
=
1
,
label
=
1
,
has_default_value
=
True
,
default_value
=-
1
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
2225
,
serialized_end
=
2392
,
)
_PROGRAMDESC
=
_descriptor
.
Descriptor
(
name
=
'ProgramDesc'
,
full_name
=
'paddle.framework.proto.ProgramDesc'
,
filename
=
None
,
file
=
DESCRIPTOR
,
containing_type
=
None
,
fields
=
[
_descriptor
.
FieldDescriptor
(
name
=
'blocks'
,
full_name
=
'paddle.framework.proto.ProgramDesc.blocks'
,
index
=
0
,
number
=
1
,
type
=
11
,
cpp_type
=
10
,
label
=
3
,
has_default_value
=
False
,
default_value
=
[],
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
_descriptor
.
FieldDescriptor
(
name
=
'version'
,
full_name
=
'paddle.framework.proto.ProgramDesc.version'
,
index
=
1
,
number
=
2
,
type
=
11
,
cpp_type
=
10
,
label
=
1
,
has_default_value
=
False
,
default_value
=
None
,
message_type
=
None
,
enum_type
=
None
,
containing_type
=
None
,
is_extension
=
False
,
extension_scope
=
None
,
options
=
None
),
],
extensions
=
[
],
nested_types
=
[],
enum_types
=
[
],
options
=
None
,
is_extendable
=
False
,
syntax
=
'proto2'
,
extension_ranges
=
[],
oneofs
=
[
],
serialized_start
=
2394
,
serialized_end
=
2508
,
)
_OPDESC_ATTR
.
fields_by_name
[
'type'
].
enum_type
=
_ATTRTYPE
_OPDESC_ATTR
.
containing_type
=
_OPDESC
_OPDESC_VAR
.
containing_type
=
_OPDESC
_OPDESC
.
fields_by_name
[
'inputs'
].
message_type
=
_OPDESC_VAR
_OPDESC
.
fields_by_name
[
'outputs'
].
message_type
=
_OPDESC_VAR
_OPDESC
.
fields_by_name
[
'attrs'
].
message_type
=
_OPDESC_ATTR
_OPPROTO_VAR
.
containing_type
=
_OPPROTO
_OPPROTO_ATTR
.
fields_by_name
[
'type'
].
enum_type
=
_ATTRTYPE
_OPPROTO_ATTR
.
containing_type
=
_OPPROTO
_OPPROTO
.
fields_by_name
[
'inputs'
].
message_type
=
_OPPROTO_VAR
_OPPROTO
.
fields_by_name
[
'outputs'
].
message_type
=
_OPPROTO_VAR
_OPPROTO
.
fields_by_name
[
'attrs'
].
message_type
=
_OPPROTO_ATTR
_VARTYPE_TENSORDESC
.
fields_by_name
[
'data_type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE_TENSORDESC
.
containing_type
=
_VARTYPE
_VARTYPE_LODTENSORDESC
.
fields_by_name
[
'tensor'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE_LODTENSORDESC
.
containing_type
=
_VARTYPE
_VARTYPE_LODTENSORARRAYDESC
.
fields_by_name
[
'tensor'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE_LODTENSORARRAYDESC
.
containing_type
=
_VARTYPE
_VARTYPE_READERDESC
.
fields_by_name
[
'lod_tensor'
].
message_type
=
_VARTYPE_LODTENSORDESC
_VARTYPE_READERDESC
.
containing_type
=
_VARTYPE
_VARTYPE_TUPLE
.
fields_by_name
[
'element_type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE_TUPLE
.
containing_type
=
_VARTYPE
_VARTYPE
.
fields_by_name
[
'type'
].
enum_type
=
_VARTYPE_TYPE
_VARTYPE
.
fields_by_name
[
'selected_rows'
].
message_type
=
_VARTYPE_TENSORDESC
_VARTYPE
.
fields_by_name
[
'lod_tensor'
].
message_type
=
_VARTYPE_LODTENSORDESC
_VARTYPE
.
fields_by_name
[
'tensor_array'
].
message_type
=
_VARTYPE_LODTENSORARRAYDESC
_VARTYPE
.
fields_by_name
[
'reader'
].
message_type
=
_VARTYPE_READERDESC
_VARTYPE
.
fields_by_name
[
'tuple'
].
message_type
=
_VARTYPE_TUPLE
_VARTYPE_TYPE
.
containing_type
=
_VARTYPE
_VARDESC
.
fields_by_name
[
'type'
].
message_type
=
_VARTYPE
_BLOCKDESC
.
fields_by_name
[
'vars'
].
message_type
=
_VARDESC
_BLOCKDESC
.
fields_by_name
[
'ops'
].
message_type
=
_OPDESC
_PROGRAMDESC
.
fields_by_name
[
'blocks'
].
message_type
=
_BLOCKDESC
_PROGRAMDESC
.
fields_by_name
[
'version'
].
message_type
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'Version'
]
=
_VERSION
DESCRIPTOR
.
message_types_by_name
[
'OpDesc'
]
=
_OPDESC
DESCRIPTOR
.
message_types_by_name
[
'OpProto'
]
=
_OPPROTO
DESCRIPTOR
.
message_types_by_name
[
'VarType'
]
=
_VARTYPE
DESCRIPTOR
.
message_types_by_name
[
'VarDesc'
]
=
_VARDESC
DESCRIPTOR
.
message_types_by_name
[
'BlockDesc'
]
=
_BLOCKDESC
DESCRIPTOR
.
message_types_by_name
[
'ProgramDesc'
]
=
_PROGRAMDESC
DESCRIPTOR
.
enum_types_by_name
[
'AttrType'
]
=
_ATTRTYPE
Version
=
_reflection
.
GeneratedProtocolMessageType
(
'Version'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VERSION
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.Version)
))
_sym_db
.
RegisterMessage
(
Version
)
OpDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'OpDesc'
,
(
_message
.
Message
,),
dict
(
Attr
=
_reflection
.
GeneratedProtocolMessageType
(
'Attr'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_OPDESC_ATTR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Attr)
))
,
Var
=
_reflection
.
GeneratedProtocolMessageType
(
'Var'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_OPDESC_VAR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc.Var)
))
,
DESCRIPTOR
=
_OPDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpDesc)
))
_sym_db
.
RegisterMessage
(
OpDesc
)
_sym_db
.
RegisterMessage
(
OpDesc
.
Attr
)
_sym_db
.
RegisterMessage
(
OpDesc
.
Var
)
OpProto
=
_reflection
.
GeneratedProtocolMessageType
(
'OpProto'
,
(
_message
.
Message
,),
dict
(
Var
=
_reflection
.
GeneratedProtocolMessageType
(
'Var'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_OPPROTO_VAR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Var)
))
,
Attr
=
_reflection
.
GeneratedProtocolMessageType
(
'Attr'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_OPPROTO_ATTR
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto.Attr)
))
,
DESCRIPTOR
=
_OPPROTO
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.OpProto)
))
_sym_db
.
RegisterMessage
(
OpProto
)
_sym_db
.
RegisterMessage
(
OpProto
.
Var
)
_sym_db
.
RegisterMessage
(
OpProto
.
Attr
)
VarType
=
_reflection
.
GeneratedProtocolMessageType
(
'VarType'
,
(
_message
.
Message
,),
dict
(
TensorDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'TensorDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARTYPE_TENSORDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.TensorDesc)
))
,
LoDTensorDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'LoDTensorDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARTYPE_LODTENSORDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorDesc)
))
,
LoDTensorArrayDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'LoDTensorArrayDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARTYPE_LODTENSORARRAYDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.LoDTensorArrayDesc)
))
,
ReaderDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'ReaderDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARTYPE_READERDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.ReaderDesc)
))
,
Tuple
=
_reflection
.
GeneratedProtocolMessageType
(
'Tuple'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARTYPE_TUPLE
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType.Tuple)
))
,
DESCRIPTOR
=
_VARTYPE
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarType)
))
_sym_db
.
RegisterMessage
(
VarType
)
_sym_db
.
RegisterMessage
(
VarType
.
TensorDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
LoDTensorDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
LoDTensorArrayDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
ReaderDesc
)
_sym_db
.
RegisterMessage
(
VarType
.
Tuple
)
VarDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'VarDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_VARDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.VarDesc)
))
_sym_db
.
RegisterMessage
(
VarDesc
)
BlockDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'BlockDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_BLOCKDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.BlockDesc)
))
_sym_db
.
RegisterMessage
(
BlockDesc
)
ProgramDesc
=
_reflection
.
GeneratedProtocolMessageType
(
'ProgramDesc'
,
(
_message
.
Message
,),
dict
(
DESCRIPTOR
=
_PROGRAMDESC
,
__module__
=
'framework_pb2'
# @@protoc_insertion_point(class_scope:paddle.framework.proto.ProgramDesc)
))
_sym_db
.
RegisterMessage
(
ProgramDesc
)
DESCRIPTOR
.
has_options
=
True
DESCRIPTOR
.
_options
=
_descriptor
.
_ParseOptions
(
descriptor_pb2
.
FileOptions
(),
_b
(
'H
\003
'
))
# @@protoc_insertion_point(module_scope)
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录