提交 761c7489 编写于 作者: 刘琦

Merge branch 'check-non-const-filter' into 'master'

check non-const filter in convolution for onnx model

See merge request !1032
......@@ -241,14 +241,14 @@ class OnnxNode(object):
self.node_proto = node
def print_info(self):
print "node: ", self.name
print " type: ", self.op_type
print " domain: ", self.domain
print " inputs: ", self.inputs
print " outputs: ", self.outputs
print " attrs:"
print("node: ", self.name)
print(" type: ", self.op_type)
print(" domain: ", self.domain)
print(" inputs: ", self.inputs)
print(" outputs: ", self.outputs)
print(" attrs:")
for arg in self.attrs:
print " %s: %s" % (arg, self.attrs[arg])
print(" %s: %s" % (arg, self.attrs[arg]))
class OnnxTensor(object):
......@@ -378,11 +378,11 @@ class OnnxConverter(base_converter.ConverterInterface):
opset_imp = onnx_model.opset_import
polish_available = True
print "onnx model IR version: ", ir_version
print("onnx model IR version: ", ir_version)
for imp in opset_imp:
domain = imp.domain
version = imp.version
print "constains ops domain: ", domain, "version:", version
print("constains ops domain: ", domain, "version:", version)
if 'kaldi2onnx' in domain:
polish_available = False
self._data_format = DataFormat.DF_NONE
......@@ -397,11 +397,11 @@ class OnnxConverter(base_converter.ConverterInterface):
@staticmethod
def print_graph_info(graph):
for value_info in graph.value_info:
print "value info:", value_info
print("value info:", value_info)
for value_info in graph.input:
print "inputs info:", value_info
print("inputs info:", value_info)
for value_info in graph.output:
print "outputs info:", value_info
print("outputs info:", value_info)
def extract_shape_info(self, graph):
def extract_value_info(shape_dict, value_info):
......@@ -674,6 +674,8 @@ class OnnxConverter(base_converter.ConverterInterface):
op.type = MaceOp.DepthwiseConv2d.name
else:
op.type = MaceOp.Conv2D.name
mace_check(op.input[1] in self._consts,
"Mace does not support non-const filter convolution.")
dilation_arg = op.arg.add()
dilation_arg.name = MaceKeyword.mace_dilations_str
......
......@@ -963,7 +963,9 @@ class Transformer(base_converter.ConverterInterface):
net = self._model
for op in net.op:
if op.type == MaceOp.Conv2D.name:
if op.type == MaceOp.Conv2D.name \
and len(op.input) >= 2 \
and op.input[1] in self._consts:
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
batch, height, width, channels = self.sort_feature_map_shape(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册