# Copyright 2018 Xiaomi, Inc. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import argparse import sys import numpy as np import re # Validation Flow: # 1. Generate input data # python generate_data.py \ # --input_node input_node \ # --input_shape 1,64,64,3 \ # --input_file input_file # def generate_data(name, shape): np.random.seed() data = np.random.random(shape) * 2 - 1 input_file_name = FLAGS.input_file + "_" + re.sub('[^0-9a-zA-Z]+', '_', name) print 'Generate input file: ', input_file_name data.astype(np.float32).tofile(input_file_name) def main(unused_args): input_names = [name for name in FLAGS.input_node.split(',')] input_shapes = [shape for shape in FLAGS.input_shape.split(':')] assert len(input_names) == len(input_shapes) for i in range(len(input_names)): shape = [int(x) for x in input_shapes[i].split(',')] generate_data(input_names[i], shape) print "Generate input file done." def parse_args(): """Parses command line arguments.""" parser = argparse.ArgumentParser() parser.register("type", "bool", lambda v: v.lower() == "true") parser.add_argument( "--input_file", type=str, default="", help="input file.") parser.add_argument( "--input_node", type=str, default="input_node", help="input node") parser.add_argument( "--input_shape", type=str, default="1,64,64,3", help="input shape.") return parser.parse_known_args() if __name__ == '__main__': FLAGS, unparsed = parse_args() main(unused_args=[sys.argv[0]] + unparsed)