提交 20812e12 编写于 作者: L liutuo

check non-const filter in convolution for onnx model

上级 91bd5c11
...@@ -241,14 +241,14 @@ class OnnxNode(object): ...@@ -241,14 +241,14 @@ class OnnxNode(object):
self.node_proto = node self.node_proto = node
def print_info(self): def print_info(self):
print "node: ", self.name print("node: ", self.name)
print " type: ", self.op_type print(" type: ", self.op_type)
print " domain: ", self.domain print(" domain: ", self.domain)
print " inputs: ", self.inputs print(" inputs: ", self.inputs)
print " outputs: ", self.outputs print(" outputs: ", self.outputs)
print " attrs:" print(" attrs:")
for arg in self.attrs: for arg in self.attrs:
print " %s: %s" % (arg, self.attrs[arg]) print(" %s: %s" % (arg, self.attrs[arg]))
class OnnxTensor(object): class OnnxTensor(object):
...@@ -378,11 +378,11 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -378,11 +378,11 @@ class OnnxConverter(base_converter.ConverterInterface):
opset_imp = onnx_model.opset_import opset_imp = onnx_model.opset_import
polish_available = True polish_available = True
print "onnx model IR version: ", ir_version print("onnx model IR version: ", ir_version)
for imp in opset_imp: for imp in opset_imp:
domain = imp.domain domain = imp.domain
version = imp.version version = imp.version
print "constains ops domain: ", domain, "version:", version print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain: if 'kaldi2onnx' in domain:
polish_available = False polish_available = False
self._data_format = DataFormat.DF_NONE self._data_format = DataFormat.DF_NONE
...@@ -397,11 +397,11 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -397,11 +397,11 @@ class OnnxConverter(base_converter.ConverterInterface):
@staticmethod @staticmethod
def print_graph_info(graph): def print_graph_info(graph):
for value_info in graph.value_info: for value_info in graph.value_info:
print "value info:", value_info print("value info:", value_info)
for value_info in graph.input: for value_info in graph.input:
print "inputs info:", value_info print("inputs info:", value_info)
for value_info in graph.output: for value_info in graph.output:
print "outputs info:", value_info print("outputs info:", value_info)
def extract_shape_info(self, graph): def extract_shape_info(self, graph):
def extract_value_info(shape_dict, value_info): def extract_value_info(shape_dict, value_info):
...@@ -674,6 +674,8 @@ class OnnxConverter(base_converter.ConverterInterface): ...@@ -674,6 +674,8 @@ class OnnxConverter(base_converter.ConverterInterface):
op.type = MaceOp.DepthwiseConv2d.name op.type = MaceOp.DepthwiseConv2d.name
else: else:
op.type = MaceOp.Conv2D.name op.type = MaceOp.Conv2D.name
mace_check(op.input[1] in self._consts,
"Mace does not support non-const filter convolution.")
dilation_arg = op.arg.add() dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str dilation_arg.name = MaceKeyword.mace_dilations_str
......
...@@ -963,7 +963,9 @@ class Transformer(base_converter.ConverterInterface): ...@@ -963,7 +963,9 @@ class Transformer(base_converter.ConverterInterface):
net = self._model net = self._model
for op in net.op: for op in net.op:
if op.type == MaceOp.Conv2D.name: if op.type == MaceOp.Conv2D.name \
and len(op.input) >= 2 \
and op.input[1] in self._consts:
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
batch, height, width, channels = self.sort_feature_map_shape( batch, height, width, channels = self.sort_feature_map_shape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册