提交 2cf5ee88 编写于 作者: 李超

Merge branch 'scalar' into 'master'

fix: Use 1 element tensor as scalar for Eltwise

See merge request applied-machine-learning/sysml/mace!1308
...@@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[0] not in self._consts: node.inputs[0] not in self._consts:
const_name = node.inputs[1] const_name = node.inputs[1]
const_tensor = self._consts[const_name] const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0: dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: if const_tensor.data_type == mace_pb2.DT_INT32:
...@@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface):
node.inputs[1] not in self._consts: node.inputs[1] not in self._consts:
const_name = node.inputs[0] const_name = node.inputs[0]
const_tensor = self._consts[const_name] const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0: dims = const_tensor.dims
if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1):
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
if const_tensor.data_type == mace_pb2.DT_INT32: if const_tensor.data_type == mace_pb2.DT_INT32:
......
...@@ -587,9 +587,11 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -587,9 +587,11 @@ class TensorflowConverter(base_converter.ConverterInterface):
EltwiseType.SUM, EltwiseType.PROD, EltwiseType.SUM, EltwiseType.PROD,
EltwiseType.MAX, EltwiseType.MIN] EltwiseType.MAX, EltwiseType.MIN]
if (len(tf_op.inputs) > 1 and if len(tf_op.inputs) > 1:
len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and shape = self.infer_tensor_shape(tf_op.inputs[1])
tf_op.inputs[1].op.type == TFOpType.Const.name): if (len(shape) == 0 or
(len(shape) == 1 and shape[0] == 1)) and \
tf_op.inputs[1].op.type == TFOpType.Const.name:
scalar = tf_op.inputs[1].eval().astype(np.float32) scalar = tf_op.inputs[1].eval().astype(np.float32)
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
...@@ -601,9 +603,12 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -601,9 +603,12 @@ class TensorflowConverter(base_converter.ConverterInterface):
value_index_arg.i = 1 value_index_arg.i = 1
self._skip_tensor.add(tf_op.inputs[1].name) self._skip_tensor.add(tf_op.inputs[1].name)
del op.input[1] del op.input[1]
elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \ else:
tf_op.inputs[0].op.type == TFOpType.Const.name and \ shape = self.infer_tensor_shape(tf_op.inputs[0])
is_commutative(type_arg.i): if (len(shape) == 0 or
(len(shape) == 1 and shape[0] == 1)) and \
is_commutative(type_arg.i) and \
tf_op.inputs[0].op.type == TFOpType.Const.name:
scalar = tf_op.inputs[0].eval().astype(np.float32) scalar = tf_op.inputs[0].eval().astype(np.float32)
value_arg = op.arg.add() value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str value_arg.name = MaceKeyword.mace_scalar_input_str
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册