提交 9d147284 编写于 作者: M Macrobull

better type-shape inference support

上级 74feae9b
......@@ -17,8 +17,8 @@ def make_var_name(name):
make a valid variable name in Python code
"""
if name == '':
return '_'
assert name
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:-': #
......
......@@ -20,8 +20,8 @@ def make_var_name(name):
make a valid variable name in Python code
"""
if name == '':
return '_'
assert name
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:-': #
......
......@@ -2,14 +2,17 @@
# setopt SH_WORD_SPLIT # if zsh
# alias python="python3" # if ...
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
base_url="https://s3.amazonaws.com/download.onnx/models/opset_9/"
convert_cmd="python -m onnx2fluid"
validate_cmd="$convert_cmd.validation"
convert_flags="-e -o /tmp/export/"
validate_flags1="/tmp/export/model.py"
validate_flags2="/tmp/export/__model__"
# alias http_get="wget -c" # if no aria2
alias http_get="aria2c -c -s8 -x8"
# alias python="python3" # if ...
validate_flags3="/tmp/export/__model__ -i"
bvlc_alexnet()
......@@ -23,21 +26,23 @@ bvlc_alexnet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
$validate_cmd $validate_flags1 -t "$npz"
$validate_cmd $validate_flags2 -t "$npz"
done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -53,14 +58,15 @@ bvlc_googlenet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -76,14 +82,15 @@ bvlc_reference_caffenet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -99,14 +106,15 @@ bvlc_reference_rcnn_ilsvrc13()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -122,21 +130,23 @@ densenet121()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 fc6_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
$validate_cmd $validate_flags1 -t "$npz"
$validate_cmd $validate_flags2 -t "$npz"
done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 fc6_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -152,14 +162,15 @@ emotion_ferplus()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" Input3 Plus692_Output_0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -175,21 +186,23 @@ inception_v1()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
$validate_cmd $validate_flags1 -t "$npz"
$validate_cmd $validate_flags2 -t "$npz"
done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -205,21 +218,23 @@ inception_v2()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
$validate_cmd $validate_flags1 -t "$npz"
$validate_cmd $validate_flags2 -t "$npz"
done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -235,14 +250,15 @@ mobilenet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data mobilenetv20_output_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -258,14 +274,15 @@ resnet18()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv15_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -281,21 +298,23 @@ resnet50()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for npz in "$bn_tar/"*.npz
do
echo "converting $npz ..."
python convert_data_npz.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
$validate_cmd $validate_flags1 -t "$npz"
$validate_cmd $validate_flags2 -t "$npz"
done
$validate_cmd $validate_flags3 -t "$npz"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -311,14 +330,15 @@ resnet100_arcface()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data fc1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -334,14 +354,15 @@ resnet101_duc()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data seg_loss
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -357,14 +378,15 @@ resnet152()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data resnetv27_dense0_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -380,14 +402,15 @@ shufflenet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -403,14 +426,15 @@ squeezenet()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 softmaxout_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -426,14 +450,15 @@ squeezenet1v1()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data squeezenet0_flatten0_reshape0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -450,14 +475,15 @@ ssd()
mkdir "$bn_tar"
tar xf "$fn_tar" -C "$bn_tar/"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" image bboxes,labels,scores
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -473,14 +499,15 @@ tiny_yolov2()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" image grid
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -496,14 +523,15 @@ vgg16bn()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -y
$convert_cmd $convert_flags "$fn_model" -y
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" data vgg0_dense2_fwd
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -519,14 +547,15 @@ vgg19()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" data_0 prob_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -535,21 +564,22 @@ yolov3()
{
bn_tar="yolov3"
fn_tar="$bn_tar.tar.gz"
fn_model="$bn_tar/model.onnx"
fn_model="$bn_tar/yolov3.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_10/yolov3/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model" -x #
$convert_cmd $convert_flags "$fn_model" -x #
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir ..."
python convert_data_pb.py "$pb_dir" input_1:01,image_shape:01 yolonms_layer_1/ExpandDims_1:0,yolonms_layer_1/ExpandDims_3:0,yolonms_layer_1/concat_2:0
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......@@ -565,14 +595,15 @@ zfnet512()
echo "extracting ..."
tar xf "$fn_tar"
python -m onnx2fluid $convert_flags "$fn_model"
$convert_cmd $convert_flags "$fn_model"
for pb_dir in "$bn_tar/"*/
do
echo "converting $pb_dir"
python convert_data_pb.py "$pb_dir" gpu_0/data_0 gpu_0/softmax_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags1 -t $(dirname "$pb_dir/x").npz
$validate_cmd $validate_flags2 -t $(dirname "$pb_dir/x").npz
done
$validate_cmd $validate_flags3 -t $(dirname "$pb_dir/x").npz
rm -rf "$bn_tar/"
}
......
......@@ -61,11 +61,14 @@ def main(**kwargs):
passed = True
golden_data_filename = kwargs.pop('test_data', '')
infer_inputs = kwargs.pop('infer_inputs', None)
if golden_data_filename or infer_inputs is not None:
save_inference_model = infer_inputs is not None
if golden_data_filename or save_inference_model:
from .validation import validate
save_inference_model = infer_inputs is not None
inference_input_names = infer_inputs and infer_inputs.split(',')
if save_inference_model:
inference_input_names = infer_inputs.split(',')
else:
inference_input_names = None
logger.info('starting validation on desc ...')
passed &= validate(shutil.os.path.join(save_dir, '__model__'),
......
......@@ -23,7 +23,7 @@ def make_var_name(name):
"""
if name == '':
return '_'
return ''
if name[0].isdigit():
return 'var_' + name
for s in ' \\|/:.-':
......@@ -170,7 +170,7 @@ def convert(onnx_model_filename,
for var_name, var_desc in fluid_program.var_descs.items():
if not var_desc.type.lod_tensor.HasField('tensor'):
bad_vars.append(var_name)
if len(bad_vars) > 0:
if bad_vars:
logger.warning('type-shape not infered for var %s ...',
', '.join(bad_vars[:5]))
logger.warning('this causes little problem for PaddlePaddle, '
......
......@@ -99,6 +99,9 @@ def get_attribute_value2(attr):
elif attr.type == onnx.AttributeProto.STRING:
value = attr.s
value = value.decode() if isinstance(value, bytes) else value
elif attr.type == onnx.AttributeProto.STRINGS:
value = attr.strings
value = [s.decode() if isinstance(s, bytes) else s for s in value]
else:
value = get_attribute_value(attr)
return value
......@@ -161,12 +164,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_in_degrees):
if degree == 0:
queue.append(node_idx)
while len(queue) > 0:
while queue:
node_idx = queue.pop(0)
node_topo.append(node_idx)
for val_name in nodes[node_idx].output:
output_refs[val_name].remove(node_idx)
if len(output_refs[val_name]) > 0:
if output_refs[val_name]:
continue
output_refs.pop(val_name)
if val_name not in input_refs:
......@@ -186,12 +189,12 @@ def node_topo(nodes, topo='default'):
for node_idx, degree in enumerate(node_out_degrees):
if degree == 0:
queue.append(node_idx)
while len(queue) > 0:
while queue:
node_idx = queue.pop(0)
node_topo.append(node_idx)
for val_name in nodes[node_idx].input:
input_refs[val_name].remove(node_idx)
if len(input_refs[val_name]) > 0:
if input_refs[val_name]:
continue
input_refs.pop(val_name)
if val_name not in output_refs:
......@@ -210,7 +213,10 @@ def node_iter(nodes, indices=None):
generator for ONNX node graph with given indices
"""
for index in indices or range(len(nodes)):
if indices is None:
indices = range(len(nodes))
for index in indices:
node = nodes[index]
name = node.name
domain = node.domain
......@@ -221,9 +227,11 @@ def node_iter(nodes, indices=None):
if name == '':
name = 'op_' + str(index)
else: # make_op_name
for s in ' \\|/:-': #
name = name.replace(s, '_')
# else: # make_op_name
# for s in ' \\|/:-': #
# name = name.replace(s, '_')
if domain == '':
domain = DEFAULT_OP_DOMAIN
......@@ -356,10 +364,11 @@ def polish_and_save(model_filename,
run polish_model and save
"""
if save_filename is None:
save_filename = model_filename.replace('.onnx', suffix + '.onnx')
model = onnx.load(model_filename)
model = polish_model(model, *args, **kwargs)
save_filename = save_filename or model_filename.replace(
'.onnx', suffix + '.onnx')
onnx.save(model, save_filename)
logger.info('polished model saved to: %s', save_filename)
return save_filename
......@@ -495,7 +504,7 @@ def optimize_model_cast(model):
for node_idx, node in enumerate(nodes):
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
continue
if not (node.op_type == 'Cast'):
if node.op_type != 'Cast':
continue
attrs = node_attrs(node)
output_dtype = TENSOR_TYPE_TO_NP_TYPE[attrs['to']]
......@@ -551,7 +560,7 @@ def optimize_model_slice(model):
node = nodes[node_idx]
if not (node.domain == DEFAULT_OP_DOMAIN or node.domain == ''):
return chain
if not node.op_type == 'Slice':
if node.op_type != 'Slice':
return chain
chain.append(node_idx)
output_name = node.output[0]
......@@ -585,10 +594,10 @@ def optimize_model_slice(model):
nodes_to_remove = []
for node_idx in range(len(nodes)):
slice_chain = build_slice_node_chain(node_idx)
if len(slice_chain) == 0:
if not slice_chain:
continue
merged_slice = merge_slice(slice_chain)
if len(merged_slice) > 0 and len(slice_chain) == 1: # no need to merge
if merged_slice and len(slice_chain) == 1: # no need to merge
continue
attrs = {'axes': [], 'starts': [], 'ends': []}
......@@ -602,12 +611,11 @@ def optimize_model_slice(model):
output_name = last_node.output[0]
processed = -1
if output_name in input_refs: # 0, [1...]
new_input_name = first_node.output[0] if len(
merged_slice) > 0 else input_name
new_input_name = first_node.output[0] if merged_slice else input_name
processed = skip_node_forward(ret_nodes, output_name,
new_input_name, input_refs)
if processed > 0:
if len(merged_slice) > 0:
if merged_slice:
remain_idx = slice_chain[0]
remove_chain = slice_chain[1:]
slice_node = ret_nodes[remain_idx]
......@@ -621,12 +629,11 @@ def optimize_model_slice(model):
remove_chain = slice_chain
if processed < 0 and input_name in output_refs:
new_output_name = last_node.input[0] if len(
merged_slice) > 0 else output_name
new_output_name = last_node.input[0] if merged_slice else output_name
processed = skip_node_backward(ret_nodes, input_name,
new_output_name, output_refs)
if processed > 0:
if len(merged_slice) > 0:
if merged_slice:
remain_idx = slice_chain[-1]
remove_chain = slice_chain[:-1]
slice_node = ret_nodes[remain_idx]
......@@ -641,7 +648,7 @@ def optimize_model_slice(model):
if processed > 0:
nodes_to_remove.extend(remove_chain)
if len(merged_slice) == 0:
if not merged_slice:
logger.debug('skip slice chain %s -> %s -> %s', input_name,
slice_chain, output_name)
elif processed < 0: # NEVERFIX: not merge standalone slice chain
......
此差异已折叠。
......@@ -119,7 +119,7 @@ def export_onnx_with_validation(
return list(map(tensors_to_arrays, tensors))
def zip_dict(
keys: Union[Iterable[Any], None],
keys: Optional[Iterable[Any]],
values: Sequence[Union[Any, Sequence[Any]]],
) -> MyDict[Text, Union[object, MyDict[Text, object]]]:
keys = keys or range(len(values))
......
......@@ -160,7 +160,8 @@ def validate(fluid_model_filename,
logger.info('with %d inputs and %d outputs', len(input_data),
len(output_data))
elif save_inference_model:
assert inference_input_names, 'input names required for type-shape inference'
assert inference_input_names is not None, (
'input names required for type-shape inference')
input_names = inference_input_names
logger.info('using input names: %s', ', '.join(input_names))
......@@ -178,7 +179,7 @@ def validate(fluid_model_filename,
fluid.io.load_inference_model(fluid_model_dir, exe)
logger.info('model re-load passed')
if not golden_data_filename:
if golden_data_filename == '':
return True
# execute
......
......@@ -133,7 +133,7 @@ class Program(object):
od_attr.type = framework_pb2.STRING
od_attr.s = value
elif isinstance(value, list):
if len(value) > 0: # TODO: test all items
if value: # TODO: test all items
if isinstance(value[0],
bool): # bool.mro() = [bool, int, object]
od_attr.type = framework_pb2.BOOLEANS
......@@ -183,23 +183,16 @@ class Program(object):
if self.code_mutable:
self.codes.append(code)
def OpDesc(self,
op_type,
input_key_vals=None,
output_key_vals=None,
attrs=None):
def OpDesc(self, op_type, input_key_vals, output_key_vals, attrs):
"""
add OpDesc
"""
desc = framework_pb2.OpDesc()
desc.type = op_type
if input_key_vals:
desc.inputs.extend(self.OpDescVars(*input_key_vals))
if output_key_vals:
desc.outputs.extend(self.OpDescVars(*output_key_vals))
if attrs:
desc.attrs.extend(self.OpDescAttrs(attrs))
desc.inputs.extend(self.OpDescVars(*input_key_vals))
desc.outputs.extend(self.OpDescVars(*output_key_vals))
desc.attrs.extend(self.OpDescAttrs(attrs))
self.op_descs.append(desc)
return desc
......@@ -212,7 +205,7 @@ class Program(object):
add VarDesc,
"""
assert name not in self.var_descs, 'var naming conflicted'
assert name not in self.var_descs, 'var name {} conflicts'.format(name)
var_desc = framework_pb2.VarDesc()
var_desc.name = name
......@@ -220,10 +213,10 @@ class Program(object):
var_desc.type.type = framework_pb2.VarType.LOD_TENSOR
self.var_descs[name] = var_desc
if value_info:
if value_info is not None:
self.VarTypeShapeInfo(name, value_info, remove_batch=remove_batch)
def Op(self, domain, op_type, *args, **kwargs):
def Op(self, domain, op_type, inputs, outputs, attrs, *args, **kwargs):
"""
convert an ONNX op and add it to program
"""
......@@ -232,15 +225,17 @@ class Program(object):
raise ValueError('only default domain supported')
if op_type in symbolic.DEFAULT_OP_MAPPING:
symbolic._default(self, op_type, *args, **kwargs)
symbolic._default(self, op_type, inputs, outputs, attrs, *args,
**kwargs)
elif hasattr(symbolic, op_type):
fn = getattr(symbolic, op_type)
fn(self, *args, **kwargs)
fn(self, inputs, outputs, attrs, *args, **kwargs)
else:
raise ValueError('conversion for {}::{} not supported'.format(
domain, op_type))
def IntermediateOp(self, domain, op_type, *args, **kwargs):
def IntermediateOp(self, domain, op_type, inputs, outputs, attrs, *args,
**kwargs):
"""
convert an intermediate ONNX op declaring in desc program only
"""
......@@ -248,7 +243,7 @@ class Program(object):
code_mutable = self.code_mutable
self.code_mutable = False
try:
self.Op(domain, op_type, *args, **kwargs)
self.Op(domain, op_type, inputs, outputs, attrs, *args, **kwargs)
except BaseException as e:
self.code_mutable = code_mutable
raise e
......@@ -272,14 +267,15 @@ class Program(object):
tensor_desc.data_type = self.Dtype(dtype) # required
shape = value_info.get('shape', None)
if shape is not None:
tensor_desc.dims.extend(shape)
if len(shape) > 0: # skip scalars
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
False) #not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
if not shape: # None or scalars
return
tensor_desc.dims.extend(shape)
if remove_batch is None:
remove_batch = value_info.get('remove_batch',
False) #not persistable)
if remove_batch:
tensor_desc.dims[0] = -1
class Writer(object):
......@@ -337,8 +333,8 @@ class Writer(object):
emit an ONNX weight into program
"""
if value_info.get('embedded_as', []):
embedded_names = value_info['embedded_as']
embedded_names = value_info.get('embedded_as', [])
if embedded_names:
prog.Code('# parameter {} embedded as {}'.format(
name, embedded_names))
for embedded_name in embedded_names:
......@@ -431,7 +427,8 @@ class Writer(object):
assert lod is None or isinstance(lod,
list), 'lod should be None or list'
lod = lod or [0]
if lod is None:
lod = [0]
tensor_desc = framework_pb2.VarType.TensorDesc()
tensor_desc.data_type = Program.Dtype(weight.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册