提交 2a82fdeb 编写于 作者: M Macrobull

fix name and add ops

上级 9828c2c7
...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用 ...@@ -17,13 +17,13 @@ onnx2fluid支持将ONNX模型转换为PaddlePaddle模型,并用于预测,用
在如下环境配置中测试成功: 在如下环境配置中测试成功:
* python 3.5+ * python 3.5+
* onnx == 1.4.0 * onnx == 1.4.1
* paddlepaddle == 1.3.0 (可选,仅用于验证) * paddlepaddle == 1.5.0 (可选,仅用于验证)
使用[Anaconda](https://docs.anaconda.com/anaconda/install): 使用[Anaconda](https://docs.anaconda.com/anaconda/install):
``` shell ``` shell
conda install -c conda-forge onnx conda install -c conda-forge onnx
pip install paddlepaddle==1.3.0 pip install paddlepaddle==1.5.0
``` ```
## 动手玩 ## 动手玩
...@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz ...@@ -49,6 +49,8 @@ onnx2fluid sample_1.onnx -t sample_1.npz
## 使用说明 ## 使用说明
目前支持 **ONNX opset 9+** 的部分算子,对应PyTorch版本 **1.0/1.1(stable opset)**,更多兼容信息请参考[ONNX文档](https://github.com/onnx/onnx/blob/master/docs/Operators.md)
onnx2fluid: onnx2fluid:
```shell ```shell
...@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx ...@@ -79,5 +81,5 @@ onnx2fluid.validate [-d] [-t test_data.npz] [-p 1e-3] /path/to/onnx/model.onnx
## 参考 ## 参考
* PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_cn/layers_cn.html) * PaddlePaddle [算子](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_cn/layers_cn.html)
* PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.4/api_guides/low_level/inference.html#id4) * PaddlePaddle [加载预测模型](http://www.paddlepaddle.org/documentation/docs/zh/1.5/api_guides/low_level/inference.html#id4)
...@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation ...@@ -20,50 +20,56 @@ from onnx2fluid.torch_export_helper import export_onnx_with_validation
prefix = 'sample_' prefix = 'sample_'
idx = 0 idx = 0
######### example: RNN ######## ######## example: RNN ########
#
#class Model(nn.Module):
# def __init__(self):
# super(Model, self).__init__()
# self.rnn = nn.RNN(4, 6, 2)
#
# def forward(self, x):
# y = x
# y, h = self.rnn(y)
# return y
#
#
#model = Model()
#model.eval()
#xb = torch.rand((2, 3, 4))
#yp = model(xb)
#idx += 1
#print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx),
# ['x'], ['y'],
# verbose=True, training=False)
######### example: random ########
# class Model(nn.Module):
#class Model(nn.Module): def __init__(self):
# def __init__(self): super(Model, self).__init__()
# super(Model, self).__init__() self.gru = nn.GRU(4, 5, 3)
# self.lstm = nn.LSTM(5, 6, 2)
# def forward(self, x):
# y = torch.rand((2, 3)) # + torch.rand_like(xb) def forward(self, x):
# y = y + torch.randn((2, 3)) # + torch.randn_like(xb) y = x
# return y y, h = self.gru(y)
# y, h = self.lstm(y)
# return y
#model = Model()
#model.eval()
#xb = torch.rand((2, 3)) model = Model()
#yp = model(xb) model.eval()
#idx += 1 xb = torch.rand((2, 3, 4))
#print('index: ', idx) yp = model(xb)
#export_onnx_with_validation(model, [xb], prefix + str(idx), idx += 1
# ['x'], ['y'], print('index: ', idx)
# verbose=True, training=False) export_onnx_with_validation(model, [xb],
prefix + str(idx), ['x'], ['y'],
verbose=True,
training=False)
######## example: random ########
class Model(nn.Module):
def __init__(self):
super(Model, self).__init__()
def forward(self, x):
y = torch.rand((2, 3)) # + torch.rand_like(xb)
y = y + torch.randn((2, 3)) # + torch.randn_like(xb)
return y
model = Model()
model.eval()
xb = torch.rand((2, 3))
yp = model(xb)
idx += 1
print('index: ', idx)
export_onnx_with_validation(model, [xb],
prefix + str(idx), ['x'], ['y'],
verbose=True,
training=False)
######## example: fc ######## ######## example: fc ########
...@@ -175,7 +181,7 @@ class Model(nn.Module): ...@@ -175,7 +181,7 @@ class Model(nn.Module):
super(Model, self).__init__() super(Model, self).__init__()
self.conv = nn.Conv2d(3, 8, 3) self.conv = nn.Conv2d(3, 8, 3)
self.batch_norm = nn.BatchNorm2d(8) self.batch_norm = nn.BatchNorm2d(8)
self.pool = nn.AdaptiveAvgPool2d(2) self.pool = nn.AdaptiveAvgPool2d(1)
def forward(self, x): def forward(self, x):
y = x y = x
...@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb], ...@@ -215,9 +221,10 @@ export_onnx_with_validation(model, [xb],
#yp = model(xb) #yp = model(xb)
#idx += 1 #idx += 1
#print('index: ', idx) #print('index: ', idx)
#export_onnx_with_validation(model, [xb], prefix + str(idx), #export_onnx_with_validation(
# ['x'], ['y'], # model, [xb], prefix + str(idx),
# verbose=True, training=False) # ['x'], ['y'],
# verbose=True, training=False)
######## example: empty ######## ######## example: empty ########
......
...@@ -24,14 +24,14 @@ bvlc_alexnet() ...@@ -24,14 +24,14 @@ bvlc_alexnet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -54,7 +54,7 @@ bvlc_googlenet() ...@@ -54,7 +54,7 @@ bvlc_googlenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -77,7 +77,7 @@ bvlc_reference_caffenet() ...@@ -77,7 +77,7 @@ bvlc_reference_caffenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -100,7 +100,7 @@ bvlc_reference_rcnn_ilsvrc13()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
...@@ -123,14 +123,14 @@ densenet121() ...@@ -123,14 +123,14 @@ densenet121()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1 python convert_data_pb.py "$pb_dir" data_0 fc6_1
...@@ -153,7 +153,7 @@ emotion_ferplus() ...@@ -153,7 +153,7 @@ emotion_ferplus()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0 python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
...@@ -176,14 +176,14 @@ inception_v1() ...@@ -176,14 +176,14 @@ inception_v1()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -206,14 +206,14 @@ inception_v2() ...@@ -206,14 +206,14 @@ inception_v2()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -236,7 +236,7 @@ mobilenet() ...@@ -236,7 +236,7 @@ mobilenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
...@@ -259,7 +259,7 @@ resnet18() ...@@ -259,7 +259,7 @@ resnet18()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
...@@ -282,14 +282,14 @@ resnet50() ...@@ -282,14 +282,14 @@ resnet50()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for npz in "$bn_tar"/*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
...@@ -312,7 +312,7 @@ resnet100_arcface() ...@@ -312,7 +312,7 @@ resnet100_arcface()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1 python convert_data_pb.py "$pb_dir" data fc1
...@@ -335,7 +335,7 @@ resnet101_duc() ...@@ -335,7 +335,7 @@ resnet101_duc()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss python convert_data_pb.py "$pb_dir" data seg_loss
...@@ -358,7 +358,7 @@ resnet152() ...@@ -358,7 +358,7 @@ resnet152()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
...@@ -381,7 +381,7 @@ shufflenet() ...@@ -381,7 +381,7 @@ shufflenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
...@@ -404,7 +404,7 @@ squeezenet() ...@@ -404,7 +404,7 @@ squeezenet()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 softmaxout_1 python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
...@@ -427,7 +427,7 @@ squeezenet1v1() ...@@ -427,7 +427,7 @@ squeezenet1v1()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
...@@ -448,10 +448,10 @@ ssd() ...@@ -448,10 +448,10 @@ ssd()
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
mkdir "$bn_tar" mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar"/ tar xf "$fn_tar" -C "$bn_tar/"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
...@@ -474,7 +474,7 @@ tiny_yolov2() ...@@ -474,7 +474,7 @@ tiny_yolov2()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" image grid python convert_data_pb.py "$pb_dir" image grid
...@@ -497,7 +497,7 @@ vgg16bn() ...@@ -497,7 +497,7 @@ vgg16bn()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y python -m onnx2fluid $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
...@@ -520,7 +520,7 @@ vgg19() ...@@ -520,7 +520,7 @@ vgg19()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
...@@ -543,7 +543,7 @@ yolov3() ...@@ -543,7 +543,7 @@ yolov3()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x # python -m onnx2fluid $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0 python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
...@@ -566,7 +566,7 @@ zfnet512() ...@@ -566,7 +566,7 @@ zfnet512()
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" python -m onnx2fluid $convert_flags "$fn_model"
for pb_dir in "$bn_tar"/*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
......
...@@ -17,6 +17,22 @@ __all__ = [ ...@@ -17,6 +17,22 @@ __all__ = [
DEFAULT_ONNX_OPSET_VERSION = 9 DEFAULT_ONNX_OPSET_VERSION = 9
def make_var_name(name):
"""
make a valid variable name in Python code and filename in filesystem
"""
if name == '':
return '_'
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:.-':
name = name.replace(s, '_')
if name.startswith('_'):
name = 'var' + name
return name
def convert(onnx_model_filename, def convert(onnx_model_filename,
save_dir, save_dir,
model_basename='model.py', model_basename='model.py',
...@@ -30,6 +46,12 @@ def convert(onnx_model_filename, ...@@ -30,6 +46,12 @@ def convert(onnx_model_filename,
convert an ONNX model to Paddle fluid Python code and desc pb convert an ONNX model to Paddle fluid Python code and desc pb
""" """
assert isinstance(onnx_model_filename, str)
assert isinstance(save_dir, str)
assert isinstance(model_basename, str)
assert isinstance(model_func_name, str)
assert onnx_opset_version is None or isinstance(onnx_opset_version, int)
import onnx import onnx
from onnx.checker import ValidationError from onnx.checker import ValidationError
...@@ -41,7 +63,6 @@ def convert(onnx_model_filename, ...@@ -41,7 +63,6 @@ def convert(onnx_model_filename,
from .onnx_utils import inferred_model_value_info from .onnx_utils import inferred_model_value_info
from .onnx_utils import polish_model from .onnx_utils import polish_model
from .writer import Program, Writer from .writer import Program, Writer
from .writer import make_var_name
logger = logging.getLogger('convert') logger = logging.getLogger('convert')
...@@ -88,17 +109,21 @@ def convert(onnx_model_filename, ...@@ -88,17 +109,21 @@ def convert(onnx_model_filename,
fluid_writer = Writer() fluid_writer = Writer()
# model components # model components
# graph_name = onnx_graph.name inp_vars = [make_var_name(value.name) for value in onnx_graph.input]
graph_inputs = [value.name for value in onnx_graph.input] out_vars = [make_var_name(value.name) for value in onnx_graph.output]
graph_outputs = [value.name for value in onnx_graph.output] par_vars = []
graph_params = [] value_infos = inferred_model_value_info(onnx_model)
graph_value_infos = inferred_model_value_info(onnx_model) value_infos = {
make_var_name(key): value
for key, value in value_infos.items()
}
# prepare additional value_info # prepare additional value_info
# for weights # for weights
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
value_info = graph_value_infos[name] var_name = make_var_name(name)
value_info['embeded_as'] = [] value_info = value_infos[var_name]
value_info['embedded_as'] = []
value_info['get_weight'] = (lambda w: lambda: w.tolist())( value_info['get_weight'] = (lambda w: lambda: w.tolist())(
weight) # lazy getter weight) # lazy getter
...@@ -108,19 +133,23 @@ def convert(onnx_model_filename, ...@@ -108,19 +133,23 @@ def convert(onnx_model_filename,
topo = 'forward' topo = 'forward'
for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph, for name, domain, op_type, inputs, outputs, attrs in graph_ops(onnx_graph,
topo=topo): topo=topo):
logger.debug('translating op %s %s::%s ...', name, domain, op_type) op_name = make_var_name(name)
inputs = [make_var_name(val) for val in inputs]
outputs = [make_var_name(val) for val in outputs]
logger.debug('translating op %s(%s) %s::%s ...', name, op_name, domain,
op_type)
if domain == DEFAULT_OP_DOMAIN: if domain == DEFAULT_OP_DOMAIN:
domain = '' domain = ''
try: try:
fluid_writer.emit_op( fluid_writer.emit_op(
fluid_program, fluid_program,
name, op_name,
domain, domain,
op_type, op_type,
inputs, inputs,
outputs, outputs,
attrs, attrs,
graph_value_infos, value_infos,
embed_params=embed_params, embed_params=embed_params,
) )
except BaseException as e: except BaseException as e:
...@@ -133,17 +162,16 @@ def convert(onnx_model_filename, ...@@ -133,17 +162,16 @@ def convert(onnx_model_filename,
len(fluid_program.op_descs)) len(fluid_program.op_descs))
# type-shape info copy # type-shape info copy
for name, value_info in graph_value_infos.items(): for var_name, value_info in value_infos.items():
var_name = make_var_name(name)
fluid_program.VarTypeShapeInfo(var_name, value_info, fluid_program.VarTypeShapeInfo(var_name, value_info,
remove_batch=False) # remove_batch=False) #
bad_var_names = [] bad_vars = []
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_var_names.append(var_name) bad_vars.append(var_name)
if len(bad_var_names) > 0: if len(bad_vars) > 0:
logger.warning('type-shape not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_var_names[:5])) ', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
'but Paddle Mobile may not infer correctly') 'but Paddle Mobile may not infer correctly')
logger.warning('please consider running validation with -i ' logger.warning('please consider running validation with -i '
...@@ -151,40 +179,41 @@ def convert(onnx_model_filename, ...@@ -151,40 +179,41 @@ def convert(onnx_model_filename,
# weight writer # weight writer
for name, weight in graph_weights(onnx_graph): for name, weight in graph_weights(onnx_graph):
graph_params.append(name) var_name = make_var_name(name)
value_info = graph_value_infos[name] par_vars.append(var_name)
var_names = value_info.get('embeded_as', []) value_info = value_infos[var_name]
if var_names: embedded_names = value_info.get('embedded_as', [])
if len(var_names) > 1: if embedded_names:
if len(embedded_names) > 1:
logger.info( logger.info(
'weight %s is shared between ops, more disk space will be consumed', 'weight %s is shared between ops, more disk space will be consumed',
name) name)
logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) as %s ...', name,
weight.dtype, weight.size, weight.nbytes, var_names) weight.dtype, weight.size, weight.nbytes,
for var_name in var_names: # multiple references embedded_names)
for embedded_name in embedded_names: # multiple references
fluid_writer.write_weight( fluid_writer.write_weight(
weight, shutil.os.path.join(save_dir, var_name)) weight, shutil.os.path.join(save_dir, embedded_name))
else: else:
logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name, logger.debug('saving weight %s(%s[%d], %dB) to %s ...', name,
weight.dtype, weight.size, weight.nbytes, weight.dtype, weight.size, weight.nbytes, var_name)
make_var_name(name)) fluid_writer.write_weight(weight,
fluid_writer.write_weight( shutil.os.path.join(save_dir, var_name))
weight, shutil.os.path.join(save_dir, make_var_name(name))) fluid_writer.emit_param(fluid_program, var_name, value_info)
fluid_writer.emit_param(fluid_program, name, value_info)
param_codes = fluid_program.codes param_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
logger.info('%d weights converted', len(graph_params)) logger.info('%d weights converted', len(par_vars))
# input writer # input writer
external_inputs = [] external_inputs = []
for name in graph_inputs: for var_name in inp_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_inputs.append(name) external_inputs.append(var_name)
fluid_writer.emit_inputs(fluid_program, fluid_writer.emit_inputs(fluid_program,
external_inputs, external_inputs,
graph_value_infos, value_infos,
remove_batch=False) # TODO: remove_batch=False) # TODO:
input_codes = fluid_program.codes input_codes = fluid_program.codes
fluid_program.codes = [] fluid_program.codes = []
...@@ -192,11 +221,11 @@ def convert(onnx_model_filename, ...@@ -192,11 +221,11 @@ def convert(onnx_model_filename,
# output writer # output writer
external_outputs = [] external_outputs = []
for name in graph_outputs: for var_name in out_vars:
if name not in graph_params: if var_name not in par_vars:
value_info = graph_value_infos[name] value_info = value_infos[var_name]
assert value_info['external'] assert value_info['external']
external_outputs.append(name) external_outputs.append(var_name)
fluid_writer.emit_outputs(fluid_program, external_outputs) fluid_writer.emit_outputs(fluid_program, external_outputs)
output_codes = [''] + fluid_program.codes # add an empty line output_codes = [''] + fluid_program.codes # add an empty line
fluid_program.codes = [] fluid_program.codes = []
...@@ -204,10 +233,18 @@ def convert(onnx_model_filename, ...@@ -204,10 +233,18 @@ def convert(onnx_model_filename,
# code generation # code generation
header_codes = fluid_writer.header_code( header_codes = fluid_writer.header_code(
model_func_name, 'From: {}'.format(onnx_model_filename)) model_func_name,
'From: {}'.format(onnx_model_filename),
)
code_filename = shutil.os.path.join(save_dir, model_basename) code_filename = shutil.os.path.join(save_dir, model_basename)
fluid_writer.write_code_file(code_filename, header_codes, input_codes, fluid_writer.write_code_file(
param_codes, op_codes, output_codes) code_filename,
header_codes,
input_codes,
param_codes,
op_codes,
output_codes,
)
logger.info('code saved to %s, factory function: %s', code_filename, logger.info('code saved to %s, factory function: %s', code_filename,
model_func_name) model_func_name)
......
...@@ -87,6 +87,9 @@ def get_attribute_value2(attr): ...@@ -87,6 +87,9 @@ def get_attribute_value2(attr):
get_attribute_value enhanced get_attribute_value enhanced
""" """
assert isinstance(
attr, onnx.AttributeProto), 'attr is not a AttributeProto instance'
if attr.type == onnx.AttributeProto.TENSOR: if attr.type == onnx.AttributeProto.TENSOR:
dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type]) dtype = np.dtype(TENSOR_TYPE_TO_NP_TYPE[attr.t.data_type])
data = attr.t.raw_data data = attr.t.raw_data
...@@ -106,6 +109,9 @@ def tensor_dtype(tensor): ...@@ -106,6 +109,9 @@ def tensor_dtype(tensor):
get ONNX tensor in np.dtype get ONNX tensor in np.dtype
""" """
assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type] return TENSOR_TYPE_TO_NP_TYPE[tensor.type.tensor_type.elem_type]
...@@ -114,6 +120,9 @@ def tensor_shape(tensor): ...@@ -114,6 +120,9 @@ def tensor_shape(tensor):
get ONNX tensor shape get ONNX tensor shape
""" """
assert isinstance(
tensor, onnx.ValueInfoProto), 'tensor is not a ValueInfoProto instance'
return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim]) return tuple([dim.dim_value for dim in tensor.type.tensor_type.shape.dim])
...@@ -122,6 +131,8 @@ def node_attrs(node): ...@@ -122,6 +131,8 @@ def node_attrs(node):
convert ONNX node attributes to dict convert ONNX node attributes to dict
""" """
assert isinstance(node, onnx.NodeProto), 'node is not a NodeProto instance'
return {attr.name: get_attribute_value2(attr) return {attr.name: get_attribute_value2(attr)
for attr in node.attribute} # dict for attr in node.attribute} # dict
...@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'): ...@@ -224,9 +235,8 @@ def graph_ops(graph, topo='default'):
generator for ONNX node graph with given topology generator for ONNX node graph with given topology
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
return node_iter(graph.node, node_topo(graph.node, topo)) return node_iter(graph.node, node_topo(graph.node, topo))
...@@ -236,9 +246,8 @@ def graph_weights(graph): ...@@ -236,9 +246,8 @@ def graph_weights(graph):
generator for weights of an ONNX model generator for weights of an ONNX model
""" """
if not isinstance(graph, onnx.GraphProto): assert isinstance(graph,
logger.error('graph is not a GraphProto instance') onnx.GraphProto), 'graph is not a GraphProto instance'
return
for initializer in graph.initializer: for initializer in graph.initializer:
name = initializer.name name = initializer.name
...@@ -251,6 +260,9 @@ def inferred_model_value_info(model): ...@@ -251,6 +260,9 @@ def inferred_model_value_info(model):
collect value/type info for an ONNX model collect value/type info for an ONNX model
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
model = infer_shapes(model) model = infer_shapes(model)
graph = model.graph graph = model.graph
value_info = Dict() value_info = Dict()
...@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None): ...@@ -353,6 +365,10 @@ def optimize_model_skip_op_for_inference(model, op_list=None):
""" """
skip ops can be bypassed for inference skip ops can be bypassed for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
if op_list is None: if op_list is None:
op_list = ('Dropout', 'Identity') op_list = ('Dropout', 'Identity')
...@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True): ...@@ -415,6 +431,9 @@ def optimize_model_strip_initializer(model, keep_input_only=True):
strip weights for inference strip weights for inference
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
out_names = [val.name for val in model.graph.output] out_names = [val.name for val in model.graph.output]
...@@ -456,6 +475,9 @@ def optimize_model_cast(model): ...@@ -456,6 +475,9 @@ def optimize_model_cast(model):
strip cascade and unecessary onnx::Cast-9: strip cascade and unecessary onnx::Cast-9:
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
value_info = inferred_model_value_info(model) value_info = inferred_model_value_info(model)
...@@ -513,6 +535,9 @@ def optimize_model_slice(model): ...@@ -513,6 +535,9 @@ def optimize_model_slice(model):
strip cascade and unecessary onnx::Slice-1:9 strip cascade and unecessary onnx::Slice-1:9
""" """
assert isinstance(model,
onnx.ModelProto), 'model is not a ModelProto instance'
nodes = model.graph.node nodes = model.graph.node
input_refs, output_refs = build_value_refs(nodes) input_refs, output_refs = build_value_refs(nodes)
......
此差异已折叠。
...@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019 ...@@ -6,6 +6,9 @@ Created on Fri Mar 22 11:22:46 2019
@author: Macrobull @author: Macrobull
""" """
from __future__ import division
import logging
import numpy as np import numpy as np
import torch import torch
...@@ -24,6 +27,8 @@ from typing import ( ...@@ -24,6 +27,8 @@ from typing import (
Union, Union,
) )
logger = logging.getLogger(__name__)
__all__ = [ __all__ = [
'export_data', 'export_data',
'export_onnx_with_validation', 'export_onnx_with_validation',
...@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None: ...@@ -76,7 +81,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
return str(obj) return str(obj)
prefix_ = prefix + ('_' if prefix else '') prefix_ = prefix + ('_' if prefix else '')
fp = open('{}.txt'.format(prefix or 'meta'), 'w') fp = open('{}.txt'.format(prefix or 'meta'), mode='w')
for key, value in state_dict.items(): for key, value in state_dict.items():
data = None data = None
if torch.is_tensor(value): if torch.is_tensor(value):
...@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None: ...@@ -93,7 +98,7 @@ def export_data(state_dict: Mapping[Text, Any], prefix: Text = '') -> None:
def export_onnx_with_validation( def export_onnx_with_validation(
model: torch.nn.Module, model: torch.nn.Module, # or JITScriptModule
inputs: Sequence[Union[torch.Tensor, Sequence[object]]], inputs: Sequence[Union[torch.Tensor, Sequence[object]]],
export_basepath: Text, export_basepath: Text,
input_names: Optional[List[Text]] = None, input_names: Optional[List[Text]] = None,
......
...@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog): ...@@ -43,7 +43,8 @@ def fluid_prog_shape_infer(prog):
import paddle.fluid as fluid import paddle.fluid as fluid
assert isinstance(prog, fluid.framework.Program) assert isinstance(prog,
fluid.framework.Program), 'prog is not a Program instance'
logger.info('performing type-shape inference ...') logger.info('performing type-shape inference ...')
for block in prog.blocks: for block in prog.blocks:
...@@ -84,6 +85,8 @@ def validate(fluid_model_filename, ...@@ -84,6 +85,8 @@ def validate(fluid_model_filename,
inference the converted Paddle fluid model, validate with given golden data inference the converted Paddle fluid model, validate with given golden data
""" """
assert isinstance(fluid_model_filename, str)
import numpy as np import numpy as np
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -153,7 +156,7 @@ def validate(fluid_model_filename, ...@@ -153,7 +156,7 @@ def validate(fluid_model_filename,
input_data = flatten_dict(input_data) input_data = flatten_dict(input_data)
output_data = flatten_dict(output_data) output_data = flatten_dict(output_data)
input_names = input_data.keys() input_names = input_data.keys()
output_names = output_data.keys() # output_names = output_data.keys()
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
else: else:
......
...@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict ...@@ -16,7 +16,6 @@ from collections import OrderedDict as Dict
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from . import symbolic from . import symbolic
from .symbolic import _make_var_name as make_var_name
try: try:
import paddle.fluid.proto.framework_pb2 as framework_pb2 import paddle.fluid.proto.framework_pb2 as framework_pb2
...@@ -63,7 +62,7 @@ def make_attr_name(name): ...@@ -63,7 +62,7 @@ def make_attr_name(name):
assert name != '', 'name should not be empty' assert name != '', 'name should not be empty'
for s in ' \\|/:-': # for s in ' \\|/:.-': #
name = name.replace(s, '_') name = name.replace(s, '_')
if not name.startswith('_'): if not name.startswith('_'):
name = '_' + name name = '_' + name
...@@ -207,7 +206,7 @@ class Program(object): ...@@ -207,7 +206,7 @@ class Program(object):
return desc return desc
def VarDesc(self, def VarDesc(self,
var_name, name,
persistable=False, persistable=False,
value_info=None, value_info=None,
remove_batch=None): remove_batch=None):
...@@ -215,18 +214,16 @@ class Program(object): ...@@ -215,18 +214,16 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert var_name not in self.var_descs, 'var naming conflicted' assert name not in self.var_descs, 'var naming conflicted'
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = var_name var_desc.name = name
var_desc.persistable = persistable var_desc.persistable = persistable
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[var_name] = var_desc self.var_descs[name] = var_desc
if value_info: if value_info:
self.VarTypeShapeInfo(var_name, self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
value_info,
remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, *args, **kwargs):
""" """
...@@ -260,19 +257,19 @@ class Program(object): ...@@ -260,19 +257,19 @@ class Program(object):
else: else:
self.code_mutable = code_mutable self.code_mutable = code_mutable
def VarTypeShapeInfo(self, var_name, value_info, remove_batch=None): def VarTypeShapeInfo(self, name, value_info, remove_batch=None):
""" """
set value_info for var set value_info for var
""" """
if var_name not in self.var_descs: if name not in self.var_descs:
return return
dtype = value_info.get('dtype', None) dtype = value_info.get('dtype', None)
if dtype is None: if dtype is None:
return return
var_desc = self.var_descs[var_name] var_desc = self.var_descs[name]
tensor_desc = var_desc.type.lod_tensor.tensor tensor_desc = var_desc.type.lod_tensor.tensor
tensor_desc.data_type = self.Dtype(dtype) # required tensor_desc.data_type = self.Dtype(dtype) # required
...@@ -292,8 +289,7 @@ class Writer(object): ...@@ -292,8 +289,7 @@ class Writer(object):
fluid code and desc writter fluid code and desc writter
""" """
# CODE_INDENT = ' ' * 4 CODE_INDENT = ' ' * 4 # '\t'
CODE_INDENT = '\t'
@staticmethod @staticmethod
def header_code(func_name, info=''): def header_code(func_name, info=''):
...@@ -313,6 +309,7 @@ class Writer(object): ...@@ -313,6 +309,7 @@ class Writer(object):
codes.append('from paddle.fluid import initializer, layers') codes.append('from paddle.fluid import initializer, layers')
codes.append('') codes.append('')
codes.append('') codes.append('')
codes.append('def {}():'.format(func_name)) codes.append('def {}():'.format(func_name))
return codes return codes
...@@ -342,24 +339,26 @@ class Writer(object): ...@@ -342,24 +339,26 @@ class Writer(object):
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embeded_as', []): if value_info.get('embedded_as', []):
var_names = value_info['embeded_as'] embedded_names = value_info['embedded_as']
prog.Code('# parameter {} embeded as {}'.format(name, var_names)) prog.Code('# parameter {} embedded as {}'.format(
for var_name in var_names: name, embedded_names))
prog.VarDesc(var_name, persistable=True, value_info=value_info) for embedded_name in embedded_names:
prog.VarDesc(embedded_name,
persistable=True,
value_info=value_info)
else: else:
var_name = make_var_name(name)
attr_name = make_attr_name(name) attr_name = make_attr_name(name)
prog.Code('# parameter {}: {}'.format(name, var_name)) prog.Code('# parameter {}'.format(name))
prog.Code('{} = ParamAttr(name={})' # , trainable=True prog.Code('{} = ParamAttr(name={})' # , trainable=True
.format(attr_name, repr(var_name))) .format(attr_name, repr(name)))
prog.Code( prog.Code(
'{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}' '{} = layers.create_parameter(shape={}, dtype={}, name={}, attr={}'
', default_initializer=initializer.Constant(0))' #, is_bias={} ', default_initializer=initializer.Constant(0))' #, is_bias={}
.format(var_name, value_info['shape'], .format(name, value_info['shape'],
repr(value_info['dtype'].name), repr(name), repr(value_info['dtype'].name), repr(name),
attr_name)) #, value_info.get('is_bias', False))) attr_name)) #, value_info.get('is_bias', False)))
prog.VarDesc(var_name, persistable=True, value_info=value_info) prog.VarDesc(name, persistable=True, value_info=value_info)
@staticmethod @staticmethod
def emit_inputs(prog, names, value_infos, remove_batch=None): def emit_inputs(prog, names, value_infos, remove_batch=None):
...@@ -368,7 +367,6 @@ class Writer(object): ...@@ -368,7 +367,6 @@ class Writer(object):
""" """
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name)
value_info = value_infos[name] value_info = value_infos[name]
shape = value_info['shape'] shape = value_info['shape']
if remove_batch is None: if remove_batch is None:
...@@ -377,13 +375,13 @@ class Writer(object): ...@@ -377,13 +375,13 @@ class Writer(object):
if remove_batch: if remove_batch:
shape = shape[1:] shape = shape[1:]
prog.Code('# input {}: {}'.format(name, var_name)) prog.Code('# input {}'.format(name))
prog.Code(( prog.Code((
'{} = layers.data(name={}, shape={}, dtype={}, ' '{} = layers.data(name={}, shape={}, dtype={}, '
'append_batch_size={})' # , stop_gradient=True 'append_batch_size={})' # , stop_gradient=True
).format( ).format(
var_name, name,
repr(var_name), repr(name),
shape, shape,
repr(value_info['dtype'].name), repr(value_info['dtype'].name),
remove_batch, remove_batch,
...@@ -391,12 +389,10 @@ class Writer(object): ...@@ -391,12 +389,10 @@ class Writer(object):
prog.OpDesc( prog.OpDesc(
'feed', 'feed',
(['feed'], 'X'), (['feed'], 'X'),
([var_name], 'Out'), ([name], 'Out'),
{'col': idx}, {'col': idx},
) )
prog.VarDesc(var_name, prog.VarDesc(name, value_info=value_info, remove_batch=remove_batch)
value_info=value_info,
remove_batch=remove_batch)
@staticmethod @staticmethod
def emit_outputs(prog, names): #, value_infos def emit_outputs(prog, names): #, value_infos
...@@ -406,12 +402,11 @@ class Writer(object): ...@@ -406,12 +402,11 @@ class Writer(object):
code = 'return ' code = 'return '
for idx, name in enumerate(names): for idx, name in enumerate(names):
var_name = make_var_name(name) code += name + ', '
code += var_name + ', '
prog.OpDesc( prog.OpDesc(
'fetch', 'fetch',
([var_name], 'X'), ([name], 'X'),
(['fetch'], 'Out'), (['fetch'], 'Out'),
{'col': idx}, {'col': idx},
) )
...@@ -458,8 +453,7 @@ class Writer(object): ...@@ -458,8 +453,7 @@ class Writer(object):
for name, weight in weights.items(): for name, weight in weights.items():
assert isinstance(weights, dict), 'dict type weights required' assert isinstance(weights, dict), 'dict type weights required'
var_name = make_var_name(name) filename = os.path.join(save_dir, name)
filename = os.path.join(save_dir, var_name)
Writer.write_weight(weight, filename) Writer.write_weight(weight, filename)
logger.debug('saved weight %s to %s', name, filename) logger.debug('saved weight %s to %s', name, filename)
......
-e . -e .
onnx>=1.4 onnx>=1.4
paddlepaddle paddlepaddle>=1.5
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册