提交 66f55b89 编写于 作者: M Macrobull

try shape = [1, -1] in Reshape

上级 a2189f02
......@@ -31,11 +31,17 @@ def _make_var_name(name):
fn = sys.argv[1]
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
data = np.load(fn, encoding='bytes')
input_data = data['inputs']
output_data = data['outputs']
while squeeze_data and input_data.ndim > 4 and input_data.shape[0] == 1:
input_data = input_data.squeeze(0)
while squeeze_data and output_data.ndim > 2 and output_data.shape[0] == 1:
output_data = output_data.squeeze(0)
inputs = Dict(zip(map(_make_var_name, input_names), [input_data]))
outputs = Dict(zip(map(_make_var_name, output_name), [output_data]))
......
......@@ -36,6 +36,7 @@ def _make_var_name(name):
data_dir = os.path.dirname(sys.argv[1])
input_names = sys.argv[2].split(':')
output_name = sys.argv[3].split(':')
squeeze_data = len(sys.argv) > 4
# Load inputs
inputs = []
......@@ -43,7 +44,10 @@ for fn in glob(os.path.join(data_dir, 'input_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
inputs.append(numpy_helper.to_array(tensor))
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 4 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
inputs.append(tensor)
# Load outputs
outputs = []
......@@ -51,7 +55,10 @@ for fn in glob(os.path.join(data_dir, 'output_*.pb')):
tensor = onnx.TensorProto()
with open(fn, 'rb') as f:
tensor.ParseFromString(f.read())
outputs.append(numpy_helper.to_array(tensor))
tensor = numpy_helper.to_array(tensor)
while squeeze_data and tensor.ndim > 2 and tensor.shape[0] == 1:
tensor = tensor.squeeze(0)
outputs.append(tensor)
inputs = Dict(zip(map(_make_var_name, input_names), inputs))
outputs = Dict(zip(map(_make_var_name, output_name), outputs))
......
......@@ -18,6 +18,7 @@ bvlc_alexnet()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -25,7 +26,7 @@ bvlc_alexnet()
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1
python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
......@@ -45,6 +46,7 @@ bvlc_googlenet()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -65,6 +67,7 @@ bvlc_reference_caffenet()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -85,6 +88,7 @@ bvlc_reference_rcnn_ilsvrc13()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -93,8 +97,8 @@ bvlc_reference_rcnn_ilsvrc13()
do
echo "converting $pb_dir"
python convert_data_pb_0.py "$pb_dir" data_0 fc-rcnn_1
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz
python -m onnx2fluid.validation $validate_flags1 -t $(dirname "$pb_dir/x").npz -p 0
python -m onnx2fluid.validation $validate_flags2 -t $(dirname "$pb_dir/x").npz -p 0
done
}
......@@ -105,6 +109,7 @@ inception_v1()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -112,7 +117,7 @@ inception_v1()
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1
python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
......@@ -132,6 +137,7 @@ inception_v2()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -139,7 +145,7 @@ inception_v2()
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" data_0 prob_1
python convert_data_npz_0.py "$npz" data_0 prob_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
......@@ -159,6 +165,7 @@ resnet50()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -166,7 +173,7 @@ resnet50()
for npz in "$bn_tar"/*.npz
do
echo "converting $npz ..."
python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1
python convert_data_npz_0.py "$npz" gpu_0/data_0 gpu_0/softmaxout_1 -s
python -m onnx2fluid.validation $validate_flags1 -t "$npz"
python -m onnx2fluid.validation $validate_flags2 -t "$npz"
done
......@@ -186,6 +193,7 @@ shufflenet()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -206,6 +214,7 @@ squeezenet()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -226,6 +235,7 @@ tiny_yolov2()
fn_model="$bn_tar/model.onnx"
http_get "https://onnxzoo.blob.core.windows.net/models/opset_8/tiny_yolov2/$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -246,6 +256,7 @@ vgg19()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -266,6 +277,7 @@ zfnet512()
fn_model="$bn_tar/model.onnx"
http_get "$base_url$fn_tar"
rm -rf "$bn_tar/"
echo "extracting ..."
tar xf "$fn_tar"
......@@ -288,7 +300,7 @@ inception_v1
inception_v2
resnet50
shufflenet
squeezenet
tiny_yolov2 # not supported
squeezenet # softmax bug
# tiny_yolov2 # not supported
vgg19
zfnet512
......@@ -49,7 +49,8 @@ def main(**kwargs):
basepath, _ = shutil.os.path.splitext(filename)
save_dir = kwargs.get('output_dir', '')
# model.onnx -> model/
save_dir = (save_dir.rstrip('/') if save_dir else basepath) + '/'
save_dir = (save_dir.rstrip(shutil.os.sep)
if save_dir else basepath) + shutil.os.sep
model_basename = DEFAULT_MODEL_MODULE + '.py'
model_func_name = DEFAULT_MODEL_FUNC
embed_params = kwargs.get('embed_params', False)
......@@ -109,7 +110,7 @@ def main(**kwargs):
# create zip file
if archive is not None:
if archive == '':
archive = save_dir.rstrip('/') + '.zip'
archive = save_dir.rstrip(shutil.os.sep) + '.zip'
logger.info('compressing file to %s ...', archive)
shutil.sys.stderr.write('\n')
shutil.sys.stderr.flush()
......
......@@ -69,6 +69,7 @@ DEFAULT_OP_MAPPING = {
'Sin': ['sin', ['X'], ['Out']],
'Squeeze': ['squeeze', ['X'], ['Out']], # attrs bypassed, FIXME: emit squeeze2
'Softplus': ['softplus', ['X'], ['Out']],
# FIXME: default axis = -1, reshape required before and after
'Softmax': ['softmax', ['X'], ['Out'], dict(axis='')],
'Softsign': ['softsign', ['X'], ['Out']],
'Sqrt': ['sqrt', ['X'], ['Out']],
......@@ -799,7 +800,7 @@ def Constant(prog, inputs, outputs, attrs, value_infos, *args, **kwargs):
shape = list(value.shape)
_logger.warning(
'in (Constant -> %s): '
'shape of %s not inferred, '
'attribute "shape" of %s not inferred, '
'using value as 1-D tensor may lead to fails', outputs, val_output)
# generation
......@@ -1152,7 +1153,7 @@ def Gemm(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
vm_dtype = np.dtype('float32')
_logger.warning(
'in %s(%s -> Gemm -> %s): '
'beta seems to be an interger, '
'attribute "beta" seems to be an interger, '
'however dtype can not be inferred, '
'still use float32', name, inputs, outputs)
beta = np.dtype(vm_dtype).type(beta)
......@@ -1432,9 +1433,17 @@ def Reshape(prog, inputs, outputs, attrs, value_infos, name, *args, **kwargs):
is_const_shape = shape and 'const_value' in value_infos[val_shape]
if shape is None:
shape = _shape_or_none(value_infos, val_reshaped)
assert shape is not None, (
'given shape is neither const value nor deductible from output, '
'this is not supported')
# assert shape is not None, ('given shape is neither const value nor deductible from output, '
# 'this is not supported')
if shape is None:
shape = [1, -1] # who knows
_logger.warning(
'in %s(%s -> Reshape -> %s): '
'input "shape" not inferred, use [1, -1] as dummy value, '
'the behavior of Paddle fluid maybe undefined', name, inputs,
outputs)
fluid_op = 'reshape'
name_attr = ', name={}'.format(repr(name)) if name else ''
......@@ -1574,6 +1583,7 @@ def Sum(prog, inputs, outputs, *args, **kwargs):
'[' + ', '.join(var_inps) + ']',
# attrs
))
fluid_op = 'sum'
prog.VarDesc(var_sum)
prog.OpDesc(
fluid_op,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册