提交 2f51ac3b 编写于 作者: L liutuo

fix onnx converter datatype

上级 f9c99209
......@@ -191,9 +191,9 @@ OnnxOpType = Enum('OnnxOpType',
onnx_attr_translator = {
"axis": lambda x: int(x),
"axes": lambda x: [int(a) for a in x],
"dtype": lambda x: data_type.onnx2tf(x),
"dtype": lambda x: onnx_dtype(x),
"keepdims": lambda x: bool(x),
"to": lambda x: data_type.onnx2tf(x),
"to": lambda x: onnx_dtype(x),
}
......@@ -567,11 +567,7 @@ class OnnxConverter(base_converter.ConverterInterface):
tensor.data_type = mace_pb2.DT_FLOAT
tensor.float_data.extend(
onnx_tensor.astype(np.float32).flat)
elif data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
elif data_type == np.int64:
elif data_type == np.int64 or data_type == np.int32:
tensor.data_type = mace_pb2.DT_INT32
tensor.int32_data.extend(
onnx_tensor.astype(np.int32).flat)
......@@ -668,9 +664,9 @@ class OnnxConverter(base_converter.ConverterInterface):
if 'to' in node.attrs:
dtype = node.attrs['to']
if dtype == TensorProto.FLOAT:
if dtype == np.float32 or dtype == np.float64:
op.output_type.extend([self._option.data_type])
elif dtype == TensorProto.INT:
elif dtype == np.int64 or dtype == np.int32:
op.output_type.extend([mace_pb2.DT_INT32])
else:
mace_check(False, "data type %s not supported" % dtype)
......@@ -959,7 +955,14 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
else:
mace_check(False,
"Does not support param's data type %s"
% const_tensor.data_type)
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
......@@ -972,7 +975,14 @@ class OnnxConverter(base_converter.ConverterInterface):
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
if const_tensor.data_type == mace_pb2.DT_INT32:
value_arg.f = float(const_tensor.int32_data[0])
elif const_tensor.data_type == mace_pb2.DT_FLOAT:
value_arg.f = const_tensor.float_data[0]
else:
mace_check(False,
"Does not support param's data type %s"
% const_tensor.data_type)
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
......
......@@ -253,7 +253,7 @@ class ShapeInference(object):
aspect_ratio = ConverterUtil.get_arg(op, MaceKeyword.mace_aspect_ratio_str).floats # noqa
num_prior = len(aspect_ratio) * len(min_size) + len(max_size)
output_shape[2] = num_prior * input_h * input_w * 4
output_shape[2] = int(num_prior * input_h * input_w * 4)
self.add_output_shape(op, [output_shape])
def infer_shape_reshape(self, op):
......@@ -275,7 +275,7 @@ class ShapeInference(object):
output_shape[i] = dim[i]
product *= dim[i]
if idx != -1:
output_shape[idx] = input_size / product
output_shape[idx] = int(input_size / product)
self.add_output_shape(op, [output_shape])
else:
output_shape = []
......
......@@ -1440,15 +1440,15 @@ class Transformer(base_converter.ConverterInterface):
arg.i = 1
elif arg.i == 3:
arg.i = 2
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
len(input_shape) == 2:
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
if op.input[0] in self._producer:
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if (producer.type == MaceOp.FullyConnected.name
and len(input_shape) == 2):
axis_arg = ConverterUtil.get_arg(
op, MaceKeyword.mace_axis_str)
if axis_arg.i == 1:
axis_arg.i = 3
elif op.type == MaceOp.Squeeze.name:
for arg in op.arg:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册