diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 0ed14ae1f973d723c5c9c62eb51cf52fba6ce2c0..68f781a23dfc4fe5d09163b59422be15fec31f87 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -241,14 +241,14 @@ class OnnxNode(object): self.node_proto = node def print_info(self): - print "node: ", self.name - print " type: ", self.op_type - print " domain: ", self.domain - print " inputs: ", self.inputs - print " outputs: ", self.outputs - print " attrs:" + print("node: ", self.name) + print(" type: ", self.op_type) + print(" domain: ", self.domain) + print(" inputs: ", self.inputs) + print(" outputs: ", self.outputs) + print(" attrs:") for arg in self.attrs: - print " %s: %s" % (arg, self.attrs[arg]) + print(" %s: %s" % (arg, self.attrs[arg])) class OnnxTensor(object): @@ -378,11 +378,11 @@ class OnnxConverter(base_converter.ConverterInterface): opset_imp = onnx_model.opset_import polish_available = True - print "onnx model IR version: ", ir_version + print("onnx model IR version: ", ir_version) for imp in opset_imp: domain = imp.domain version = imp.version - print "constains ops domain: ", domain, "version:", version + print("constains ops domain: ", domain, "version:", version) if 'kaldi2onnx' in domain: polish_available = False self._data_format = DataFormat.DF_NONE @@ -397,11 +397,11 @@ class OnnxConverter(base_converter.ConverterInterface): @staticmethod def print_graph_info(graph): for value_info in graph.value_info: - print "value info:", value_info + print("value info:", value_info) for value_info in graph.input: - print "inputs info:", value_info + print("inputs info:", value_info) for value_info in graph.output: - print "outputs info:", value_info + print("outputs info:", value_info) def extract_shape_info(self, graph): def extract_value_info(shape_dict, value_info): @@ -674,6 +674,8 @@ class OnnxConverter(base_converter.ConverterInterface): op.type = MaceOp.DepthwiseConv2d.name else: op.type = MaceOp.Conv2D.name + mace_check(op.input[1] in self._consts, + "Mace does not support non-const filter convolution.") dilation_arg = op.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index d440eeaa54edc4c3d0934d445c6140568b6ccffc..b5462a6ca1d6e6994a9722e1e8c0b61e9e90e377 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -963,7 +963,9 @@ class Transformer(base_converter.ConverterInterface): net = self._model for op in net.op: - if op.type == MaceOp.Conv2D.name: + if op.type == MaceOp.Conv2D.name \ + and len(op.input) >= 2 \ + and op.input[1] in self._consts: producer = self._producer[op.input[0]] input_shape = producer.output_shape[0].dims batch, height, width, channels = self.sort_feature_map_shape(