提交 9d147284 编写于 作者: M Macrobull

better type-shape inference support

上级 74feae9b
...@@ -17,8 +17,8 @@ def make_var_name(name): ...@@ -17,8 +17,8 @@ def make_var_name(name):
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:-': # for s in ' \\|/:-': #
......
...@@ -20,8 +20,8 @@ def make_var_name(name): ...@@ -20,8 +20,8 @@ def make_var_name(name):
make a valid variable name in Python code make a valid variable name in Python code
""" """
if name == '': assert name
return '_'
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:-': # for s in ' \\|/:-': #
......
...@@ -2,14 +2,17 @@ ...@@ -2,14 +2,17 @@
# setopt SH_WORD_SPLIT # if zsh # setopt SH_WORD_SPLIT # if zsh
# alias python="python3" # if ...
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/" base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
convert_cmd="python -m onnx2fluid"
validate_cmd="$convert_cmd.validation"
convert_flags="-e -o /tmp/export/" convert_flags="-e -o /tmp/export/"
validate_flags1="/tmp/export/model.py" validate_flags1="/tmp/export/model.py"
validate_flags2="/tmp/export/__model__" validate_flags2="/tmp/export/__model__"
validate_flags3="/tmp/export/__model__ -i"
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
# alias python="python3" # if ...
bvlc_alexnet() bvlc_alexnet()
...@@ -23,21 +26,23 @@ bvlc_alexnet() ...@@ -23,21 +26,23 @@ bvlc_alexnet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -53,14 +58,15 @@ bvlc_googlenet() ...@@ -53,14 +58,15 @@ bvlc_googlenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -76,14 +82,15 @@ bvlc_reference_caffenet() ...@@ -76,14 +82,15 @@ bvlc_reference_caffenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -122,21 +130,23 @@ densenet121() ...@@ -122,21 +130,23 @@ densenet121()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1 python convert_data_pb.py "$pb_dir" data_0 fc6_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -152,14 +162,15 @@ emotion_ferplus() ...@@ -152,14 +162,15 @@ emotion_ferplus()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0 python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -175,21 +186,23 @@ inception_v1() ...@@ -175,21 +186,23 @@ inception_v1()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -205,21 +218,23 @@ inception_v2() ...@@ -205,21 +218,23 @@ inception_v2()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -235,14 +250,15 @@ mobilenet() ...@@ -235,14 +250,15 @@ mobilenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -258,14 +274,15 @@ resnet18() ...@@ -258,14 +274,15 @@ resnet18()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -281,21 +298,23 @@ resnet50() ...@@ -281,21 +298,23 @@ resnet50()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz for npz in "$bn_tar/"*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" $validate_cmd $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" $validate_cmd $validate_flags2 -t "$npz"
done done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -311,14 +330,15 @@ resnet100_arcface() ...@@ -311,14 +330,15 @@ resnet100_arcface()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1 python convert_data_pb.py "$pb_dir" data fc1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -334,14 +354,15 @@ resnet101_duc() ...@@ -334,14 +354,15 @@ resnet101_duc()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss python convert_data_pb.py "$pb_dir" data seg_loss
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -357,14 +378,15 @@ resnet152() ...@@ -357,14 +378,15 @@ resnet152()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -380,14 +402,15 @@ shufflenet() ...@@ -380,14 +402,15 @@ shufflenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -403,14 +426,15 @@ squeezenet() ...@@ -403,14 +426,15 @@ squeezenet()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 softmaxout_1 python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -426,14 +450,15 @@ squeezenet1v1() ...@@ -426,14 +450,15 @@ squeezenet1v1()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0 python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -450,14 +475,15 @@ ssd() ...@@ -450,14 +475,15 @@ ssd()
mkdir "$bn_tar" mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar/" tar xf "$fn_tar" -C "$bn_tar/"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -473,14 +499,15 @@ tiny_yolov2() ...@@ -473,14 +499,15 @@ tiny_yolov2()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" image grid python convert_data_pb.py "$pb_dir" image grid
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -496,14 +523,15 @@ vgg16bn() ...@@ -496,14 +523,15 @@ vgg16bn()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y $convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -519,14 +547,15 @@ vgg19() ...@@ -519,14 +547,15 @@ vgg19()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1 python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -535,21 +564,22 @@ yolov3() ...@@ -535,21 +564,22 @@ yolov3()
{ {
bn_tar="yolov3" bn_tar="yolov3"
fn_tar="$bn_tar.tar.gz" fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/yolov3.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar" http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar"
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x # $convert_cmd $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir ..." echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0 python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
...@@ -565,14 +595,15 @@ zfnet512() ...@@ -565,14 +595,15 @@ zfnet512()
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" $convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/ for pb_dir in "$bn_tar/"*/
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1 python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz $validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/" rm -rf "$bn_tar/"
} }
......
...@@ -61,11 +61,14 @@ def main(**kwargs): ...@@ -61,11 +61,14 @@ def main(**kwargs):
passed = True passed = True
golden_data_filename = kwargs.pop('test_data', '') golden_data_filename = kwargs.pop('test_data', '')
infer_inputs = kwargs.pop('infer_inputs', None) infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs is not None: save_inference_model = infer_inputs is not None
if golden_data_filename or save_inference_model:
from .validation import validate from .validation import validate
save_inference_model = infer_inputs is not None if save_inference_model:
inference_input_names = infer_inputs and infer_inputs.split(',') inference_input_names = infer_inputs.split(',')
else:
inference_input_names = None
logger.info('starting validation on desc ...') logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'), passed &= validate(shutil.os.path.join(save_dir, '__model__'),
......
...@@ -23,7 +23,7 @@ def make_var_name(name): ...@@ -23,7 +23,7 @@ def make_var_name(name):
""" """
if name == '': if name == '':
return '_' return ''
if name[0].isdigit(): if name[0].isdigit():
return 'var_' + name return 'var_' + name
for s in ' \\|/:.-': for s in ' \\|/:.-':
...@@ -170,7 +170,7 @@ def convert(onnx_model_filename, ...@@ -170,7 +170,7 @@ def convert(onnx_model_filename,
for var_name, var_desc in fluid_program.var_descs.items(): for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'): if not var_desc.type.lod_tensor.HasField('tensor'):
bad_vars.append(var_name) bad_vars.append(var_name)
if len(bad_vars) > 0: if bad_vars:
logger.warning('type-shape not infered for var %s ...', logger.warning('type-shape not infered for var %s ...',
', '.join(bad_vars[:5])) ', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, ' logger.warning('this causes little problem for PaddlePaddle, '
......
...@@ -99,6 +99,9 @@ def get_attribute_value2(attr): ...@@ -99,6 +99,9 @@ def get_attribute_value2(attr):
elif attr.type == onnx.AttributeProto.STRING: elif attr.type == onnx.AttributeProto.STRING:
value = attr.s value = attr.s
value = value.decode() if isinstance(value, bytes) else value value = value.decode() if isinstance(value, bytes) else value
elif attr.type == onnx.AttributeProto.STRINGS:
value = attr.strings
value = [s.decode() if isinstance(s, bytes) else s for s in value]
else: else:
value = get_attribute_value(attr) value = get_attribute_value(attr)
return value return value
...@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'): ...@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_in_degrees): for node_idx, degree in enumerate(node_in_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].output: for val_name in nodes[node_idx].output:
output_refs[val_name].remove(node_idx) output_refs[val_name].remove(node_idx)
if len(output_refs[val_name]) > 0: if output_refs[val_name]:
continue continue
output_refs.pop(val_name) output_refs.pop(val_name)
if val_name not in input_refs: if val_name not in input_refs:
...@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'): ...@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_out_degrees): for node_idx, degree in enumerate(node_out_degrees):
if degree == 0: if degree == 0:
queue.append(node_idx) queue.append(node_idx)
while len(queue) > 0: while queue:
node_idx = queue.pop(0) node_idx = queue.pop(0)
node_topo.append(node_idx) node_topo.append(node_idx)
for val_name in nodes[node_idx].input: for val_name in nodes[node_idx].input:
input_refs[val_name].remove(node_idx) input_refs[val_name].remove(node_idx)
if len(input_refs[val_name]) > 0: if input_refs[val_name]:
continue continue
input_refs.pop(val_name) input_refs.pop(val_name)
if val_name not in output_refs: if val_name not in output_refs:
...@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None): ...@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices generator for ONNX node graph with given indices
""" """
for index in indices or range(len(nodes)): if indices is None:
indices = range(len(nodes))
for index in indices:
node = nodes[index] node = nodes[index]
name = node.name name = node.name
domain = node.domain domain = node.domain
...@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None): ...@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None):
if name == '': if name == '':
name = 'op_' + str(index) name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:-': #
name = name.replace(s, '_') # else: # make_op_name
# for s in ' \\|/:-': #
# name = name.replace(s, '_')
if domain == '': if domain == '':
domain = DEFAULT_OP_DOMAIN domain = DEFAULT_OP_DOMAIN
...@@ -356,10 +364,11 @@ def polish_and_save(model_filename, ...@@ -356,10 +364,11 @@ def polish_and_save(model_filename,
run polish_model and save run polish_model and save
""" """
if save_filename is None:
save_filename = model_filename.replace('.onnx', suffix + '.onnx')
model = onnx.load(model_filename) model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs) model = polish_model(model, *args, **kwargs)
save_filename = save_filename or model_filename.replace(
'.onnx', suffix + '.onnx')
onnx.save(model, save_filename) onnx.save(model, save_filename)
logger.info('polished model saved to: %s', save_filename) logger.info('polished model saved to: %s', save_filename)
return save_filename return save_filename
...@@ -495,7 +504,7 @@ def optimize_model_cast(model): ...@@ -495,7 +504,7 @@ def optimize_model_cast(model):
for node_idx, node in enumerate(nodes): for node_idx, node in enumerate(nodes):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue continue
if not (node.op_type == 'Cast'): if node.op_type != 'Cast':
continue continue
attrs = node_attrs(node) attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']] output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
...@@ -551,7 +560,7 @@ def optimize_model_slice(model): ...@@ -551,7 +560,7 @@ def optimize_model_slice(model):
node = nodes[node_idx] node = nodes[node_idx]
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''): if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain return chain
if not node.op_type == 'Slice': if node.op_type != 'Slice':
return chain return chain
chain.append(node_idx) chain.append(node_idx)
output_name = node.output[0] output_name = node.output[0]
...@@ -585,10 +594,10 @@ def optimize_model_slice(model): ...@@ -585,10 +594,10 @@ def optimize_model_slice(model):
nodes_to_remove = [] nodes_to_remove = []
for node_idx in range(len(nodes)): for node_idx in range(len(nodes)):
slice_chain = build_slice_node_chain(node_idx) slice_chain = build_slice_node_chain(node_idx)
if len(slice_chain) == 0: if not slice_chain:
continue continue
merged_slice = merge_slice(slice_chain) merged_slice = merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge if merged_slice and len(slice_chain) == 1: # no need to merge
continue continue
attrs = {'axes': [], 'starts': [], 'ends': []} attrs = {'axes': [], 'starts': [], 'ends': []}
...@@ -602,12 +611,11 @@ def optimize_model_slice(model): ...@@ -602,12 +611,11 @@ def optimize_model_slice(model):
output_name = last_node.output[0] output_name = last_node.output[0]
processed = -1 processed = -1
if output_name in input_refs: # 0, [1...] if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len( new_input_name = first_node.output[0] if merged_slice else input_name
merged_slice) > 0 else input_name
processed = skip_node_forward(ret_nodes, output_name, processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs) new_input_name, input_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[0] remain_idx = slice_chain[0]
remove_chain = slice_chain[1:] remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -621,12 +629,11 @@ def optimize_model_slice(model): ...@@ -621,12 +629,11 @@ def optimize_model_slice(model):
remove_chain = slice_chain remove_chain = slice_chain
if processed < 0 and input_name in output_refs: if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len( new_output_name = last_node.input[0] if merged_slice else output_name
merged_slice) > 0 else output_name
processed = skip_node_backward(ret_nodes, input_name, processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs) new_output_name, output_refs)
if processed > 0: if processed > 0:
if len(merged_slice) > 0: if merged_slice:
remain_idx = slice_chain[-1] remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1] remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx] slice_node = ret_nodes[remain_idx]
...@@ -641,7 +648,7 @@ def optimize_model_slice(model): ...@@ -641,7 +648,7 @@ def optimize_model_slice(model):
if processed > 0: if processed > 0:
nodes_to_remove.extend(remove_chain) nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0: if not merged_slice:
logger.debug('skip slice chain %s -> %s -> %s', input_name, logger.debug('skip slice chain %s -> %s -> %s', input_name,
slice_chain, output_name) slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain elif processed < 0: # NEVERFIX: not merge standalone slice chain
......
...@@ -44,50 +44,62 @@ DEFAULT_OP_MAPPING = { ...@@ -44,50 +44,62 @@ DEFAULT_OP_MAPPING = {
## nil ops ## ## nil ops ##
'RandomUniform': 'RandomUniform':
['uniform_random', [], ['Out'], dict(high='max', low='min'), ['uniform_random', [], ['Out'], dict(high='max', low='min'),
dict(), None, None, False], # TODO: add dtype support dict(max=1., min=0., seed=0), None, None, False], # TODO: add dtype support
'RandomNormal': 'RandomNormal':
['gaussian_random', [], ['Out'], dict(scale='std'), ['gaussian_random', [], ['Out'], dict(scale='std'),
dict(), None, None, False], # TODO: add dtype support dict(mean=0., std=1., seed=0), None, None, False], # TODO: add dtype support
## unary ops ## ## unary ops ##
'Abs': ['abs', ['X'], ['Out']], 'Abs': ['abs', ['X'], ['Out']],
'Acos': ['acos', ['X'], ['Out']], 'Acos': ['acos', ['X'], ['Out']],
'Asin': ['asin', ['X'], ['Out']], 'Asin': ['asin', ['X'], ['Out']],
'Atan': ['atan', ['X'], ['Out']], 'Atan': ['atan', ['X'], ['Out']],
'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims='')], 'ArgMax': ['argmax', ['X'], ['Out'], dict(keepdims=''), dict(axis=0)],
'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims='')], 'ArgMin': ['argmin', ['X'], ['Out'], dict(keepdims=''), dict(axis=0)],
'Ceil': ['ceil', ['X'], ['Out']], 'Ceil': ['ceil', ['X'], ['Out']],
'Clip': ['clip', ['X'], ['Out']], # attrs bypassed 'Clip':
['clip', ['X'], ['Out'], dict(), dict(
min=(_np.array([255, 255, 127, 255], dtype=_np.uint8).view(_np.float32)),
max=(_np.array([255, 255, 127, 127], dtype=_np.uint8).view(_np.float32)),
)],
'Cos': ['cos', ['X'], ['Out']], 'Cos': ['cos', ['X'], ['Out']],
'Elu': ['elu', ['X'], ['Out']], 'Elu': ['elu', ['X'], ['Out'], dict(), dict(alpha=1.)],
'Exp': ['exp', ['X'], ['Out']], 'Exp': ['exp', ['X'], ['Out']],
'Flatten': ['flatten', ['X'], ['Out']], # attrs bypassed, FIXME: emit flatten2 'Flatten': ['flatten', ['X'], ['Out'], dict(), dict(axis=1)], # FIXME: emit flatten2
'Floor': ['floor', ['X'], ['Out']], 'Floor': ['floor', ['X'], ['Out']],
'Gather': ['gather', ['X'], ['Out'], dict(axis='')], 'Gather': ['gather', ['X', "Index"], ['Out'], dict(axis='')],
'HardSigmoid': ['hard_sigmoid', ['X'], ['Out'], dict(alpha='slope', beta='offset')], 'HardSigmoid':
['hard_sigmoid', ['X'], ['Out'], dict(alpha='slope', beta='offset'),
dict(slope=.2, offset=.5)],
'Identity': ['assign', ['X'], ['Out']], 'Identity': ['assign', ['X'], ['Out']],
'LeakyRelu': ['leaky_relu', ['X'], ['Out']], 'LeakyRelu': ['leaky_relu', ['X'], ['Out'], dict(), dict(alpha=.01)],
'Log': ['log', ['X'], ['Out']], 'Log': ['log', ['X'], ['Out']],
'LRN': ['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k')], # 'LRN':
['lrn', ['X'], ['Out', 'MidOut'], dict(size='n', bias='k'),
dict(n=5, k=1., alpha=1e-4, beta=.75)], #
'Reciprocal': ['reciprocal', ['X'], ['Out']], 'Reciprocal': ['reciprocal', ['X'], ['Out']],
'Relu': ['relu', ['X'], ['Out']], 'Relu': ['relu', ['X'], ['Out']],
'Round': ['round', ['X'], ['Out']], 'Round': ['round', ['X'], ['Out']],
'Selu': ['selu', ['X'], ['Out'], dict(gamma='scale')], 'Selu':
'Shape': ['shape', ['X'], ['Out']], # FIXME: out is int64 vs int32 ['selu', ['X'], ['Out'], dict(gamma='scale'), dict(
scale=1.0507009873554804934193349852946,
alpha=1.6732632423543772848170429916717,
)],
'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')], 'Shrink': ['softshrink', ['X'], ['Out'], dict(bias='', labmd='')],
'Sigmoid': ['sigmoid', ['X'], ['Out']], 'Sigmoid': ['sigmoid', ['X'], ['Out']],
'Sign': ['sign', ['X'], ['Out']], 'Sign': ['sign', ['X'], ['Out']],
'Sin': ['sin', ['X'], ['Out']], 'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2 'Squeeze': ['squeeze', ['X'], ['Out']], # FIXME: emit squeeze2
# FIXME: default axis = -1, reshape required before and after # FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')], 'Softmax': ['softmax', ['X'], ['Out'], dict(axis=''), dict(axis=-1)],
'Softplus': ['softplus', ['X'], ['Out']], 'Softplus': ['softplus', ['X'], ['Out']],
'Softsign': ['softsign', ['X'], ['Out']], 'Softsign': ['softsign', ['X'], ['Out']],
'SpaceToDepth': ['space_to_depth', ['X'], ['Out']], 'SpaceToDepth': ['space_to_depth', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']], 'Sqrt': ['sqrt', ['X'], ['Out']],
'Tanh': ['tanh', ['X'], ['Out']], 'Tanh': ['tanh', ['X'], ['Out']],
'ThresholdedRelu': ['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold')], 'ThresholdedRelu':
['thresholded_relu', ['X'], ['Out'], dict(alpha='threshold'), dict(alpha=1.)],
#'Transpose': ['transpose', ['X'], ['Out']], #'Transpose': ['transpose', ['X'], ['Out']],
'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit unsqueeze2 'Unsqueeze': ['unsqueeze', ['X'], ['Out']], # FIXME: emit unsqueeze2
## binary ops ## ## binary ops ##
'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], 'Add': ['elementwise_add', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
#'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')], #'AffineGrid': ['affine_grid', ['Theta'], ['Output'], dict(size='out_shape')],
...@@ -103,20 +115,31 @@ DEFAULT_OP_MAPPING = { ...@@ -103,20 +115,31 @@ DEFAULT_OP_MAPPING = {
'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], 'Mul': ['elementwise_mul', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Not': ['logical_not', ['X', 'Y'], ['Out']], 'Not': ['logical_not', ['X', 'Y'], ['Out']],
'OneHot': # assuming values=[0, 1], axis=-1 and drop them 'OneHot': # assuming values=[0, 1], axis=-1 and drop them
['one_hot', ['Input', 'Depth'], ['Out'], dict(axis=''), dict(), ['one_hot', ['Input', 'depth_tensor'], ['Out'], dict(axis=''), dict(),
[0, 1], None, False], [0, 1], None, False],
'Or': ['logical_or', ['X', 'Y'], ['Out']], 'Or': ['logical_or', ['X', 'Y'], ['Out']],
'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent 'Pow': ['elementwise_pow', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], # TODO: pow for scalar exponent
'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)], 'Sub': ['elementwise_sub', ['X', 'Y'], ['Out'], dict(), dict(axis=-1)],
'Xor': ['logical_xor', ['X', 'Y'], ['Out']], 'Xor': ['logical_xor', ['X', 'Y'], ['Out']],
# reduce ops # reduce ops
'ReduceMax': ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], # TODO: fix reduce_all ?
'ReduceMean': ['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], 'ReduceMax':
'ReduceMin': ['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], ['reduce_max', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim'),
'ReduceProd': ['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], dict(keep_dim=1)],
'ReduceSum': ['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim')], 'ReduceMean':
['reduce_mean', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)],
'ReduceMin':
['reduce_min', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)],
'ReduceProd':
['reduce_prod', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)],
'ReduceSum':
['reduce_sum', ['X'], ['Out'], dict(axes='dim', keepdims='keep_dim'),
dict(keep_dim=1)],
# other ops # other ops
'Scatter': ['scatter', ['X', 'Index', 'Updates'], ['Out']], 'Scatter': ['scatter', ['X', 'Ids', 'Updates'], ['Out'], dict(), dict(overwrite=True)],
'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']], 'TopK': ['topk', ['X', 'K'], ['Out', 'Indices']],
} }
...@@ -133,7 +156,7 @@ DEFAULT_IOA_CONSTRAINTS = { ...@@ -133,7 +156,7 @@ DEFAULT_IOA_CONSTRAINTS = {
(lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 supported'), (lambda i, o, a: a.get('axis', 0) == 0, 'only axis = 0 supported'),
], ],
'Shrink': [ 'Shrink': [
(lambda i, o, a: a.get('bias', 0) == a.get('lambd', 0.5), (lambda i, o, a: a.get('bias', 0) == a.get('lambd', .5),
'only SoftShrink with bias = lambd supported'), 'only SoftShrink with bias = lambd supported'),
], ],
# 'Softmax': # 'Softmax':
...@@ -231,9 +254,12 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -231,9 +254,12 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
fluid_attrs.update(mapped_attrs) # as new attrs fluid_attrs.update(mapped_attrs) # as new attrs
var_inps = list(map(inputs.__getitem__, var_inps = list(map(inputs.__getitem__,
input_perm)) if input_perm else inputs input_perm)) if input_perm is not None else inputs
var_outs = list(map(outputs.__getitem__, var_outs = list(map(outputs.__getitem__,
output_perm)) if output_perm else outputs output_perm)) if output_perm is not None else outputs
for var_name in var_inps + var_outs:
assert var_name
arg_name = ', name={}'.format( arg_name = ', name={}'.format(
repr(name)) if fill_name_field and name else '' repr(name)) if fill_name_field and name else ''
arg_attrs = [ arg_attrs = [
...@@ -252,7 +278,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -252,7 +278,7 @@ def _default(prog, op_type, inputs, outputs, attrs, *args, name='', **kwargs):
num_vars = len(var_outs) num_vars = len(var_outs)
num_args = len(fluid_output_args) num_args = len(fluid_output_args)
if num_vars < num_args: if num_vars < num_args:
assert fill_name_field, 'name required to name dummy output variables' assert fill_name_field and name, 'name required to name dummy output variables'
for idx_out in range(num_vars, num_args): for idx_out in range(num_vars, num_args):
var_out = name + '.' + fluid_output_args[idx_out] # dummy output var_out = name + '.' + fluid_output_args[idx_out] # dummy output
var_outs.append(var_out) var_outs.append(var_out)
...@@ -267,6 +293,7 @@ def _assign(prog, mapping): ...@@ -267,6 +293,7 @@ def _assign(prog, mapping):
fluid_op = 'assign' fluid_op = 'assign'
for var_dst, var_src in mapping.items(): for var_dst, var_src in mapping.items():
assert var_dst and var_src
prog.Code('{} = {} # assign'.format(var_dst, var_src)) prog.Code('{} = {} # assign'.format(var_dst, var_src))
# prog.Code('{} = layers.{}({})' # prog.Code('{} = layers.{}({})'
# .format(var_dst, # .format(var_dst,
...@@ -282,18 +309,17 @@ def _assign(prog, mapping): ...@@ -282,18 +309,17 @@ def _assign(prog, mapping):
) )
def _zeros_like(prog, var_ref, var_out, value_infos): def _zeros_like(prog, var_ref, var_out):
prog.Op( prog.Op(
'', '',
'Sub', 'Sub',
[var_ref, var_ref], [var_ref, var_ref],
[var_out], [var_out],
{'axis': 0}, {'axis': 0},
value_infos,
) )
def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE def _pad_if_asymmetric(prog, pads, var_input, value_infos, scope): # pads: SSEE
assert len(pads) & 1 == 0 assert len(pads) & 1 == 0
ndims = len(pads) // 2 ndims = len(pads) // 2
symmetric = True symmetric = True
...@@ -304,7 +330,8 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE ...@@ -304,7 +330,8 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
if symmetric: if symmetric:
return pads[:ndims], var_input return pads[:ndims], var_input
var_padded = var_input + '_pad' # explicit variable assert scope
var_padded = scope + '_pad' # explicit variable
prog.Op( prog.Op(
'', '',
'Pad', 'Pad',
...@@ -316,7 +343,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE ...@@ -316,7 +343,7 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
'pads': pads, 'pads': pads,
}, },
value_infos=value_infos, value_infos=value_infos,
name=(var_input + '/pad'), name=(scope + '/pad'),
) )
return [0] * ndims, var_padded return [0] * ndims, var_padded
...@@ -324,7 +351,8 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE ...@@ -324,7 +351,8 @@ def _pad_if_asymmetric(prog, pads, var_input, value_infos): # pads: SSEE
def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
# I/O # I/O
var_x, = inputs var_x, = inputs
var_y, var_indices, = (outputs + [None] * 1)[:2] var_y, var_indices, = (outputs + [''] * 1)[:2]
assert var_x and var_y
# interpretation # interpretation
pool_size = attrs['output_size'] # required pool_size = attrs['output_size'] # required
...@@ -359,21 +387,21 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''): ...@@ -359,21 +387,21 @@ def _adaptive_pool(prog, pool_type, inputs, outputs, attrs, name=''):
(['X'], [var_x]), (['X'], [var_x]),
(['Out', 'Indices'], [var_y] + ([var_indices] if var_indices else [])), (['Out', 'Indices'], [var_y] + ([var_indices] if var_indices else [])),
{ {
'global_pooling': False,
'adaptive': True, 'adaptive': True,
'require_index': bool(var_indices),
'pooling_type': pool_type, 'pooling_type': pool_type,
'ksize': pool_size, 'ksize': pool_size,
# unused # unused
# 'exclusive': True, # 'exclusive': True,
# 'global_pooling': False,
}, },
) )
def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): def _global_pool(prog, pool_type, inputs, outputs, value_infos, name=''):
# I/O # I/O
var_x, = inputs var_x, = inputs
var_y, = outputs var_y, = outputs
assert var_x and var_y
# interpretation # interpretation
input_shape = _shape_or_none(value_infos, var_x) input_shape = _shape_or_none(value_infos, var_x)
...@@ -406,10 +434,10 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -406,10 +434,10 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
(['Out'], [var_y]), (['Out'], [var_y]),
{ {
'global_pooling': True, 'global_pooling': True,
'adaptive': False,
'pooling_type': pool_type, 'pooling_type': pool_type,
'ksize': [-1, -1],
# unused # unused
'adaptive': False,
'ksize': [-1, -1],
'strides': [-1, -1], 'strides': [-1, -1],
'paddings': [0, 0], 'paddings': [0, 0],
'ceil_mode': False, 'ceil_mode': False,
...@@ -417,10 +445,11 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -417,10 +445,11 @@ def _global_pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
) )
def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name):
# I/O # I/O
var_x, = inputs var_x, = inputs
var_y, var_indices, = (outputs + [None] * 1)[:2] var_y, var_indices, = (outputs + [''] * 1)[:2]
assert name and var_x and var_y
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -436,8 +465,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -436,8 +465,8 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
strides = attrs.get('strides', [1] * poolnd) # optional strides = attrs.get('strides', [1] * poolnd) # optional
ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional ceil_mode = bool(attrs.get('ceil_mode', 0)) # optional
pads = attrs.get('pads', [0] * (poolnd * 2)) # optional pads = attrs.get('pads', [0] * (poolnd * 2)) # optional
paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos, name)
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name))
# generation # generation
prog.Code('{} = layers.{}({}, exclusive=True' prog.Code('{} = layers.{}({}, exclusive=True'
...@@ -467,23 +496,23 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''): ...@@ -467,23 +496,23 @@ def _pool(prog, pool_type, inputs, outputs, attrs, value_infos, name=''):
(['Out', 'Indices'], [var_y] + ([var_indices] if var_indices else [])), (['Out', 'Indices'], [var_y] + ([var_indices] if var_indices else [])),
{ {
'global_pooling': False, 'global_pooling': False,
'adaptive': False,
'require_index': bool(var_indices),
'pooling_type': pool_type, 'pooling_type': pool_type,
'ksize': pool_size, 'ksize': pool_size,
'strides': strides, 'strides': strides,
'paddings': paddings, 'paddings': paddings,
'ceil_mode': ceil_mode, 'ceil_mode': ceil_mode,
# unused # unused
'exclusive': True, 'adaptive': False,
# 'exclusive': True,
}, },
) )
def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): def _roi_pool(prog, fluid_op, inputs, outputs, attrs, name):
# I/O # I/O
var_x, var_rois, = inputs var_x, var_rois, = inputs
var_y, = outputs var_y, = outputs
assert name and var_x and var_rois and var_y
# interpretation # interpretation
spatial_scale = attrs['spatial_scale'] # required spatial_scale = attrs['spatial_scale'] # required
...@@ -526,7 +555,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name): ...@@ -526,7 +555,7 @@ def _roi_pool(prog, fluid_op, inputs, outputs, attrs, value_infos, name):
prog.VarDesc(var_argmax) prog.VarDesc(var_argmax)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['X', 'Rois'], [var_x, var_rois]), (['X', 'ROIs'], [var_x, var_rois]),
(['Out', 'Argmax'], [var_y] + ([var_argmax] if is_max_pool else [])), (['Out', 'Argmax'], [var_y] + ([var_argmax] if is_max_pool else [])),
od_attrs, od_attrs,
) )
...@@ -536,6 +565,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -536,6 +565,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
# I/O # I/O
var_x, var_scales, = inputs var_x, var_scales, = inputs
var_y, = outputs var_y, = outputs
assert var_x and var_scales and var_y
# interpretation # interpretation
# output shape # output shape
...@@ -551,7 +581,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''): ...@@ -551,7 +581,7 @@ def _interpolate(prog, inputs, outputs, attrs, value_infos, name=''):
1] == 1, 'only scale on (NC)HW supported' 1] == 1, 'only scale on (NC)HW supported'
assert scales[2] == scales[ assert scales[2] == scales[
3], 'only aspect-ratio-invariant scale supported' 3], 'only aspect-ratio-invariant scale supported'
scale = scales and scales[2] scale = scales[2]
# try input shape # try input shape
if scale is None: if scale is None:
assert out_shape_, 'neither scales nor output shape available' assert out_shape_, 'neither scales nor output shape available'
...@@ -618,6 +648,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -618,6 +648,7 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
# I/O # I/O
var_theta, = inputs var_theta, = inputs
var_grid, = outputs var_grid, = outputs
assert var_theta and var_grid
# interpretation # interpretation
fluid_op = 'affine_grid' fluid_op = 'affine_grid'
...@@ -644,19 +675,13 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -644,19 +675,13 @@ def AffineGrid(prog, inputs, outputs, attrs, *args, name='', **kwargs):
) )
def AveragePool(prog, def AveragePool(prog, inputs, outputs, attrs, value_infos, name, *args,
inputs,
outputs,
attrs,
value_infos,
name='',
*args,
**kwargs): **kwargs):
""" """
onnx::AveragePool-10: onnx::AveragePool-10:
""" """
return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name=name) return _pool(prog, 'avg', inputs, outputs, attrs, value_infos, name)
def BatchNormalization(prog, def BatchNormalization(prog,
...@@ -675,9 +700,10 @@ def BatchNormalization(prog, ...@@ -675,9 +700,10 @@ def BatchNormalization(prog,
# I/O # I/O
var_x, var_scale, var_b, var_mean, var_var, = inputs var_x, var_scale, var_b, var_mean, var_var, = inputs
var_y, var_mean_, var_var_, var_saved_mean, var_saved_variance, = ( var_y, var_mean_, var_var_, var_saved_mean, var_saved_variance, = (
outputs + [None] * 4)[:5] outputs + [''] * 4)[:5]
assert var_saved_mean or (name != '') assert var_x and var_scale and var_b and var_mean and var_var and var_y
assert var_saved_variance or (name != '') assert var_saved_mean or name
assert var_saved_variance or name
var_saved_mean = var_saved_mean or (name + '.saved_mean') # dummy output var_saved_mean = var_saved_mean or (name + '.saved_mean') # dummy output
var_saved_variance = var_saved_variance or (name + '.saved_variance' var_saved_variance = var_saved_variance or (name + '.saved_variance'
) # dummy output ) # dummy output
...@@ -696,7 +722,7 @@ def BatchNormalization(prog, ...@@ -696,7 +722,7 @@ def BatchNormalization(prog,
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable embed_params &= embeddable
if embed_params: if embed_params:
assert name != '' assert name
embedded_scale = name + '.w_0' embedded_scale = name + '.w_0'
embedded_b = name + '.b_0' embedded_b = name + '.b_0'
embedded_mean = name + '.w_1' embedded_mean = name + '.w_1'
...@@ -744,6 +770,7 @@ def BatchNormalization(prog, ...@@ -744,6 +770,7 @@ def BatchNormalization(prog,
'epsilon': epsilon, 'epsilon': epsilon,
'is_test': 1, 'is_test': 1,
# unused # unused
'data_layout': 'NCHW',
}, },
) )
...@@ -756,6 +783,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -756,6 +783,7 @@ def Cast(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O # I/O
var_input, = inputs var_input, = inputs
var_output, = outputs var_output, = outputs
assert var_input and var_output
# interpretation # interpretation
dtype = attrs['to'] # required dtype = attrs['to'] # required
...@@ -799,6 +827,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -799,6 +827,7 @@ def Concat(prog, inputs, outputs, attrs, *args, name='', **kwargs):
# I/O # I/O
var_ret, = outputs var_ret, = outputs
assert var_ret
# interpretation # interpretation
fluid_op = 'concat' fluid_op = 'concat'
...@@ -833,6 +862,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -833,6 +862,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O # I/O
assert len(inputs) == 0, 'constant op accept no inputs' assert len(inputs) == 0, 'constant op accept no inputs'
var_output, = outputs var_output, = outputs
assert var_output
# interpretation # interpretation
value = attrs['value'] # required value = attrs['value'] # required
...@@ -852,7 +882,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -852,7 +882,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'using value as 1-D tensor may lead to fails', outputs, var_output) 'using value as 1-D tensor may lead to fails', outputs, var_output)
# generation # generation
if len(shape) == 0 or value.size == 1: # scalar or 1-size if not shape or value.size == 1: # scalar or 1-size
shape = [1] # WORKAROUND: bad scalar support shape = [1] # WORKAROUND: bad scalar support
value = value.tolist()[0] value = value.tolist()[0]
fluid_op = 'fill_constant' fluid_op = 'fill_constant'
...@@ -890,6 +920,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -890,6 +920,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O # I/O
var_shape, = inputs var_shape, = inputs
var_output, = outputs var_output, = outputs
assert var_shape and var_output
shape = _const_weight_or_none(value_infos, var_shape) shape = _const_weight_or_none(value_infos, var_shape)
if shape is None: if shape is None:
...@@ -908,7 +939,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -908,7 +939,7 @@ def ConstantOfShape(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[], [],
outputs, outputs,
attrs, attrs,
value_infos, value_infos=value_infos,
) )
...@@ -926,8 +957,9 @@ def Conv(prog, ...@@ -926,8 +957,9 @@ def Conv(prog,
""" """
# I/O # I/O
var_x, var_w, var_b, = (inputs + [None] * 1)[:3] var_x, var_w, var_b, = (inputs + [''] * 1)[:3]
var_y, = outputs var_y, = outputs
assert name and var_x and var_w and var_y
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -945,7 +977,7 @@ def Conv(prog, ...@@ -945,7 +977,7 @@ def Conv(prog,
strides = attrs.get('strides', [1] * convnd) # optional strides = attrs.get('strides', [1] * convnd) # optional
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos, name)
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
embeddable = _check_embeddable(value_infos, embeddable = _check_embeddable(value_infos,
*([var_w] + ([var_b] if var_b else []))) *([var_w] + ([var_b] if var_b else [])))
...@@ -995,7 +1027,7 @@ def Conv(prog, ...@@ -995,7 +1027,7 @@ def Conv(prog,
var_conv = (name + '.conv') if var_b else var_y # hidden variable var_conv = (name + '.conv') if var_b else var_y # hidden variable
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['Input', 'Filter'], [var_x, var_w]), # , 'Bias', 'ResidualData' (['Input', 'Filter'], [var_x, var_w]),
(['Output'], [var_conv]), (['Output'], [var_conv]),
{ {
'strides': strides, 'strides': strides,
...@@ -1012,7 +1044,6 @@ def Conv(prog, ...@@ -1012,7 +1044,6 @@ def Conv(prog,
[var_conv, var_b], # [var_conv, var_b], #
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos,
name=(name + '/bias'), name=(name + '/bias'),
) )
else: else:
...@@ -1033,8 +1064,9 @@ def ConvTranspose(prog, ...@@ -1033,8 +1064,9 @@ def ConvTranspose(prog,
""" """
# I/O # I/O
var_x, var_w, var_b, = (inputs + [None] * 1)[:3] var_x, var_w, var_b, = (inputs + [''] * 1)[:3]
var_y, = outputs var_y, = outputs
assert name and var_x and var_w and var_y
# interpretation # interpretation
assert attrs.get( assert attrs.get(
...@@ -1056,7 +1088,7 @@ def ConvTranspose(prog, ...@@ -1056,7 +1088,7 @@ def ConvTranspose(prog,
dilations = attrs.get('dilations', [1] * convnd) # optional dilations = attrs.get('dilations', [1] * convnd) # optional
output_size = attrs.get('output_shape', []) # optional output_size = attrs.get('output_shape', []) # optional
pads = attrs.get('pads', [0] * (convnd * 2)) # optional pads = attrs.get('pads', [0] * (convnd * 2)) # optional
paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos) paddings, var_x = _pad_if_asymmetric(prog, pads, var_x, value_infos, name)
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
embeddable = _check_embeddable(value_infos, embeddable = _check_embeddable(value_infos,
*([var_w] + ([var_b] if var_b else []))) *([var_w] + ([var_b] if var_b else [])))
...@@ -1109,7 +1141,7 @@ def ConvTranspose(prog, ...@@ -1109,7 +1141,7 @@ def ConvTranspose(prog,
var_conv = (name + '.conv') if var_b else var_y # hidden variable var_conv = (name + '.conv') if var_b else var_y # hidden variable
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['Input', 'Filter'], [var_x, var_w]), # , 'Bias', 'ResidualData' (['Input', 'Filter'], [var_x, var_w]),
(['Output'], [var_conv]), (['Output'], [var_conv]),
{ {
'strides': strides, 'strides': strides,
...@@ -1128,7 +1160,6 @@ def ConvTranspose(prog, ...@@ -1128,7 +1160,6 @@ def ConvTranspose(prog,
[var_conv, var_b], # [var_conv, var_b], #
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos,
name=(name + '/bias'), name=(name + '/bias'),
) )
else: else:
...@@ -1143,6 +1174,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1143,6 +1174,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
# due to fluid fc don't support transposed weight, we use matmul + ew_add # due to fluid fc don't support transposed weight, we use matmul + ew_add
var_a, var_b, var_c, = inputs var_a, var_b, var_c, = inputs
var_y, = outputs var_y, = outputs
assert name and var_a and var_b and var_c and var_y
alpha = attrs.get('alpha', 1.) # optional alpha = attrs.get('alpha', 1.) # optional
beta = attrs.get('beta', 1.) # optional beta = attrs.get('beta', 1.) # optional
...@@ -1160,7 +1192,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1160,7 +1192,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'transpose_y': trans_b, 'transpose_y': trans_b,
'alpha': alpha, 'alpha': alpha,
}, },
value_infos=value_infos,
name=(name + '/mm'), name=(name + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
...@@ -1176,7 +1207,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1176,7 +1207,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_mm, var_c], [var_mm, var_c],
[var_y], [var_y],
{'axis': 1}, {'axis': 1},
value_infos=value_infos,
name=(name + '/bias'), name=(name + '/bias'),
) )
else: else:
...@@ -1198,8 +1228,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1198,8 +1228,6 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[], [],
[var_beta], [var_beta],
{'value': beta}, {'value': beta},
value_infos=value_infos,
name=var_beta,
) )
prog.Op( prog.Op(
'', '',
...@@ -1207,8 +1235,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1207,8 +1235,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_c, var_beta], [var_c, var_beta],
[var_vm], [var_vm],
dict(), dict(),
value_infos=value_infos, name=(name + '.beta/scale'),
name=(var_beta + '/scale'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1223,7 +1250,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1223,7 +1250,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
def GlobalAveragePool(prog, def GlobalAveragePool(prog,
inputs, inputs,
outputs, outputs,
attrs, attrs_,
value_infos, value_infos,
name='', name='',
*args, *args,
...@@ -1232,19 +1259,13 @@ def GlobalAveragePool(prog, ...@@ -1232,19 +1259,13 @@ def GlobalAveragePool(prog,
onnx::GlobalAveragePool-1: onnx::GlobalAveragePool-1:
""" """
return _global_pool(prog, return _global_pool(prog, 'avg', inputs, outputs, value_infos, name=name)
'avg',
inputs,
outputs,
attrs,
value_infos,
name=name)
def GlobalMaxPool(prog, def GlobalMaxPool(prog,
inputs, inputs,
outputs, outputs,
attrs, attrs_,
value_infos, value_infos,
name='', name='',
*args, *args,
...@@ -1253,25 +1274,20 @@ def GlobalMaxPool(prog, ...@@ -1253,25 +1274,20 @@ def GlobalMaxPool(prog,
onnx::GlobalMaxPool-1: onnx::GlobalMaxPool-1:
""" """
return _global_pool(prog, return _global_pool(prog, 'max', inputs, outputs, value_infos, name=name)
'max',
inputs,
outputs,
attrs,
value_infos,
name=name)
def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def GRU(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
""" """
onnx::GRU-7: onnx::GRU-7:
""" """
var_x, var_w, var_r, var_b, var_len, var_xh, = (inputs + [None] * 3)[:6] var_x, var_w, var_r, var_b, var_len, var_xh, = (inputs + [''] * 3)[:6]
var_y, var_yh, = (outputs + [None] * 2)[:2] var_y, var_yh, = (outputs + [''] * 2)[:2]
var_gate = var_y + '.gate' # dummy output assert name and var_x and var_w and var_r # and (var_y or var_yh)
var_reset = var_y + '.reset' # dummy output var_gate = name + '.gate' # dummy output
var_hidden = var_y + '.hidden' # dummy output, # var_yh var_reset = name + '.reset' # dummy output
var_hidden = name + '.hidden' # dummy output, # var_yh
# interpretation # interpretation
x_shape = _shape_or_none(value_infos, var_x) x_shape = _shape_or_none(value_infos, var_x)
...@@ -1279,19 +1295,19 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1279,19 +1295,19 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
assert x_shape[1] == 1, 'only X with batch_size = 1 supported' assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert 'clip' not in attrs, 'clipping not supported' assert 'clip' not in attrs, 'clipping not supported'
hidden_size = attrs.get('hidden_size', None) # optional hidden_size = attrs.get('hidden_size', None) # optional
if not hidden_size: if hidden_size is None:
r_shape = _shape_or_none(value_infos, var_r) r_shape = _shape_or_none(value_infos, var_r)
if r_shape: if r_shape:
hidden_size = r_shape[-1] hidden_size = r_shape[-1]
if not hidden_size: if hidden_size is None:
w_shape = _shape_or_none(value_infos, var_w) w_shape = _shape_or_none(value_infos, var_w)
if w_shape: if w_shape:
hidden_size = w_shape[-2] // 3 hidden_size = w_shape[-2] // 3
if not hidden_size and var_b: if hidden_size is None and var_b:
b_shape = _shape_or_none(value_infos, var_b) b_shape = _shape_or_none(value_infos, var_b)
if b_shape: if b_shape:
hidden_size = b_shape[-1] // 6 hidden_size = b_shape[-1] // 6
if not hidden_size and var_xh: if hidden_size is None and var_xh:
xh_shape = _shape_or_none(value_infos, var_xh) xh_shape = _shape_or_none(value_infos, var_xh)
if xh_shape: if xh_shape:
hidden_size = xh_shape[-1] hidden_size = xh_shape[-1]
...@@ -1313,26 +1329,26 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1313,26 +1329,26 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
# generation # generation
var_x0 = var_x + '_0' # explicit variable var_x0 = name + '_x0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_x], [var_x],
[var_x0], [var_x0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_x + '/index'), name=(name + '.x/index'),
) )
var_w0 = var_w + '_0' # explicit variable var_w0 = name + '_w0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_w], [var_w],
[var_w0], [var_w0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_w + '/index'), name=(name + '.w/index'),
) )
var_fc = var_x0 + '_fc' var_fc = name + '_fc'
var_mm = (var_x0 + '_mm') if var_b else var_fc var_mm = (name + '_mm') if var_b else var_fc
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
...@@ -1342,35 +1358,34 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1342,35 +1358,34 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'transpose_x': 0, 'transpose_x': 0,
'transpose_y': 1, 'transpose_y': 1,
}, },
value_infos=value_infos, name=(name + '/mm'),
name=(var_x0 + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
prog.OpDescAttrs({ prog.OpDescAttrs({
'transpose_X': 0, 'transpose_X': 0,
'transpose_Y': 1, 'transpose_Y': 1,
})) # f**k you API })) # f**k you API
var_r0 = var_r + '_0' # explicit variable var_r0 = name + '_r0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_r], [var_r],
[var_r0], [var_r0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_r + '/index'), name=(name + '.r/index'),
) )
var_r0t = var_r0 + '_t' # explicit variable var_r0t = name + '_r0t' # explicit variable
prog.Op( prog.Op(
'', '',
'Transpose', 'Transpose',
[var_r0], [var_r0],
[var_r0t], [var_r0t],
{'perm': [1, 0]}, # transpose OI->IO {'perm': [1, 0]}, # transpose OI->IO
name=(var_r0 + '/transpose'), name=(name + '.r0/transpose'),
) )
if var_b: if var_b:
var_bi = var_b + '_i' # explicit variable var_bi = name + '_bi' # explicit variable
var_bh = var_b + '_h' # explicit variable var_bh = name + '_bh' # explicit variable
prog.Op( prog.Op(
'', '',
'Split', 'Split',
...@@ -1380,17 +1395,17 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1380,17 +1395,17 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'axis': 1, # split on x 'axis': 1, # split on x
'split': [hidden_size * 3, hidden_size * 3], 'split': [hidden_size * 3, hidden_size * 3],
}, },
name=(var_b + '/split'), name=(name + '.b/split'),
) )
# squeeze bi so Gemm Add can be performed on axis=1 exaclty # squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0 = var_bi + '_0' # explicit variable var_bi0 = name + '_bi0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_bi], [var_bi],
[var_bi0], [var_bi0],
{'axes': [0]}, # slice on d {'axes': [0]}, # slice on d
name=(var_bi + '/index'), name=(name + '.bi/index'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1398,19 +1413,19 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1398,19 +1413,19 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
[var_mm, var_bi0], [var_mm, var_bi0],
[var_fc], [var_fc],
{'axis': 1}, # {'axis': 1}, #
name=(var_x0 + '/bias'), name=(name + '.i/bias'),
) )
if var_xh: if var_xh:
var_xh0 = var_xh + '_0' # explicit variable var_xh0 = name + '_xh0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_xh], [var_xh],
[var_xh0], [var_xh0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xh + '/index'), name=(name + '.xh/index'),
) )
var_y00 = var_y + '_00' # explicit variable var_y00 = name + '_y00' # explicit variable #
prog.Code('{} = layers.{}({}, {}, origin_mode=True' prog.Code('{} = layers.{}({}, {}, origin_mode=True'
', h_0={}' ', h_0={}'
', is_reverse={}' ', is_reverse={}'
...@@ -1447,13 +1462,23 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -1447,13 +1462,23 @@ def GRU(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
'origin_mode': True, 'origin_mode': True,
}, },
) )
if var_y:
prog.Op( prog.Op(
'', '',
'Unsqueeze', 'Unsqueeze',
[var_y00], [var_y00],
[var_y], [var_y],
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_y + '/reshape'), name=(name + '.y/reshape'),
)
if var_yh:
prog.Op(
'',
'Unsqueeze',
[var_y00], #
[var_yh], #
{'axes': [1, 1]}, # extrude on dn
name=(name + '.yh/reshape'),
) )
...@@ -1462,9 +1487,10 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1462,9 +1487,10 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
onnx::LSTM-7: onnx::LSTM-7:
""" """
var_x, var_w, var_r, var_b, var_len, var_xh, var_xc, var_p, = ( var_x, var_w, var_r, var_b, var_len, var_xh, var_xc, var_p, = (inputs +
inputs + [None] * 5)[:8] [''] * 5)[:8]
var_y, var_yh, var_yc, = (outputs + [None] * 3)[:3] var_y, var_yh, var_yc, = (outputs + [''] * 3)[:3]
assert name and var_x and var_w and var_r # and (var_y or var_yh or var_yc)
var_gate = name + '.gate' var_gate = name + '.gate'
var_pre = name + '.pre' var_pre = name + '.pre'
...@@ -1474,27 +1500,27 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1474,27 +1500,27 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
assert x_shape[1] == 1, 'only X with batch_size = 1 supported' assert x_shape[1] == 1, 'only X with batch_size = 1 supported'
assert 'clip' not in attrs, 'clipping not supported' assert 'clip' not in attrs, 'clipping not supported'
hidden_size = attrs.get('hidden_size', None) # optional hidden_size = attrs.get('hidden_size', None) # optional
if not hidden_size: if hidden_size is None:
r_shape = _shape_or_none(value_infos, var_r) r_shape = _shape_or_none(value_infos, var_r)
if r_shape: if r_shape:
hidden_size = r_shape[-1] hidden_size = r_shape[-1]
if not hidden_size: if hidden_size is None:
w_shape = _shape_or_none(value_infos, var_w) w_shape = _shape_or_none(value_infos, var_w)
if w_shape: if w_shape:
hidden_size = w_shape[-2] // 4 hidden_size = w_shape[-2] // 4
if not hidden_size and var_b: if hidden_size is None and var_b:
b_shape = _shape_or_none(value_infos, var_b) b_shape = _shape_or_none(value_infos, var_b)
if b_shape: if b_shape:
hidden_size = b_shape[-1] // 8 hidden_size = b_shape[-1] // 8
if not hidden_size and var_xh: if hidden_size is None and var_xh:
xh_shape = _shape_or_none(value_infos, var_xh) xh_shape = _shape_or_none(value_infos, var_xh)
if xh_shape: if xh_shape:
hidden_size = xh_shape[-1] hidden_size = xh_shape[-1]
if not hidden_size and var_xc: if hidden_size is None and var_xc:
xc_shape = _shape_or_none(value_infos, var_xc) xc_shape = _shape_or_none(value_infos, var_xc)
if xc_shape: if xc_shape:
hidden_size = xc_shape[-1] hidden_size = xc_shape[-1]
if not hidden_size and var_p: if hidden_size is None and var_p:
p_shape = _shape_or_none(value_infos, var_p) p_shape = _shape_or_none(value_infos, var_p)
if p_shape: if p_shape:
hidden_size = p_shape[-1] // 3 hidden_size = p_shape[-1] // 3
...@@ -1520,26 +1546,26 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1520,26 +1546,26 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
# generation # generation
var_x0 = var_x + '_0' # explicit variable var_x0 = name + '_x0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_x], [var_x],
[var_x0], [var_x0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_x + '/index'), name=(name + '.x/index'),
) )
var_w0 = var_w + '_0' # explicit variable var_w0 = name + '_w0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_w], [var_w],
[var_w0], [var_w0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_w + '/index'), name=(name + '.w/index'),
) )
var_fc = var_x0 + '_fc' var_fc = name + '_fc'
var_mm = (var_x0 + '_mm') if var_b else var_fc var_mm = (name + '_mm') if var_b else var_fc
prog.Op( prog.Op(
'', '',
'MatMul', 'MatMul',
...@@ -1549,7 +1575,6 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1549,7 +1575,6 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'transpose_x': 0, 'transpose_x': 0,
'transpose_y': 1, 'transpose_y': 1,
}, },
value_infos=value_infos,
name=(name + '/mm'), name=(name + '/mm'),
) )
prog.op_descs[-1].attrs.extend( prog.op_descs[-1].attrs.extend(
...@@ -1557,27 +1582,27 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1557,27 +1582,27 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'transpose_X': 0, 'transpose_X': 0,
'transpose_Y': 1, 'transpose_Y': 1,
})) # f**k you API })) # f**k you API
var_r0 = var_r + '_0' # explicit variable var_r0 = name + '_r0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_r], [var_r],
[var_r0], [var_r0],
{'axes': [0]}, # index on d {'axes': [0]}, # index on d
name=(var_r + '/index'), name=(name + '.r/index'),
) )
var_r0t = var_r0 + '_t' # explicit variable var_r0t = name + '_r0t' # explicit variable
prog.Op( prog.Op(
'', '',
'Transpose', 'Transpose',
[var_r0], [var_r0],
[var_r0t], [var_r0t],
{'perm': [1, 0]}, # transpose OI->IO {'perm': [1, 0]}, # transpose OI->IO
name=(var_r0 + '/transpose'), name=(name + '.r0/transpose'),
) )
if var_b: if var_b:
var_bi = var_b + '_i' # explicit variable var_bi = name + '_bi' # explicit variable
var_bh = var_b + '_h' # explicit variable var_bh = name + '_bh' # explicit variable
prog.Op( prog.Op(
'', '',
'Split', 'Split',
...@@ -1587,17 +1612,17 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1587,17 +1612,17 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'axis': 1, # split on x 'axis': 1, # split on x
'split': [hidden_size * 4, hidden_size * 4], 'split': [hidden_size * 4, hidden_size * 4],
}, },
name=(var_b + '/split'), name=(name + '.b/split'),
) )
# squeeze bi so Gemm Add can be performed on axis=1 exaclty # squeeze bi so Gemm Add can be performed on axis=1 exaclty
var_bi0 = var_bi + '_0' # explicit variable var_bi0 = name + '_bi0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_bi], [var_bi],
[var_bi0], [var_bi0],
{'axes': [0]}, # slice on d {'axes': [0]}, # slice on d
name=(var_bi + '/index'), name=(name + '.bi/index'),
) )
prog.Op( prog.Op(
'', '',
...@@ -1605,44 +1630,44 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1605,44 +1630,44 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_mm, var_bi0], [var_mm, var_bi0],
[var_fc], [var_fc],
{'axis': 1}, # {'axis': 1}, #
name=(name + '/bias'), name=(name + '.i/bias'),
) )
if var_xh: if var_xh:
var_xh0 = var_xh + '_0' # explicit variable var_xh0 = name + '_xh0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_xh], [var_xh],
[var_xh0], [var_xh0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xh + '/index'), name=(name + '.xh/index'),
) )
if var_xc: if var_xc:
var_xc0 = var_xc + '_0' # explicit variable var_xc0 = name + '_xc0' # explicit variable
prog.Op( prog.Op(
'', '',
'Squeeze', 'Squeeze',
[var_xc], [var_xc],
[var_xc0], [var_xc0],
{'axes': [1]}, # index on n {'axes': [1]}, # index on n
name=(var_xc + '/index'), name=(name + '.xc/index'),
) )
var_bhp = var_p var_bhp = var_p
if var_b: if var_b:
if var_p: if var_p:
var_bhp = var_bh + '_p' # explicit variable var_bhp = name + '_bhp' # explicit variable
prog.Op( prog.Op(
'', '',
'Concat', 'Concat',
[var_bh, var_p], [var_bh, var_p],
[var_bhp], [var_bhp],
{'axes': [1]}, # cat on x {'axis': [1]}, # cat on x
name=(name + '/concat'), name=(name + '/concat'),
) )
else: else:
var_bhp = var_bh var_bhp = var_bh
var_yh0 = var_yh + '_0' # explicit variable var_yh0 = name + '_yh0' # explicit variable
var_yc0 = var_yc + '_0' # explicit variable var_yc0 = name + '_yc0' # explicit variable
prog.Code('{}, {} = layers.{}({}, {}' prog.Code('{}, {} = layers.{}({}, {}'
', h_0={}' ', h_0={}'
', c_0={}' ', c_0={}'
...@@ -1690,14 +1715,23 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1690,14 +1715,23 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'candidate_activation': candidate_activation, 'candidate_activation': candidate_activation,
}, },
) )
# if var_yh: if var_y:
prog.Op(
'',
'Unsqueeze',
[var_yh0], #
[var_y], # var_y
{'axes': [1, 1]}, # extrude on dn
name=(name + '.y/reshape'),
)
if var_yh:
prog.Op( prog.Op(
'', '',
'Unsqueeze', 'Unsqueeze',
[var_yh0], [var_yh0],
[var_y], # var_yh [var_yh], # var_yh
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_y + '/reshape'), name=(name + '.yh/reshape'),
) )
if var_yc: if var_yc:
prog.Op( prog.Op(
...@@ -1706,26 +1740,24 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1706,26 +1740,24 @@ def LSTM(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
[var_yc0], [var_yc0],
[var_yc], [var_yc],
{'axes': [1, 1]}, # extrude on dn {'axes': [1, 1]}, # extrude on dn
name=(var_yc + '/reshape'), name=(name + '.yc/reshape'),
) )
def MaxPool(prog, inputs, outputs, attrs, value_infos, name='', *args, def MaxPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
**kwargs):
""" """
onnx::MaxPool-10: onnx::MaxPool-10:
""" """
return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name=name) return _pool(prog, 'max', inputs, outputs, attrs, value_infos, name)
def MaxRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, def MaxRoiPool(prog, inputs, outputs, attrs, name, *args, **kwargs):
**kwargs):
""" """
onnx::MaxRoiPool-1: onnx::MaxRoiPool-1:
""" """
_roi_pool(prog, 'roi_pool', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'roi_pool', inputs, outputs, attrs, name)
def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
...@@ -1736,6 +1768,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1736,6 +1768,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
# I/O # I/O
var_data, = inputs var_data, = inputs
var_output, = outputs var_output, = outputs
assert var_data and var_output
# interpretation # interpretation
pads = attrs['pads'] # required pads = attrs['pads'] # required
...@@ -1783,7 +1816,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1783,7 +1816,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['X'], [var_data]), (['X', 'Paddings'], [var_data]), #
(['Out'], [var_output]), (['Out'], [var_output]),
od_attrs, od_attrs,
) )
...@@ -1792,7 +1825,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1792,7 +1825,7 @@ def Pad(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
def PRelu(prog, def PRelu(prog,
inputs, inputs,
outputs, outputs,
attrs, attrs_,
value_infos, value_infos,
name='', name='',
embed_params=False, embed_params=False,
...@@ -1805,12 +1838,13 @@ def PRelu(prog, ...@@ -1805,12 +1838,13 @@ def PRelu(prog,
# I/O # I/O
var_x, var_slope, = inputs var_x, var_slope, = inputs
var_y, = outputs var_y, = outputs
assert name and var_x and var_slope and var_y
# interpretation # interpretation
mode = 'channel' mode = 'channel'
slope_shape = _shape_or_none(value_infos, var_slope) slope_shape = _shape_or_none(value_infos, var_slope)
if slope_shape is not None: if slope_shape is not None:
if len(slope_shape) == 0: if not slope_shape:
mode = 'all' mode = 'all'
elif len(slope_shape) >= 2: elif len(slope_shape) >= 2:
if slope_shape[1] != _np.product( if slope_shape[1] != _np.product(
...@@ -1825,7 +1859,7 @@ def PRelu(prog, ...@@ -1825,7 +1859,7 @@ def PRelu(prog,
_logger.warning('broken Python code will be generated') _logger.warning('broken Python code will be generated')
embed_params &= embeddable embed_params &= embeddable
if embed_params: if embed_params:
assert name != '' assert name
embedded_slope = name + '.w_0' embedded_slope = name + '.w_0'
value_infos[var_slope]['embedded_as'].append(embedded_slope) value_infos[var_slope]['embedded_as'].append(embedded_slope)
var_slope = embedded_slope var_slope = embedded_slope
...@@ -1854,15 +1888,15 @@ def PRelu(prog, ...@@ -1854,15 +1888,15 @@ def PRelu(prog,
) )
def PsRoiPool(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): def PsRoiPool(prog, inputs, outputs, attrs, name, *args, **kwargs):
""" """
caffe2::PsRoiPool caffe2::PsRoiPool
""" """
_roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'psroi_pool', inputs, outputs, attrs, name)
def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): def Reshape(prog, inputs, outputs, attrs_, value_infos, name, *args, **kwargs):
""" """
onnx::Reshape-5: onnx::Reshape-5:
""" """
...@@ -1870,10 +1904,12 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1870,10 +1904,12 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
# I/O # I/O
var_data, var_shape, = inputs var_data, var_shape, = inputs
var_reshaped, = outputs var_reshaped, = outputs
assert name and var_data and var_shape and var_reshaped
# interpretation # interpretation
shape = _const_weight_or_none(value_infos, var_shape) shape = _const_weight_or_none(value_infos, var_shape)
is_const_shape = shape and 'const_value' in value_infos[var_shape] is_const_shape = shape is not None and 'const_value' in value_infos[
var_shape]
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, var_reshaped) shape = _shape_or_none(value_infos, var_reshaped)
...@@ -1898,8 +1934,9 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1898,8 +1934,9 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
name_attr = ', name={}'.format(repr(name)) name_attr = ', name={}'.format(repr(name))
# generation # generation
var_shape_int32 = var_shape + ('_int32' if shape_dtype != _np.int32 else '' var_shape_i32 = (
) # explicit variable name + '_shape_i32'
) if shape_dtype != _np.int32 else var_shape # explicit variable
prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape)) prog.Code('# shape: {} = {} # const as literal'.format(var_shape, shape))
if is_const_shape: if is_const_shape:
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
...@@ -1918,9 +1955,16 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1918,9 +1955,16 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
'', '',
'Cast', 'Cast',
[var_shape], [var_shape],
[var_shape_int32], [var_shape_i32],
{'to': _np.dtype('int32')}, # use np.dtype {'to': _np.dtype('int32')}, # use np.dtype
value_infos=value_infos, value_infos={
var_shape: {
'dtype': shape_dtype
},
var_shape_i32: {
'dtype': _np.dtype('int32')
},
},
name=(name + '/cast'), name=(name + '/cast'),
) )
prog.Code('{} = layers.{}({}' prog.Code('{} = layers.{}({}'
...@@ -1932,7 +1976,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1932,7 +1976,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
var_data, var_data,
# attrs # attrs
shape, shape,
var_shape_int32, var_shape_i32,
name_attr, name_attr,
)) ))
fluid_op = 'reshape2' fluid_op = 'reshape2'
...@@ -1941,7 +1985,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1941,7 +1985,7 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
prog.VarDesc(var_xshape) prog.VarDesc(var_xshape)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['X', 'Shape'], [var_data, var_shape_int32]), (['X', 'Shape', 'ShapeTensor'], [var_data, var_shape_i32]), #
(['Out', 'XShape'], [var_reshaped, var_xshape]), (['Out', 'XShape'], [var_reshaped, var_xshape]),
{'shape': shape}, {'shape': shape},
) )
...@@ -1955,44 +1999,57 @@ def Resize(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -1955,44 +1999,57 @@ def Resize(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name) return _interpolate(prog, inputs, outputs, attrs, value_infos, name=name)
def RoiAlign(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): def RoiAlign(prog, inputs, outputs, attrs, name, *args, **kwargs):
""" """
caffe2::RoiAlign caffe2::RoiAlign
""" """
_roi_pool(prog, 'roi_align', inputs, outputs, attrs, value_infos, name) _roi_pool(prog, 'roi_align', inputs, outputs, attrs, name)
#def Shape( def Shape(prog, inputs, outputs, attrs_, name, **kwargs):
# prog, inputs, outputs, attrs, value_infos, """
# *args, **kwargs): onnx::Shape-1:
# """ """
# onnx::ConstantOfShape-1:
# """ # I/O
# var_data, = inputs
# # I/O var_shape, = outputs
# val_data, = inputs assert name and var_data and var_shape
# val_shape, = outputs
# var_data = _make_var_name(val_data) # interpretation
# var_shape = _make_var_name(val_shape) fluid_op = 'shape'
# var_shape_i64 = name + '_shape_i64'
# # interpretation
# fluid_op = 'shape' # generation
## value_infos[val_shape]['remove_batch'] = False prog.Code('{} = layers.{}({})'.format(
# var_shape_i64,
# # generation fluid_op,
# prog.Code('{} = layers.{}({})' var_data,
# .format(var_shape, # attrs
# fluid_op, ))
# var_data, prog.VarDesc(var_shape_i64)
# # attrs prog.OpDesc(
# )) fluid_op,
# prog.VarDesc(var_shape) # , _value_info_or_none(value_infos, val_shape)) (['Input'], [var_data]),
# prog.OpDesc(fluid_op, (['Out'], [var_shape_i64]),
# ([var_data], 'X'), )
# ([var_shape], 'Out'), prog.Op(
# dict(), '',
# ) 'Cast',
[var_shape_i64],
[var_shape],
{'to': _np.dtype('int32')}, # use np.dtype
value_infos={
var_shape: {
'dtype': _np.dtype('int32')
},
var_shape_i64: {
'dtype': _np.dtype('int64')
},
},
name=(name + '/cast'),
)
def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
...@@ -2003,6 +2060,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -2003,6 +2060,7 @@ def Slice(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
# I/O # I/O
var_data, = inputs var_data, = inputs
var_output, = outputs var_output, = outputs
assert var_data and var_output
# interpretation # interpretation
fluid_op = 'slice' fluid_op = 'slice'
...@@ -2059,6 +2117,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -2059,6 +2117,7 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
# I/O # I/O
var_input, = inputs var_input, = inputs
assert var_input
# interpretation # interpretation
fluid_op = 'split' fluid_op = 'split'
...@@ -2093,13 +2152,14 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -2093,13 +2152,14 @@ def Split(prog, inputs, outputs, attrs, *args, name='', **kwargs):
) )
def Sum(prog, inputs, outputs, *args, **kwargs): def Sum(prog, inputs, outputs, attrs_, *args, **kwargs):
""" """
onnx::Sum-8: onnx::Sum-8:
""" """
# I/O # I/O
var_sum, = outputs var_sum, = outputs
assert var_sum
# interpretation # interpretation
fluid_op = 'sums' fluid_op = 'sums'
...@@ -2121,7 +2181,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs): ...@@ -2121,7 +2181,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
) )
def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): def Tile(prog, inputs, outputs, attrs_, value_infos, name='', *args, **kwargs):
""" """
onnx::Tile-1: onnx::Tile-1:
""" """
...@@ -2129,10 +2189,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -2129,10 +2189,11 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
# I/O # I/O
var_input, var_repeats, = inputs var_input, var_repeats, = inputs
var_output, = outputs var_output, = outputs
assert var_input and var_repeats and var_output
# interpretation # interpretation
repeats = _const_weight_or_none(value_infos, var_repeats) repeats = _const_weight_or_none(value_infos, var_repeats)
assert repeats is not None, 'only const repeats supported' assert repeats is not None, 'only const repeats supported' # if contain_tensor(expand_times)
fluid_op = 'expand' fluid_op = 'expand'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -2152,13 +2213,13 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs): ...@@ -2152,13 +2213,13 @@ def Tile(prog, inputs, outputs, attrs, value_infos, name='', *args, **kwargs):
prog.VarDesc(var_output) prog.VarDesc(var_output)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
(['X'], [var_input]), (['X', 'expand_times_tensor'], [var_input]), # TODO
(['Out'], [var_output]), (['Out'], [var_output]),
{'expand_times': repeats}, {'expand_times': repeats},
) )
def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs): def Transpose(prog, inputs, outputs, attrs, name, *args, **kwargs):
""" """
onnx::Transpose-1: onnx::Transpose-1:
""" """
...@@ -2166,6 +2227,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs): ...@@ -2166,6 +2227,7 @@ def Transpose(prog, inputs, outputs, attrs, *args, name='', **kwargs):
# I/O # I/O
var_data, = inputs var_data, = inputs
var_transposed, = outputs var_transposed, = outputs
assert name and var_data and var_transposed
# interpretation # interpretation
fluid_op = 'transpose' fluid_op = 'transpose'
......
...@@ -119,7 +119,7 @@ def export_onnx_with_validation( ...@@ -119,7 +119,7 @@ def export_onnx_with_validation(
return list(map(tensors_to_arrays, tensors)) return list(map(tensors_to_arrays, tensors))
def zip_dict( def zip_dict(
keys: Union[Iterable[Any], None], keys: Optional[Iterable[Any]],
values: Sequence[Union[Any, Sequence[Any]]], values: Sequence[Union[Any, Sequence[Any]]],
) -> MyDict[Text, Union[object, MyDict[Text, object]]]: ) -> MyDict[Text, Union[object, MyDict[Text, object]]]:
keys = keys or range(len(values)) keys = keys or range(len(values))
......
...@@ -160,7 +160,8 @@ def validate(fluid_model_filename, ...@@ -160,7 +160,8 @@ def validate(fluid_model_filename,
logger.info('with %d inputs and %d outputs', len(input_data), logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data)) len(output_data))
elif save_inference_model: elif save_inference_model:
assert inference_input_names, 'input names required for type-shape inference' assert inference_input_names is not None, (
'input names required for type-shape inference')
input_names = inference_input_names input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names)) logger.info('using input names: %s', ', '.join(input_names))
...@@ -178,7 +179,7 @@ def validate(fluid_model_filename, ...@@ -178,7 +179,7 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe) fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed') logger.info('model re-load passed')
if not golden_data_filename: if golden_data_filename == '':
return True return True
# execute # execute
......
...@@ -133,7 +133,7 @@ class Program(object): ...@@ -133,7 +133,7 @@ class Program(object):
od_attr.type = framework_pb2.STRING od_attr.type = framework_pb2.STRING
od_attr.s = value od_attr.s = value
elif isinstance(value, list): elif isinstance(value, list):
if len(value) > 0: # TODO: test all items if value: # TODO: test all items
if isinstance(value[0], if isinstance(value[0],
bool): # bool.mro() = [bool, int, object] bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS od_attr.type = framework_pb2.BOOLEANS
...@@ -183,22 +183,15 @@ class Program(object): ...@@ -183,22 +183,15 @@ class Program(object):
if self.code_mutable: if self.code_mutable:
self.codes.append(code) self.codes.append(code)
def OpDesc(self, def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs):
op_type,
input_key_vals=None,
output_key_vals=None,
attrs=None):
""" """
add OpDesc add OpDesc
""" """
desc = framework_pb2.OpDesc() desc = framework_pb2.OpDesc()
desc.type = op_type desc.type = op_type
if input_key_vals:
desc.inputs.extend(self.OpDescVars(*input_key_vals)) desc.inputs.extend(self.OpDescVars(*input_key_vals))
if output_key_vals:
desc.outputs.extend(self.OpDescVars(*output_key_vals)) desc.outputs.extend(self.OpDescVars(*output_key_vals))
if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs)) desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc) self.op_descs.append(desc)
return desc return desc
...@@ -212,7 +205,7 @@ class Program(object): ...@@ -212,7 +205,7 @@ class Program(object):
add VarDesc, add VarDesc,
""" """
assert name not in self.var_descs, 'var naming conflicted' assert name not in self.var_descs, 'var name {} conflicts'.format(name)
var_desc = framework_pb2.VarDesc() var_desc = framework_pb2.VarDesc()
var_desc.name = name var_desc.name = name
...@@ -220,10 +213,10 @@ class Program(object): ...@@ -220,10 +213,10 @@ class Program(object):
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[name] = var_desc self.var_descs[name] = var_desc
if value_info: if value_info is not None:
self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch) self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs): def Op(self, domain, op_type, inputs, outputs, attrs, *args, **kwargs):
""" """
convert an ONNX op and add it to program convert an ONNX op and add it to program
""" """
...@@ -232,15 +225,17 @@ class Program(object): ...@@ -232,15 +225,17 @@ class Program(object):
raise ValueError('only default domain supported') raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING: if op_type in symbolic.DEFAULT_OP_MAPPING:
symbolic._default(self, op_type, *args, **kwargs) symbolic._default(self, op_type, inputs, outputs, attrs, *args,
**kwargs)
elif hasattr(symbolic, op_type): elif hasattr(symbolic, op_type):
fn = getattr(symbolic, op_type) fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs) fn(self, inputs, outputs, attrs, *args, **kwargs)
else: else:
raise ValueError('conversion for {}::{} not supported'.format( raise ValueError('conversion for {}::{} not supported'.format(
domain, op_type)) domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs): def IntermediateOp(self, domain, op_type, inputs, outputs, attrs, *args,
**kwargs):
""" """
convert an intermediate ONNX op declaring in desc program only convert an intermediate ONNX op declaring in desc program only
""" """
...@@ -248,7 +243,7 @@ class Program(object): ...@@ -248,7 +243,7 @@ class Program(object):
code_mutable = self.code_mutable code_mutable = self.code_mutable
self.code_mutable = False self.code_mutable = False
try: try:
self.Op(domain, op_type, *args, **kwargs) self.Op(domain, op_type, inputs, outputs, attrs, *args, **kwargs)
except BaseException as e: except BaseException as e:
self.code_mutable = code_mutable self.code_mutable = code_mutable
raise e raise e
...@@ -272,9 +267,10 @@ class Program(object): ...@@ -272,9 +267,10 @@ class Program(object):
tensor_desc.data_type = self.Dtype(dtype) # required tensor_desc.data_type = self.Dtype(dtype) # required
shape = value_info.get('shape', None) shape = value_info.get('shape', None)
if shape is not None: if not shape: # None or scalars
return
tensor_desc.dims.extend(shape) tensor_desc.dims.extend(shape)
if len(shape) > 0: # skip scalars
if remove_batch is None: if remove_batch is None:
remove_batch = value_info.get('remove_batch', remove_batch = value_info.get('remove_batch',
False) #not persistable) False) #not persistable)
...@@ -337,8 +333,8 @@ class Writer(object): ...@@ -337,8 +333,8 @@ class Writer(object):
emit an ONNX weight into program emit an ONNX weight into program
""" """
if value_info.get('embedded_as', []): embedded_names = value_info.get('embedded_as', [])
embedded_names = value_info['embedded_as'] if embedded_names:
prog.Code('# parameter {} embedded as {}'.format( prog.Code('# parameter {} embedded as {}'.format(
name, embedded_names)) name, embedded_names))
for embedded_name in embedded_names: for embedded_name in embedded_names:
...@@ -431,7 +427,8 @@ class Writer(object): ...@@ -431,7 +427,8 @@ class Writer(object):
assert lod is None or isinstance(lod, assert lod is None or isinstance(lod,
list), 'lod should be None or list' list), 'lod should be None or list'
lod = lod or [0] if lod is None:
lod = [0]
tensor_desc = framework_pb2.VarType.TensorDesc() tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype) tensor_desc.data_type = Program.Dtype(weight.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册