提交 a5cbc20d 编写于 作者: 卢旭辉

Merge branch 'fix-onnx-converter-datatype' into 'master'

fix onnx converter datatype

See merge request !1217
...@@ -191,9 +191,9 @@ OnnxOpType = Enum('OnnxOpType', ...@@ -191,9 +191,9 @@ OnnxOpType = Enum('OnnxOpType',
onnx_attr_translator = { onnx_attr_translator = {
"axis": lambda x: int(x), "axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x], "axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x), "dtype": lambda x: onnx_dtype(x),
"keepdims": lambda x: bool(x), "keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x), "to": lambda x: onnx_dtype(x),
} }
...@@ -567,11 +567,7 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -567,11 +567,7 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.data_type = mace_pb2.DT_FLOAT tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend( tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat) onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32: elif data_type == np.int64 or data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
tensor.data_type = mace_pb2.DT_INT32 tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend( tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat) onnx_tensor.astype(np.int32).flat)
...@@ -668,9 +664,9 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -668,9 +664,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'to' in node.attrs: if 'to' in node.attrs:
dtype = node.attrs['to'] dtype = node.attrs['to']
if dtype == TensorProto.FLOAT: if dtype == np.float32 or dtype == np.float64:
op.output_type.extend([self._option.data_type]) op.output_type.extend([self._option.data_type])
elif dtype == TensorProto.INT: elif dtype == np.int64 or dtype == np.int32:
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
else: else:
mace_check(False, "data type %s not supported" % dtype) mace_check(False, "data type %s not supported" % dtype)
...@@ -959,7 +955,14 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -959,7 +955,14 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0] if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
else:
mace_check(False,
"Does not support param's data type %s"
% const_tensor.data_type)
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str MaceKeyword.mace_scalar_input_index_str
...@@ -972,7 +975,14 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -972,7 +975,14 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0: if len(const_tensor.dims) == 0:
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0] if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
else:
mace_check(False,
"Does not support param's data type %s"
% const_tensor.data_type)
value_index_arg = op.arg.add() value_index_arg = op.arg.add()
value_index_arg.name = \ value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str MaceKeyword.mace_scalar_input_index_str
......
...@@ -253,7 +253,7 @@ class ShapeInference(object): ...@@ -253,7 +253,7 @@ class ShapeInference(object):
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
num_prior = len(aspect_ratio) * len(min_size) + len(max_size) num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = num_prior * input_h * input_w * 4 output_shape[2] = int(num_prior * input_h * input_w * 4)
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op): def infer_shape_reshape(self, op):
...@@ -275,7 +275,7 @@ class ShapeInference(object): ...@@ -275,7 +275,7 @@ class ShapeInference(object):
output_shape[i] = dim[i] output_shape[i] = dim[i]
product *= dim[i] product *= dim[i]
if idx != -1: if idx != -1:
output_shape[idx] = input_size / product output_shape[idx] = int(input_size / product)
self.add_output_shape(op, [output_shape]) self.add_output_shape(op, [output_shape])
else: else:
output_shape = [] output_shape = []
......
...@@ -1440,15 +1440,15 @@ class Transformer(base_converter.ConverterInterface): ...@@ -1440,15 +1440,15 @@ class Transformer(base_converter.ConverterInterface):
arg.i = 1 arg.i = 1
elif arg.i == 3: elif arg.i == 3:
arg.i = 2 arg.i = 2
if op.input[0] in self._producer:
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \ if (producer.type == MaceOp.FullyConnected.name
len(input_shape) == 2: and len(input_shape) == 2):
axis_arg = ConverterUtil.get_arg( axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str) op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1: if axis_arg.i == 1:
axis_arg.i = 3 axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name: elif op.type == MaceOp.Squeeze.name:
for arg in op.arg: for arg in op.arg:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册