diff --git a/mace/python/tools/converter.py b/mace/python/tools/converter.py index 164b54e45595891752a6302a8883b9271e2bc7aa..fe378b1411d3f5a64f39ac333743e2d889204eb2 100644 --- a/mace/python/tools/converter.py +++ b/mace/python/tools/converter.py @@ -105,7 +105,7 @@ def main(unused_args): for i in xrange(len(input_node_names)): input_node = cvt.NodeInfo() input_node.name = input_node_names[i] - input_node.shape = parse_int_array_from_str(FLAGS.input_shape) + input_node.shape = parse_int_array_from_str(input_node_shapes[i]) option.add_input_node(input_node) output_node_names = FLAGS.output_node.split(',')