提交 9d147284 编写于 作者: M Macrobull

better type-shape inference support

上级 74feae9b
...@@ -17,8 +17,8 @@ def make_var_name(name): ...@@ -17,8 +17,8 @@ def make_var_name(name):
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:-': # for s in ' \\|/:-': #
......
...@@ -20,8 +20,8 @@ def make_var_name(name): ...@@ -20,8 +20,8 @@ def make_var_name(name):
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:-': # for s in ' \\|/:-': #
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
# setopt SH_WORD_SPLIT # if zsh # setopt SH_WORD_SPLIT # if zsh
# alias python="python3" # if ...
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/" base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
convert_cmd="python -m onnx2fluid"
validate_cmd="$convert_cmd.validation"
convert_flags="-e -o /tmp/export/" convert_flags="-e -o /tmp/export/"
validate_flags1="/tmp/export/model.py" validate_flags1="/tmp/export/model.py"
validate_flags2="/tmp/export/__model__" validate_flags2="/tmp/export/__model__"
validate_flags3="/tmp/export/__model__ -i"
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
# alias python="python3" # if ...
bvlc_alexnet() bvlc_alexnet()
...@@ -23,21 +26,23 @@ bvlc_alexnet() ...@@ -23,21 +26,23 @@ bvlc_alexnet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -53,14 +58,15 @@ bvlc_googlenet() ...@@ -53,14 +58,15 @@ bvlc_googlenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -76,14 +82,15 @@ bvlc_reference_caffenet() ...@@ -76,14 +82,15 @@ bvlc_reference_caffenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -122,21 +130,23 @@ densenet121() ...@@ -122,21 +130,23 @@ densenet121()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1 python convert_data_pb.py "$pb_dir" data_0 fc6_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -152,14 +162,15 @@ emotion_ferplus() ...@@ -152,14 +162,15 @@ emotion_ferplus()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0 python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -175,21 +186,23 @@ inception_v1() ...@@ -175,21 +186,23 @@ inception_v1()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -205,21 +218,23 @@ inception_v2() ...@@ -205,21 +218,23 @@ inception_v2()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -235,14 +250,15 @@ mobilenet() ...@@ -235,14 +250,15 @@ mobilenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -258,14 +274,15 @@ resnet18() ...@@ -258,14 +274,15 @@ resnet18()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -281,21 +298,23 @@ resnet50() ...@@ -281,21 +298,23 @@ resnet50()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -311,14 +330,15 @@ resnet100_arcface() ...@@ -311,14 +330,15 @@ resnet100_arcface()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1 python convert_data_pb.py "$pb_dir" data fc1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -334,14 +354,15 @@ resnet101_duc() ...@@ -334,14 +354,15 @@ resnet101_duc()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss python convert_data_pb.py "$pb_dir" data seg_loss
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -357,14 +378,15 @@ resnet152() ...@@ -357,14 +378,15 @@ resnet152()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -380,14 +402,15 @@ shufflenet() ...@@ -380,14 +402,15 @@ shufflenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -403,14 +426,15 @@ squeezenet() ...@@ -403,14 +426,15 @@ squeezenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 softmaxout_1 python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -426,14 +450,15 @@ squeezenet1v1() ...@@ -426,14 +450,15 @@ squeezenet1v1()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -450,14 +475,15 @@ ssd() ...@@ -450,14 +475,15 @@ ssd()
mkdir "$bn_tar" mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar/" tar xf "$fn_tar" -C "$bn_tar/"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -473,14 +499,15 @@ tiny_yolov2() ...@@ -473,14 +499,15 @@ tiny_yolov2()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" image grid python convert_data_pb.py "$pb_dir" image grid
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -496,14 +523,15 @@ vgg16bn() ...@@ -496,14 +523,15 @@ vgg16bn()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -519,14 +547,15 @@ vgg19() ...@@ -519,14 +547,15 @@ vgg19()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -535,21 +564,22 @@ yolov3() ...@@ -535,21 +564,22 @@ yolov3()
{ {
bn_tar="yolov3" bn_tar="yolov3"
fn_tar="$bn_tar.tar.gz" fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/yolov3.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar" http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar"
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x # $convert_cmd $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0 python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -565,14 +595,15 @@ zfnet512() ...@@ -565,14 +595,15 @@ zfnet512()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
......
...@@ -61,11 +61,14 @@ def main(**kwargs): ...@@ -61,11 +61,14 @@ def main(**kwargs):
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
infer_inputs = kwargs.pop('infer_inputs', None) infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs is not None: save_inference_model = infer_inputs is not None
if golden_data_filename or save_inference_model:
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None if save_inference_model:
inference_input_names = infer_inputs and infer_inputs.split(',') inference_input_names = infer_inputs.split(',')
else:
inference_input_names = None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(shutil.os.path.join(save_dir, '__model__'),
......
...@@ -23,7 +23,7 @@ def make_var_name(name): ...@@ -23,7 +23,7 @@ def make_var_name(name):
""" """
if name == '': if name == '':
return '_' return ''
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:.-': for s in ' \\|/:.-':
...@@ -170,7 +170,7 @@ def convert(onnx_model_filename, ...@@ -170,7 +170,7 @@ def convert(onnx_model_filename,
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_vars.append(var_name) bad_vars.append(var_name)
if len(bad_vars) > 0: if bad_vars:
logger.warning('type-shape not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_vars[:5])) ', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
......
...@@ -99,6 +99,9 @@ def get_attribute_value2(attr): ...@@ -99,6 +99,9 @@ def get_attribute_value2(attr):
elif attr.type == onnx.AttributeProto.STRING: elif attr.type == onnx.AttributeProto.STRING:
value = attr.s value = attr.s
value = value.decode() if isinstance(value, bytes) else value value = value.decode() if isinstance(value, bytes) else value
elif attr.type == onnx.AttributeProto.STRINGS:
value = attr.strings
value = [s.decode() if isinstance(s, bytes) else s for s in value]
else: else:
value = get_attribute_value(attr) value = get_attribute_value(attr)
return value return value
...@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'): ...@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_in_degrees): for node_idx, degree in enumerate(node_in_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].output: for val_name in nodes[node_idx].output:
output_refs[val_name].remove(node_idx) output_refs[val_name].remove(node_idx)
if len(output_refs[val_name]) > 0: if output_refs[val_name]:
continue continue
output_refs.pop(val_name) output_refs.pop(val_name)
if val_name not in input_refs: if val_name not in input_refs:
...@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'): ...@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_out_degrees): for node_idx, degree in enumerate(node_out_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].input: for val_name in nodes[node_idx].input:
input_refs[val_name].remove(node_idx) input_refs[val_name].remove(node_idx)
if len(input_refs[val_name]) > 0: if input_refs[val_name]:
continue continue
input_refs.pop(val_name) input_refs.pop(val_name)
if val_name not in output_refs: if val_name not in output_refs:
...@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None): ...@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices generator for ONNX node graph with given indices
""" """
for index in indices or range(len(nodes)): if indices is None:
indices = range(len(nodes))
for index in indices:
node = nodes[index] node = nodes[index]
name = node.name name = node.name
domain = node.domain domain = node.domain
...@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None): ...@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:-': #
name = name.replace(s, '_') # else: # make_op_name
# for s in ' \\|/:-': #
# name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
...@@ -356,10 +364,11 @@ def polish_and_save(model_filename, ...@@ -356,10 +364,11 @@ def polish_and_save(model_filename,
run polish_model and save run polish_model and save
""" """
if save_filename is None:
save_filename = model_filename.replace('.onnx', suffix + '.onnx')
model = onnx.load(model_filename) model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs) model = polish_model(model, *args, **kwargs)
save_filename = save_filename or model_filename.replace(
'.onnx', suffix + '.onnx')
onnx.save(model, save_filename) onnx.save(model, save_filename)
logger.info('polished model saved to: %s', save_filename) logger.info('polished model saved to: %s', save_filename)
return save_filename return save_filename
...@@ -495,7 +504,7 @@ def optimize_model_cast(model): ...@@ -495,7 +504,7 @@ def optimize_model_cast(model):
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
if not (node.op_type == 'Cast'): if node.op_type != 'Cast':
continue continue
attrs = node_attrs(node) attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']] output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
...@@ -551,7 +560,7 @@ def optimize_model_slice(model): ...@@ -551,7 +560,7 @@ def optimize_model_slice(model):
node = nodes[node_idx] node = nodes[node_idx]
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain return chain
if not node.op_type == 'Slice': if node.op_type != 'Slice':
return chain return chain
chain.append(node_idx) chain.append(node_idx)
output_name = node.output[0] output_name = node.output[0]
...@@ -585,10 +594,10 @@ def optimize_model_slice(model): ...@@ -585,10 +594,10 @@ def optimize_model_slice(model):
nodes_to_remove = [] nodes_to_remove = []
for node_idx in range(len(nodes)): for node_idx in range(len(nodes)):
slice_chain = build_slice_node_chain(node_idx) slice_chain = build_slice_node_chain(node_idx)
if len(slice_chain) == 0: if not slice_chain:
continue continue
merged_slice = merge_slice(slice_chain) merged_slice = merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge if merged_slice and len(slice_chain) == 1: # no need to merge
continue continue
attrs = {'axes': [], 'starts': [], 'ends': []} attrs = {'axes': [], 'starts': [], 'ends': []}
...@@ -602,12 +611,11 @@ def optimize_model_slice(model): ...@@ -602,12 +611,11 @@ def optimize_model_slice(model):
output_name = last_node.output[0] output_name = last_node.output[0]
processed = -1 processed = -1
if output_name in input_refs: # 0, [1...] if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len( new_input_name = first_node.output[0] if merged_slice else input_name
merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name, processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs) new_input_name, input_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[0] remain_idx = slice_chain[0]
remove_chain = slice_chain[1:] remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -621,12 +629,11 @@ def optimize_model_slice(model): ...@@ -621,12 +629,11 @@ def optimize_model_slice(model):
remove_chain = slice_chain remove_chain = slice_chain
if processed < 0 and input_name in output_refs: if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len( new_output_name = last_node.input[0] if merged_slice else output_name
merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name, processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs) new_output_name, output_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[-1] remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1] remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -641,7 +648,7 @@ def optimize_model_slice(model): ...@@ -641,7 +648,7 @@ def optimize_model_slice(model):
if processed > 0: if processed > 0:
nodes_to_remove.extend(remove_chain) nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0: if not merged_slice:
logger.debug('skip slice chain %s -> %s -> %s', input_name, logger.debug('skip slice chain %s -> %s -> %s', input_name,
slice_chain, output_name) slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain elif processed < 0: # NEVERFIX: not merge standalone slice chain
......
此差异已折叠。
...@@ -119,7 +119,7 @@ def export_onnx_with_validation( ...@@ -119,7 +119,7 @@ def export_onnx_with_validation(
return list(map(tensors_to_arrays, tensors)) return list(map(tensors_to_arrays, tensors))
def zip_dict( def zip_dict(
keys: Union[Iterable[Any], None], keys: Optional[Iterable[Any]],
values: Sequence[Union[Any, Sequence[Any]]], values: Sequence[Union[Any, Sequence[Any]]],
) -> MyDict[Text, Union[object, MyDict[Text, object]]]: ) -> MyDict[Text, Union[object, MyDict[Text, object]]]:
keys = keys or range(len(values)) keys = keys or range(len(values))
......
...@@ -160,7 +160,8 @@ def validate(fluid_model_filename, ...@@ -160,7 +160,8 @@ def validate(fluid_model_filename,
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
elif save_inference_model: elif save_inference_model:
assert inference_input_names, 'input names required for type-shape inference' assert inference_input_names is not None, (
'input names required for type-shape inference')
input_names = inference_input_names input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names)) logger.info('using input names: %s', ', '.join(input_names))
...@@ -178,7 +179,7 @@ def validate(fluid_model_filename, ...@@ -178,7 +179,7 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
if not golden_data_filename: if golden_data_filename == '':
return True return True
# execute # execute
......
...@@ -133,7 +133,7 @@ class Program(object): ...@@ -133,7 +133,7 @@ class Program(object):
od_attr.type = framework_pb2.STRING od_attr.type = framework_pb2.STRING
od_attr.s = value od_attr.s = value
elif isinstance(value, list): elif isinstance(value, list):
if len(value) > 0: # TODO: test all items if value: # TODO: test all items
if isinstance(value[0], if isinstance(value[0],
bool): # bool.mro() = [bool, int, object] bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS od_attr.type = framework_pb2.BOOLEANS
...@@ -183,23 +183,16 @@ class Program(object): ...@@ -183,23 +183,16 @@ class Program(object):
if self.code_mutable: if self.code_mutable:
self.codes.append(code) self.codes.append(code)
def OpDesc(self, def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs):
op_type,
input_key_vals=None,
output_key_vals=None,
attrs=None):
""" """
add OpDesc add OpDesc
""" """
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = op_type desc.type = op_type
if input_key_vals: desc.inputs.extend(self.OpDescVars(*input_key_vals))
desc.inputs.extend(self.OpDescVars(*input_key_vals)) desc.outputs.extend(self.OpDescVars(*output_key_vals))
if output_key_vals: desc.attrs.extend(self.OpDescAttrs(attrs))
desc.outputs.extend(self.OpDescVars(*output_key_vals))
if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
...@@ -212,7 +205,7 @@ class Program(object): ...@@ -212,7 +205,7 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert name not in self.var_descs, 'var naming conflicted' assert name not in self.var_descs, 'var name {} conflicts'.format(name)
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = name var_desc.name = name
...@@ -220,10 +213,10 @@ class Program(object): ...@@ -220,10 +213,10 @@ class Program(object):
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[name] = var_desc self.var_descs[name] = var_desc
if value_info: if value_info is not None:
self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch) self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, inputs, outputs, attrs, *args, **kwargs):
""" """
convert an ONNX op and add it to program convert an ONNX op and add it to program
""" """
...@@ -232,15 +225,17 @@ class Program(object): ...@@ -232,15 +225,17 @@ class Program(object):
raise ValueError('only default domain supported') raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING: if op_type in symbolic.DEFAULT_OP_MAPPING:
symbolic._default(self, op_type, *args, **kwargs) symbolic._default(self, op_type, inputs, outputs, attrs, *args,
**kwargs)
elif hasattr(symbolic, op_type): elif hasattr(symbolic, op_type):
fn = getattr(symbolic, op_type) fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs) fn(self, inputs, outputs, attrs, *args, **kwargs)
else: else:
raise ValueError('conversion for {}::{} not supported'.format( raise ValueError('conversion for {}::{} not supported'.format(
domain, op_type)) domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs): def IntermediateOp(self, domain, op_type, inputs, outputs, attrs, *args,
**kwargs):
""" """
convert an intermediate ONNX op declaring in desc program only convert an intermediate ONNX op declaring in desc program only
""" """
...@@ -248,7 +243,7 @@ class Program(object): ...@@ -248,7 +243,7 @@ class Program(object):
code_mutable = self.code_mutable code_mutable = self.code_mutable
self.code_mutable = False self.code_mutable = False
try: try:
self.Op(domain, op_type, *args, **kwargs) self.Op(domain, op_type, inputs, outputs, attrs, *args, **kwargs)
except BaseException as e: except BaseException as e:
self.code_mutable = code_mutable self.code_mutable = code_mutable
raise e raise e
...@@ -272,14 +267,15 @@ class Program(object): ...@@ -272,14 +267,15 @@ class Program(object):
tensor_desc.data_type = self.Dtype(dtype) # required tensor_desc.data_type = self.Dtype(dtype) # required
shape = value_info.get('shape', None) shape = value_info.get('shape', None)
if shape is not None: if not shape: # None or scalars
tensor_desc.dims.extend(shape) return
if len(shape) > 0: # skip scalars
if remove_batch is None: tensor_desc.dims.extend(shape)
remove_batch = value_info.get('remove_batch', if remove_batch is None:
False) #not persistable) remove_batch = value_info.get('remove_batch',
if remove_batch: False) #not persistable)
tensor_desc.dims[0] = -1 if remove_batch:
tensor_desc.dims[0] = -1
class Writer(object): class Writer(object):
...@@ -337,8 +333,8 @@ class Writer(object): ...@@ -337,8 +333,8 @@ class Writer(object):
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embedded_as', []): embedded_names = value_info.get('embedded_as', [])
embedded_names = value_info['embedded_as'] if embedded_names:
prog.Code('# parameter {} embedded as {}'.format( prog.Code('# parameter {} embedded as {}'.format(
name, embedded_names)) name, embedded_names))
for embedded_name in embedded_names: for embedded_name in embedded_names:
...@@ -431,7 +427,8 @@ class Writer(object): ...@@ -431,7 +427,8 @@ class Writer(object):
assert lod is None or isinstance(lod, assert lod is None or isinstance(lod,
list), 'lod should be None or list' list), 'lod should be None or list'
lod = lod or [0] if lod is None:
lod = [0]
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册