提交 b39b5399 编写于 作者: B Bin Li

fix: Use 1 element tensor as scalar for Eltwise

上级 37cbf203
......@@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[0] not in self._consts:
const_name = node.inputs[1]
const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0:
dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32:
......@@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[1] not in self._consts:
const_name = node.inputs[0]
const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0:
dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32:
......
......@@ -587,33 +587,38 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType.SUM, EltwiseType.PROD,
EltwiseType.MAX, EltwiseType.MIN]
if (len(tf_op.inputs) > 1 and
len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and
tf_op.inputs[1].op.type == TFOpType.Const.name):
scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = scalar
self._skip_tensor.add(tf_op.inputs[1].name)
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 1
self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1]
elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \
tf_op.inputs[0].op.type == TFOpType.Const.name and \
is_commutative(type_arg.i):
scalar = tf_op.inputs[0].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = scalar
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 0
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
if len(tf_op.inputs) > 1:
shape = self.infer_tensor_shape(tf_op.inputs[1])
if (len(shape) == 0 or
(len(shape) == 1 and shape[0] == 1)) and \
tf_op.inputs[1].op.type == TFOpType.Const.name:
scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = scalar
self._skip_tensor.add(tf_op.inputs[1].name)
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 1
self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1]
else:
shape = self.infer_tensor_shape(tf_op.inputs[0])
if (len(shape) == 0 or
(len(shape) == 1 and shape[0] == 1)) and \
is_commutative(type_arg.i) and \
tf_op.inputs[0].op.type == TFOpType.Const.name:
scalar = tf_op.inputs[0].eval().astype(np.float32)
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = scalar
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 0
self._skip_tensor.add(tf_op.inputs[0].name)
del op.input[0]
except tf.errors.InvalidArgumentError:
pass
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册