提交 e9d3a319 编写于 作者: L liutuo

add onnx scalar elt

上级 ece5761c
......@@ -941,6 +941,33 @@ class OnnxConverter(base_converter.ConverterInterface):
coeff_arg = op.arg.add()
coeff_arg.name = MaceKeyword.mace_coeff_str
coeff_arg.floats.extend([min_value, max_value])
elif len(node.inputs) == 2:
if node.inputs[1] in self._consts and \
node.inputs[0] not in self._consts:
const_name = node.inputs[1]
const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 1
del op.input[1]
elif node.inputs[0] in self._consts and \
node.inputs[1] not in self._consts:
const_name = node.inputs[0]
const_tensor = self._consts[const_name]
if len(const_tensor.dims) == 0:
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_scalar_input_str
value_arg.f = const_tensor.float_data[0]
value_index_arg = op.arg.add()
value_index_arg.name = \
MaceKeyword.mace_scalar_input_index_str
value_index_arg.i = 0
del op.input[0]
@staticmethod
def copy_node_attr(op, node, attr_name, dtype=AttributeType.INT,
......
......@@ -22,12 +22,12 @@ from onnx import optimizer
def main():
if len(sys.argv) != 3:
print "Usage: python onnx_optimizer.py model.onnx model_opt.onnx"
print("Usage: python onnx_optimizer.py model.onnx model_opt.onnx")
sys.exit(0)
in_path = sys.argv[1]
out_path = sys.argv[2]
original_model = onnx.load(in_path)
print "Start optimize ONNX model for inference:"
print("Start optimize ONNX model for inference:")
passes = ['eliminate_identity',
'fuse_consecutive_squeezes',
'fuse_consecutive_transposes',
......@@ -35,15 +35,14 @@ def main():
'eliminate_nop_transpose',
'eliminate_unused_initializer',
'extract_constant_to_initializer',
'fuse_add_bias_into_conv',
'fuse_bn_into_conv',
'fuse_transpose_into_gemm']
for i in range(len(passes)):
print i, ".", passes[i]
print("%s.%s" % (i, passes[i]))
optimized_model = optimizer.optimize(original_model, passes)
onnx.save_model(optimized_model, out_path)
print "Optimize Finished!"
print "Please check new model in:", out_path
print("Optimize Finished!")
print("Please check new model in:", out_path)
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册