提交 66f55b89 编写于 作者: M Macrobull

try shape = [1, -1] in Reshape

上级 a2189f02
...@@ -31,11 +31,17 @@ def _make_var_name(name): ...@@ -31,11 +31,17 @@ def _make_var_name(name):
fn = sys.argv[1] fn = sys.argv[1]
input_names = sys.argv[2].split(':') input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':') output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
data = np.load(fn, encoding='bytes') data = np.load(fn, encoding='bytes')
input_data = data['inputs'] input_data = data['inputs']
output_data = data['outputs'] output_data = data['outputs']
while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
input_data = input_data.squeeze(0)
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
output_data = output_data.squeeze(0)
inputs = Dict(zip(map(_make_var_name, input_names), [input_data])) inputs = Dict(zip(map(_make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data])) outputs = Dict(zip(map(_make_var_name, output_name), [output_data]))
......
...@@ -36,6 +36,7 @@ def _make_var_name(name): ...@@ -36,6 +36,7 @@ def _make_var_name(name):
data_dir = os.path.dirname(sys.argv[1]) data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':') input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':') output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
# Load inputs # Load inputs
inputs = [] inputs = []
...@@ -43,7 +44,10 @@ for fn in glob(os.path.join(data_dir, 'input_*.pb')): ...@@ -43,7 +44,10 @@ for fn in glob(os.path.join(data_dir, 'input_*.pb')):
tensor = onnx.TensorProto() tensor = onnx.TensorProto()
with open(fn, 'rb') as f: with open(fn, 'rb') as f:
tensor.ParseFromString(f.read()) tensor.ParseFromString(f.read())
inputs.append(numpy_helper.to_array(tensor)) tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 4 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
inputs.append(tensor)
# Load outputs # Load outputs
outputs = [] outputs = []
...@@ -51,7 +55,10 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')): ...@@ -51,7 +55,10 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = onnx.TensorProto() tensor = onnx.TensorProto()
with open(fn, 'rb') as f: with open(fn, 'rb') as f:
tensor.ParseFromString(f.read()) tensor.ParseFromString(f.read())
outputs.append(numpy_helper.to_array(tensor)) tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 2 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
outputs.append(tensor)
inputs = Dict(zip(map(_make_var_name, input_names), inputs)) inputs = Dict(zip(map(_make_var_name, input_names), inputs))
outputs = Dict(zip(map(_make_var_name, output_name), outputs)) outputs = Dict(zip(map(_make_var_name, output_name), outputs))
......
...@@ -18,6 +18,7 @@ bvlc_alexnet() ...@@ -18,6 +18,7 @@ bvlc_alexnet()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -25,7 +26,7 @@ bvlc_alexnet() ...@@ -25,7 +26,7 @@ bvlc_alexnet()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
...@@ -45,6 +46,7 @@ bvlc_googlenet() ...@@ -45,6 +46,7 @@ bvlc_googlenet()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -65,6 +67,7 @@ bvlc_reference_caffenet() ...@@ -65,6 +67,7 @@ bvlc_reference_caffenet()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -85,6 +88,7 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -85,6 +88,7 @@ bvlc_reference_rcnn_ilsvrc13()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -93,8 +97,8 @@ bvlc_reference_rcnn_ilsvrc13() ...@@ -93,8 +97,8 @@ bvlc_reference_rcnn_ilsvrc13()
do do
echo "converting $pb_dir" echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 fc-rcnn_1 python convert_data_pb_0.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz -p 0
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz -p 0
done done
} }
...@@ -105,6 +109,7 @@ inception_v1() ...@@ -105,6 +109,7 @@ inception_v1()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -112,7 +117,7 @@ inception_v1() ...@@ -112,7 +117,7 @@ inception_v1()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
...@@ -132,6 +137,7 @@ inception_v2() ...@@ -132,6 +137,7 @@ inception_v2()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -139,7 +145,7 @@ inception_v2() ...@@ -139,7 +145,7 @@ inception_v2()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1 python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
...@@ -159,6 +165,7 @@ resnet50() ...@@ -159,6 +165,7 @@ resnet50()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -166,7 +173,7 @@ resnet50() ...@@ -166,7 +173,7 @@ resnet50()
for npz in "$bn_tar"/*.npz for npz in "$bn_tar"/*.npz
do do
echo "converting $npz ..." echo "converting $npz ..."
python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz" python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz" python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done done
...@@ -186,6 +193,7 @@ shufflenet() ...@@ -186,6 +193,7 @@ shufflenet()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -206,6 +214,7 @@ squeezenet() ...@@ -206,6 +214,7 @@ squeezenet()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -226,6 +235,7 @@ tiny_yolov2() ...@@ -226,6 +235,7 @@ tiny_yolov2()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar" http_get "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -246,6 +256,7 @@ vgg19() ...@@ -246,6 +256,7 @@ vgg19()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -266,6 +277,7 @@ zfnet512() ...@@ -266,6 +277,7 @@ zfnet512()
fn_model="$bn_tar/model.onnx" fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar" http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..." echo "extracting ..."
tar xf "$fn_tar" tar xf "$fn_tar"
...@@ -288,7 +300,7 @@ inception_v1 ...@@ -288,7 +300,7 @@ inception_v1
inception_v2 inception_v2
resnet50 resnet50
shufflenet shufflenet
squeezenet squeezenet # softmax bug
tiny_yolov2 # not supported # tiny_yolov2 # not supported
vgg19 vgg19
zfnet512 zfnet512
...@@ -49,7 +49,8 @@ def main(**kwargs): ...@@ -49,7 +49,8 @@ def main(**kwargs):
basepath, _ = shutil.os.path.splitext(filename) basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '') save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/ # model.onnx -> model/
save_dir = (save_dir.rstrip('/') if save_dir else basepath) + '/' save_dir = (save_dir.rstrip(shutil.os.sep)
if save_dir else basepath) + shutil.os.sep
model_basename = DEFAULT_MODEL_MODULE + '.py' model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False) embed_params = kwargs.get('embed_params', False)
...@@ -109,7 +110,7 @@ def main(**kwargs): ...@@ -109,7 +110,7 @@ def main(**kwargs):
# create zip file # create zip file
if archive is not None: if archive is not None:
if archive == '': if archive == '':
archive = save_dir.rstrip('/') + '.zip' archive = save_dir.rstrip(shutil.os.sep) + '.zip'
logger.info('compressing file to %s ...', archive) logger.info('compressing file to %s ...', archive)
shutil.sys.stderr.write('\n') shutil.sys.stderr.write('\n')
shutil.sys.stderr.flush() shutil.sys.stderr.flush()
......
...@@ -69,6 +69,7 @@ DEFAULT_OP_MAPPING = { ...@@ -69,6 +69,7 @@ DEFAULT_OP_MAPPING = {
'Sin': ['sin', ['X'], ['Out']], 'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2 'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']], 'Softplus': ['softplus', ['X'], ['Out']],
# FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')], 'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softsign': ['softsign', ['X'], ['Out']], 'Softsign': ['softsign', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']], 'Sqrt': ['sqrt', ['X'], ['Out']],
...@@ -799,7 +800,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs): ...@@ -799,7 +800,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
shape = list(value.shape) shape = list(value.shape)
_logger.warning( _logger.warning(
'in (Constant -> %s): ' 'in (Constant -> %s): '
'shape of %s not inferred, ' 'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output) 'using value as 1-D tensor may lead to fails', outputs, val_output)
# generation # generation
...@@ -1152,7 +1153,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1152,7 +1153,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
vm_dtype = np.dtype('float32') vm_dtype = np.dtype('float32')
_logger.warning( _logger.warning(
'in %s(%s -> Gemm -> %s): ' 'in %s(%s -> Gemm -> %s): '
'beta seems to be an interger, ' 'attribute "beta" seems to be an interger, '
'however dtype can not be inferred, ' 'however dtype can not be inferred, '
'still use float32', name, inputs, outputs) 'still use float32', name, inputs, outputs)
beta = np.dtype(vm_dtype).type(beta) beta = np.dtype(vm_dtype).type(beta)
...@@ -1432,9 +1433,17 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs): ...@@ -1432,9 +1433,17 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
is_const_shape = shape and 'const_value' in value_infos[val_shape] is_const_shape = shape and 'const_value' in value_infos[val_shape]
if shape is None: if shape is None:
shape = _shape_or_none(value_infos, val_reshaped) shape = _shape_or_none(value_infos, val_reshaped)
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported') # assert shape is not None, ('given shape is neither const value nor deductible from output, '
# 'this is not supported')
if shape is None:
shape = [1, -1] # who knows
_logger.warning(
'in %s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined', name, inputs,
outputs)
fluid_op = 'reshape' fluid_op = 'reshape'
name_attr = ', name={}'.format(repr(name)) if name else '' name_attr = ', name={}'.format(repr(name)) if name else ''
...@@ -1574,6 +1583,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs): ...@@ -1574,6 +1583,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
'[' + ', '.join(var_inps) + ']', '[' + ', '.join(var_inps) + ']',
# attrs # attrs
)) ))
fluid_op = 'sum'
prog.VarDesc(var_sum) prog.VarDesc(var_sum)
prog.OpDesc( prog.OpDesc(
fluid_op, fluid_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册