提交 2a82fdeb 编写于 作者: M Macrobull

fix name and add ops

上级 9828c2c7
...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用 ...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用
在如下环境配置中测试成功: 在如下环境配置中测试成功:
* python 3.5+ * python 3.5+
* onnx == 1.4.0 * onnx == 1.4.1
* paddlepaddle == 1.3.0 (可选,仅用于验证) * paddlepaddle == 1.5.0 (可选,仅用于验证)
使用[Anaconda](https://docs.anaconda.com/anaconda/install): 使用[Anaconda](https://docs.anaconda.com/anaconda/install):
``` shell ``` shell
conda install -c conda-forge onnx conda install -c conda-forge onnx
pip install paddlepaddle==1.3.0 pip install paddlepaddle==1.5.0
``` ```
## 动手玩 ## 动手玩
...@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz ...@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz
## 使用说明 ## 使用说明
目前支持 **ONNX opset 9+** 的部分算子,对应PyTorch版本 **1.0/1.1(stable opset)**,更多兼容信息请参考[ONNX文档](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
onnx2fluid: onnx2fluid:
```shell ```shell
...@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx ...@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx
## 参考 ## 参考
* PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_cn/layers_cn.html) * PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_cn/layers_cn.html)
* PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_guides/low_level/inference.html#id4) * PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_guides/low_level/inference.html#id4)
...@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation ...@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix = 'sample_' prefix = 'sample_'
idx = 0 idx = 0
######### example: RNN ######## ######## example: RNN ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.rnn = nn.RNN(4, 6, 2)
#
# def forward(self, x):
# y = x
# y, h = self.rnn(y)
# return y
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
# class Model(nn.Module):
#class Model(nn.Module): def __init__(self):
# def __init__(self): super(Model, self).__init__()
# super(Model, self).__init__() self.gru = nn.GRU(4, 5, 3)
# self.lstm = nn.LSTM(5, 6, 2)
# def forward(self, x):
# y = torch.rand((2, 3)) # + torch.rand_like(xb) def forward(self, x):
# y = y + torch.randn((2, 3)) # + torch.randn_like(xb) y = x
# return y y, h = self.gru(y)
# y, h = self.lstm(y)
# return y
#model = Model()
#model.eval()
#xb = torch.rand((2, 3)) model = Model()
#yp = model(xb) model.eval()
#idx += 1 xb = torch.rand((2, 3, 4))
#print('index: ', idx) yp = model(xb)
#export_onnx_with_validation(model, [xb], prefix + str(idx), idx += 1
# ['x'], ['y'], print('index: ', idx)
# verbose=True, training=False) export_onnx_with_validation(model, [xb],
prefix + str(idx), ['x'], ['y'],
verbose=True,
training=False)
######## example: random ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
y = torch.rand((2, 3)) # + torch.rand_like(xb)
y = y + torch.randn((2, 3)) # + torch.randn_like(xb)
return y
model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
prefix + str(idx), ['x'], ['y'],
verbose=True,
training=False)
######## example: fc ######## ######## example: fc ########
...@@ -175,7 +181,7 @@ class Model(nn.Module): ...@@ -175,7 +181,7 @@ class Model(nn.Module):
super(Model, self).__init__() super(Model, self).__init__()
self.conv = nn.Conv2d(3, 8, 3) self.conv = nn.Conv2d(3, 8, 3)
self.batch_norm = nn.BatchNorm2d(8) self.batch_norm = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(2) self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x): def forward(self, x):
y = x y = x
...@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb], ...@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb],
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx), #export_onnx_with_validation(
# ['x'], ['y'], # model, [xb], prefix + str(idx),
# verbose=True, training=False) # ['x'], ['y'],
# verbose=True, training=False)
######## example: empty ######## ######## example: empty ########
......
...@@ -24,14 +24,14 @@ bvlc_alexnet() ...@@ -24,14 +24,14 @@ bvlc_alexnet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -54,7 +54,7 @@ bvlc_googlenet() ...@@ -54,7 +54,7 @@ bvlc_googlenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -77,7 +77,7 @@ bvlc_reference_caffenet() ...@@ -77,7 +77,7 @@ bvlc_reference_caffenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
...@@ -123,14 +123,14 @@ densenet121() ...@@ -123,14 +123,14 @@ densenet121()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1 python convert_data_pb.py "$pb_dir" data_0 fc6_1
...@@ -153,7 +153,7 @@ emotion_ferplus() ...@@ -153,7 +153,7 @@ emotion_ferplus()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0 python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
...@@ -176,14 +176,14 @@ inception_v1() ...@@ -176,14 +176,14 @@ inception_v1()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -206,14 +206,14 @@ inception_v2() ...@@ -206,14 +206,14 @@ inception_v2()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -236,7 +236,7 @@ mobilenet() ...@@ -236,7 +236,7 @@ mobilenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
...@@ -259,7 +259,7 @@ resnet18() ...@@ -259,7 +259,7 @@ resnet18()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
...@@ -282,14 +282,14 @@ resnet50() ...@@ -282,14 +282,14 @@ resnet50()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
...@@ -312,7 +312,7 @@ resnet100_arcface() ...@@ -312,7 +312,7 @@ resnet100_arcface()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1 python convert_data_pb.py "$pb_dir" data fc1
...@@ -335,7 +335,7 @@ resnet101_duc() ...@@ -335,7 +335,7 @@ resnet101_duc()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss python convert_data_pb.py "$pb_dir" data seg_loss
...@@ -358,7 +358,7 @@ resnet152() ...@@ -358,7 +358,7 @@ resnet152()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
...@@ -381,7 +381,7 @@ shufflenet() ...@@ -381,7 +381,7 @@ shufflenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
...@@ -404,7 +404,7 @@ squeezenet() ...@@ -404,7 +404,7 @@ squeezenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 softmaxout_1 python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
...@@ -427,7 +427,7 @@ squeezenet1v1() ...@@ -427,7 +427,7 @@ squeezenet1v1()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
...@@ -448,10 +448,10 @@ ssd() ...@@ -448,10 +448,10 @@ ssd()
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
mkdir "$bn_tar" mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar"/ tar xf "$fn_tar" -C "$bn_tar/"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
...@@ -474,7 +474,7 @@ tiny_yolov2() ...@@ -474,7 +474,7 @@ tiny_yolov2()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" image grid python convert_data_pb.py "$pb_dir" image grid
...@@ -497,7 +497,7 @@ vgg16bn() ...@@ -497,7 +497,7 @@ vgg16bn()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
...@@ -520,7 +520,7 @@ vgg19() ...@@ -520,7 +520,7 @@ vgg19()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -543,7 +543,7 @@ yolov3() ...@@ -543,7 +543,7 @@ yolov3()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x # python -m onnx2fluid $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0 python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
...@@ -566,7 +566,7 @@ zfnet512() ...@@ -566,7 +566,7 @@ zfnet512()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
......
...@@ -17,6 +17,22 @@ __all__ = [ ...@@ -17,6 +17,22 @@ __all__ = [
DEFAULT_ONNX_OPSET_VERSION = 9 DEFAULT_ONNX_OPSET_VERSION = 9
def make_var_name(name):
"""
make a valid variable name in Python code and filename in filesystem
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:.-':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
def convert(onnx_model_filename, def convert(onnx_model_filename,
save_dir, save_dir,
model_basename='model.py', model_basename='model.py',
...@@ -30,6 +46,12 @@ def convert(onnx_model_filename, ...@@ -30,6 +46,12 @@ def convert(onnx_model_filename,
convert an ONNX model to Paddle fluid Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
""" """
assert isinstance(onnx_model_filename, str)
assert isinstance(save_dir, str)
assert isinstance(model_basename, str)
assert isinstance(model_func_name, str)
assert onnx_opset_version is None or isinstance(onnx_opset_version, int)
import onnx import onnx
from onnx.checker import ValidationError from onnx.checker import ValidationError
...@@ -41,7 +63,6 @@ def convert(onnx_model_filename, ...@@ -41,7 +63,6 @@ def convert(onnx_model_filename,
from .onnx_utils import inferred_model_value_info from .onnx_utils import inferred_model_value_info
from .onnx_utils import polish_model from .onnx_utils import polish_model
from .writer import Program, Writer from .writer import Program, Writer
from .writer import make_var_name
logger = logging.getLogger('convert') logger = logging.getLogger('convert')
...@@ -88,17 +109,21 @@ def convert(onnx_model_filename, ...@@ -88,17 +109,21 @@ def convert(onnx_model_filename,
fluid_writer = Writer() fluid_writer = Writer()
# model components # model components
# graph_name = onnx_graph.name inp_vars = [make_var_name(value.name) for value in onnx_graph.input]
graph_inputs = [value.name for value in onnx_graph.input] out_vars = [make_var_name(value.name) for value in onnx_graph.output]
graph_outputs = [value.name for value in onnx_graph.output] par_vars = []
graph_params = [] value_infos = inferred_model_value_info(onnx_model)
graph_value_infos = inferred_model_value_info(onnx_model) value_infos = {
make_var_name(key): value
for key, value in value_infos.items()
}
# prepare additional value_info # prepare additional value_info
# for weights # for weights
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name] var_name = make_var_name(name)
value_info['embeded_as'] = [] value_info = value_infos[var_name]
value_info['embedded_as'] = []
value_info['get_weight'] = (lambda w: lambda: w.tolist())( value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter weight) # lazy getter
...@@ -108,19 +133,23 @@ def convert(onnx_model_filename, ...@@ -108,19 +133,23 @@ def convert(onnx_model_filename,
topo = 'forward' topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph,
topo=topo): topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type) op_name = make_var_name(name)
inputs = [make_var_name(val) for val in inputs]
outputs = [make_var_name(val) for val in outputs]
logger.debug('translating op %s(%s) %s::%s ...', name, op_name, domain,
op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
domain = '' domain = ''
try: try:
fluid_writer.emit_op( fluid_writer.emit_op(
fluid_program, fluid_program,
name, op_name,
domain, domain,
op_type, op_type,
inputs, inputs,
outputs, outputs,
attrs, attrs,
graph_value_infos, value_infos,
embed_params=embed_params, embed_params=embed_params,
) )
except BaseException as e: except BaseException as e:
...@@ -133,17 +162,16 @@ def convert(onnx_model_filename, ...@@ -133,17 +162,16 @@ def convert(onnx_model_filename,
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# type-shape info copy # type-shape info copy
for name, value_info in graph_value_infos.items(): for var_name, value_info in value_infos.items():
var_name = make_var_name(name)
fluid_program.VarTypeShapeInfo(var_name, value_info, fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # remove_batch=False) #
bad_var_names = [] bad_vars = []
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name) bad_vars.append(var_name)
if len(bad_var_names) > 0: if len(bad_vars) > 0:
logger.warning('type-shape not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5])) ', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly') 'but Paddle Mobile may not infer correctly')
logger.warning('please consider running validation with -i ' logger.warning('please consider running validation with -i '
...@@ -151,40 +179,41 @@ def convert(onnx_model_filename, ...@@ -151,40 +179,41 @@ def convert(onnx_model_filename,
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
graph_params.append(name) var_name = make_var_name(name)
value_info = graph_value_infos[name] par_vars.append(var_name)
var_names = value_info.get('embeded_as', []) value_info = value_infos[var_name]
if var_names: embedded_names = value_info.get('embedded_as', [])
if len(var_names) > 1: if embedded_names:
if len(embedded_names) > 1:
logger.info( logger.info(
'weight %s is shared between ops, more disk space will be consumed', 'weight %s is shared between ops, more disk space will be consumed',
name) name)
logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_names) weight.dtype, weight.size, weight.nbytes,
for var_name in var_names: # multiple references embedded_names)
for embedded_name in embedded_names: # multiple references
fluid_writer.write_weight( fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name)) weight, shutil.os.path.join(save_dir, embedded_name))
else: else:
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes, weight.dtype, weight.size, weight.nbytes, var_name)
make_var_name(name)) fluid_writer.write_weight(weight,
fluid_writer.write_weight( shutil.os.path.join(save_dir, var_name))
weight, shutil.os.path.join(save_dir, make_var_name(name))) fluid_writer.emit_param(fluid_program, var_name, value_info)
fluid_writer.emit_param(fluid_program, name, value_info)
param_codes = fluid_program.codes param_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
logger.info('%d weights converted', len(graph_params)) logger.info('%d weights converted', len(par_vars))
# input writer # input writer
external_inputs = [] external_inputs = []
for name in graph_inputs: for var_name in inp_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_inputs.append(name) external_inputs.append(var_name)
fluid_writer.emit_inputs(fluid_program, fluid_writer.emit_inputs(fluid_program,
external_inputs, external_inputs,
graph_value_infos, value_infos,
remove_batch=False) # TODO: remove_batch=False) # TODO:
input_codes = fluid_program.codes input_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
...@@ -192,11 +221,11 @@ def convert(onnx_model_filename, ...@@ -192,11 +221,11 @@ def convert(onnx_model_filename,
# output writer # output writer
external_outputs = [] external_outputs = []
for name in graph_outputs: for var_name in out_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_outputs.append(name) external_outputs.append(var_name)
fluid_writer.emit_outputs(fluid_program, external_outputs) fluid_writer.emit_outputs(fluid_program, external_outputs)
output_codes = [''] + fluid_program.codes # add an empty line output_codes = [''] + fluid_program.codes # add an empty line
fluid_program.codes = [] fluid_program.codes = []
...@@ -204,10 +233,18 @@ def convert(onnx_model_filename, ...@@ -204,10 +233,18 @@ def convert(onnx_model_filename,
# code generation # code generation
header_codes = fluid_writer.header_code( header_codes = fluid_writer.header_code(
model_func_name, 'From: {}'.format(onnx_model_filename)) model_func_name,
'From: {}'.format(onnx_model_filename),
)
code_filename = shutil.os.path.join(save_dir, model_basename) code_filename = shutil.os.path.join(save_dir, model_basename)
fluid_writer.write_code_file(code_filename, header_codes, input_codes, fluid_writer.write_code_file(
param_codes, op_codes, output_codes) code_filename,
header_codes,
input_codes,
param_codes,
op_codes,
output_codes,
)
logger.info('code saved to %s, factory function: %s', code_filename, logger.info('code saved to %s, factory function: %s', code_filename,
model_func_name) model_func_name)
......
...@@ -87,6 +87,9 @@ def get_attribute_value2(attr): ...@@ -87,6 +87,9 @@ def get_attribute_value2(attr):
get_attribute_value enhanced get_attribute_value enhanced
""" """
assert isinstance(
attr, onnx.AttributeProto), 'attr is not a AttributeProto instance'
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data data = attr.t.raw_data
...@@ -106,6 +109,9 @@ def tensor_dtype(tensor): ...@@ -106,6 +109,9 @@ def tensor_dtype(tensor):
get ONNX tensor in np.dtype get ONNX tensor in np.dtype
""" """
assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type] return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
...@@ -114,6 +120,9 @@ def tensor_shape(tensor): ...@@ -114,6 +120,9 @@ def tensor_shape(tensor):
get ONNX tensor shape get ONNX tensor shape
""" """
assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim]) return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim])
...@@ -122,6 +131,8 @@ def node_attrs(node): ...@@ -122,6 +131,8 @@ def node_attrs(node):
convert ONNX node attributes to dict convert ONNX node attributes to dict
""" """
assert isinstance(node, onnx.NodeProto), 'node is not a NodeProto instance'
return {attr.name: get_attribute_value2(attr) return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict for attr in node.attribute} # dict
...@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'): ...@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'):
generator for ONNX node graph with given topology generator for ONNX node graph with given topology
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
return node_iter(graph.node, node_topo(graph.node, topo)) return node_iter(graph.node, node_topo(graph.node, topo))
...@@ -236,9 +246,8 @@ def graph_weights(graph): ...@@ -236,9 +246,8 @@ def graph_weights(graph):
generator for weights of an ONNX model generator for weights of an ONNX model
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
for initializer in graph.initializer: for initializer in graph.initializer:
name = initializer.name name = initializer.name
...@@ -251,6 +260,9 @@ def inferred_model_value_info(model): ...@@ -251,6 +260,9 @@ def inferred_model_value_info(model):
collect value/type info for an ONNX model collect value/type info for an ONNX model
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
model = infer_shapes(model) model = infer_shapes(model)
graph = model.graph graph = model.graph
value_info = Dict() value_info = Dict()
...@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
if op_list is None: if op_list is None:
op_list = ('Dropout', 'Identity') op_list = ('Dropout', 'Identity')
...@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
strip weights for inference strip weights for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
out_names = [val.name for val in model.graph.output] out_names = [val.name for val in model.graph.output]
...@@ -456,6 +475,9 @@ def optimize_model_cast(model): ...@@ -456,6 +475,9 @@ def optimize_model_cast(model):
strip cascade and unecessary onnx::Cast-9: strip cascade and unecessary onnx::Cast-9:
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
value_info = inferred_model_value_info(model) value_info = inferred_model_value_info(model)
...@@ -513,6 +535,9 @@ def optimize_model_slice(model): ...@@ -513,6 +535,9 @@ def optimize_model_slice(model):
strip cascade and unecessary onnx::Slice-1:9 strip cascade and unecessary onnx::Slice-1:9
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
......
...@@ -50,6 +50,9 @@ DEFAULT_OP_MAPPING = { ...@@ -50,6 +50,9 @@ DEFAULT_OP_MAPPING = {
dict(), None, None, False], dict(), None, None, False],
## unary ops ## ## unary ops ##
'Abs': ['abs', ['X'], ['Out']], 'Abs': ['abs', ['X'], ['Out']],
'Acos': ['acos', ['X'], ['Out']],
'Asin': ['asin', ['X'], ['Out']],
'Atan': ['atan', ['X'], ['Out']],
'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')], 'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')],
'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')], 'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')],
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
...@@ -144,52 +147,36 @@ DEFAULT_IOA_CONSTRAINTS = { ...@@ -144,52 +147,36 @@ DEFAULT_IOA_CONSTRAINTS = {
} }
def _make_var_name(name): def _dtype(value_infos, name):
""" return _np.dtype(value_infos[name]['dtype'])
make a valid variable name in Python code and in filesystem
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:-': #
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
def _dtype(value_infos, val_name): def _dtype_or_none(value_infos, name):
return _np.dtype(value_infos[val_name]['dtype']) if name not in value_infos:
def _dtype_or_none(value_infos, val_name):
if val_name not in value_infos:
return None return None
value_info = value_infos[val_name] value_info = value_infos[name]
if 'dtype' not in value_info: if 'dtype' not in value_info:
return None return None
return _np.dtype(value_info['dtype']) return _np.dtype(value_info['dtype'])
def _shape(value_infos, val_name): def _shape(value_infos, name):
return list(value_infos[val_name]['shape']) return list(value_infos[name]['shape'])
def _shape_or_none(value_infos, val_name): def _shape_or_none(value_infos, name):
if val_name not in value_infos: if name not in value_infos:
return None return None
value_info = value_infos[val_name] value_info = value_infos[name]
if 'shape' not in value_info: if 'shape' not in value_info:
return None return None
return list(value_info['shape']) return list(value_info['shape'])
def _const_weight_or_none(value_infos, val_name): def _const_weight_or_none(value_infos, name):
if val_name not in value_infos: if name not in value_infos:
return None return None
value_info = value_infos[val_name] value_info = value_infos[name]
const_value = value_info.get('const_value', None) const_value = value_info.get('const_value', None)
if const_value is not None: if const_value is not None:
return const_value return const_value
...@@ -199,11 +186,11 @@ def _const_weight_or_none(value_infos, val_name): ...@@ -199,11 +186,11 @@ def _const_weight_or_none(value_infos, val_name):
return None return None
def _check_embeddable(value_infos, *val_names): def _check_embeddable(value_infos, *names):
keyword = 'get_weight' keyword = 'get_weight'
for val_name in val_names: for name in names:
if keyword not in value_infos[val_name]: if keyword not in value_infos[name]:
_logger.warning('parameter %s not embeddable', val_name) _logger.warning('parameter %s not embeddable', name)
return False return False
return True return True
...@@ -239,12 +226,10 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -239,12 +226,10 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_attrs = default_attrs.copy() fluid_attrs = default_attrs.copy()
fluid_attrs.update(mapped_attrs) # as new attrs fluid_attrs.update(mapped_attrs) # as new attrs
val_inps = inputs if input_perm is None else map(inputs.__getitem__, var_inps = inputs if input_perm is None else list(
input_perm) map(inputs.__getitem__, input_perm))
val_outs = outputs if output_perm is None else map(outputs.__getitem__, var_outs = outputs if output_perm is None else list(
output_perm) map(outputs.__getitem__, output_perm))
var_inps = [_make_var_name(val) for val in val_inps]
var_outs = [_make_var_name(val) for val in val_outs]
arg_name = ', name={}'.format( arg_name = ', name={}'.format(
repr(name)) if fill_name_field and name else '' repr(name)) if fill_name_field and name else ''
arg_attrs = [ arg_attrs = [
...@@ -277,9 +262,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -277,9 +262,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
def _assign(prog, mapping): def _assign(prog, mapping):
fluid_op = 'assign' fluid_op = 'assign'
for val_dst, val_src in mapping.items(): for var_dst, var_src in mapping.items():
var_dst = _make_var_name(val_dst)
var_src = _make_var_name(val_src)
prog.Code('{} = {} # assign'.format(var_dst, var_src)) prog.Code('{} = {} # assign'.format(var_dst, var_src))
# prog.Code('{} = layers.{}({})' # prog.Code('{} = layers.{}({})'
# .format(var_dst, # .format(var_dst,
...@@ -295,18 +278,18 @@ def _assign(prog, mapping): ...@@ -295,18 +278,18 @@ def _assign(prog, mapping):
) )
def _zeros_like(prog, val_ref, val_out, value_infos): def _zeros_like(prog, var_ref, var_out, value_infos):
prog.Op( prog.Op(
'', '',
'Sub', 'Sub',
[val_ref, val_ref], [var_ref, var_ref],
[val_out], # val [var_out],
{'axis': 0}, {'axis': 0},
value_infos, value_infos,
) )
def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE def _pad_if_asymmetric(prog, pads, var_name, value_infos): # pads: SSEE
assert len(pads) & 1 == 0 assert len(pads) & 1 == 0
ndims = len(pads) // 2 ndims = len(pads) // 2
symmetric = True symmetric = True
...@@ -315,36 +298,29 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE ...@@ -315,36 +298,29 @@ def _pad_if_asymmetric(prog, pads, val_name, value_infos): # pads: SSEE
symmetric = False symmetric = False
break break
if symmetric: if symmetric:
return pads[:ndims], val_name return pads[:ndims], var_name
val_padded = val_name + '_padded' # explicit variable var_padded = var_name + '_padded' # explicit variable
prog.Op( prog.Op(
'', '',
'Pad', 'Pad',
[val_name], [var_name],
[val_padded], # val [var_padded],
{ {
'mode': 'constant', 'mode': 'constant',
'value': 0., 'value': 0.,
'pads': pads, 'pads': pads,
}, },
value_infos=value_infos, value_infos=value_infos,
name=val_padded, name=var_padded,
) )
return [0] * ndims, val_padded return [0] * ndims, var_padded
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
# I/O # I/O
val_x, = inputs var_x, = inputs
val_y, = outputs[:1] var_y, var_indices = (outputs + [None] * 1)[:2]
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
has_indices = len(outputs) > 1
if has_indices:
val_indices = outputs[1]
var_indices = _make_var_name(val_indices)
# interpretation # interpretation
pool_size = attrs['output_size'] # required pool_size = attrs['output_size'] # required
...@@ -361,28 +337,28 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): ...@@ -361,28 +337,28 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
', pool_type={}' ', pool_type={}'
'{})'.format( '{})'.format(
var_y, var_y,
', {}'.format(var_indices) if has_indices else '', ', {}'.format(var_indices) if var_indices else '',
fluid_op, fluid_op,
var_x, var_x,
# attrs # attrs
has_indices, bool(var_indices),
pool_size, pool_size,
repr(pool_type), repr(pool_type),
name_attr, name_attr,
)) ))
fluid_op = 'pool{}d'.format(poolnd) fluid_op = 'pool{}d'.format(poolnd)
prog.VarDesc(var_y) prog.VarDesc(var_y)
if has_indices: if var_indices:
prog.VarDesc(var_indices) prog.VarDesc(var_indices)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if var_indices else []), 'Out', 'Indices'),
{ {
'global_pooling': False, 'global_pooling': False,
'adaptive': True, 'adaptive': True,
'exclusive': True, 'exclusive': True,
'require_index': has_indices, 'require_index': bool(var_indices),
'pooling_type': pool_type, 'pooling_type': pool_type,
'ksize': pool_size, 'ksize': pool_size,
}, },
...@@ -391,14 +367,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): ...@@ -391,14 +367,12 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
# I/O # I/O
val_x, = inputs var_x, = inputs
val_y, = outputs var_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation # interpretation
input_shape = _shape_or_none(value_infos, val_x) input_shape = _shape_or_none(value_infos, var_x)
output_shape = _shape_or_none(value_infos, val_y) output_shape = _shape_or_none(value_infos, var_y)
assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC... assert input_shape is not None or output_shape is not None, 'poolnd not inferred' # NC...
if input_shape is not None: if input_shape is not None:
poolnd = len(input_shape) - 2 # NC... poolnd = len(input_shape) - 2 # NC...
...@@ -436,14 +410,8 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -436,14 +410,8 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
# I/O # I/O
val_x, = inputs var_x, = inputs
val_y, = outputs[:1] var_y, var_indices = (outputs + [None] * 1)[:2]
var_y = _make_var_name(val_y)
has_indices = len(outputs) > 1
if has_indices:
val_indices = outputs[1]
var_indices = _make_var_name(val_indices)
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -457,8 +425,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -457,8 +425,7 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
strides = attrs.get('strides', [1] * poolnd) # optional strides = attrs.get('strides', [1] * poolnd) # optional
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
pads = attrs.get('pads', [0] * (poolnd * 2)) # optional pads = attrs.get('pads', [0] * (poolnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
...@@ -481,17 +448,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -481,17 +448,17 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
name_attr, name_attr,
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
if has_indices: if var_indices:
prog.VarDesc(var_indices) prog.VarDesc(var_indices)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x], 'X'), ([var_x], 'X'),
([var_y] + ([var_indices] if has_indices else []), 'Out', 'Indices'), ([var_y] + ([var_indices] if var_indices else []), 'Out', 'Indices'),
{ {
'global_pooling': False, 'global_pooling': False,
'adaptive': False, 'adaptive': False,
'exclusive': True, 'exclusive': True,
'require_index': has_indices, 'require_index': bool(var_indices),
'pooling_type': pool_type, 'pooling_type': pool_type,
'ksize': pool_size, 'ksize': pool_size,
'strides': strides, 'strides': strides,
...@@ -503,11 +470,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -503,11 +470,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
# I/O # I/O
val_x, val_rois = inputs var_x, var_rois = inputs
val_y, = outputs var_y, = outputs
var_x = _make_var_name(val_x)
var_rois = _make_var_name(val_rois)
var_y = _make_var_name(val_y)
# interpretation # interpretation
spatial_scale = attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
...@@ -536,7 +500,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -536,7 +500,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
'{})'.format( '{})'.format(
var_y, var_y,
fluid_op, fluid_op,
val_x, var_x,
var_rois, var_rois,
# attrs # attrs
spatial_scale, spatial_scale,
...@@ -546,7 +510,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -546,7 +510,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
)) ))
prog.VarDesc(var_y) prog.VarDesc(var_y)
if is_max_pool: if is_max_pool:
var_argmax = _make_var_name(name + '.argmax') # hidden variable var_argmax = name + '.argmax' # hidden variable
prog.VarDesc(var_argmax) prog.VarDesc(var_argmax)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
...@@ -558,19 +522,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -558,19 +522,17 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
# I/O # I/O
val_x, val_scales = inputs var_x, var_scales = inputs
val_y, = outputs var_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation # interpretation
# output shape # output shape
out_shape_ = _shape_or_none(value_infos, val_y) out_shape_ = _shape_or_none(value_infos, var_y)
if out_shape_ is not None: if out_shape_ is not None:
assert len(out_shape_) == 4, 'only 4-D Tensor as X and Y supported' assert len(out_shape_) == 4, 'only 4-D Tensor as X and Y supported'
out_shape_ = out_shape_[2:] out_shape_ = out_shape_[2:]
# try scales # try scales
scales = _const_weight_or_none(value_infos, val_scales) scales = _const_weight_or_none(value_infos, var_scales)
if scales is not None: if scales is not None:
assert len(scales) == 4, 'only 4-D Tensor as X and Y supported' assert len(scales) == 4, 'only 4-D Tensor as X and Y supported'
assert scales[0] == 1 and scales[ assert scales[0] == 1 and scales[
...@@ -585,7 +547,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -585,7 +547,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
else: else:
out_shape = None out_shape = None
if out_shape_ is None: if out_shape_ is None:
in_shape = _shape_or_none(value_infos, val_x) in_shape = _shape_or_none(value_infos, var_x)
assert in_shape is not None, 'out_shape required but not inferrable' assert in_shape is not None, 'out_shape required but not inferrable'
assert len(in_shape) == 4, 'only 4-D Tensor as X and Y supported' assert len(in_shape) == 4, 'only 4-D Tensor as X and Y supported'
out_shape_ = [in_shape[2] * scale, in_shape[3] * scale] out_shape_ = [in_shape[2] * scale, in_shape[3] * scale]
...@@ -642,10 +604,8 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -642,10 +604,8 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
""" """
# I/O # I/O
val_theta, = inputs var_theta, = inputs
val_grid, = outputs var_grid, = outputs
var_theta = _make_var_name(val_theta)
var_grid = _make_var_name(val_grid)
# interpretation # interpretation
fluid_op = 'affine_grid' fluid_op = 'affine_grid'
...@@ -701,10 +661,8 @@ def BatchNormalization(prog, ...@@ -701,10 +661,8 @@ def BatchNormalization(prog,
""" """
# I/O # I/O
val_x, val_scale, val_b, val_mean, val_var = inputs var_x, var_scale, var_b, var_mean, var_var = inputs
val_y, = outputs var_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
var_saved_mean = name + '.saved_mean' # dummy output var_saved_mean = name + '.saved_mean' # dummy output
var_saved_variance = name + '.saved_variance' # dummy output var_saved_variance = name + '.saved_variance' # dummy output
...@@ -714,28 +672,28 @@ def BatchNormalization(prog, ...@@ -714,28 +672,28 @@ def BatchNormalization(prog,
epsilon = attrs.get('epsilon', 1e-5) # optional epsilon = attrs.get('epsilon', 1e-5) # optional
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
embed_params = _check_embeddable(value_infos, val_scale, val_b, embed_params = _check_embeddable(value_infos, var_scale, var_b,
val_mean, val_var) var_mean, var_var)
if not embed_params and name: if not embed_params and name:
_logger.warning('for op %s(%s -> BatchNormalization -> %s)', name, _logger.warning('for op %s(%s -> BatchNormalization -> %s)', name,
inputs, outputs) inputs, outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_scale = name + '.w_0' embedded_scale = name + '.w_0'
var_b = name + '.b_0' embedded_b = name + '.b_0'
var_mean = name + '.w_1' embedded_mean = name + '.w_1'
var_var = name + '.w_2' embedded_var = name + '.w_2'
value_infos[val_scale]['embeded_as'].append(var_scale) value_infos[var_scale]['embedded_as'].append(embedded_scale)
value_infos[val_b]['embeded_as'].append(var_b) value_infos[var_b]['embedded_as'].append(embedded_b)
value_infos[val_mean]['embeded_as'].append(var_mean) value_infos[var_mean]['embedded_as'].append(embedded_mean)
value_infos[val_var]['embeded_as'].append(var_var) value_infos[var_var]['embedded_as'].append(embedded_var)
var_scale = embedded_scale
var_b = embedded_b
var_mean = embedded_mean
var_var = embedded_var
param_attr = '' param_attr = ''
else: else:
var_scale = _make_var_name(val_scale)
var_b = _make_var_name(val_b)
var_mean = _make_var_name(val_mean)
var_var = _make_var_name(val_var)
param_attr = (', param_attr={}, bias_attr={}' param_attr = (', param_attr={}, bias_attr={}'
', moving_mean_name={}, moving_variance_name={}').format( ', moving_mean_name={}, moving_variance_name={}').format(
repr(var_scale), repr(var_b), repr(var_mean), repr(var_scale), repr(var_b), repr(var_mean),
...@@ -780,16 +738,14 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -780,16 +738,14 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
""" """
# I/O # I/O
val_input, = inputs var_input, = inputs
val_output, = outputs var_output, = outputs
var_input = _make_var_name(val_input)
var_output = _make_var_name(val_output)
# interpretation # interpretation
dtype = attrs['to'] # required dtype = attrs['to'] # required
if not isinstance(dtype, _np.dtype): # additional: possible np.dtype if not isinstance(dtype, _np.dtype): # additional: possible np.dtype
dtype = TENSOR_TYPE_TO_NP_TYPE[dtype] dtype = TENSOR_TYPE_TO_NP_TYPE[dtype]
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, var_output)
if output_dtype is not None: if output_dtype is not None:
assert dtype == output_dtype, 'dtype of to unmatches output' assert dtype == output_dtype, 'dtype of to unmatches output'
...@@ -812,7 +768,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -812,7 +768,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
([var_output], 'Out'), ([var_output], 'Out'),
{ {
'in_dtype': prog.Dtype(_dtype(value_infos, 'in_dtype': prog.Dtype(_dtype(value_infos,
val_input)), # holy, required var_input)), # holy, required
'out_dtype': prog.Dtype(dtype), 'out_dtype': prog.Dtype(dtype),
}, },
) )
...@@ -824,9 +780,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -824,9 +780,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
""" """
# I/O # I/O
val_concat_result, = outputs var_ret, = outputs
var_inps = [_make_var_name(val) for val in inputs]
var_concat_result = _make_var_name(val_concat_result)
# interpretation # interpretation
fluid_op = 'concat' fluid_op = 'concat'
...@@ -837,18 +791,18 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -837,18 +791,18 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
', axis={}' ', axis={}'
'{})'.format( '{})'.format(
var_concat_result, var_ret,
fluid_op, fluid_op,
'[' + ', '.join(var_inps) + ']', '[' + ', '.join(inputs) + ']',
# attrs # attrs
axis, axis,
name_attr, name_attr,
)) ))
prog.VarDesc(var_concat_result) prog.VarDesc(var_ret)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(var_inps, *(['X'] * len(var_inps))), (inputs, *(['X'] * len(inputs))),
([var_concat_result], 'Out'), ([var_ret], 'Out'),
{'axis': axis}, {'axis': axis},
) )
...@@ -860,13 +814,12 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -860,13 +814,12 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O # I/O
assert len(inputs) == 0, 'constant op accept no inputs' assert len(inputs) == 0, 'constant op accept no inputs'
val_output, = outputs var_output, = outputs
var_output = _make_var_name(val_output)
# interpretation # interpretation
value = attrs['value'] # required value = attrs['value'] # required
dtype = _np.dtype(value.dtype) dtype = _np.dtype(value.dtype)
output_dtype = _dtype_or_none(value_infos, val_output) output_dtype = _dtype_or_none(value_infos, var_output)
if output_dtype is not None: if output_dtype is not None:
assert dtype == output_dtype, 'tensor dtype unmatches storage dtype' assert dtype == output_dtype, 'tensor dtype unmatches storage dtype'
...@@ -874,13 +827,13 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -874,13 +827,13 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# dtype = _np.dtype('float32') # HINT: force to float32 # dtype = _np.dtype('float32') # HINT: force to float32
shape = attrs.get('shape', None) # shape = attrs.get('shape', None) #
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_output) shape = _shape_or_none(value_infos, var_output)
if shape is None: if shape is None:
shape = list(value.shape) shape = list(value.shape)
_logger.warning( _logger.warning(
'in op (Constant -> %s): ' 'in op (Constant -> %s): '
'attribute "shape" of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output) 'using value as 1-D tensor may lead to fails', outputs, var_output)
# generation # generation
value = value.tolist() value = value.tolist()
...@@ -911,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -911,7 +864,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
prog.Code('# {} = {} # passed directly as literal'.format( prog.Code('# {} = {} # passed directly as literal'.format(
var_output, value)) var_output, value))
value_infos[val_output]['const_value'] = value value_infos[var_output]['const_value'] = value
def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...@@ -920,13 +873,12 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -920,13 +873,12 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
""" """
# I/O # I/O
val_shape, = inputs var_shape, = inputs
val_output, = outputs var_output, = outputs
var_shape = _make_var_name(val_shape)
shape = _const_weight_or_none(value_infos, val_shape) shape = _const_weight_or_none(value_infos, var_shape)
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_output) shape = _shape_or_none(value_infos, var_output)
assert shape is not None, ( assert shape is not None, (
'given shape is neither const value nor deductible from output, ' 'given shape is neither const value nor deductible from output, '
'this is not supported') 'this is not supported')
...@@ -959,53 +911,47 @@ def Conv(prog, ...@@ -959,53 +911,47 @@ def Conv(prog,
""" """
# I/O # I/O
val_x, val_w = inputs[:2] var_x, var_w = inputs[:2]
val_y, = outputs var_y, var_b = (outputs + [None] * 1)[:2]
var_y = _make_var_name(val_y)
has_bias = len(inputs) == 3
if has_bias:
val_b, = inputs[2:]
# interpretation # interpretation
assert attrs.get( assert attrs.get(
'auto_pad', 'NOTSET' 'auto_pad', 'NOTSET'
) == 'NOTSET', 'only auto_pad == NOTSET is supported' # optional ) == 'NOTSET', 'only auto_pad == NOTSET is supported' # optional
kernel_shape = _shape(value_infos, val_w)[2:] # OI... kernel_shape = _shape(value_infos, var_w)[2:] # OI...
assert kernel_shape == attrs[ assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW 'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported' assert 2 <= convnd <= 3, 'only conv2d and conv3d is supported'
num_out_channels = _shape(value_infos, val_w)[0] # OI... num_out_channels = _shape(value_infos, var_w)[0] # OI...
fluid_op = 'conv{}d'.format(convnd) fluid_op = 'conv{}d'.format(convnd)
num_groups = attrs.get('group', 1) # optional num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias embed_params = (_check_embeddable(value_infos, var_w) and not var_b
or _check_embeddable(value_infos, val_b)) or _check_embeddable(value_infos, var_b))
if not embed_params and name: if not embed_params and name:
_logger.warning('for op %s(%s -> Conv -> %s)', name, inputs, _logger.warning('for op %s(%s -> Conv -> %s)', name, inputs,
outputs) outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' embedded_w = name + '.w_0'
value_infos[val_w]['embeded_as'].append(var_w) value_infos[var_w]['embedded_as'].append(embedded_w)
if has_bias: var_w = embedded_w
var_b = name + '.b_0' if var_b:
value_infos[val_b]['embeded_as'].append(var_b) embedded_b = name + '.b_0'
value_infos[var_b]['embedded_as'].append(embedded_b)
var_b = embedded_b
param_attr = '' param_attr = ''
else: else:
param_attr = ', bias_attr=False' param_attr = ', bias_attr=False'
else: else:
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format( param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_w),
repr(var_b) if var_b else False) repr(var_b) if var_b else False)
...@@ -1036,7 +982,7 @@ def Conv(prog, ...@@ -1036,7 +982,7 @@ def Conv(prog,
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if var_b else var_y], 'Output'),
{ {
'strides': strides, 'strides': strides,
'paddings': paddings, 'paddings': paddings,
...@@ -1044,13 +990,13 @@ def Conv(prog, ...@@ -1044,13 +990,13 @@ def Conv(prog,
'groups': num_groups, 'groups': num_groups,
}, },
) )
if has_bias: if var_b:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
'', '',
'Add', 'Add',
[var_conv, var_b], # [var_conv, var_b], #
[val_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '.bias'),
...@@ -1073,13 +1019,8 @@ def ConvTranspose(prog, ...@@ -1073,13 +1019,8 @@ def ConvTranspose(prog,
""" """
# I/O # I/O
val_x, val_w = inputs[:2] var_x, var_w = inputs[:2]
val_y, = outputs var_y, var_b = (outputs + [None] * 1)[:2]
var_y = _make_var_name(val_y)
has_bias = len(inputs) == 3
if has_bias:
val_b, = inputs[2:]
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -1088,41 +1029,40 @@ def ConvTranspose(prog, ...@@ -1088,41 +1029,40 @@ def ConvTranspose(prog,
assert sum(attrs.get( assert sum(attrs.get(
'output_padding', 'output_padding',
[])) == 0, 'only zero output_padding is supported' # optional ? [])) == 0, 'only zero output_padding is supported' # optional ?
kernel_shape = _shape(value_infos, val_w)[2:] # IO... kernel_shape = _shape(value_infos, var_w)[2:] # IO...
assert kernel_shape == attrs[ assert kernel_shape == attrs[
'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW 'kernel_shape'], 'kernel_shape in attr unmatches value_info' # HW
convnd = len(kernel_shape) convnd = len(kernel_shape)
assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose is supported' assert 2 <= convnd <= 3, 'only conv2d_transpose and conv3d_transpose is supported'
num_out_channels = _shape(value_infos, val_w)[1] # IO... num_out_channels = _shape(value_infos, var_w)[1] # IO...
fluid_op = 'conv{}d_transpose'.format(convnd) fluid_op = 'conv{}d_transpose'.format(convnd)
num_groups = attrs.get('group', 1) # optional num_groups = attrs.get('group', 1) # optional
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, val_x = _pad_if_asymmetric(prog, pads, val_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos)
var_x = _make_var_name(val_x)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
embed_params = (_check_embeddable(value_infos, val_w) and not has_bias embed_params = (_check_embeddable(value_infos, var_w) and not var_b
or _check_embeddable(value_infos, val_b)) or _check_embeddable(value_infos, var_b))
if not embed_params and name: if not embed_params and name:
_logger.warning('for op %s(%s -> ConvTranspose -> %s)', name, _logger.warning('for op %s(%s -> ConvTranspose -> %s)', name,
inputs, outputs) inputs, outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_w = name + '.w_0' embedded_w = name + '.w_0'
value_infos[val_w]['embeded_as'].append(var_w) value_infos[var_w]['embedded_as'].append(embedded_w)
if has_bias: var_w = embedded_w
var_b = name + '.b_0' if var_b:
value_infos[val_b]['embeded_as'].append(var_b) embedded_b = name + '.b_0'
value_infos[var_b]['embedded_as'].append(embedded_b)
var_b = embedded_b
param_attr = '' param_attr = ''
else: else:
param_attr = ', bias_attr=False' param_attr = ', bias_attr=False'
else: else:
var_w = _make_var_name(val_w)
var_b = _make_var_name(val_b) if has_bias else False
param_attr = ', param_attr={}, bias_attr={}'.format( param_attr = ', param_attr={}, bias_attr={}'.format(
repr(var_w), repr(var_w),
repr(var_b) if var_b else False) repr(var_b) if var_b else False)
...@@ -1154,7 +1094,7 @@ def ConvTranspose(prog, ...@@ -1154,7 +1094,7 @@ def ConvTranspose(prog,
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData' ([var_x, var_w], 'Input', 'Filter'), # , 'Bias', 'ResidualData'
([var_conv if has_bias else var_y], 'Output'), ([var_conv if var_b else var_y], 'Output'),
{ {
'strides': strides, 'strides': strides,
'paddings': paddings, 'paddings': paddings,
...@@ -1163,13 +1103,13 @@ def ConvTranspose(prog, ...@@ -1163,13 +1103,13 @@ def ConvTranspose(prog,
'groups': num_groups, 'groups': num_groups,
}, },
) )
if has_bias: if var_b:
prog.VarDesc(var_conv) prog.VarDesc(var_conv)
prog.IntermediateOp( prog.IntermediateOp(
'', '',
'Add', 'Add',
[var_conv, var_b], # [var_conv, var_b], #
[val_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '.bias'), name=(name + '.bias'),
...@@ -1184,27 +1124,27 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1184,27 +1124,27 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
""" """
# due to fluid fc don't support transposed weight, we use matmul + ew_add # due to fluid fc don't support transposed weight, we use matmul + ew_add
val_a, val_b, val_c = inputs var_a, var_b, var_c = inputs
val_y, = outputs var_y, = outputs
alpha = attrs.get('alpha', 1.) # optional alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional beta = attrs.get('beta', 1.) # optional
trans_a = bool(attrs.get('transA', 0)) # optional trans_a = bool(attrs.get('transA', 0)) # optional
trans_b = bool(attrs.get('transB', 0)) # optional trans_b = bool(attrs.get('transB', 0)) # optional
val_mm = name + '_mm' # explicit variable var_mm = name + '_mm' # explicit variable
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
[val_a, val_b], [var_a, var_b],
[val_mm], # val [var_mm], # val
{ {
'transpose_x': trans_a, 'transpose_x': trans_a,
'transpose_y': trans_b, 'transpose_y': trans_b,
'alpha': alpha, 'alpha': alpha,
}, },
value_infos=value_infos, value_infos=value_infos,
name=val_mm, name=var_mm,
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs({ prog.OpDescAttrs({
...@@ -1216,17 +1156,17 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1216,17 +1156,17 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog.Op( prog.Op(
'', '',
'Add', 'Add',
[val_mm, val_c], [var_mm, var_c],
[val_y], # val [var_y], # val
{'axis': 1}, {'axis': 1},
value_infos=value_infos, value_infos=value_infos,
name=(name + '_beta'), name=(name + '_beta'),
) )
else: else:
val_beta = name + '_beta' # explicit variable var_beta = name + '_beta' # explicit variable
val_vm = name + '_vm' # explicit variable var_vm = name + '_vm' # explicit variable
if beta.is_integer(): if beta.is_integer():
vm_dtype = _dtype_or_none(value_infos, val_c) vm_dtype = _dtype_or_none(value_infos, var_c)
if vm_dtype is None: if vm_dtype is None:
vm_dtype = _np.dtype('float32') vm_dtype = _np.dtype('float32')
_logger.warning( _logger.warning(
...@@ -1239,16 +1179,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1239,16 +1179,16 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'', '',
'Constant', 'Constant',
[], [],
[val_beta], # val [var_beta], # val
{'value': beta}, {'value': beta},
value_infos=value_infos, value_infos=value_infos,
name=val_beta, name=var_beta,
) )
prog.Op( prog.Op(
'', '',
'Mul', 'Mul',
[val_c, val_beta], [var_c, var_beta],
[val_vm], # val [var_vm], # val
dict(), dict(),
value_infos=value_infos, value_infos=value_infos,
name=(name + '_scale'), name=(name + '_scale'),
...@@ -1256,8 +1196,8 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1256,8 +1196,8 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog.Op( prog.Op(
'', '',
'Add', 'Add',
[val_mm, val_vm], [var_mm, var_vm],
[val_y], # val [var_y], # val
{'axis': 1}, {'axis': 1},
name=(name + '_bias'), name=(name + '_bias'),
) )
...@@ -1305,6 +1245,64 @@ def GlobalMaxPool(prog, ...@@ -1305,6 +1245,64 @@ def GlobalMaxPool(prog,
name=name) name=name)
def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
"""
onnx::GRU-7:
"""
var_x, var_w, var_r, var_b, var_len, var_xh = (inputs + [None] * 3)[:6]
var_y, var_yh = (outputs + [None] * 2)[:2]
# interpretation
fluid_op = 'gru_unit'
param_attr = ''
# generation
prog.Code('{}, _, {} = layers.{}({}, {}, {}'
'{})'.format(
var_yh,
var_y,
fluid_op,
var_x,
var_xh,
0,
param_attr,
))
# raise NotImplementedError()
def LSTM(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
"""
onnx::LSTM-7:
"""
var_x, var_w, var_r, var_b, var_len, var_xh, var_xc, var_p = (
inputs + [None] * 5)[:8]
var_y, var_yh, var_yc = (outputs + [None] * 3)[:3]
# interpretation
fluid_op = 'lstm_unit'
param_attr = ''
# generation
prog.Code('{}, {}, {} = layers.{}({}, {}, {}'
'{})'.format(
var_y,
var_yh,
var_yc,
fluid_op,
var_x,
var_xh,
var_xc,
param_attr,
))
# raise NotImplementedError()
def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args, def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args,
**kwargs): **kwargs):
""" """
...@@ -1329,17 +1327,15 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1329,17 +1327,15 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
""" """
# I/O # I/O
val_data, = inputs var_data, = inputs
val_output, = outputs var_output, = outputs
var_data = _make_var_name(val_data)
var_output = _make_var_name(val_output)
# interpretation # interpretation
pads = attrs['pads'] # required pads = attrs['pads'] # required
mode = attrs.get('mode', 'constant') # optional mode = attrs.get('mode', 'constant') # optional
value = attrs.get('value', 0.) # optional value = attrs.get('value', 0.) # optional
data_shape = _shape_or_none(value_infos, val_data) data_shape = _shape_or_none(value_infos, var_data)
output_shape = _shape_or_none(value_infos, val_output) output_shape = _shape_or_none(value_infos, var_output)
assume_pad2d = False assume_pad2d = False
if len(pads) == 4: if len(pads) == 4:
assume_pad2d |= mode != 'constant' assume_pad2d |= mode != 'constant'
...@@ -1400,14 +1396,12 @@ def PRelu(prog, ...@@ -1400,14 +1396,12 @@ def PRelu(prog,
""" """
# I/O # I/O
val_x, val_slope = inputs var_x, var_slope = inputs
val_y, = outputs var_y, = outputs
var_x = _make_var_name(val_x)
var_y = _make_var_name(val_y)
# interpretation # interpretation
mode = 'channel' mode = 'channel'
slope_shape = _shape_or_none(value_infos, val_slope) slope_shape = _shape_or_none(value_infos, var_slope)
if slope_shape is not None: if slope_shape is not None:
if len(slope_shape) == 0: if len(slope_shape) == 0:
mode = 'all' mode = 'all'
...@@ -1418,18 +1412,18 @@ def PRelu(prog, ...@@ -1418,18 +1412,18 @@ def PRelu(prog,
fluid_op = 'prelu' fluid_op = 'prelu'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
if embed_params: if embed_params:
embed_params = _check_embeddable(value_infos, val_slope) embed_params = _check_embeddable(value_infos, var_slope)
if not embed_params and name: if not embed_params and name:
_logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs, _logger.warning('for op %s(%s -> PRelu -> %s)', name, inputs,
outputs) outputs)
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
if embed_params: if embed_params:
assert name != '' assert name != ''
var_slope = name + '.w_0' embedded_slope = name + '.w_0'
value_infos[val_slope]['embeded_as'].append(var_slope) value_infos[var_slope]['embedded_as'].append(embedded_slope)
var_slope = embedded_slope
param_attr = '' param_attr = ''
else: else:
var_slope = _make_var_name(val_slope)
param_attr = ', param_attr={}'.format(repr(var_slope)) param_attr = ', param_attr={}'.format(repr(var_slope))
# generation # generation
...@@ -1467,17 +1461,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1467,17 +1461,14 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
""" """
# I/O # I/O
val_data, val_shape = inputs var_data, var_shape = inputs
val_reshaped, = outputs var_reshaped, = outputs
var_data = _make_var_name(val_data)
var_shape = _make_var_name(val_shape)
var_reshaped = _make_var_name(val_reshaped)
# interpretation # interpretation
shape = _const_weight_or_none(value_infos, val_shape) shape = _const_weight_or_none(value_infos, var_shape)
is_const_shape = shape and 'const_value' in value_infos[val_shape] is_const_shape = shape and 'const_value' in value_infos[var_shape]
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_reshaped) shape = _shape_or_none(value_infos, var_reshaped)
# assert shape is not None, ('given shape is neither const value nor deductible from output, ' # assert shape is not None, ('given shape is neither const value nor deductible from output, '
...@@ -1493,8 +1484,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1493,8 +1484,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
# generation # generation
val_shape_int32 = val_shape + '_int32' # explicit variable var_shape_int32 = var_shape + '_int32' # explicit variable
var_shape_int32 = _make_var_name(val_shape_int32)
prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape)) prog.Code('# shape:{}={} # const as literal'.format(var_shape, shape))
if is_const_shape: if is_const_shape:
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
...@@ -1511,8 +1501,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1511,8 +1501,8 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog.Op( prog.Op(
'', '',
'Cast', 'Cast',
[val_shape], [var_shape],
[val_shape_int32], # var [var_shape_int32], # var
{'to': _np.dtype('int32')}, # use np.dtype {'to': _np.dtype('int32')}, # use np.dtype
value_infos=value_infos, value_infos=value_infos,
name=(name + '_cast'), name=(name + '_cast'),
...@@ -1595,17 +1585,15 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1595,17 +1585,15 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
""" """
# I/O # I/O
val_data, = inputs var_data, = inputs
val_output, = outputs var_output, = outputs
var_data = _make_var_name(val_data)
var_output = _make_var_name(val_output)
# interpretation # interpretation
fluid_op = 'slice' fluid_op = 'slice'
axes = attrs['axes'] # required axes = attrs['axes'] # required
starts = attrs['starts'] # required starts = attrs['starts'] # required
ends = attrs['ends'] # required ends = attrs['ends'] # required
shape = _shape_or_none(value_infos, val_data) shape = _shape_or_none(value_infos, var_data)
if shape is not None: if shape is not None:
# ndims = len(shape) # ndims = len(shape)
# for idx, value in enumerate(axes): # for idx, value in enumerate(axes):
...@@ -1654,9 +1642,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1654,9 +1642,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
""" """
# I/O # I/O
val_input, = inputs var_input, = inputs
var_outs = [_make_var_name(val) for val in outputs]
var_input = _make_var_name(val_input)
# interpretation # interpretation
fluid_op = 'split' fluid_op = 'split'
...@@ -1668,7 +1654,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1668,7 +1654,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
prog.Code('{} = layers.{}({}, {}' prog.Code('{} = layers.{}({}, {}'
', dim={}' ', dim={}'
'{})'.format( '{})'.format(
', '.join(var_outs), ', '.join(outputs),
fluid_op, fluid_op,
var_input, var_input,
split, split,
...@@ -1676,12 +1662,12 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1676,12 +1662,12 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
axis, axis,
name_attr, name_attr,
)) ))
for var_out in var_outs: for var_out in outputs:
prog.VarDesc(var_out) prog.VarDesc(var_out)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(var_input, 'X'), (var_input, 'X'),
([var_outs], *(['Out'] * len(var_outs))), ([outputs], *(['Out'] * len(outputs))),
{ {
'axis': axis, 'axis': axis,
'sections': split, 'sections': split,
...@@ -1695,9 +1681,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs): ...@@ -1695,9 +1681,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
""" """
# I/O # I/O
val_sum, = outputs var_sum, = outputs
var_inps = [_make_var_name(val) for val in inputs]
var_sum = _make_var_name(val_sum)
# interpretation # interpretation
fluid_op = 'sums' fluid_op = 'sums'
...@@ -1706,14 +1690,14 @@ def Sum(prog, inputs, outputs, *args, **kwargs): ...@@ -1706,14 +1690,14 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
prog.Code('{} = layers.{}({})'.format( prog.Code('{} = layers.{}({})'.format(
var_sum, var_sum,
fluid_op, fluid_op,
'[' + ', '.join(var_inps) + ']', '[' + ', '.join(inputs) + ']',
# attrs # attrs
)) ))
fluid_op = 'sum' fluid_op = 'sum'
prog.VarDesc(var_sum) prog.VarDesc(var_sum)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(var_inps, *(['X'] * len(var_inps))), (inputs, *(['X'] * len(inputs))),
([var_sum], 'Out'), ([var_sum], 'Out'),
dict(), dict(),
) )
...@@ -1725,14 +1709,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1725,14 +1709,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
""" """
# I/O # I/O
val_input, val_repeats = inputs var_input, var_repeats = inputs
val_output, = outputs var_output, = outputs
var_input = _make_var_name(val_input)
var_repeats = _make_var_name(val_repeats)
var_output = _make_var_name(val_output)
# interpretation # interpretation
repeats = _const_weight_or_none(value_infos, val_repeats) repeats = _const_weight_or_none(value_infos, var_repeats)
assert repeats is not None, 'only const repeats is supported' assert repeats is not None, 'only const repeats is supported'
fluid_op = 'expand' fluid_op = 'expand'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -1764,10 +1745,8 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -1764,10 +1745,8 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
""" """
# I/O # I/O
val_data, = inputs var_data, = inputs
val_transposed, = outputs var_transposed, = outputs
var_data = _make_var_name(val_data)
var_transposed = _make_var_name(val_transposed)
# interpretation # interpretation
fluid_op = 'transpose' fluid_op = 'transpose'
......
...@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019 ...@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019
@author: Macrobull @author: Macrobull
""" """
from __future__ import division
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +27,8 @@ from typing import ( ...@@ -24,6 +27,8 @@ from typing import (
Union, Union,
) )
logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'export_data', 'export_data',
'export_onnx_with_validation', 'export_onnx_with_validation',
...@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None: ...@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
return str(obj) return str(obj)
prefix_ = prefix + ('_' if prefix else '') prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix or 'meta'), 'w') fp = open('{}.txt'.format(prefix or 'meta'), mode='w')
for key, value in state_dict.items(): for key, value in state_dict.items():
data = None data = None
if torch.is_tensor(value): if torch.is_tensor(value):
...@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None: ...@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
def export_onnx_with_validation( def export_onnx_with_validation(
model: torch.nn.Module, model: torch.nn.Module, # or JITScriptModule
inputs: Sequence[Union[torch.Tensor, Sequence[object]]], inputs: Sequence[Union[torch.Tensor, Sequence[object]]],
export_basepath: Text, export_basepath: Text,
input_names: Optional[List[Text]] = None, input_names: Optional[List[Text]] = None,
......
...@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog): ...@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog):
import paddle.fluid as fluid import paddle.fluid as fluid
assert isinstance(prog, fluid.framework.Program) assert isinstance(prog,
fluid.framework.Program), 'prog is not a Program instance'
logger.info('performing type-shape inference ...') logger.info('performing type-shape inference ...')
for block in prog.blocks: for block in prog.blocks:
...@@ -84,6 +85,8 @@ def validate(fluid_model_filename, ...@@ -84,6 +85,8 @@ def validate(fluid_model_filename,
inference the converted Paddle fluid model, validate with given golden data inference the converted Paddle fluid model, validate with given golden data
""" """
assert isinstance(fluid_model_filename, str)
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -153,7 +156,7 @@ def validate(fluid_model_filename, ...@@ -153,7 +156,7 @@ def validate(fluid_model_filename,
input_data = flatten_dict(input_data) input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data) output_data = flatten_dict(output_data)
input_names = input_data.keys() input_names = input_data.keys()
output_names = output_data.keys() # output_names = output_data.keys()
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
else: else:
......
...@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict ...@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from . import symbolic from . import symbolic
from .symbolic import _make_var_name as make_var_name
try: try:
import paddle.fluid.proto.framework_pb2 as framework_pb2 import paddle.fluid.proto.framework_pb2 as framework_pb2
...@@ -63,7 +62,7 @@ def make_attr_name(name): ...@@ -63,7 +62,7 @@ def make_attr_name(name):
assert name != '', 'name should not be empty' assert name != '', 'name should not be empty'
for s in ' \\|/:-': # for s in ' \\|/:.-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -207,7 +206,7 @@ class Program(object): ...@@ -207,7 +206,7 @@ class Program(object):
return desc return desc
def VarDesc(self, def VarDesc(self,
var_name, name,
persistable=False, persistable=False,
value_info=None, value_info=None,
remove_batch=None): remove_batch=None):
...@@ -215,18 +214,16 @@ class Program(object): ...@@ -215,18 +214,16 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert var_name not in self.var_descs, 'var naming conflicted' assert name not in self.var_descs, 'var naming conflicted'
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = var_name var_desc.name = name
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc self.var_descs[name] = var_desc
if value_info: if value_info:
self.VarTypeShapeInfo(var_name, self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, *args, **kwargs):
""" """
...@@ -260,19 +257,19 @@ class Program(object): ...@@ -260,19 +257,19 @@ class Program(object):
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None): def VarTypeShapeInfo(self, name, value_info, remove_batch=None):
""" """
set value_info for var set value_info for var
""" """
if var_name not in self.var_descs: if name not in self.var_descs:
return return
dtype = value_info.get('dtype', None) dtype = value_info.get('dtype', None)
if dtype is None: if dtype is None:
return return
var_desc = self.var_descs[var_name] var_desc = self.var_descs[name]
tensor_desc = var_desc.type.lod_tensor.tensor tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dtype) # required tensor_desc.data_type = self.Dtype(dtype) # required
...@@ -292,8 +289,7 @@ class Writer(object): ...@@ -292,8 +289,7 @@ class Writer(object):
fluid code and desc writter fluid code and desc writter
""" """
# CODE_INDENT = ' ' * 4 CODE_INDENT = ' ' * 4 # '\t'
CODE_INDENT = '\t'
@staticmethod @staticmethod
def header_code(func_name, info=''): def header_code(func_name, info=''):
...@@ -313,6 +309,7 @@ class Writer(object): ...@@ -313,6 +309,7 @@ class Writer(object):
codes.append('from paddle.fluid import initializer, layers') codes.append('from paddle.fluid import initializer, layers')
codes.append('') codes.append('')
codes.append('') codes.append('')
codes.append('def {}():'.format(func_name)) codes.append('def {}():'.format(func_name))
return codes return codes
...@@ -342,24 +339,26 @@ class Writer(object): ...@@ -342,24 +339,26 @@ class Writer(object):
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embeded_as', []): if value_info.get('embedded_as', []):
var_names = value_info['embeded_as'] embedded_names = value_info['embedded_as']
prog.Code('# parameter {} embeded as {}'.format(name, var_names)) prog.Code('# parameter {} embedded as {}'.format(
for var_name in var_names: name, embedded_names))
prog.VarDesc(var_name, persistable=True, value_info=value_info) for embedded_name in embedded_names:
prog.VarDesc(embedded_name,
persistable=True,
value_info=value_info)
else: else:
var_name = make_var_name(name)
attr_name = make_attr_name(name) attr_name = make_attr_name(name)
prog.Code('# parameter {}: {}'.format(name, var_name)) prog.Code('# parameter {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name))) .format(attr_name, repr(name)))
prog.Code( prog.Code(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}' '{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={} ', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name, value_info['shape'], .format(name, value_info['shape'],
repr(value_info['dtype'].name), repr(name), repr(value_info['dtype'].name), repr(name),
attr_name)) #, value_info.get('is_bias', False))) attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info) prog.VarDesc(name, persistable=True, value_info=value_info)
@staticmethod @staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None): def emit_inputs(prog, names, value_infos, remove_batch=None):
...@@ -368,7 +367,6 @@ class Writer(object): ...@@ -368,7 +367,6 @@ class Writer(object):
""" """
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name)
value_info = value_infos[name] value_info = value_infos[name]
shape = value_info['shape'] shape = value_info['shape']
if remove_batch is None: if remove_batch is None:
...@@ -377,13 +375,13 @@ class Writer(object): ...@@ -377,13 +375,13 @@ class Writer(object):
if remove_batch: if remove_batch:
shape = shape[1:] shape = shape[1:]
prog.Code('# input {}: {}'.format(name, var_name)) prog.Code('# input {}'.format(name))
prog.Code(( prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, ' '{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True 'append_batch_size={})' # , stop_gradient=True
).format( ).format(
var_name, name,
repr(var_name), repr(name),
shape, shape,
repr(value_info['dtype'].name), repr(value_info['dtype'].name),
remove_batch, remove_batch,
...@@ -391,12 +389,10 @@ class Writer(object): ...@@ -391,12 +389,10 @@ class Writer(object):
prog.OpDesc( prog.OpDesc(
'feed', 'feed',
(['feed'], 'X'), (['feed'], 'X'),
([var_name], 'Out'), ([name], 'Out'),
{'col': idx}, {'col': idx},
) )
prog.VarDesc(var_name, prog.VarDesc(name, value_info=value_info, remove_batch=remove_batch)
value_info=value_info,
remove_batch=remove_batch)
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
...@@ -406,12 +402,11 @@ class Writer(object): ...@@ -406,12 +402,11 @@ class Writer(object):
code = 'return ' code = 'return '
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name) code += name + ', '
code += var_name + ', '
prog.OpDesc( prog.OpDesc(
'fetch', 'fetch',
([var_name], 'X'), ([name], 'X'),
(['fetch'], 'Out'), (['fetch'], 'Out'),
{'col': idx}, {'col': idx},
) )
...@@ -458,8 +453,7 @@ class Writer(object): ...@@ -458,8 +453,7 @@ class Writer(object):
for name, weight in weights.items(): for name, weight in weights.items():
assert isinstance(weights, dict), 'dict type weights required' assert isinstance(weights, dict), 'dict type weights required'
var_name = make_var_name(name) filename = os.path.join(save_dir, name)
filename = os.path.join(save_dir, var_name)
Writer.write_weight(weight, filename) Writer.write_weight(weight, filename)
logger.debug('saved weight %s to %s', name, filename) logger.debug('saved weight %s to %s', name, filename)
......
-e . -e .
onnx>=1.4 onnx>=1.4
paddlepaddle paddlepaddle>=1.5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册