diff --git a/tools/python/transform/onnx_converter.py b/tools/python/transform/onnx_converter.py index 57e50dca16394de4411ffa67da7eb3a29c79e7fa..8dd0dbd1756b64930da52ea31b9da02135168753 100644 --- a/tools/python/transform/onnx_converter.py +++ b/tools/python/transform/onnx_converter.py @@ -950,7 +950,8 @@ class OnnxConverter(base_converter.ConverterInterface): node.inputs[0] not in self._consts: const_name = node.inputs[1] const_tensor = self._consts[const_name] - if len(const_tensor.dims) == 0: + dims = const_tensor.dims + if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1): value_arg = op.arg.add() value_arg.name = MaceKeyword.mace_scalar_input_str if const_tensor.data_type == mace_pb2.DT_INT32: @@ -970,7 +971,8 @@ class OnnxConverter(base_converter.ConverterInterface): node.inputs[1] not in self._consts: const_name = node.inputs[0] const_tensor = self._consts[const_name] - if len(const_tensor.dims) == 0: + dims = const_tensor.dims + if len(dims) == 0 or (len(dims) == 1 and dims[0] == 1): value_arg = op.arg.add() value_arg.name = MaceKeyword.mace_scalar_input_str if const_tensor.data_type == mace_pb2.DT_INT32: diff --git a/tools/python/transform/tensorflow_converter.py b/tools/python/transform/tensorflow_converter.py index cfd24ace00945869a7e5a935257cd1c926916a68..90ab048fcb83c8a84ec8bbb8cb48b503921ed864 100644 --- a/tools/python/transform/tensorflow_converter.py +++ b/tools/python/transform/tensorflow_converter.py @@ -587,33 +587,38 @@ class TensorflowConverter(base_converter.ConverterInterface): EltwiseType.SUM, EltwiseType.PROD, EltwiseType.MAX, EltwiseType.MIN] - if (len(tf_op.inputs) > 1 and - len(self.infer_tensor_shape(tf_op.inputs[1])) == 0 and - tf_op.inputs[1].op.type == TFOpType.Const.name): - scalar = tf_op.inputs[1].eval().astype(np.float32) - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_scalar_input_str - value_arg.f = scalar - self._skip_tensor.add(tf_op.inputs[1].name) - value_index_arg = op.arg.add() - value_index_arg.name = \ - MaceKeyword.mace_scalar_input_index_str - value_index_arg.i = 1 - self._skip_tensor.add(tf_op.inputs[1].name) - del op.input[1] - elif len(self.infer_tensor_shape(tf_op.inputs[0])) == 0 and \ - tf_op.inputs[0].op.type == TFOpType.Const.name and \ - is_commutative(type_arg.i): - scalar = tf_op.inputs[0].eval().astype(np.float32) - value_arg = op.arg.add() - value_arg.name = MaceKeyword.mace_scalar_input_str - value_arg.f = scalar - value_index_arg = op.arg.add() - value_index_arg.name = \ - MaceKeyword.mace_scalar_input_index_str - value_index_arg.i = 0 - self._skip_tensor.add(tf_op.inputs[0].name) - del op.input[0] + if len(tf_op.inputs) > 1: + shape = self.infer_tensor_shape(tf_op.inputs[1]) + if (len(shape) == 0 or + (len(shape) == 1 and shape[0] == 1)) and \ + tf_op.inputs[1].op.type == TFOpType.Const.name: + scalar = tf_op.inputs[1].eval().astype(np.float32) + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = scalar + self._skip_tensor.add(tf_op.inputs[1].name) + value_index_arg = op.arg.add() + value_index_arg.name = \ + MaceKeyword.mace_scalar_input_index_str + value_index_arg.i = 1 + self._skip_tensor.add(tf_op.inputs[1].name) + del op.input[1] + else: + shape = self.infer_tensor_shape(tf_op.inputs[0]) + if (len(shape) == 0 or + (len(shape) == 1 and shape[0] == 1)) and \ + is_commutative(type_arg.i) and \ + tf_op.inputs[0].op.type == TFOpType.Const.name: + scalar = tf_op.inputs[0].eval().astype(np.float32) + value_arg = op.arg.add() + value_arg.name = MaceKeyword.mace_scalar_input_str + value_arg.f = scalar + value_index_arg = op.arg.add() + value_index_arg.name = \ + MaceKeyword.mace_scalar_input_index_str + value_index_arg.i = 0 + self._skip_tensor.add(tf_op.inputs[0].name) + del op.input[0] except tf.errors.InvalidArgumentError: pass