提交 5e02dc0e 编写于 作者: 李寅

Merge branch 'check_model_configs' into 'master'

Check model configs

See merge request !458
......@@ -32,8 +32,8 @@
#define CHECK_OUT_OF_RANGE_FOR_IMAGE2D(image, coord)
#endif
#define READ_IMAGET(image, coord, value) \
CMD_TYPE(read_image, CMD_DATA_TYPE)(image, coord, value)
#define READ_IMAGET(image, sampler, coord) \
CMD_TYPE(read_image, CMD_DATA_TYPE)(image, sampler, coord)
#define WRITE_IMAGET(image, coord, value) \
CHECK_OUT_OF_RANGE_FOR_IMAGE2D(image, coord) \
CMD_TYPE(write_image, CMD_DATA_TYPE)(image, coord, value);
......
......@@ -175,7 +175,10 @@ def parse_args():
parser.add_argument(
"--platform", type=str, default="tensorflow", help="tensorflow/caffe")
parser.add_argument(
"--embed_model_data", type=str2bool, default=True, help="input shape.")
"--embed_model_data",
type=str2bool,
default=True,
help="embed model data.")
return parser.parse_known_args()
......
......@@ -336,8 +336,74 @@ def str_to_caffe_env_type(v):
def parse_model_configs():
print("============== Load and Parse configs ==============")
with open(FLAGS.config) as f:
configs = yaml.load(f)
target_abis = configs.get("target_abis", [])
if not isinstance(target_abis, list) or not target_abis:
print("CONFIG ERROR:")
print("target_abis list is needed!")
print("For example: 'target_abis: [armeabi-v7a, arm64-v8a]'")
exit(1)
embed_model_data = configs.get("embed_model_data", "")
if embed_model_data == "" or not isinstance(embed_model_data, int) or \
embed_model_data < 0 or embed_model_data > 1:
print("CONFIG ERROR:")
print("embed_model_data must be integer in range [0, 1]")
exit(1)
model_names = configs.get("models", "")
if not model_names:
print("CONFIG ERROR:")
print("models attribute not found in config file")
exit(1)
for model_name in model_names:
model_config = configs["models"][model_name]
platform = model_config.get("platform", "")
if platform == "" or platform not in ["tensorflow", "caffe"]:
print("CONFIG ERROR:")
print("'platform' must be 'tensorflow' or 'caffe'")
exit(1)
for key in ["model_file_path", "model_sha256_checksum",
"runtime"]:
value = model_config.get(key, "")
if value == "":
print("CONFIG ERROR:")
print("'%s' is necessary" % key)
exit(1)
for key in ["input_nodes", "input_shapes", "output_nodes",
"output_shapes"]:
value = model_config.get(key, "")
if value == "":
print("CONFIG ERROR:")
print("'%s' is necessary" % key)
exit(1)
if not isinstance(value, list):
model_config[key] = [value]
for key in ["limit_opencl_kernel_time", "dsp_mode", "obfuscate",
"fast_conv"]:
value = model_config.get(key, "")
if value == "":
model_config[key] = 0
print("'%s' for %s is set to default value: 0" %
(key, model_name))
validation_inputs_data = model_config.get("validation_inputs_data",
[])
model_config["validation_inputs_data"] = validation_inputs_data
if not isinstance(validation_inputs_data, list):
model_config["validation_inputs_data"] = [
validation_inputs_data]
weight_file_path = model_config.get("weight_file_path", "")
model_config["weight_file_path"] = weight_file_path
print("Parse model configs successfully!\n")
return configs
......@@ -434,16 +500,10 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
for model_name in configs["models"]:
print '===================', model_name, '==================='
model_config = configs["models"][model_name]
input_file_list = model_config.get("validation_inputs_data",
[])
input_file_list = model_config["validation_inputs_data"]
data_type, device_type = get_data_and_device_type(
model_config["runtime"])
for key in ["input_nodes", "output_nodes", "input_shapes",
"output_shapes"]:
if not isinstance(model_config[key], list):
model_config[key] = [model_config[key]]
# Create model build directory
model_path_digest = md5sum(model_config["model_file_path"])
......@@ -472,7 +532,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
model_file_path, weight_file_path = get_model_files(
model_config["model_file_path"],
model_output_dir,
model_config.get("weight_file_path", ""))
model_config["weight_file_path"])
if FLAGS.mode == "build" or FLAGS.mode == "run" or \
FLAGS.mode == "validate" or \
......@@ -604,8 +664,7 @@ def process_models(project_name, configs, embed_model_data, vlog_level,
if os.path.exists(throughput_test_output_dir):
sh.rm("-rf", throughput_test_output_dir)
os.makedirs(throughput_test_output_dir)
input_file_list = model_config.get("validation_inputs_data",
[])
input_file_list = model_config["validation_inputs_data"]
sh_commands.gen_random_input(throughput_test_output_dir,
first_model["input_nodes"],
first_model["input_shapes"],
......@@ -654,7 +713,7 @@ def main(unused_args):
target_socs = get_target_socs(configs)
embed_model_data = configs.get("embed_model_data", 1)
embed_model_data = configs["embed_model_data"]
vlog_level = FLAGS.vlog_level
phone_data_dir = "/data/local/tmp/mace_run/"
for target_abi in configs["target_abis"]:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册