Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
X2Paddle
提交
2a82fdeb
X
X2Paddle
项目概览
PaddlePaddle
/
X2Paddle
大约 1 年 前同步成功
通知
328
Star
698
Fork
167
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
26
列表
看板
标记
里程碑
合并请求
4
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
X
X2Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
26
Issue
26
列表
看板
标记
里程碑
合并请求
4
合并请求
4
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2a82fdeb
编写于
7月 05, 2019
作者:
M
Macrobull
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
fix name and add ops
上级
9828c2c7
变更
10
隐藏空白更改
内联
并排
Showing
10 changed file
with
486 addition
and
434 deletion
+486
-434
onnx2fluid/README.md
onnx2fluid/README.md
+7
-5
onnx2fluid/examples/gen_some_samples.py
onnx2fluid/examples/gen_some_samples.py
+54
-47
onnx2fluid/examples/onnx_model_zoo.sh
onnx2fluid/examples/onnx_model_zoo.sh
+29
-29
onnx2fluid/onnx2fluid/conversion.py
onnx2fluid/onnx2fluid/conversion.py
+80
-43
onnx2fluid/onnx2fluid/onnx_utils.py
onnx2fluid/onnx2fluid/onnx_utils.py
+31
-6
onnx2fluid/onnx2fluid/symbolic.py
onnx2fluid/onnx2fluid/symbolic.py
+241
-262
onnx2fluid/onnx2fluid/torch_export_helper.py
onnx2fluid/onnx2fluid/torch_export_helper.py
+7
-2
onnx2fluid/onnx2fluid/validation.py
onnx2fluid/onnx2fluid/validation.py
+5
-2
onnx2fluid/onnx2fluid/writer.py
onnx2fluid/onnx2fluid/writer.py
+31
-37
onnx2fluid/requirements.txt
onnx2fluid/requirements.txt
+1
-1
未找到文件。
onnx2fluid/README.md
浏览文件 @
2a82fdeb
...
@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用
...
@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用
在如下环境配置中测试成功:
在如下环境配置中测试成功:
*
python 3.5+
*
python 3.5+
*
onnx == 1.4.
0
*
onnx == 1.4.
1
*
paddlepaddle == 1.
3
.0 (可选,仅用于验证)
*
paddlepaddle == 1.
5
.0 (可选,仅用于验证)
使用
[
Anaconda
](
https://docs.anaconda.com/anaconda/install
)
:
使用
[
Anaconda
](
https://docs.anaconda.com/anaconda/install
)
:
```
shell
```
shell
conda
install
-c
conda-forge onnx
conda
install
-c
conda-forge onnx
pip
install
paddlepaddle
==
1.
3
.0
pip
install
paddlepaddle
==
1.
5
.0
```
```
## 动手玩
## 动手玩
...
@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz
...
@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz
## 使用说明
## 使用说明
目前支持
**ONNX opset 9+**
的部分算子,对应PyTorch版本
**1.0/1.1(stable opset)**
,更多兼容信息请参考
[
ONNX文档
](
https://github.com/onnx/onnx/blob/master/docs/Operators.md
)
onnx2fluid:
onnx2fluid:
```
shell
```
shell
...
@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx
...
@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx
## 参考
## 参考
*
PaddlePaddle
[
算子
](
http://www.paddlepaddle.org/documentation/docs/zh/1.
4
/api_cn/layers_cn.html
)
*
PaddlePaddle
[
算子
](
http://www.paddlepaddle.org/documentation/docs/zh/1.
5
/api_cn/layers_cn.html
)
*
PaddlePaddle
[
加载预测模型
](
http://www.paddlepaddle.org/documentation/docs/zh/1.
4
/api_guides/low_level/inference.html#id4
)
*
PaddlePaddle
[
加载预测模型
](
http://www.paddlepaddle.org/documentation/docs/zh/1.
5
/api_guides/low_level/inference.html#id4
)
onnx2fluid/examples/gen_some_samples.py
浏览文件 @
2a82fdeb
...
@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
...
@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix
=
'sample_'
prefix
=
'sample_'
idx
=
0
idx
=
0
######### example: RNN ########
######## example: RNN ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.rnn = nn.RNN(4, 6, 2)
#
# def forward(self, x):
# y = x
# y, h = self.rnn(y)
# return y
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
#
class
Model
(
nn
.
Module
):
#class Model(nn.Module):
def
__init__
(
self
):
# def __init__(self):
super
(
Model
,
self
).
__init__
()
# super(Model, self).__init__()
self
.
gru
=
nn
.
GRU
(
4
,
5
,
3
)
#
self
.
lstm
=
nn
.
LSTM
(
5
,
6
,
2
)
# def forward(self, x):
# y = torch.rand((2, 3)) # + torch.rand_like(xb)
def
forward
(
self
,
x
):
# y = y + torch.randn((2, 3)) # + torch.randn_like(xb)
y
=
x
# return y
y
,
h
=
self
.
gru
(
y
)
#
y
,
h
=
self
.
lstm
(
y
)
#
return
y
#model = Model()
#model.eval()
#xb = torch.rand((2, 3))
model
=
Model
()
#yp = model(xb)
model
.
eval
()
#idx += 1
xb
=
torch
.
rand
((
2
,
3
,
4
))
#print('index: ', idx)
yp
=
model
(
xb
)
#export_onnx_with_validation(model, [xb], prefix + str(idx),
idx
+=
1
# ['x'], ['y'],
print
(
'index: '
,
idx
)
# verbose=True, training=False)
export_onnx_with_validation
(
model
,
[
xb
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
######## example: random ########
class
Model
(
nn
.
Module
):
def
__init__
(
self
):
super
(
Model
,
self
).
__init__
()
def
forward
(
self
,
x
):
y
=
torch
.
rand
((
2
,
3
))
# + torch.rand_like(xb)
y
=
y
+
torch
.
randn
((
2
,
3
))
# + torch.randn_like(xb)
return
y
model
=
Model
()
model
.
eval
()
xb
=
torch
.
rand
((
2
,
3
))
yp
=
model
(
xb
)
idx
+=
1
print
(
'index: '
,
idx
)
export_onnx_with_validation
(
model
,
[
xb
],
prefix
+
str
(
idx
),
[
'x'
],
[
'y'
],
verbose
=
True
,
training
=
False
)
######## example: fc ########
######## example: fc ########
...
@@ -175,7 +181,7 @@ class Model(nn.Module):
...
@@ -175,7 +181,7 @@ class Model(nn.Module):
super
(
Model
,
self
).
__init__
()
super
(
Model
,
self
).
__init__
()
self
.
conv
=
nn
.
Conv2d
(
3
,
8
,
3
)
self
.
conv
=
nn
.
Conv2d
(
3
,
8
,
3
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
8
)
self
.
batch_norm
=
nn
.
BatchNorm2d
(
8
)
self
.
pool
=
nn
.
AdaptiveAvgPool2d
(
2
)
self
.
pool
=
nn
.
AdaptiveAvgPool2d
(
1
)
def
forward
(
self
,
x
):
def
forward
(
self
,
x
):
y
=
x
y
=
x
...
@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb],
...
@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb],
#yp = model(xb)
#yp = model(xb)
#idx += 1
#idx += 1
#print('index: ', idx)
#print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx),
#export_onnx_with_validation(
# ['x'], ['y'],
# model, [xb], prefix + str(idx),
# verbose=True, training=False)
# ['x'], ['y'],
# verbose=True, training=False)
######## example: empty ########
######## example: empty ########
...
...
onnx2fluid/examples/onnx_model_zoo.sh
浏览文件 @
2a82fdeb
...
@@ -24,14 +24,14 @@ bvlc_alexnet()
...
@@ -24,14 +24,14 @@ bvlc_alexnet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
/"
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -54,7 +54,7 @@ bvlc_googlenet()
...
@@ -54,7 +54,7 @@ bvlc_googlenet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -77,7 +77,7 @@ bvlc_reference_caffenet()
...
@@ -77,7 +77,7 @@ bvlc_reference_caffenet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13()
...
@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 fc-rcnn_1
python convert_data_pb.py
"
$pb_dir
"
data_0 fc-rcnn_1
...
@@ -123,14 +123,14 @@ densenet121()
...
@@ -123,14 +123,14 @@ densenet121()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
/"
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 fc6_1
-s
python convert_data_npz.py
"
$npz
"
data_0 fc6_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 fc6_1
python convert_data_pb.py
"
$pb_dir
"
data_0 fc6_1
...
@@ -153,7 +153,7 @@ emotion_ferplus()
...
@@ -153,7 +153,7 @@ emotion_ferplus()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
Input3 Plus692_Output_0
python convert_data_pb.py
"
$pb_dir
"
Input3 Plus692_Output_0
...
@@ -176,14 +176,14 @@ inception_v1()
...
@@ -176,14 +176,14 @@ inception_v1()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
/"
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -206,14 +206,14 @@ inception_v2()
...
@@ -206,14 +206,14 @@ inception_v2()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
/"
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python convert_data_npz.py
"
$npz
"
data_0 prob_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -236,7 +236,7 @@ mobilenet()
...
@@ -236,7 +236,7 @@ mobilenet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data mobilenetv20_output_flatten0_reshape0
python convert_data_pb.py
"
$pb_dir
"
data mobilenetv20_output_flatten0_reshape0
...
@@ -259,7 +259,7 @@ resnet18()
...
@@ -259,7 +259,7 @@ resnet18()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv15_dense0_fwd
python convert_data_pb.py
"
$pb_dir
"
data resnetv15_dense0_fwd
...
@@ -282,14 +282,14 @@ resnet50()
...
@@ -282,14 +282,14 @@ resnet50()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
npz
in
"
$bn_tar
"
/
*
.npz
for
npz
in
"
$bn_tar
/"
*
.npz
do
do
echo
"converting
$npz
..."
echo
"converting
$npz
..."
python convert_data_npz.py
"
$npz
"
gpu_0/data_0 gpu_0/softmaxout_1
-s
python convert_data_npz.py
"
$npz
"
gpu_0/data_0 gpu_0/softmaxout_1
-s
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags1
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
python
-m
onnx2fluid.validation
$validate_flags2
-t
"
$npz
"
done
done
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout_1
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmaxout_1
...
@@ -312,7 +312,7 @@ resnet100_arcface()
...
@@ -312,7 +312,7 @@ resnet100_arcface()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data fc1
python convert_data_pb.py
"
$pb_dir
"
data fc1
...
@@ -335,7 +335,7 @@ resnet101_duc()
...
@@ -335,7 +335,7 @@ resnet101_duc()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data seg_loss
python convert_data_pb.py
"
$pb_dir
"
data seg_loss
...
@@ -358,7 +358,7 @@ resnet152()
...
@@ -358,7 +358,7 @@ resnet152()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data resnetv27_dense0_fwd
python convert_data_pb.py
"
$pb_dir
"
data resnetv27_dense0_fwd
...
@@ -381,7 +381,7 @@ shufflenet()
...
@@ -381,7 +381,7 @@ shufflenet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
...
@@ -404,7 +404,7 @@ squeezenet()
...
@@ -404,7 +404,7 @@ squeezenet()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 softmaxout_1
python convert_data_pb.py
"
$pb_dir
"
data_0 softmaxout_1
...
@@ -427,7 +427,7 @@ squeezenet1v1()
...
@@ -427,7 +427,7 @@ squeezenet1v1()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data squeezenet0_flatten0_reshape0
python convert_data_pb.py
"
$pb_dir
"
data squeezenet0_flatten0_reshape0
...
@@ -448,10 +448,10 @@ ssd()
...
@@ -448,10 +448,10 @@ ssd()
rm
-rf
"
$bn_tar
/"
rm
-rf
"
$bn_tar
/"
echo
"extracting ..."
echo
"extracting ..."
mkdir
"
$bn_tar
"
mkdir
"
$bn_tar
"
tar
xf
"
$fn_tar
"
-C
"
$bn_tar
"
/
tar
xf
"
$fn_tar
"
-C
"
$bn_tar
/"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
image bboxes,labels,scores
python convert_data_pb.py
"
$pb_dir
"
image bboxes,labels,scores
...
@@ -474,7 +474,7 @@ tiny_yolov2()
...
@@ -474,7 +474,7 @@ tiny_yolov2()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
image grid
python convert_data_pb.py
"
$pb_dir
"
image grid
...
@@ -497,7 +497,7 @@ vgg16bn()
...
@@ -497,7 +497,7 @@ vgg16bn()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-y
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
data vgg0_dense2_fwd
python convert_data_pb.py
"
$pb_dir
"
data vgg0_dense2_fwd
...
@@ -520,7 +520,7 @@ vgg19()
...
@@ -520,7 +520,7 @@ vgg19()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
python convert_data_pb.py
"
$pb_dir
"
data_0 prob_1
...
@@ -543,7 +543,7 @@ yolov3()
...
@@ -543,7 +543,7 @@ yolov3()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-x
#
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
-x
#
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
..."
echo
"converting
$pb_dir
..."
python convert_data_pb.py
"
$pb_dir
"
input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python convert_data_pb.py
"
$pb_dir
"
input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
...
@@ -566,7 +566,7 @@ zfnet512()
...
@@ -566,7 +566,7 @@ zfnet512()
tar
xf
"
$fn_tar
"
tar
xf
"
$fn_tar
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
python
-m
onnx2fluid
$convert_flags
"
$fn_model
"
for
pb_dir
in
"
$bn_tar
"
/
*
/
for
pb_dir
in
"
$bn_tar
/"
*
/
do
do
echo
"converting
$pb_dir
"
echo
"converting
$pb_dir
"
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
python convert_data_pb.py
"
$pb_dir
"
gpu_0/data_0 gpu_0/softmax_1
...
...
onnx2fluid/onnx2fluid/conversion.py
浏览文件 @
2a82fdeb
...
@@ -17,6 +17,22 @@ __all__ = [
...
@@ -17,6 +17,22 @@ __all__ = [
DEFAULT_ONNX_OPSET_VERSION
=
9
DEFAULT_ONNX_OPSET_VERSION
=
9
def
make_var_name
(
name
):
"""
make a valid variable name in Python code and filename in filesystem
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:.-'
:
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
def
convert
(
onnx_model_filename
,
def
convert
(
onnx_model_filename
,
save_dir
,
save_dir
,
model_basename
=
'model.py'
,
model_basename
=
'model.py'
,
...
@@ -30,6 +46,12 @@ def convert(onnx_model_filename,
...
@@ -30,6 +46,12 @@ def convert(onnx_model_filename,
convert an ONNX model to Paddle fluid Python code and desc pb
convert an ONNX model to Paddle fluid Python code and desc pb
"""
"""
assert
isinstance
(
onnx_model_filename
,
str
)
assert
isinstance
(
save_dir
,
str
)
assert
isinstance
(
model_basename
,
str
)
assert
isinstance
(
model_func_name
,
str
)
assert
onnx_opset_version
is
None
or
isinstance
(
onnx_opset_version
,
int
)
import
onnx
import
onnx
from
onnx.checker
import
ValidationError
from
onnx.checker
import
ValidationError
...
@@ -41,7 +63,6 @@ def convert(onnx_model_filename,
...
@@ -41,7 +63,6 @@ def convert(onnx_model_filename,
from
.onnx_utils
import
inferred_model_value_info
from
.onnx_utils
import
inferred_model_value_info
from
.onnx_utils
import
polish_model
from
.onnx_utils
import
polish_model
from
.writer
import
Program
,
Writer
from
.writer
import
Program
,
Writer
from
.writer
import
make_var_name
logger
=
logging
.
getLogger
(
'convert'
)
logger
=
logging
.
getLogger
(
'convert'
)
...
@@ -88,17 +109,21 @@ def convert(onnx_model_filename,
...
@@ -88,17 +109,21 @@ def convert(onnx_model_filename,
fluid_writer
=
Writer
()
fluid_writer
=
Writer
()
# model components
# model components
# graph_name = onnx_graph.name
inp_vars
=
[
make_var_name
(
value
.
name
)
for
value
in
onnx_graph
.
input
]
graph_inputs
=
[
value
.
name
for
value
in
onnx_graph
.
input
]
out_vars
=
[
make_var_name
(
value
.
name
)
for
value
in
onnx_graph
.
output
]
graph_outputs
=
[
value
.
name
for
value
in
onnx_graph
.
output
]
par_vars
=
[]
graph_params
=
[]
value_infos
=
inferred_model_value_info
(
onnx_model
)
graph_value_infos
=
inferred_model_value_info
(
onnx_model
)
value_infos
=
{
make_var_name
(
key
):
value
for
key
,
value
in
value_infos
.
items
()
}
# prepare additional value_info
# prepare additional value_info
# for weights
# for weights
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
value_info
=
graph_value_infos
[
name
]
var_name
=
make_var_name
(
name
)
value_info
[
'embeded_as'
]
=
[]
value_info
=
value_infos
[
var_name
]
value_info
[
'embedded_as'
]
=
[]
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
value_info
[
'get_weight'
]
=
(
lambda
w
:
lambda
:
w
.
tolist
())(
weight
)
# lazy getter
weight
)
# lazy getter
...
@@ -108,19 +133,23 @@ def convert(onnx_model_filename,
...
@@ -108,19 +133,23 @@ def convert(onnx_model_filename,
topo
=
'forward'
topo
=
'forward'
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
for
name
,
domain
,
op_type
,
inputs
,
outputs
,
attrs
in
graph_ops
(
onnx_graph
,
topo
=
topo
):
topo
=
topo
):
logger
.
debug
(
'translating op %s %s::%s ...'
,
name
,
domain
,
op_type
)
op_name
=
make_var_name
(
name
)
inputs
=
[
make_var_name
(
val
)
for
val
in
inputs
]
outputs
=
[
make_var_name
(
val
)
for
val
in
outputs
]
logger
.
debug
(
'translating op %s(%s) %s::%s ...'
,
name
,
op_name
,
domain
,
op_type
)
if
domain
==
DEFAULT_OP_DOMAIN
:
if
domain
==
DEFAULT_OP_DOMAIN
:
domain
=
''
domain
=
''
try
:
try
:
fluid_writer
.
emit_op
(
fluid_writer
.
emit_op
(
fluid_program
,
fluid_program
,
name
,
op_
name
,
domain
,
domain
,
op_type
,
op_type
,
inputs
,
inputs
,
outputs
,
outputs
,
attrs
,
attrs
,
graph_
value_infos
,
value_infos
,
embed_params
=
embed_params
,
embed_params
=
embed_params
,
)
)
except
BaseException
as
e
:
except
BaseException
as
e
:
...
@@ -133,17 +162,16 @@ def convert(onnx_model_filename,
...
@@ -133,17 +162,16 @@ def convert(onnx_model_filename,
len
(
fluid_program
.
op_descs
))
len
(
fluid_program
.
op_descs
))
# type-shape info copy
# type-shape info copy
for
name
,
value_info
in
graph_value_infos
.
items
():
for
var_name
,
value_info
in
value_infos
.
items
():
var_name
=
make_var_name
(
name
)
fluid_program
.
VarTypeShapeInfo
(
var_name
,
value_info
,
fluid_program
.
VarTypeShapeInfo
(
var_name
,
value_info
,
remove_batch
=
False
)
#
remove_batch
=
False
)
#
bad_var
_name
s
=
[]
bad_vars
=
[]
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
for
var_name
,
var_desc
in
fluid_program
.
var_descs
.
items
():
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
if
not
var_desc
.
type
.
lod_tensor
.
HasField
(
'tensor'
):
bad_var
_name
s
.
append
(
var_name
)
bad_vars
.
append
(
var_name
)
if
len
(
bad_var
_name
s
)
>
0
:
if
len
(
bad_vars
)
>
0
:
logger
.
warning
(
'type-shape not infered for var %s ...'
,
logger
.
warning
(
'type-shape not infered for var %s ...'
,
', '
.
join
(
bad_var
_name
s
[:
5
]))
', '
.
join
(
bad_vars
[:
5
]))
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
logger
.
warning
(
'this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly'
)
'but Paddle Mobile may not infer correctly'
)
logger
.
warning
(
'please consider running validation with -i '
logger
.
warning
(
'please consider running validation with -i '
...
@@ -151,40 +179,41 @@ def convert(onnx_model_filename,
...
@@ -151,40 +179,41 @@ def convert(onnx_model_filename,
# weight writer
# weight writer
for
name
,
weight
in
graph_weights
(
onnx_graph
):
for
name
,
weight
in
graph_weights
(
onnx_graph
):
graph_params
.
append
(
name
)
var_name
=
make_var_name
(
name
)
value_info
=
graph_value_infos
[
name
]
par_vars
.
append
(
var_name
)
var_names
=
value_info
.
get
(
'embeded_as'
,
[])
value_info
=
value_infos
[
var_name
]
if
var_names
:
embedded_names
=
value_info
.
get
(
'embedded_as'
,
[])
if
len
(
var_names
)
>
1
:
if
embedded_names
:
if
len
(
embedded_names
)
>
1
:
logger
.
info
(
logger
.
info
(
'weight %s is shared between ops, more disk space will be consumed'
,
'weight %s is shared between ops, more disk space will be consumed'
,
name
)
name
)
logger
.
debug
(
'saving weight %s(%s[%d], %dB) as %s ...'
,
name
,
logger
.
debug
(
'saving weight %s(%s[%d], %dB) as %s ...'
,
name
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
var_names
)
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
for
var_name
in
var_names
:
# multiple references
embedded_names
)
for
embedded_name
in
embedded_names
:
# multiple references
fluid_writer
.
write_weight
(
fluid_writer
.
write_weight
(
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
var
_name
))
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
embedded
_name
))
else
:
else
:
logger
.
debug
(
'saving weight %s(%s[%d], %dB) to %s ...'
,
name
,
logger
.
debug
(
'saving weight %s(%s[%d], %dB) to %s ...'
,
name
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
weight
.
dtype
,
weight
.
size
,
weight
.
nbytes
,
var_name
)
make_var_name
(
name
))
fluid_writer
.
write_weight
(
weight
,
fluid_writer
.
write_weight
(
shutil
.
os
.
path
.
join
(
save_dir
,
var_name
))
weight
,
shutil
.
os
.
path
.
join
(
save_dir
,
make_var_name
(
name
)))
fluid_writer
.
emit_param
(
fluid_program
,
var_name
,
value_info
)
fluid_writer
.
emit_param
(
fluid_program
,
name
,
value_info
)
param_codes
=
fluid_program
.
codes
param_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
fluid_program
.
codes
=
[]
logger
.
info
(
'%d weights converted'
,
len
(
graph_param
s
))
logger
.
info
(
'%d weights converted'
,
len
(
par_var
s
))
# input writer
# input writer
external_inputs
=
[]
external_inputs
=
[]
for
name
in
graph_input
s
:
for
var_name
in
inp_var
s
:
if
name
not
in
graph_param
s
:
if
var_name
not
in
par_var
s
:
value_info
=
graph_value_infos
[
name
]
value_info
=
value_infos
[
var_
name
]
assert
value_info
[
'external'
]
assert
value_info
[
'external'
]
external_inputs
.
append
(
name
)
external_inputs
.
append
(
var_
name
)
fluid_writer
.
emit_inputs
(
fluid_program
,
fluid_writer
.
emit_inputs
(
fluid_program
,
external_inputs
,
external_inputs
,
graph_
value_infos
,
value_infos
,
remove_batch
=
False
)
# TODO:
remove_batch
=
False
)
# TODO:
input_codes
=
fluid_program
.
codes
input_codes
=
fluid_program
.
codes
fluid_program
.
codes
=
[]
fluid_program
.
codes
=
[]
...
@@ -192,11 +221,11 @@ def convert(onnx_model_filename,
...
@@ -192,11 +221,11 @@ def convert(onnx_model_filename,
# output writer
# output writer
external_outputs
=
[]
external_outputs
=
[]
for
name
in
graph_output
s
:
for
var_name
in
out_var
s
:
if
name
not
in
graph_param
s
:
if
var_name
not
in
par_var
s
:
value_info
=
graph_value_infos
[
name
]
value_info
=
value_infos
[
var_
name
]
assert
value_info
[
'external'
]
assert
value_info
[
'external'
]
external_outputs
.
append
(
name
)
external_outputs
.
append
(
var_
name
)
fluid_writer
.
emit_outputs
(
fluid_program
,
external_outputs
)
fluid_writer
.
emit_outputs
(
fluid_program
,
external_outputs
)
output_codes
=
[
''
]
+
fluid_program
.
codes
# add an empty line
output_codes
=
[
''
]
+
fluid_program
.
codes
# add an empty line
fluid_program
.
codes
=
[]
fluid_program
.
codes
=
[]
...
@@ -204,10 +233,18 @@ def convert(onnx_model_filename,
...
@@ -204,10 +233,18 @@ def convert(onnx_model_filename,
# code generation
# code generation
header_codes
=
fluid_writer
.
header_code
(
header_codes
=
fluid_writer
.
header_code
(
model_func_name
,
'From: {}'
.
format
(
onnx_model_filename
))
model_func_name
,
'From: {}'
.
format
(
onnx_model_filename
),
)
code_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
)
code_filename
=
shutil
.
os
.
path
.
join
(
save_dir
,
model_basename
)
fluid_writer
.
write_code_file
(
code_filename
,
header_codes
,
input_codes
,
fluid_writer
.
write_code_file
(
param_codes
,
op_codes
,
output_codes
)
code_filename
,
header_codes
,
input_codes
,
param_codes
,
op_codes
,
output_codes
,
)
logger
.
info
(
'code saved to %s, factory function: %s'
,
code_filename
,
logger
.
info
(
'code saved to %s, factory function: %s'
,
code_filename
,
model_func_name
)
model_func_name
)
...
...
onnx2fluid/onnx2fluid/onnx_utils.py
浏览文件 @
2a82fdeb
...
@@ -87,6 +87,9 @@ def get_attribute_value2(attr):
...
@@ -87,6 +87,9 @@ def get_attribute_value2(attr):
get_attribute_value enhanced
get_attribute_value enhanced
"""
"""
assert
isinstance
(
attr
,
onnx
.
AttributeProto
),
'attr is not a AttributeProto instance'
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
if
attr
.
type
==
onnx
.
AttributeProto
.
TENSOR
:
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
dtype
=
np
.
dtype
(
TENSOR_TYPE_TO_NP_TYPE
[
attr
.
t
.
data_type
])
data
=
attr
.
t
.
raw_data
data
=
attr
.
t
.
raw_data
...
@@ -106,6 +109,9 @@ def tensor_dtype(tensor):
...
@@ -106,6 +109,9 @@ def tensor_dtype(tensor):
get ONNX tensor in np.dtype
get ONNX tensor in np.dtype
"""
"""
assert
isinstance
(
tensor
,
onnx
.
ValueInfoProto
),
'tensor is not a ValueInfoProto instance'
return
TENSOR_TYPE_TO_NP_TYPE
[
tensor
.
type
.
tensor_type
.
elem_type
]
return
TENSOR_TYPE_TO_NP_TYPE
[
tensor
.
type
.
tensor_type
.
elem_type
]
...
@@ -114,6 +120,9 @@ def tensor_shape(tensor):
...
@@ -114,6 +120,9 @@ def tensor_shape(tensor):
get ONNX tensor shape
get ONNX tensor shape
"""
"""
assert
isinstance
(
tensor
,
onnx
.
ValueInfoProto
),
'tensor is not a ValueInfoProto instance'
return
tuple
([
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
])
return
tuple
([
dim
.
dim_value
for
dim
in
tensor
.
type
.
tensor_type
.
shape
.
dim
])
...
@@ -122,6 +131,8 @@ def node_attrs(node):
...
@@ -122,6 +131,8 @@ def node_attrs(node):
convert ONNX node attributes to dict
convert ONNX node attributes to dict
"""
"""
assert
isinstance
(
node
,
onnx
.
NodeProto
),
'node is not a NodeProto instance'
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
return
{
attr
.
name
:
get_attribute_value2
(
attr
)
for
attr
in
node
.
attribute
}
# dict
for
attr
in
node
.
attribute
}
# dict
...
@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'):
...
@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'):
generator for ONNX node graph with given topology
generator for ONNX node graph with given topology
"""
"""
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
assert
isinstance
(
graph
,
logger
.
error
(
'graph is not a GraphProto instance'
)
onnx
.
GraphProto
),
'graph is not a GraphProto instance'
return
return
node_iter
(
graph
.
node
,
node_topo
(
graph
.
node
,
topo
))
return
node_iter
(
graph
.
node
,
node_topo
(
graph
.
node
,
topo
))
...
@@ -236,9 +246,8 @@ def graph_weights(graph):
...
@@ -236,9 +246,8 @@ def graph_weights(graph):
generator for weights of an ONNX model
generator for weights of an ONNX model
"""
"""
if
not
isinstance
(
graph
,
onnx
.
GraphProto
):
assert
isinstance
(
graph
,
logger
.
error
(
'graph is not a GraphProto instance'
)
onnx
.
GraphProto
),
'graph is not a GraphProto instance'
return
for
initializer
in
graph
.
initializer
:
for
initializer
in
graph
.
initializer
:
name
=
initializer
.
name
name
=
initializer
.
name
...
@@ -251,6 +260,9 @@ def inferred_model_value_info(model):
...
@@ -251,6 +260,9 @@ def inferred_model_value_info(model):
collect value/type info for an ONNX model
collect value/type info for an ONNX model
"""
"""
assert
isinstance
(
model
,
onnx
.
ModelProto
),
'model is not a ModelProto instance'
model
=
infer_shapes
(
model
)
model
=
infer_shapes
(
model
)
graph
=
model
.
graph
graph
=
model
.
graph
value_info
=
Dict
()
value_info
=
Dict
()
...
@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
...
@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
"""
"""
skip ops can be bypassed for inference
skip ops can be bypassed for inference
"""
"""
assert
isinstance
(
model
,
onnx
.
ModelProto
),
'model is not a ModelProto instance'
if
op_list
is
None
:
if
op_list
is
None
:
op_list
=
(
'Dropout'
,
'Identity'
)
op_list
=
(
'Dropout'
,
'Identity'
)
...
@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
...
@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
strip weights for inference
strip weights for inference
"""
"""
assert
isinstance
(
model
,
onnx
.
ModelProto
),
'model is not a ModelProto instance'
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
out_names
=
[
val
.
name
for
val
in
model
.
graph
.
output
]
out_names
=
[
val
.
name
for
val
in
model
.
graph
.
output
]
...
@@ -456,6 +475,9 @@ def optimize_model_cast(model):
...
@@ -456,6 +475,9 @@ def optimize_model_cast(model):
strip cascade and unecessary onnx::Cast-9:
strip cascade and unecessary onnx::Cast-9:
"""
"""
assert
isinstance
(
model
,
onnx
.
ModelProto
),
'model is not a ModelProto instance'
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
value_info
=
inferred_model_value_info
(
model
)
value_info
=
inferred_model_value_info
(
model
)
...
@@ -513,6 +535,9 @@ def optimize_model_slice(model):
...
@@ -513,6 +535,9 @@ def optimize_model_slice(model):
strip cascade and unecessary onnx::Slice-1:9
strip cascade and unecessary onnx::Slice-1:9
"""
"""
assert
isinstance
(
model
,
onnx
.
ModelProto
),
'model is not a ModelProto instance'
nodes
=
model
.
graph
.
node
nodes
=
model
.
graph
.
node
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
input_refs
,
output_refs
=
build_value_refs
(
nodes
)
...
...
onnx2fluid/onnx2fluid/symbolic.py
浏览文件 @
2a82fdeb
...
@@ -50,6 +50,9 @@ DEFAULT_OP_MAPPING = {
...
@@ -50,6 +50,9 @@ DEFAULT_OP_MAPPING = {
dict
(),
None
,
None
,
False
],
dict
(),
None
,
None
,
False
],
## unary ops ##
## unary ops ##
'Abs'
:
[
'abs'
,
[
'X'
],
[
'Out'
]],
'Abs'
:
[
'abs'
,
[
'X'
],
[
'Out'
]],
'Acos'
:
[
'acos'
,
[
'X'
],
[
'Out'
]],
'Asin'
:
[
'asin'
,
[
'X'
],
[
'Out'
]],
'Atan'
:
[
'atan'
,
[
'X'
],
[
'Out'
]],
'ArgMax'
:
[
'argmax'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMax'
:
[
'argmax'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMin'
:
[
'argmin'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'ArgMin'
:
[
'argmin'
,
[
'X'
],
[
'Out'
],
dict
(
keepdims
=
''
)],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
'Ceil'
:
[
'ceil'
,
[
'X'
],
[
'Out'
]],
...
@@ -144,52 +147,36 @@ DEFAULT_IOA_CONSTRAINTS = {
...
@@ -144,52 +147,36 @@ DEFAULT_IOA_CONSTRAINTS = {
}
}
def
_make_var_name
(
name
):
def
_dtype
(
value_infos
,
name
):
"""
return
_np
.
dtype
(
value_infos
[
name
][
'dtype'
])
make a valid variable name in Python code and in filesystem
"""
if
name
==
''
:
return
'_'
if
name
[
0
].
isdigit
():
return
'var_'
+
name
for
s
in
'
\\
|/:-'
:
#
name
=
name
.
replace
(
s
,
'_'
)
if
name
.
startswith
(
'_'
):
name
=
'var'
+
name
return
name
def
_dtype
(
value_infos
,
val_name
):
def
_dtype_or_none
(
value_infos
,
name
):
return
_np
.
dtype
(
value_infos
[
val_name
][
'dtype'
])
if
name
not
in
value_infos
:
def
_dtype_or_none
(
value_infos
,
val_name
):
if
val_name
not
in
value_infos
:
return
None
return
None
value_info
=
value_infos
[
val_
name
]
value_info
=
value_infos
[
name
]
if
'dtype'
not
in
value_info
:
if
'dtype'
not
in
value_info
:
return
None
return
None
return
_np
.
dtype
(
value_info
[
'dtype'
])
return
_np
.
dtype
(
value_info
[
'dtype'
])
def
_shape
(
value_infos
,
val_
name
):
def
_shape
(
value_infos
,
name
):
return
list
(
value_infos
[
val_
name
][
'shape'
])
return
list
(
value_infos
[
name
][
'shape'
])
def
_shape_or_none
(
value_infos
,
val_
name
):
def
_shape_or_none
(
value_infos
,
name
):
if
val_
name
not
in
value_infos
:
if
name
not
in
value_infos
:
return
None
return
None
value_info
=
value_infos
[
val_
name
]
value_info
=
value_infos
[
name
]
if
'shape'
not
in
value_info
:
if
'shape'
not
in
value_info
:
return
None
return
None
return
list
(
value_info
[
'shape'
])
return
list
(
value_info
[
'shape'
])
def
_const_weight_or_none
(
value_infos
,
val_
name
):
def
_const_weight_or_none
(
value_infos
,
name
):
if
val_
name
not
in
value_infos
:
if
name
not
in
value_infos
:
return
None
return
None
value_info
=
value_infos
[
val_
name
]
value_info
=
value_infos
[
name
]
const_value
=
value_info
.
get
(
'const_value'
,
None
)
const_value
=
value_info
.
get
(
'const_value'
,
None
)
if
const_value
is
not
None
:
if
const_value
is
not
None
:
return
const_value
return
const_value
...
@@ -199,11 +186,11 @@ def _const_weight_or_none(value_infos, val_name):
...
@@ -199,11 +186,11 @@ def _const_weight_or_none(value_infos, val_name):
return
None
return
None
def
_check_embeddable
(
value_infos
,
*
val_
names
):
def
_check_embeddable
(
value_infos
,
*
names
):
keyword
=
'get_weight'
keyword
=
'get_weight'
for
val_name
in
val_
names
:
for
name
in
names
:
if
keyword
not
in
value_infos
[
val_
name
]:
if
keyword
not
in
value_infos
[
name
]:
_logger
.
warning
(
'parameter %s not embeddable'
,
val_
name
)
_logger
.
warning
(
'parameter %s not embeddable'
,
name
)
return
False
return
False
return
True
return
True
...
@@ -239,12 +226,10 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -239,12 +226,10 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_attrs
=
default_attrs
.
copy
()
fluid_attrs
=
default_attrs
.
copy
()
fluid_attrs
.
update
(
mapped_attrs
)
# as new attrs
fluid_attrs
.
update
(
mapped_attrs
)
# as new attrs
val_inps
=
inputs
if
input_perm
is
None
else
map
(
inputs
.
__getitem__
,
var_inps
=
inputs
if
input_perm
is
None
else
list
(
input_perm
)
map
(
inputs
.
__getitem__
,
input_perm
))
val_outs
=
outputs
if
output_perm
is
None
else
map
(
outputs
.
__getitem__
,
var_outs
=
outputs
if
output_perm
is
None
else
list
(
output_perm
)
map
(
outputs
.
__getitem__
,
output_perm
))
var_inps
=
[
_make_var_name
(
val
)
for
val
in
val_inps
]
var_outs
=
[
_make_var_name
(
val
)
for
val
in
val_outs
]
arg_name
=
', name={}'
.
format
(
arg_name
=
', name={}'
.
format
(
repr
(
name
))
if
fill_name_field
and
name
else
''
repr
(
name
))
if
fill_name_field
and
name
else
''
arg_attrs
=
[
arg_attrs
=
[
...
@@ -277,9 +262,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -277,9 +262,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
def
_assign
(
prog
,
mapping
):
def
_assign
(
prog
,
mapping
):
fluid_op
=
'assign'
fluid_op
=
'assign'
for
val_dst
,
val_src
in
mapping
.
items
():
for
var_dst
,
var_src
in
mapping
.
items
():
var_dst
=
_make_var_name
(
val_dst
)
var_src
=
_make_var_name
(
val_src
)
prog
.
Code
(
'{} = {} # assign'
.
format
(
var_dst
,
var_src
))
prog
.
Code
(
'{} = {} # assign'
.
format
(
var_dst
,
var_src
))
# prog.Code('{} = layers.{}({})'
# prog.Code('{} = layers.{}({})'
# .format(var_dst,
# .format(var_dst,
...
@@ -295,18 +278,18 @@ def _assign(prog, mapping):
...
@@ -295,18 +278,18 @@ def _assign(prog, mapping):
)
)
def
_zeros_like
(
prog
,
va
l_ref
,
val
_out
,
value_infos
):
def
_zeros_like
(
prog
,
va
r_ref
,
var
_out
,
value_infos
):
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Sub'
,
'Sub'
,
[
va
l_ref
,
val
_ref
],
[
va
r_ref
,
var
_ref
],
[
va
l_out
],
# val
[
va
r_out
],
{
'axis'
:
0
},
{
'axis'
:
0
},
value_infos
,
value_infos
,
)
)
def
_pad_if_asymmetric
(
prog
,
pads
,
va
l
_name
,
value_infos
):
# pads: SSEE
def
_pad_if_asymmetric
(
prog
,
pads
,
va
r
_name
,
value_infos
):
# pads: SSEE
assert
len
(
pads
)
&
1
==
0
assert
len
(
pads
)
&
1
==
0
ndims
=
len
(
pads
)
//
2
ndims
=
len
(
pads
)
//
2
symmetric
=
True
symmetric
=
True
...
@@ -315,36 +298,29 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
...
@@ -315,36 +298,29 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
symmetric
=
False
symmetric
=
False
break
break
if
symmetric
:
if
symmetric
:
return
pads
[:
ndims
],
va
l
_name
return
pads
[:
ndims
],
va
r
_name
va
l_padded
=
val
_name
+
'_padded'
# explicit variable
va
r_padded
=
var
_name
+
'_padded'
# explicit variable
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Pad'
,
'Pad'
,
[
va
l
_name
],
[
va
r
_name
],
[
va
l_padded
],
# val
[
va
r_padded
],
{
{
'mode'
:
'constant'
,
'mode'
:
'constant'
,
'value'
:
0.
,
'value'
:
0.
,
'pads'
:
pads
,
'pads'
:
pads
,
},
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
va
l
_padded
,
name
=
va
r
_padded
,
)
)
return
[
0
]
*
ndims
,
va
l
_padded
return
[
0
]
*
ndims
,
va
r
_padded
def
_adaptive_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
name
=
''
):
def
_adaptive_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
var_x
,
=
inputs
val_y
,
=
outputs
[:
1
]
var_y
,
var_indices
=
(
outputs
+
[
None
]
*
1
)[:
2
]
var_x
=
_make_var_name
(
val_x
)
var_y
=
_make_var_name
(
val_y
)
has_indices
=
len
(
outputs
)
>
1
if
has_indices
:
val_indices
=
outputs
[
1
]
var_indices
=
_make_var_name
(
val_indices
)
# interpretation
# interpretation
pool_size
=
attrs
[
'output_size'
]
# required
pool_size
=
attrs
[
'output_size'
]
# required
...
@@ -361,28 +337,28 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
...
@@ -361,28 +337,28 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
', pool_type={}'
', pool_type={}'
'{})'
.
format
(
'{})'
.
format
(
var_y
,
var_y
,
', {}'
.
format
(
var_indices
)
if
has
_indices
else
''
,
', {}'
.
format
(
var_indices
)
if
var
_indices
else
''
,
fluid_op
,
fluid_op
,
var_x
,
var_x
,
# attrs
# attrs
has_indices
,
bool
(
var_indices
)
,
pool_size
,
pool_size
,
repr
(
pool_type
),
repr
(
pool_type
),
name_attr
,
name_attr
,
))
))
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
fluid_op
=
'pool{}d'
.
format
(
poolnd
)
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
if
has
_indices
:
if
var
_indices
:
prog
.
VarDesc
(
var_indices
)
prog
.
VarDesc
(
var_indices
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has
_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
var
_indices
else
[]),
'Out'
,
'Indices'
),
{
{
'global_pooling'
:
False
,
'global_pooling'
:
False
,
'adaptive'
:
True
,
'adaptive'
:
True
,
'exclusive'
:
True
,
'exclusive'
:
True
,
'require_index'
:
has_indices
,
'require_index'
:
bool
(
var_indices
)
,
'pooling_type'
:
pool_type
,
'pooling_type'
:
pool_type
,
'ksize'
:
pool_size
,
'ksize'
:
pool_size
,
},
},
...
@@ -391,14 +367,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
...
@@ -391,14 +367,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
def
_global_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
def
_global_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
var_x
,
=
inputs
val_y
,
=
outputs
var_y
,
=
outputs
var_x
=
_make_var_name
(
val_x
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
input_shape
=
_shape_or_none
(
value_infos
,
va
l
_x
)
input_shape
=
_shape_or_none
(
value_infos
,
va
r
_x
)
output_shape
=
_shape_or_none
(
value_infos
,
va
l
_y
)
output_shape
=
_shape_or_none
(
value_infos
,
va
r
_y
)
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
assert
input_shape
is
not
None
or
output_shape
is
not
None
,
'poolnd not inferred'
# NC...
if
input_shape
is
not
None
:
if
input_shape
is
not
None
:
poolnd
=
len
(
input_shape
)
-
2
# NC...
poolnd
=
len
(
input_shape
)
-
2
# NC...
...
@@ -436,14 +410,8 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -436,14 +410,8 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
def
_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
def
_pool
(
prog
,
pool_type
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
# I/O
# I/O
val_x
,
=
inputs
var_x
,
=
inputs
val_y
,
=
outputs
[:
1
]
var_y
,
var_indices
=
(
outputs
+
[
None
]
*
1
)[:
2
]
var_y
=
_make_var_name
(
val_y
)
has_indices
=
len
(
outputs
)
>
1
if
has_indices
:
val_indices
=
outputs
[
1
]
var_indices
=
_make_var_name
(
val_indices
)
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
...
@@ -457,8 +425,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -457,8 +425,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
poolnd
)
# optional
ceil_mode
=
bool
(
attrs
.
get
(
'ceil_mode'
,
0
))
# optional
ceil_mode
=
bool
(
attrs
.
get
(
'ceil_mode'
,
0
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
poolnd
*
2
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
poolnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
var_x
=
_pad_if_asymmetric
(
prog
,
pads
,
var_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
...
@@ -481,17 +448,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -481,17 +448,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
if
has
_indices
:
if
var
_indices
:
prog
.
VarDesc
(
var_indices
)
prog
.
VarDesc
(
var_indices
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
],
'X'
),
([
var_x
],
'X'
),
([
var_y
]
+
([
var_indices
]
if
has
_indices
else
[]),
'Out'
,
'Indices'
),
([
var_y
]
+
([
var_indices
]
if
var
_indices
else
[]),
'Out'
,
'Indices'
),
{
{
'global_pooling'
:
False
,
'global_pooling'
:
False
,
'adaptive'
:
False
,
'adaptive'
:
False
,
'exclusive'
:
True
,
'exclusive'
:
True
,
'require_index'
:
has_indices
,
'require_index'
:
bool
(
var_indices
)
,
'pooling_type'
:
pool_type
,
'pooling_type'
:
pool_type
,
'ksize'
:
pool_size
,
'ksize'
:
pool_size
,
'strides'
:
strides
,
'strides'
:
strides
,
...
@@ -503,11 +470,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
...
@@ -503,11 +470,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
def
_roi_pool
(
prog
,
fluid_op
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
):
def
_roi_pool
(
prog
,
fluid_op
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
):
# I/O
# I/O
val_x
,
val_rois
=
inputs
var_x
,
var_rois
=
inputs
val_y
,
=
outputs
var_y
,
=
outputs
var_x
=
_make_var_name
(
val_x
)
var_rois
=
_make_var_name
(
val_rois
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
spatial_scale
=
attrs
[
'spatial_scale'
]
# required
...
@@ -536,7 +500,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
...
@@ -536,7 +500,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
'{})'
.
format
(
'{})'
.
format
(
var_y
,
var_y
,
fluid_op
,
fluid_op
,
va
l
_x
,
va
r
_x
,
var_rois
,
var_rois
,
# attrs
# attrs
spatial_scale
,
spatial_scale
,
...
@@ -546,7 +510,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
...
@@ -546,7 +510,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
))
))
prog
.
VarDesc
(
var_y
)
prog
.
VarDesc
(
var_y
)
if
is_max_pool
:
if
is_max_pool
:
var_argmax
=
_make_var_name
(
name
+
'.argmax'
)
# hidden variable
var_argmax
=
name
+
'.argmax'
# hidden variable
prog
.
VarDesc
(
var_argmax
)
prog
.
VarDesc
(
var_argmax
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
...
@@ -558,19 +522,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
...
@@ -558,19 +522,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
def
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
def
_interpolate
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
):
# I/O
# I/O
val_x
,
val_scales
=
inputs
var_x
,
var_scales
=
inputs
val_y
,
=
outputs
var_y
,
=
outputs
var_x
=
_make_var_name
(
val_x
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
# output shape
# output shape
out_shape_
=
_shape_or_none
(
value_infos
,
va
l
_y
)
out_shape_
=
_shape_or_none
(
value_infos
,
va
r
_y
)
if
out_shape_
is
not
None
:
if
out_shape_
is
not
None
:
assert
len
(
out_shape_
)
==
4
,
'only 4-D Tensor as X and Y supported'
assert
len
(
out_shape_
)
==
4
,
'only 4-D Tensor as X and Y supported'
out_shape_
=
out_shape_
[
2
:]
out_shape_
=
out_shape_
[
2
:]
# try scales
# try scales
scales
=
_const_weight_or_none
(
value_infos
,
va
l
_scales
)
scales
=
_const_weight_or_none
(
value_infos
,
va
r
_scales
)
if
scales
is
not
None
:
if
scales
is
not
None
:
assert
len
(
scales
)
==
4
,
'only 4-D Tensor as X and Y supported'
assert
len
(
scales
)
==
4
,
'only 4-D Tensor as X and Y supported'
assert
scales
[
0
]
==
1
and
scales
[
assert
scales
[
0
]
==
1
and
scales
[
...
@@ -585,7 +547,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
...
@@ -585,7 +547,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
else
:
else
:
out_shape
=
None
out_shape
=
None
if
out_shape_
is
None
:
if
out_shape_
is
None
:
in_shape
=
_shape_or_none
(
value_infos
,
va
l
_x
)
in_shape
=
_shape_or_none
(
value_infos
,
va
r
_x
)
assert
in_shape
is
not
None
,
'out_shape required but not inferrable'
assert
in_shape
is
not
None
,
'out_shape required but not inferrable'
assert
len
(
in_shape
)
==
4
,
'only 4-D Tensor as X and Y supported'
assert
len
(
in_shape
)
==
4
,
'only 4-D Tensor as X and Y supported'
out_shape_
=
[
in_shape
[
2
]
*
scale
,
in_shape
[
3
]
*
scale
]
out_shape_
=
[
in_shape
[
2
]
*
scale
,
in_shape
[
3
]
*
scale
]
...
@@ -642,10 +604,8 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -642,10 +604,8 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
"""
# I/O
# I/O
val_theta
,
=
inputs
var_theta
,
=
inputs
val_grid
,
=
outputs
var_grid
,
=
outputs
var_theta
=
_make_var_name
(
val_theta
)
var_grid
=
_make_var_name
(
val_grid
)
# interpretation
# interpretation
fluid_op
=
'affine_grid'
fluid_op
=
'affine_grid'
...
@@ -701,10 +661,8 @@ def BatchNormalization(prog,
...
@@ -701,10 +661,8 @@ def BatchNormalization(prog,
"""
"""
# I/O
# I/O
val_x
,
val_scale
,
val_b
,
val_mean
,
val_var
=
inputs
var_x
,
var_scale
,
var_b
,
var_mean
,
var_var
=
inputs
val_y
,
=
outputs
var_y
,
=
outputs
var_x
=
_make_var_name
(
val_x
)
var_y
=
_make_var_name
(
val_y
)
var_saved_mean
=
name
+
'.saved_mean'
# dummy output
var_saved_mean
=
name
+
'.saved_mean'
# dummy output
var_saved_variance
=
name
+
'.saved_variance'
# dummy output
var_saved_variance
=
name
+
'.saved_variance'
# dummy output
...
@@ -714,28 +672,28 @@ def BatchNormalization(prog,
...
@@ -714,28 +672,28 @@ def BatchNormalization(prog,
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
epsilon
=
attrs
.
get
(
'epsilon'
,
1e-5
)
# optional
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
va
l_scale
,
val
_b
,
embed_params
=
_check_embeddable
(
value_infos
,
va
r_scale
,
var
_b
,
va
l_mean
,
val
_var
)
va
r_mean
,
var
_var
)
if
not
embed_params
and
name
:
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> BatchNormalization -> %s)'
,
name
,
_logger
.
warning
(
'for op %s(%s -> BatchNormalization -> %s)'
,
name
,
inputs
,
outputs
)
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_scale
=
name
+
'.w_0'
embedded_scale
=
name
+
'.w_0'
var_b
=
name
+
'.b_0'
embedded_b
=
name
+
'.b_0'
var_mean
=
name
+
'.w_1'
embedded_mean
=
name
+
'.w_1'
var_var
=
name
+
'.w_2'
embedded_var
=
name
+
'.w_2'
value_infos
[
val_scale
][
'embeded_as'
].
append
(
var_scale
)
value_infos
[
var_scale
][
'embedded_as'
].
append
(
embedded_scale
)
value_infos
[
val_b
][
'embeded_as'
].
append
(
var_b
)
value_infos
[
var_b
][
'embedded_as'
].
append
(
embedded_b
)
value_infos
[
val_mean
][
'embeded_as'
].
append
(
var_mean
)
value_infos
[
var_mean
][
'embedded_as'
].
append
(
embedded_mean
)
value_infos
[
val_var
][
'embeded_as'
].
append
(
var_var
)
value_infos
[
var_var
][
'embedded_as'
].
append
(
embedded_var
)
var_scale
=
embedded_scale
var_b
=
embedded_b
var_mean
=
embedded_mean
var_var
=
embedded_var
param_attr
=
''
param_attr
=
''
else
:
else
:
var_scale
=
_make_var_name
(
val_scale
)
var_b
=
_make_var_name
(
val_b
)
var_mean
=
_make_var_name
(
val_mean
)
var_var
=
_make_var_name
(
val_var
)
param_attr
=
(
', param_attr={}, bias_attr={}'
param_attr
=
(
', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}'
).
format
(
', moving_mean_name={}, moving_variance_name={}'
).
format
(
repr
(
var_scale
),
repr
(
var_b
),
repr
(
var_mean
),
repr
(
var_scale
),
repr
(
var_b
),
repr
(
var_mean
),
...
@@ -780,16 +738,14 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -780,16 +738,14 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
"""
# I/O
# I/O
val_input
,
=
inputs
var_input
,
=
inputs
val_output
,
=
outputs
var_output
,
=
outputs
var_input
=
_make_var_name
(
val_input
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
dtype
=
attrs
[
'to'
]
# required
dtype
=
attrs
[
'to'
]
# required
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
if
not
isinstance
(
dtype
,
_np
.
dtype
):
# additional: possible np.dtype
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
dtype
=
TENSOR_TYPE_TO_NP_TYPE
[
dtype
]
output_dtype
=
_dtype_or_none
(
value_infos
,
va
l
_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
va
r
_output
)
if
output_dtype
is
not
None
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
assert
dtype
==
output_dtype
,
'dtype of to unmatches output'
...
@@ -812,7 +768,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -812,7 +768,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
([
var_output
],
'Out'
),
([
var_output
],
'Out'
),
{
{
'in_dtype'
:
prog
.
Dtype
(
_dtype
(
value_infos
,
'in_dtype'
:
prog
.
Dtype
(
_dtype
(
value_infos
,
va
l
_input
)),
# holy, required
va
r
_input
)),
# holy, required
'out_dtype'
:
prog
.
Dtype
(
dtype
),
'out_dtype'
:
prog
.
Dtype
(
dtype
),
},
},
)
)
...
@@ -824,9 +780,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -824,9 +780,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
"""
# I/O
# I/O
val_concat_result
,
=
outputs
var_ret
,
=
outputs
var_inps
=
[
_make_var_name
(
val
)
for
val
in
inputs
]
var_concat_result
=
_make_var_name
(
val_concat_result
)
# interpretation
# interpretation
fluid_op
=
'concat'
fluid_op
=
'concat'
...
@@ -837,18 +791,18 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -837,18 +791,18 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
', axis={}'
', axis={}'
'{})'
.
format
(
'{})'
.
format
(
var_
concat_resul
t
,
var_
re
t
,
fluid_op
,
fluid_op
,
'['
+
', '
.
join
(
var_inp
s
)
+
']'
,
'['
+
', '
.
join
(
input
s
)
+
']'
,
# attrs
# attrs
axis
,
axis
,
name_attr
,
name_attr
,
))
))
prog
.
VarDesc
(
var_
concat_resul
t
)
prog
.
VarDesc
(
var_
re
t
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
(
var_inps
,
*
([
'X'
]
*
len
(
var_inp
s
))),
(
inputs
,
*
([
'X'
]
*
len
(
input
s
))),
([
var_
concat_resul
t
],
'Out'
),
([
var_
re
t
],
'Out'
),
{
'axis'
:
axis
},
{
'axis'
:
axis
},
)
)
...
@@ -860,13 +814,12 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -860,13 +814,12 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O
# I/O
assert
len
(
inputs
)
==
0
,
'constant op accept no inputs'
assert
len
(
inputs
)
==
0
,
'constant op accept no inputs'
val_output
,
=
outputs
var_output
,
=
outputs
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
value
=
attrs
[
'value'
]
# required
value
=
attrs
[
'value'
]
# required
dtype
=
_np
.
dtype
(
value
.
dtype
)
dtype
=
_np
.
dtype
(
value
.
dtype
)
output_dtype
=
_dtype_or_none
(
value_infos
,
va
l
_output
)
output_dtype
=
_dtype_or_none
(
value_infos
,
va
r
_output
)
if
output_dtype
is
not
None
:
if
output_dtype
is
not
None
:
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
assert
dtype
==
output_dtype
,
'tensor dtype unmatches storage dtype'
...
@@ -874,13 +827,13 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -874,13 +827,13 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# dtype = _np.dtype('float32') # HINT: force to float32
# dtype = _np.dtype('float32') # HINT: force to float32
shape
=
attrs
.
get
(
'shape'
,
None
)
#
shape
=
attrs
.
get
(
'shape'
,
None
)
#
if
shape
is
None
:
if
shape
is
None
:
shape
=
_shape_or_none
(
value_infos
,
va
l
_output
)
shape
=
_shape_or_none
(
value_infos
,
va
r
_output
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
list
(
value
.
shape
)
shape
=
list
(
value
.
shape
)
_logger
.
warning
(
_logger
.
warning
(
'in op (Constant -> %s): '
'in op (Constant -> %s): '
'attribute "shape" of %s not inferred, '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails'
,
outputs
,
va
l
_output
)
'using value as 1-D tensor may lead to fails'
,
outputs
,
va
r
_output
)
# generation
# generation
value
=
value
.
tolist
()
value
=
value
.
tolist
()
...
@@ -911,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -911,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog
.
Code
(
'# {} = {} # passed directly as literal'
.
format
(
prog
.
Code
(
'# {} = {} # passed directly as literal'
.
format
(
var_output
,
value
))
var_output
,
value
))
value_infos
[
va
l
_output
][
'const_value'
]
=
value
value_infos
[
va
r
_output
][
'const_value'
]
=
value
def
ConstantOfShape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
def
ConstantOfShape
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
...
@@ -920,13 +873,12 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -920,13 +873,12 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
"""
# I/O
# I/O
val_shape
,
=
inputs
var_shape
,
=
inputs
val_output
,
=
outputs
var_output
,
=
outputs
var_shape
=
_make_var_name
(
val_shape
)
shape
=
_const_weight_or_none
(
value_infos
,
va
l
_shape
)
shape
=
_const_weight_or_none
(
value_infos
,
va
r
_shape
)
if
shape
is
None
:
if
shape
is
None
:
shape
=
_shape_or_none
(
value_infos
,
va
l
_output
)
shape
=
_shape_or_none
(
value_infos
,
va
r
_output
)
assert
shape
is
not
None
,
(
assert
shape
is
not
None
,
(
'given shape is neither const value nor deductible from output, '
'given shape is neither const value nor deductible from output, '
'this is not supported'
)
'this is not supported'
)
...
@@ -959,53 +911,47 @@ def Conv(prog,
...
@@ -959,53 +911,47 @@ def Conv(prog,
"""
"""
# I/O
# I/O
val_x
,
val_w
=
inputs
[:
2
]
var_x
,
var_w
=
inputs
[:
2
]
val_y
,
=
outputs
var_y
,
var_b
=
(
outputs
+
[
None
]
*
1
)[:
2
]
var_y
=
_make_var_name
(
val_y
)
has_bias
=
len
(
inputs
)
==
3
if
has_bias
:
val_b
,
=
inputs
[
2
:]
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
'auto_pad'
,
'NOTSET'
'auto_pad'
,
'NOTSET'
)
==
'NOTSET'
,
'only auto_pad == NOTSET is supported'
# optional
)
==
'NOTSET'
,
'only auto_pad == NOTSET is supported'
# optional
kernel_shape
=
_shape
(
value_infos
,
va
l
_w
)[
2
:]
# OI...
kernel_shape
=
_shape
(
value_infos
,
va
r
_w
)[
2
:]
# OI...
assert
kernel_shape
==
attrs
[
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
convnd
=
len
(
kernel_shape
)
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d and conv3d is supported'
assert
2
<=
convnd
<=
3
,
'only conv2d and conv3d is supported'
num_out_channels
=
_shape
(
value_infos
,
va
l
_w
)[
0
]
# OI...
num_out_channels
=
_shape
(
value_infos
,
va
r
_w
)[
0
]
# OI...
fluid_op
=
'conv{}d'
.
format
(
convnd
)
fluid_op
=
'conv{}d'
.
format
(
convnd
)
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
var_x
=
_pad_if_asymmetric
(
prog
,
pads
,
var_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
va
l_w
)
and
not
has_bias
embed_params
=
(
_check_embeddable
(
value_infos
,
va
r_w
)
and
not
var_b
or
_check_embeddable
(
value_infos
,
va
l
_b
))
or
_check_embeddable
(
value_infos
,
va
r
_b
))
if
not
embed_params
and
name
:
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> Conv -> %s)'
,
name
,
inputs
,
_logger
.
warning
(
'for op %s(%s -> Conv -> %s)'
,
name
,
inputs
,
outputs
)
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
embedded_w
=
name
+
'.w_0'
value_infos
[
val_w
][
'embeded_as'
].
append
(
var_w
)
value_infos
[
var_w
][
'embedded_as'
].
append
(
embedded_w
)
if
has_bias
:
var_w
=
embedded_w
var_b
=
name
+
'.b_0'
if
var_b
:
value_infos
[
val_b
][
'embeded_as'
].
append
(
var_b
)
embedded_b
=
name
+
'.b_0'
value_infos
[
var_b
][
'embedded_as'
].
append
(
embedded_b
)
var_b
=
embedded_b
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
else
:
else
:
var_w
=
_make_var_name
(
val_w
)
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
repr
(
var_w
),
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
repr
(
var_b
)
if
var_b
else
False
)
...
@@ -1036,7 +982,7 @@ def Conv(prog,
...
@@ -1036,7 +982,7 @@ def Conv(prog,
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
var_b
else
var_y
],
'Output'
),
{
{
'strides'
:
strides
,
'strides'
:
strides
,
'paddings'
:
paddings
,
'paddings'
:
paddings
,
...
@@ -1044,13 +990,13 @@ def Conv(prog,
...
@@ -1044,13 +990,13 @@ def Conv(prog,
'groups'
:
num_groups
,
'groups'
:
num_groups
,
},
},
)
)
if
has_bias
:
if
var_b
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
''
,
''
,
'Add'
,
'Add'
,
[
var_conv
,
var_b
],
#
[
var_conv
,
var_b
],
#
[
va
l
_y
],
[
va
r
_y
],
{
'axis'
:
1
},
{
'axis'
:
1
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'.bias'
),
name
=
(
name
+
'.bias'
),
...
@@ -1073,13 +1019,8 @@ def ConvTranspose(prog,
...
@@ -1073,13 +1019,8 @@ def ConvTranspose(prog,
"""
"""
# I/O
# I/O
val_x
,
val_w
=
inputs
[:
2
]
var_x
,
var_w
=
inputs
[:
2
]
val_y
,
=
outputs
var_y
,
var_b
=
(
outputs
+
[
None
]
*
1
)[:
2
]
var_y
=
_make_var_name
(
val_y
)
has_bias
=
len
(
inputs
)
==
3
if
has_bias
:
val_b
,
=
inputs
[
2
:]
# interpretation
# interpretation
assert
attrs
.
get
(
assert
attrs
.
get
(
...
@@ -1088,41 +1029,40 @@ def ConvTranspose(prog,
...
@@ -1088,41 +1029,40 @@ def ConvTranspose(prog,
assert
sum
(
attrs
.
get
(
assert
sum
(
attrs
.
get
(
'output_padding'
,
'output_padding'
,
[]))
==
0
,
'only zero output_padding is supported'
# optional ?
[]))
==
0
,
'only zero output_padding is supported'
# optional ?
kernel_shape
=
_shape
(
value_infos
,
va
l
_w
)[
2
:]
# IO...
kernel_shape
=
_shape
(
value_infos
,
va
r
_w
)[
2
:]
# IO...
assert
kernel_shape
==
attrs
[
assert
kernel_shape
==
attrs
[
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
'kernel_shape'
],
'kernel_shape in attr unmatches value_info'
# HW
convnd
=
len
(
kernel_shape
)
convnd
=
len
(
kernel_shape
)
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose is supported'
assert
2
<=
convnd
<=
3
,
'only conv2d_transpose and conv3d_transpose is supported'
num_out_channels
=
_shape
(
value_infos
,
va
l
_w
)[
1
]
# IO...
num_out_channels
=
_shape
(
value_infos
,
va
r
_w
)[
1
]
# IO...
fluid_op
=
'conv{}d_transpose'
.
format
(
convnd
)
fluid_op
=
'conv{}d_transpose'
.
format
(
convnd
)
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
num_groups
=
attrs
.
get
(
'group'
,
1
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
strides
=
attrs
.
get
(
'strides'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
dilations
=
attrs
.
get
(
'dilations'
,
[
1
]
*
convnd
)
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
pads
=
attrs
.
get
(
'pads'
,
[
0
]
*
(
convnd
*
2
))
# optional
paddings
,
val_x
=
_pad_if_asymmetric
(
prog
,
pads
,
val_x
,
value_infos
)
paddings
,
var_x
=
_pad_if_asymmetric
(
prog
,
pads
,
var_x
,
value_infos
)
var_x
=
_make_var_name
(
val_x
)
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
embed_params
=
(
_check_embeddable
(
value_infos
,
va
l_w
)
and
not
has_bias
embed_params
=
(
_check_embeddable
(
value_infos
,
va
r_w
)
and
not
var_b
or
_check_embeddable
(
value_infos
,
va
l
_b
))
or
_check_embeddable
(
value_infos
,
va
r
_b
))
if
not
embed_params
and
name
:
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> ConvTranspose -> %s)'
,
name
,
_logger
.
warning
(
'for op %s(%s -> ConvTranspose -> %s)'
,
name
,
inputs
,
outputs
)
inputs
,
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_w
=
name
+
'.w_0'
embedded_w
=
name
+
'.w_0'
value_infos
[
val_w
][
'embeded_as'
].
append
(
var_w
)
value_infos
[
var_w
][
'embedded_as'
].
append
(
embedded_w
)
if
has_bias
:
var_w
=
embedded_w
var_b
=
name
+
'.b_0'
if
var_b
:
value_infos
[
val_b
][
'embeded_as'
].
append
(
var_b
)
embedded_b
=
name
+
'.b_0'
value_infos
[
var_b
][
'embedded_as'
].
append
(
embedded_b
)
var_b
=
embedded_b
param_attr
=
''
param_attr
=
''
else
:
else
:
param_attr
=
', bias_attr=False'
param_attr
=
', bias_attr=False'
else
:
else
:
var_w
=
_make_var_name
(
val_w
)
var_b
=
_make_var_name
(
val_b
)
if
has_bias
else
False
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
param_attr
=
', param_attr={}, bias_attr={}'
.
format
(
repr
(
var_w
),
repr
(
var_w
),
repr
(
var_b
)
if
var_b
else
False
)
repr
(
var_b
)
if
var_b
else
False
)
...
@@ -1154,7 +1094,7 @@ def ConvTranspose(prog,
...
@@ -1154,7 +1094,7 @@ def ConvTranspose(prog,
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_x
,
var_w
],
'Input'
,
'Filter'
),
# , 'Bias', 'ResidualData'
([
var_conv
if
has_bias
else
var_y
],
'Output'
),
([
var_conv
if
var_b
else
var_y
],
'Output'
),
{
{
'strides'
:
strides
,
'strides'
:
strides
,
'paddings'
:
paddings
,
'paddings'
:
paddings
,
...
@@ -1163,13 +1103,13 @@ def ConvTranspose(prog,
...
@@ -1163,13 +1103,13 @@ def ConvTranspose(prog,
'groups'
:
num_groups
,
'groups'
:
num_groups
,
},
},
)
)
if
has_bias
:
if
var_b
:
prog
.
VarDesc
(
var_conv
)
prog
.
VarDesc
(
var_conv
)
prog
.
IntermediateOp
(
prog
.
IntermediateOp
(
''
,
''
,
'Add'
,
'Add'
,
[
var_conv
,
var_b
],
#
[
var_conv
,
var_b
],
#
[
va
l
_y
],
[
va
r
_y
],
{
'axis'
:
1
},
{
'axis'
:
1
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'.bias'
),
name
=
(
name
+
'.bias'
),
...
@@ -1184,27 +1124,27 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1184,27 +1124,27 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
"""
# due to fluid fc don't support transposed weight, we use matmul + ew_add
# due to fluid fc don't support transposed weight, we use matmul + ew_add
va
l_a
,
val_b
,
val
_c
=
inputs
va
r_a
,
var_b
,
var
_c
=
inputs
va
l
_y
,
=
outputs
va
r
_y
,
=
outputs
alpha
=
attrs
.
get
(
'alpha'
,
1.
)
# optional
alpha
=
attrs
.
get
(
'alpha'
,
1.
)
# optional
beta
=
attrs
.
get
(
'beta'
,
1.
)
# optional
beta
=
attrs
.
get
(
'beta'
,
1.
)
# optional
trans_a
=
bool
(
attrs
.
get
(
'transA'
,
0
))
# optional
trans_a
=
bool
(
attrs
.
get
(
'transA'
,
0
))
# optional
trans_b
=
bool
(
attrs
.
get
(
'transB'
,
0
))
# optional
trans_b
=
bool
(
attrs
.
get
(
'transB'
,
0
))
# optional
va
l
_mm
=
name
+
'_mm'
# explicit variable
va
r
_mm
=
name
+
'_mm'
# explicit variable
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'MatMul'
,
'MatMul'
,
[
va
l_a
,
val
_b
],
[
va
r_a
,
var
_b
],
[
va
l
_mm
],
# val
[
va
r
_mm
],
# val
{
{
'transpose_x'
:
trans_a
,
'transpose_x'
:
trans_a
,
'transpose_y'
:
trans_b
,
'transpose_y'
:
trans_b
,
'alpha'
:
alpha
,
'alpha'
:
alpha
,
},
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
va
l
_mm
,
name
=
va
r
_mm
,
)
)
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
op_descs
[
-
1
].
attrs
.
extend
(
prog
.
OpDescAttrs
({
prog
.
OpDescAttrs
({
...
@@ -1216,17 +1156,17 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1216,17 +1156,17 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Add'
,
'Add'
,
[
va
l_mm
,
val
_c
],
[
va
r_mm
,
var
_c
],
[
va
l
_y
],
# val
[
va
r
_y
],
# val
{
'axis'
:
1
},
{
'axis'
:
1
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_beta'
),
name
=
(
name
+
'_beta'
),
)
)
else
:
else
:
va
l
_beta
=
name
+
'_beta'
# explicit variable
va
r
_beta
=
name
+
'_beta'
# explicit variable
va
l
_vm
=
name
+
'_vm'
# explicit variable
va
r
_vm
=
name
+
'_vm'
# explicit variable
if
beta
.
is_integer
():
if
beta
.
is_integer
():
vm_dtype
=
_dtype_or_none
(
value_infos
,
va
l
_c
)
vm_dtype
=
_dtype_or_none
(
value_infos
,
va
r
_c
)
if
vm_dtype
is
None
:
if
vm_dtype
is
None
:
vm_dtype
=
_np
.
dtype
(
'float32'
)
vm_dtype
=
_np
.
dtype
(
'float32'
)
_logger
.
warning
(
_logger
.
warning
(
...
@@ -1239,16 +1179,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1239,16 +1179,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
''
,
''
,
'Constant'
,
'Constant'
,
[],
[],
[
va
l
_beta
],
# val
[
va
r
_beta
],
# val
{
'value'
:
beta
},
{
'value'
:
beta
},
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
va
l
_beta
,
name
=
va
r
_beta
,
)
)
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Mul'
,
'Mul'
,
[
va
l_c
,
val
_beta
],
[
va
r_c
,
var
_beta
],
[
va
l
_vm
],
# val
[
va
r
_vm
],
# val
dict
(),
dict
(),
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_scale'
),
name
=
(
name
+
'_scale'
),
...
@@ -1256,8 +1196,8 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1256,8 +1196,8 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Add'
,
'Add'
,
[
va
l_mm
,
val
_vm
],
[
va
r_mm
,
var
_vm
],
[
va
l
_y
],
# val
[
va
r
_y
],
# val
{
'axis'
:
1
},
{
'axis'
:
1
},
name
=
(
name
+
'_bias'
),
name
=
(
name
+
'_bias'
),
)
)
...
@@ -1305,6 +1245,64 @@ def GlobalMaxPool(prog,
...
@@ -1305,6 +1245,64 @@ def GlobalMaxPool(prog,
name
=
name
)
name
=
name
)
def
GRU
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
*
args
,
**
kwargs
):
"""
onnx::GRU-7:
"""
var_x
,
var_w
,
var_r
,
var_b
,
var_len
,
var_xh
=
(
inputs
+
[
None
]
*
3
)[:
6
]
var_y
,
var_yh
=
(
outputs
+
[
None
]
*
2
)[:
2
]
# interpretation
fluid_op
=
'gru_unit'
param_attr
=
''
# generation
prog
.
Code
(
'{}, _, {} = layers.{}({}, {}, {}'
'{})'
.
format
(
var_yh
,
var_y
,
fluid_op
,
var_x
,
var_xh
,
0
,
param_attr
,
))
# raise NotImplementedError()
def
LSTM
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
"""
onnx::LSTM-7:
"""
var_x
,
var_w
,
var_r
,
var_b
,
var_len
,
var_xh
,
var_xc
,
var_p
=
(
inputs
+
[
None
]
*
5
)[:
8
]
var_y
,
var_yh
,
var_yc
=
(
outputs
+
[
None
]
*
3
)[:
3
]
# interpretation
fluid_op
=
'lstm_unit'
param_attr
=
''
# generation
prog
.
Code
(
'{}, {}, {} = layers.{}({}, {}, {}'
'{})'
.
format
(
var_y
,
var_yh
,
var_yc
,
fluid_op
,
var_x
,
var_xh
,
var_xc
,
param_attr
,
))
# raise NotImplementedError()
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
def
MaxPool
(
prog
,
inputs
,
outputs
,
attrs
,
value_infos
,
name
=
''
,
*
args
,
**
kwargs
):
**
kwargs
):
"""
"""
...
@@ -1329,17 +1327,15 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1329,17 +1327,15 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
"""
# I/O
# I/O
val_data
,
=
inputs
var_data
,
=
inputs
val_output
,
=
outputs
var_output
,
=
outputs
var_data
=
_make_var_name
(
val_data
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
pads
=
attrs
[
'pads'
]
# required
pads
=
attrs
[
'pads'
]
# required
mode
=
attrs
.
get
(
'mode'
,
'constant'
)
# optional
mode
=
attrs
.
get
(
'mode'
,
'constant'
)
# optional
value
=
attrs
.
get
(
'value'
,
0.
)
# optional
value
=
attrs
.
get
(
'value'
,
0.
)
# optional
data_shape
=
_shape_or_none
(
value_infos
,
va
l
_data
)
data_shape
=
_shape_or_none
(
value_infos
,
va
r
_data
)
output_shape
=
_shape_or_none
(
value_infos
,
va
l
_output
)
output_shape
=
_shape_or_none
(
value_infos
,
va
r
_output
)
assume_pad2d
=
False
assume_pad2d
=
False
if
len
(
pads
)
==
4
:
if
len
(
pads
)
==
4
:
assume_pad2d
|=
mode
!=
'constant'
assume_pad2d
|=
mode
!=
'constant'
...
@@ -1400,14 +1396,12 @@ def PRelu(prog,
...
@@ -1400,14 +1396,12 @@ def PRelu(prog,
"""
"""
# I/O
# I/O
val_x
,
val_slope
=
inputs
var_x
,
var_slope
=
inputs
val_y
,
=
outputs
var_y
,
=
outputs
var_x
=
_make_var_name
(
val_x
)
var_y
=
_make_var_name
(
val_y
)
# interpretation
# interpretation
mode
=
'channel'
mode
=
'channel'
slope_shape
=
_shape_or_none
(
value_infos
,
va
l
_slope
)
slope_shape
=
_shape_or_none
(
value_infos
,
va
r
_slope
)
if
slope_shape
is
not
None
:
if
slope_shape
is
not
None
:
if
len
(
slope_shape
)
==
0
:
if
len
(
slope_shape
)
==
0
:
mode
=
'all'
mode
=
'all'
...
@@ -1418,18 +1412,18 @@ def PRelu(prog,
...
@@ -1418,18 +1412,18 @@ def PRelu(prog,
fluid_op
=
'prelu'
fluid_op
=
'prelu'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
if
embed_params
:
if
embed_params
:
embed_params
=
_check_embeddable
(
value_infos
,
va
l
_slope
)
embed_params
=
_check_embeddable
(
value_infos
,
va
r
_slope
)
if
not
embed_params
and
name
:
if
not
embed_params
and
name
:
_logger
.
warning
(
'for op %s(%s -> PRelu -> %s)'
,
name
,
inputs
,
_logger
.
warning
(
'for op %s(%s -> PRelu -> %s)'
,
name
,
inputs
,
outputs
)
outputs
)
_logger
.
warning
(
'broken Python code will be generated'
)
_logger
.
warning
(
'broken Python code will be generated'
)
if
embed_params
:
if
embed_params
:
assert
name
!=
''
assert
name
!=
''
var_slope
=
name
+
'.w_0'
embedded_slope
=
name
+
'.w_0'
value_infos
[
val_slope
][
'embeded_as'
].
append
(
var_slope
)
value_infos
[
var_slope
][
'embedded_as'
].
append
(
embedded_slope
)
var_slope
=
embedded_slope
param_attr
=
''
param_attr
=
''
else
:
else
:
var_slope
=
_make_var_name
(
val_slope
)
param_attr
=
', param_attr={}'
.
format
(
repr
(
var_slope
))
param_attr
=
', param_attr={}'
.
format
(
repr
(
var_slope
))
# generation
# generation
...
@@ -1467,17 +1461,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1467,17 +1461,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
"""
"""
# I/O
# I/O
val_data
,
val_shape
=
inputs
var_data
,
var_shape
=
inputs
val_reshaped
,
=
outputs
var_reshaped
,
=
outputs
var_data
=
_make_var_name
(
val_data
)
var_shape
=
_make_var_name
(
val_shape
)
var_reshaped
=
_make_var_name
(
val_reshaped
)
# interpretation
# interpretation
shape
=
_const_weight_or_none
(
value_infos
,
va
l
_shape
)
shape
=
_const_weight_or_none
(
value_infos
,
va
r
_shape
)
is_const_shape
=
shape
and
'const_value'
in
value_infos
[
va
l
_shape
]
is_const_shape
=
shape
and
'const_value'
in
value_infos
[
va
r
_shape
]
if
shape
is
None
:
if
shape
is
None
:
shape
=
_shape_or_none
(
value_infos
,
va
l
_reshaped
)
shape
=
_shape_or_none
(
value_infos
,
va
r
_reshaped
)
# assert shape is not None, ('given shape is neither const value nor deductible from output, '
# assert shape is not None, ('given shape is neither const value nor deductible from output, '
...
@@ -1493,8 +1484,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1493,8 +1484,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
# generation
# generation
val_shape_int32
=
val_shape
+
'_int32'
# explicit variable
var_shape_int32
=
var_shape
+
'_int32'
# explicit variable
var_shape_int32
=
_make_var_name
(
val_shape_int32
)
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
prog
.
Code
(
'# shape:{}={} # const as literal'
.
format
(
var_shape
,
shape
))
if
is_const_shape
:
if
is_const_shape
:
prog
.
Code
(
'{} = layers.{}({}'
prog
.
Code
(
'{} = layers.{}({}'
...
@@ -1511,8 +1501,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
...
@@ -1511,8 +1501,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog
.
Op
(
prog
.
Op
(
''
,
''
,
'Cast'
,
'Cast'
,
[
va
l
_shape
],
[
va
r
_shape
],
[
va
l
_shape_int32
],
# var
[
va
r
_shape_int32
],
# var
{
'to'
:
_np
.
dtype
(
'int32'
)},
# use np.dtype
{
'to'
:
_np
.
dtype
(
'int32'
)},
# use np.dtype
value_infos
=
value_infos
,
value_infos
=
value_infos
,
name
=
(
name
+
'_cast'
),
name
=
(
name
+
'_cast'
),
...
@@ -1595,17 +1585,15 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...
@@ -1595,17 +1585,15 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
"""
# I/O
# I/O
val_data
,
=
inputs
var_data
,
=
inputs
val_output
,
=
outputs
var_output
,
=
outputs
var_data
=
_make_var_name
(
val_data
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
fluid_op
=
'slice'
fluid_op
=
'slice'
axes
=
attrs
[
'axes'
]
# required
axes
=
attrs
[
'axes'
]
# required
starts
=
attrs
[
'starts'
]
# required
starts
=
attrs
[
'starts'
]
# required
ends
=
attrs
[
'ends'
]
# required
ends
=
attrs
[
'ends'
]
# required
shape
=
_shape_or_none
(
value_infos
,
va
l
_data
)
shape
=
_shape_or_none
(
value_infos
,
va
r
_data
)
if
shape
is
not
None
:
if
shape
is
not
None
:
# ndims = len(shape)
# ndims = len(shape)
# for idx, value in enumerate(axes):
# for idx, value in enumerate(axes):
...
@@ -1654,9 +1642,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1654,9 +1642,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
"""
# I/O
# I/O
val_input
,
=
inputs
var_input
,
=
inputs
var_outs
=
[
_make_var_name
(
val
)
for
val
in
outputs
]
var_input
=
_make_var_name
(
val_input
)
# interpretation
# interpretation
fluid_op
=
'split'
fluid_op
=
'split'
...
@@ -1668,7 +1654,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1668,7 +1654,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog
.
Code
(
'{} = layers.{}({}, {}'
prog
.
Code
(
'{} = layers.{}({}, {}'
', dim={}'
', dim={}'
'{})'
.
format
(
'{})'
.
format
(
', '
.
join
(
var_o
uts
),
', '
.
join
(
outp
uts
),
fluid_op
,
fluid_op
,
var_input
,
var_input
,
split
,
split
,
...
@@ -1676,12 +1662,12 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1676,12 +1662,12 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
axis
,
axis
,
name_attr
,
name_attr
,
))
))
for
var_out
in
var_o
uts
:
for
var_out
in
outp
uts
:
prog
.
VarDesc
(
var_out
)
prog
.
VarDesc
(
var_out
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
(
var_input
,
'X'
),
(
var_input
,
'X'
),
([
var_outs
],
*
([
'Out'
]
*
len
(
var_o
uts
))),
([
outputs
],
*
([
'Out'
]
*
len
(
outp
uts
))),
{
{
'axis'
:
axis
,
'axis'
:
axis
,
'sections'
:
split
,
'sections'
:
split
,
...
@@ -1695,9 +1681,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
...
@@ -1695,9 +1681,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
"""
"""
# I/O
# I/O
val_sum
,
=
outputs
var_sum
,
=
outputs
var_inps
=
[
_make_var_name
(
val
)
for
val
in
inputs
]
var_sum
=
_make_var_name
(
val_sum
)
# interpretation
# interpretation
fluid_op
=
'sums'
fluid_op
=
'sums'
...
@@ -1706,14 +1690,14 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
...
@@ -1706,14 +1690,14 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
prog
.
Code
(
'{} = layers.{}({})'
.
format
(
prog
.
Code
(
'{} = layers.{}({})'
.
format
(
var_sum
,
var_sum
,
fluid_op
,
fluid_op
,
'['
+
', '
.
join
(
var_inp
s
)
+
']'
,
'['
+
', '
.
join
(
input
s
)
+
']'
,
# attrs
# attrs
))
))
fluid_op
=
'sum'
fluid_op
=
'sum'
prog
.
VarDesc
(
var_sum
)
prog
.
VarDesc
(
var_sum
)
prog
.
OpDesc
(
prog
.
OpDesc
(
fluid_op
,
fluid_op
,
(
var_inps
,
*
([
'X'
]
*
len
(
var_inp
s
))),
(
inputs
,
*
([
'X'
]
*
len
(
input
s
))),
([
var_sum
],
'Out'
),
([
var_sum
],
'Out'
),
dict
(),
dict
(),
)
)
...
@@ -1725,14 +1709,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...
@@ -1725,14 +1709,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
"""
# I/O
# I/O
val_input
,
val_repeats
=
inputs
var_input
,
var_repeats
=
inputs
val_output
,
=
outputs
var_output
,
=
outputs
var_input
=
_make_var_name
(
val_input
)
var_repeats
=
_make_var_name
(
val_repeats
)
var_output
=
_make_var_name
(
val_output
)
# interpretation
# interpretation
repeats
=
_const_weight_or_none
(
value_infos
,
va
l
_repeats
)
repeats
=
_const_weight_or_none
(
value_infos
,
va
r
_repeats
)
assert
repeats
is
not
None
,
'only const repeats is supported'
assert
repeats
is
not
None
,
'only const repeats is supported'
fluid_op
=
'expand'
fluid_op
=
'expand'
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
name_attr
=
', name={}'
.
format
(
repr
(
name
))
if
name
else
''
...
@@ -1764,10 +1745,8 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
...
@@ -1764,10 +1745,8 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
"""
"""
# I/O
# I/O
val_data
,
=
inputs
var_data
,
=
inputs
val_transposed
,
=
outputs
var_transposed
,
=
outputs
var_data
=
_make_var_name
(
val_data
)
var_transposed
=
_make_var_name
(
val_transposed
)
# interpretation
# interpretation
fluid_op
=
'transpose'
fluid_op
=
'transpose'
...
...
onnx2fluid/onnx2fluid/torch_export_helper.py
浏览文件 @
2a82fdeb
...
@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019
...
@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019
@author: Macrobull
@author: Macrobull
"""
"""
from
__future__
import
division
import
logging
import
numpy
as
np
import
numpy
as
np
import
torch
import
torch
...
@@ -24,6 +27,8 @@ from typing import (
...
@@ -24,6 +27,8 @@ from typing import (
Union
,
Union
,
)
)
logger
=
logging
.
getLogger
(
__name__
)
__all__
=
[
__all__
=
[
'export_data'
,
'export_data'
,
'export_onnx_with_validation'
,
'export_onnx_with_validation'
,
...
@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
...
@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
return
str
(
obj
)
return
str
(
obj
)
prefix_
=
prefix
+
(
'_'
if
prefix
else
''
)
prefix_
=
prefix
+
(
'_'
if
prefix
else
''
)
fp
=
open
(
'{}.txt'
.
format
(
prefix
or
'meta'
),
'w'
)
fp
=
open
(
'{}.txt'
.
format
(
prefix
or
'meta'
),
mode
=
'w'
)
for
key
,
value
in
state_dict
.
items
():
for
key
,
value
in
state_dict
.
items
():
data
=
None
data
=
None
if
torch
.
is_tensor
(
value
):
if
torch
.
is_tensor
(
value
):
...
@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
...
@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
def
export_onnx_with_validation
(
def
export_onnx_with_validation
(
model
:
torch
.
nn
.
Module
,
model
:
torch
.
nn
.
Module
,
# or JITScriptModule
inputs
:
Sequence
[
Union
[
torch
.
Tensor
,
Sequence
[
object
]]],
inputs
:
Sequence
[
Union
[
torch
.
Tensor
,
Sequence
[
object
]]],
export_basepath
:
Text
,
export_basepath
:
Text
,
input_names
:
Optional
[
List
[
Text
]]
=
None
,
input_names
:
Optional
[
List
[
Text
]]
=
None
,
...
...
onnx2fluid/onnx2fluid/validation.py
浏览文件 @
2a82fdeb
...
@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog):
...
@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog):
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
assert
isinstance
(
prog
,
fluid
.
framework
.
Program
)
assert
isinstance
(
prog
,
fluid
.
framework
.
Program
),
'prog is not a Program instance'
logger
.
info
(
'performing type-shape inference ...'
)
logger
.
info
(
'performing type-shape inference ...'
)
for
block
in
prog
.
blocks
:
for
block
in
prog
.
blocks
:
...
@@ -84,6 +85,8 @@ def validate(fluid_model_filename,
...
@@ -84,6 +85,8 @@ def validate(fluid_model_filename,
inference the converted Paddle fluid model, validate with given golden data
inference the converted Paddle fluid model, validate with given golden data
"""
"""
assert
isinstance
(
fluid_model_filename
,
str
)
import
numpy
as
np
import
numpy
as
np
import
paddle.fluid
as
fluid
import
paddle.fluid
as
fluid
...
@@ -153,7 +156,7 @@ def validate(fluid_model_filename,
...
@@ -153,7 +156,7 @@ def validate(fluid_model_filename,
input_data
=
flatten_dict
(
input_data
)
input_data
=
flatten_dict
(
input_data
)
output_data
=
flatten_dict
(
output_data
)
output_data
=
flatten_dict
(
output_data
)
input_names
=
input_data
.
keys
()
input_names
=
input_data
.
keys
()
output_names
=
output_data
.
keys
()
#
output_names = output_data.keys()
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
logger
.
info
(
'with %d inputs and %d outputs'
,
len
(
input_data
),
len
(
output_data
))
len
(
output_data
))
else
:
else
:
...
...
onnx2fluid/onnx2fluid/writer.py
浏览文件 @
2a82fdeb
...
@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict
...
@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict
logger
=
logging
.
getLogger
(
__name__
)
logger
=
logging
.
getLogger
(
__name__
)
from
.
import
symbolic
from
.
import
symbolic
from
.symbolic
import
_make_var_name
as
make_var_name
try
:
try
:
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
import
paddle.fluid.proto.framework_pb2
as
framework_pb2
...
@@ -63,7 +62,7 @@ def make_attr_name(name):
...
@@ -63,7 +62,7 @@ def make_attr_name(name):
assert
name
!=
''
,
'name should not be empty'
assert
name
!=
''
,
'name should not be empty'
for
s
in
'
\\
|/:-'
:
#
for
s
in
'
\\
|/:
.
-'
:
#
name
=
name
.
replace
(
s
,
'_'
)
name
=
name
.
replace
(
s
,
'_'
)
if
not
name
.
startswith
(
'_'
):
if
not
name
.
startswith
(
'_'
):
name
=
'_'
+
name
name
=
'_'
+
name
...
@@ -207,7 +206,7 @@ class Program(object):
...
@@ -207,7 +206,7 @@ class Program(object):
return
desc
return
desc
def
VarDesc
(
self
,
def
VarDesc
(
self
,
var_
name
,
name
,
persistable
=
False
,
persistable
=
False
,
value_info
=
None
,
value_info
=
None
,
remove_batch
=
None
):
remove_batch
=
None
):
...
@@ -215,18 +214,16 @@ class Program(object):
...
@@ -215,18 +214,16 @@ class Program(object):
add VarDesc,
add VarDesc,
"""
"""
assert
var_
name
not
in
self
.
var_descs
,
'var naming conflicted'
assert
name
not
in
self
.
var_descs
,
'var naming conflicted'
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
=
framework_pb2
.
VarDesc
()
var_desc
.
name
=
var_
name
var_desc
.
name
=
name
var_desc
.
persistable
=
persistable
var_desc
.
persistable
=
persistable
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
var_desc
.
type
.
type
=
framework_pb2
.
VarType
.
LOD_TENSOR
self
.
var_descs
[
var_
name
]
=
var_desc
self
.
var_descs
[
name
]
=
var_desc
if
value_info
:
if
value_info
:
self
.
VarTypeShapeInfo
(
var_name
,
self
.
VarTypeShapeInfo
(
name
,
value_info
,
remove_batch
=
remove_batch
)
value_info
,
remove_batch
=
remove_batch
)
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
def
Op
(
self
,
domain
,
op_type
,
*
args
,
**
kwargs
):
"""
"""
...
@@ -260,19 +257,19 @@ class Program(object):
...
@@ -260,19 +257,19 @@ class Program(object):
else
:
else
:
self
.
code_mutable
=
code_mutable
self
.
code_mutable
=
code_mutable
def
VarTypeShapeInfo
(
self
,
var_
name
,
value_info
,
remove_batch
=
None
):
def
VarTypeShapeInfo
(
self
,
name
,
value_info
,
remove_batch
=
None
):
"""
"""
set value_info for var
set value_info for var
"""
"""
if
var_
name
not
in
self
.
var_descs
:
if
name
not
in
self
.
var_descs
:
return
return
dtype
=
value_info
.
get
(
'dtype'
,
None
)
dtype
=
value_info
.
get
(
'dtype'
,
None
)
if
dtype
is
None
:
if
dtype
is
None
:
return
return
var_desc
=
self
.
var_descs
[
var_
name
]
var_desc
=
self
.
var_descs
[
name
]
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
=
var_desc
.
type
.
lod_tensor
.
tensor
tensor_desc
.
data_type
=
self
.
Dtype
(
dtype
)
# required
tensor_desc
.
data_type
=
self
.
Dtype
(
dtype
)
# required
...
@@ -292,8 +289,7 @@ class Writer(object):
...
@@ -292,8 +289,7 @@ class Writer(object):
fluid code and desc writter
fluid code and desc writter
"""
"""
# CODE_INDENT = ' ' * 4
CODE_INDENT
=
' '
*
4
# '\t'
CODE_INDENT
=
'
\t
'
@
staticmethod
@
staticmethod
def
header_code
(
func_name
,
info
=
''
):
def
header_code
(
func_name
,
info
=
''
):
...
@@ -313,6 +309,7 @@ class Writer(object):
...
@@ -313,6 +309,7 @@ class Writer(object):
codes
.
append
(
'from paddle.fluid import initializer, layers'
)
codes
.
append
(
'from paddle.fluid import initializer, layers'
)
codes
.
append
(
''
)
codes
.
append
(
''
)
codes
.
append
(
''
)
codes
.
append
(
''
)
codes
.
append
(
'def {}():'
.
format
(
func_name
))
codes
.
append
(
'def {}():'
.
format
(
func_name
))
return
codes
return
codes
...
@@ -342,24 +339,26 @@ class Writer(object):
...
@@ -342,24 +339,26 @@ class Writer(object):
emit an ONNX weight into program
emit an ONNX weight into program
"""
"""
if
value_info
.
get
(
'embeded_as'
,
[]):
if
value_info
.
get
(
'embedded_as'
,
[]):
var_names
=
value_info
[
'embeded_as'
]
embedded_names
=
value_info
[
'embedded_as'
]
prog
.
Code
(
'# parameter {} embeded as {}'
.
format
(
name
,
var_names
))
prog
.
Code
(
'# parameter {} embedded as {}'
.
format
(
for
var_name
in
var_names
:
name
,
embedded_names
))
prog
.
VarDesc
(
var_name
,
persistable
=
True
,
value_info
=
value_info
)
for
embedded_name
in
embedded_names
:
prog
.
VarDesc
(
embedded_name
,
persistable
=
True
,
value_info
=
value_info
)
else
:
else
:
var_name
=
make_var_name
(
name
)
attr_name
=
make_attr_name
(
name
)
attr_name
=
make_attr_name
(
name
)
prog
.
Code
(
'# parameter {}
: {}'
.
format
(
name
,
var_
name
))
prog
.
Code
(
'# parameter {}
'
.
format
(
name
))
prog
.
Code
(
'{} = ParamAttr(name={})'
# , trainable=True
prog
.
Code
(
'{} = ParamAttr(name={})'
# , trainable=True
.
format
(
attr_name
,
repr
(
var_
name
)))
.
format
(
attr_name
,
repr
(
name
)))
prog
.
Code
(
prog
.
Code
(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))'
#, is_bias={}
', default_initializer=initializer.Constant(0))'
#, is_bias={}
.
format
(
var_
name
,
value_info
[
'shape'
],
.
format
(
name
,
value_info
[
'shape'
],
repr
(
value_info
[
'dtype'
].
name
),
repr
(
name
),
repr
(
value_info
[
'dtype'
].
name
),
repr
(
name
),
attr_name
))
#, value_info.get('is_bias', False)))
attr_name
))
#, value_info.get('is_bias', False)))
prog
.
VarDesc
(
var_
name
,
persistable
=
True
,
value_info
=
value_info
)
prog
.
VarDesc
(
name
,
persistable
=
True
,
value_info
=
value_info
)
@
staticmethod
@
staticmethod
def
emit_inputs
(
prog
,
names
,
value_infos
,
remove_batch
=
None
):
def
emit_inputs
(
prog
,
names
,
value_infos
,
remove_batch
=
None
):
...
@@ -368,7 +367,6 @@ class Writer(object):
...
@@ -368,7 +367,6 @@ class Writer(object):
"""
"""
for
idx
,
name
in
enumerate
(
names
):
for
idx
,
name
in
enumerate
(
names
):
var_name
=
make_var_name
(
name
)
value_info
=
value_infos
[
name
]
value_info
=
value_infos
[
name
]
shape
=
value_info
[
'shape'
]
shape
=
value_info
[
'shape'
]
if
remove_batch
is
None
:
if
remove_batch
is
None
:
...
@@ -377,13 +375,13 @@ class Writer(object):
...
@@ -377,13 +375,13 @@ class Writer(object):
if
remove_batch
:
if
remove_batch
:
shape
=
shape
[
1
:]
shape
=
shape
[
1
:]
prog
.
Code
(
'# input {}
: {}'
.
format
(
name
,
var_
name
))
prog
.
Code
(
'# input {}
'
.
format
(
name
))
prog
.
Code
((
prog
.
Code
((
'{} = layers.data(name={}, shape={}, dtype={}, '
'{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})'
# , stop_gradient=True
'append_batch_size={})'
# , stop_gradient=True
).
format
(
).
format
(
var_
name
,
name
,
repr
(
var_
name
),
repr
(
name
),
shape
,
shape
,
repr
(
value_info
[
'dtype'
].
name
),
repr
(
value_info
[
'dtype'
].
name
),
remove_batch
,
remove_batch
,
...
@@ -391,12 +389,10 @@ class Writer(object):
...
@@ -391,12 +389,10 @@ class Writer(object):
prog
.
OpDesc
(
prog
.
OpDesc
(
'feed'
,
'feed'
,
([
'feed'
],
'X'
),
([
'feed'
],
'X'
),
([
var_
name
],
'Out'
),
([
name
],
'Out'
),
{
'col'
:
idx
},
{
'col'
:
idx
},
)
)
prog
.
VarDesc
(
var_name
,
prog
.
VarDesc
(
name
,
value_info
=
value_info
,
remove_batch
=
remove_batch
)
value_info
=
value_info
,
remove_batch
=
remove_batch
)
@
staticmethod
@
staticmethod
def
emit_outputs
(
prog
,
names
):
#, value_infos
def
emit_outputs
(
prog
,
names
):
#, value_infos
...
@@ -406,12 +402,11 @@ class Writer(object):
...
@@ -406,12 +402,11 @@ class Writer(object):
code
=
'return '
code
=
'return '
for
idx
,
name
in
enumerate
(
names
):
for
idx
,
name
in
enumerate
(
names
):
var_name
=
make_var_name
(
name
)
code
+=
name
+
', '
code
+=
var_name
+
', '
prog
.
OpDesc
(
prog
.
OpDesc
(
'fetch'
,
'fetch'
,
([
var_
name
],
'X'
),
([
name
],
'X'
),
([
'fetch'
],
'Out'
),
([
'fetch'
],
'Out'
),
{
'col'
:
idx
},
{
'col'
:
idx
},
)
)
...
@@ -458,8 +453,7 @@ class Writer(object):
...
@@ -458,8 +453,7 @@ class Writer(object):
for
name
,
weight
in
weights
.
items
():
for
name
,
weight
in
weights
.
items
():
assert
isinstance
(
weights
,
dict
),
'dict type weights required'
assert
isinstance
(
weights
,
dict
),
'dict type weights required'
var_name
=
make_var_name
(
name
)
filename
=
os
.
path
.
join
(
save_dir
,
name
)
filename
=
os
.
path
.
join
(
save_dir
,
var_name
)
Writer
.
write_weight
(
weight
,
filename
)
Writer
.
write_weight
(
weight
,
filename
)
logger
.
debug
(
'saved weight %s to %s'
,
name
,
filename
)
logger
.
debug
(
'saved weight %s to %s'
,
name
,
filename
)
...
...
onnx2fluid/requirements.txt
浏览文件 @
2a82fdeb
-e .
-e .
onnx>=1.4
onnx>=1.4
paddlepaddle
paddlepaddle
>=1.5
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录