From 20812e12e49f34dbf31c57369246e877a64105b4 Mon Sep 17 00:00:00 2001 From: liutuo Date: Tue, 26 Mar 2019 17:37:07 +0800 Subject: [PATCH] check non-const filter in convolution for onnx model --- .../tools/converter_tool/onnx_converter.py | 26 ++++++++++--------- .../tools/converter_tool/transformer.py | 4 ++- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/mace/python/tools/converter_tool/onnx_converter.py b/mace/python/tools/converter_tool/onnx_converter.py index 0ed14ae1..68f781a2 100644 --- a/mace/python/tools/converter_tool/onnx_converter.py +++ b/mace/python/tools/converter_tool/onnx_converter.py @@ -241,14 +241,14 @@ class OnnxNode(object): self.node_proto = node def print_info(self): - print "node: ", self.name - print " type: ", self.op_type - print " domain: ", self.domain - print " inputs: ", self.inputs - print " outputs: ", self.outputs - print " attrs:" + print("node: ", self.name) + print(" type: ", self.op_type) + print(" domain: ", self.domain) + print(" inputs: ", self.inputs) + print(" outputs: ", self.outputs) + print(" attrs:") for arg in self.attrs: - print " %s: %s" % (arg, self.attrs[arg]) + print(" %s: %s" % (arg, self.attrs[arg])) class OnnxTensor(object): @@ -378,11 +378,11 @@ class OnnxConverter(base_converter.ConverterInterface): opset_imp = onnx_model.opset_import polish_available = True - print "onnx model IR version: ", ir_version + print("onnx model IR version: ", ir_version) for imp in opset_imp: domain = imp.domain version = imp.version - print "constains ops domain: ", domain, "version:", version + print("constains ops domain: ", domain, "version:", version) if 'kaldi2onnx' in domain: polish_available = False self._data_format = DataFormat.DF_NONE @@ -397,11 +397,11 @@ class OnnxConverter(base_converter.ConverterInterface): @staticmethod def print_graph_info(graph): for value_info in graph.value_info: - print "value info:", value_info + print("value info:", value_info) for value_info in graph.input: - print "inputs info:", value_info + print("inputs info:", value_info) for value_info in graph.output: - print "outputs info:", value_info + print("outputs info:", value_info) def extract_shape_info(self, graph): def extract_value_info(shape_dict, value_info): @@ -674,6 +674,8 @@ class OnnxConverter(base_converter.ConverterInterface): op.type = MaceOp.DepthwiseConv2d.name else: op.type = MaceOp.Conv2D.name + mace_check(op.input[1] in self._consts, + "Mace does not support non-const filter convolution.") dilation_arg = op.arg.add() dilation_arg.name = MaceKeyword.mace_dilations_str diff --git a/mace/python/tools/converter_tool/transformer.py b/mace/python/tools/converter_tool/transformer.py index d440eeaa..b5462a6c 100644 --- a/mace/python/tools/converter_tool/transformer.py +++ b/mace/python/tools/converter_tool/transformer.py @@ -963,7 +963,9 @@ class Transformer(base_converter.ConverterInterface): net = self._model for op in net.op: - if op.type == MaceOp.Conv2D.name: + if op.type == MaceOp.Conv2D.name \ + and len(op.input) >= 2 \ + and op.input[1] in self._consts: producer = self._producer[op.input[0]] input_shape = producer.output_shape[0].dims batch, height, width, channels = self.sort_feature_map_shape( -- GitLab