提交 b39b5399 编写于 作者: B Bin Li

fix: Use 1 element tensor as scalar for Eltwise

上级 37cbf203
...@@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[0] not in self._consts: node.inputs[0] not in self._consts:
const_name = node.inputs[1] const_name = node.inputs[1]
const_tensor = self._consts[const_name] const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0: dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: if const_tensor.data_type == mace_pb2.DT_INT32:
...@@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[1] not in self._consts: node.inputs[1] not in self._consts:
const_name = node.inputs[0] const_name = node.inputs[0]
const_tensor = self._consts[const_name] const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0: dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: if const_tensor.data_type == mace_pb2.DT_INT32:
......
...@@ -587,33 +587,38 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -587,33 +587,38 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType.SUM, EltwiseType.PROD, EltwiseType.SUM, EltwiseType.PROD,
EltwiseType.MAX, EltwiseType.MIN] EltwiseType.MAX, EltwiseType.MIN]
if (len(tf_op.inputs) > 1 and if len(tf_op.inputs) > 1:
len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and shape = self.infer_tensor_shape(tf_op.inputs[1])
tf_op.inputs[1].op.type == TFOpType.Const.name): if (len(shape) == 0 or
scalar = tf_op.inputs[1].eval().astype(np.float32) (len(shape) == 1 and shape[0] == 1)) and \
value_arg = op.arg.add() tf_op.inputs[1].op.type == TFOpType.Const.name:
value_arg.name = MaceKeyword.mace_scalar_input_str scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg.f = scalar value_arg = op.arg.add()
self._skip_tensor.add(tf_op.inputs[1].name) value_arg.name = MaceKeyword.mace_scalar_input_str
value_index_arg = op.arg.add() value_arg.f = scalar
value_index_arg.name = \ self._skip_tensor.add(tf_op.inputs[1].name)
MaceKeyword.mace_scalar_input_index_str value_index_arg = op.arg.add()
value_index_arg.i = 1 value_index_arg.name = \
self._skip_tensor.add(tf_op.inputs[1].name) MaceKeyword.mace_scalar_input_index_str
del op.input[1] value_index_arg.i = 1
elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \ self._skip_tensor.add(tf_op.inputs[1].name)
tf_op.inputs[0].op.type == TFOpType.Const.name and \ del op.input[1]
is_commutative(type_arg.i): else:
scalar = tf_op.inputs[0].eval().astype(np.float32) shape = self.infer_tensor_shape(tf_op.inputs[0])
value_arg = op.arg.add() if (len(shape) == 0 or
value_arg.name = MaceKeyword.mace_scalar_input_str (len(shape) == 1 and shape[0] == 1)) and \
value_arg.f = scalar is_commutative(type_arg.i) and \
value_index_arg = op.arg.add() tf_op.inputs[0].op.type == TFOpType.Const.name:
value_index_arg.name = \ scalar = tf_op.inputs[0].eval().astype(np.float32)
MaceKeyword.mace_scalar_input_index_str value_arg = op.arg.add()
value_index_arg.i = 0 value_arg.name = MaceKeyword.mace_scalar_input_str
self._skip_tensor.add(tf_op.inputs[0].name) value_arg.f = scalar
del op.input[0] value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 0
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
except tf.errors.InvalidArgumentError: except tf.errors.InvalidArgumentError:
pass pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册