mace_tools.py 10.3 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=tools/example.yaml \
6 7 8 9
#     --round=100 \
#     --mode=all

import argparse
10
import hashlib
11 12 13 14
import os
import shutil
import subprocess
import sys
15
import urllib
Y
yejianwu 已提交
16
import yaml
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

from ConfigParser import ConfigParser

def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
36
def get_global_runtime(configs):
37
  runtime_list = []
Y
yejianwu 已提交
38 39 40
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
41

Y
yejianwu 已提交
42
  global_runtime = ""
43 44 45 46 47 48 49 50 51
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
52
  return global_runtime
53 54


Y
yejianwu 已提交
55 56 57
def generate_opencl_and_version_code():
  command = "bash tools/generate_opencl_and_version_code.sh"
  run_command(command)
58

59

Y
yejianwu 已提交
60 61
def clear_env(target_soc):
  command = "bash tools/clear_env.sh {}".format(target_soc)
62 63 64
  run_command(command)


Y
yejianwu 已提交
65
def generate_random_input(target_soc, model_output_dir):
66
  generate_data_or_not = True
Y
yejianwu 已提交
67 68
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
69 70 71 72 73
  run_command(command)


def generate_model_code():
  command = "bash tools/generate_model_code.sh"
Y
yejianwu 已提交
74
  run_command(command)
75 76


77 78 79
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
80 81 82
  run_command(command)


Y
yejianwu 已提交
83
def tuning_run(target_soc, model_output_dir, running_round, tuning, production_mode,
W
wuchenghui 已提交
84
               restart_round, option_args=''):
Y
yejianwu 已提交
85 86
  command = "bash tools/tuning_run.sh {} {} {} {} {} {} \"{}\"".format(
      target_soc, model_output_dir, running_round, int(tuning), int(production_mode),
W
wuchenghui 已提交
87
      restart_round, option_args)
88 89
  run_command(command)

Y
yejianwu 已提交
90

W
wuchenghui 已提交
91 92
def benchmark_model(model_output_dir, option_args=''):
  command = "bash tools/benchmark.sh {} \"{}\"".format(model_output_dir, option_args)
93
  run_command(command)
94

Y
yejianwu 已提交
95

Y
yejianwu 已提交
96 97
def run_model(target_soc, model_output_dir, running_round, restart_round, option_args):
  tuning_run(target_soc, model_output_dir, running_round, False, False, restart_round,
W
wuchenghui 已提交
98
             option_args)
99 100


Y
yejianwu 已提交
101
def generate_production_code(target_soc, model_output_dirs, pull_or_not):
102 103 104 105
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
Y
yejianwu 已提交
106 107
  command = "bash tools/generate_production_code.sh {} {} {}".format(
      target_soc, cl_bin_dirs_str, int(pull_or_not))
108 109 110
  run_command(command)


Y
yejianwu 已提交
111
def build_mace_run_prod(target_soc, model_output_dir, tuning, global_runtime):
Y
yejianwu 已提交
112
  if "dsp" == global_runtime:
113 114 115 116
    hexagon_mode = True
  else:
    hexagon_mode = False

117
  production_or_not = False
118
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
119
  tuning_run(
Y
yejianwu 已提交
120
      target_soc,
121 122 123
      model_output_dir,
      running_round=0,
      tuning=tuning,
李寅 已提交
124 125
      production_mode=production_or_not,
      restart_round=1)
126 127 128

  production_or_not = True
  pull_or_not = True
Y
yejianwu 已提交
129
  generate_production_code(target_soc, [model_output_dir], pull_or_not)
130
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
131 132


133 134 135 136 137 138
def build_run_throughput_test(run_seconds, merged_lib_file, model_input_dir):
  command = "bash tools/build_run_throughput_test.sh {} {} {}".format(
    run_seconds, merged_lib_file, model_input_dir)
  run_command(command)


Y
fix run  
yejianwu 已提交
139
def validate_model(target_soc, model_output_dir):
140
  generate_data_or_not = False
Y
fix run  
yejianwu 已提交
141 142
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
143 144 145 146 147 148 149 150
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


Y
yejianwu 已提交
151
def merge_libs_and_tuning_results(target_soc, output_dir, model_output_dirs):
152
  pull_or_not = False
Y
yejianwu 已提交
153
  generate_production_code(target_soc, model_output_dirs, pull_or_not)
154 155 156
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
Y
yejianwu 已提交
157
  command = "bash tools/merge_libs.sh {} {} {}".format(target_soc, output_dir,
158 159 160 161 162
                                                    model_output_dirs_str)
  run_command(command)


def parse_model_configs():
Y
yejianwu 已提交
163 164 165
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
166 167 168 169 170 171 172


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
173
      "--config",
174 175 176 177
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
178
      "--output_dir", type=str, default="build", help="The output dir.")
179 180
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
181 182
  parser.add_argument("--run_seconds", type=int, default=10,
                      help="The model throughput test running seconds.")
李寅 已提交
183 184
  parser.add_argument(
    "--restart_round", type=int, default=1, help="The model restart round.")
185 186
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
187 188
  parser.add_argument("--mode", type=str, default="all",
                      help="[build|run|validate|merge|all|throughput_test].")
189 190 191 192
  return parser.parse_known_args()


def main(unused_args):
Y
yejianwu 已提交
193
  configs = parse_model_configs()
194

195 196
  if FLAGS.mode == "validate":
    FLAGS.round = 1
李寅 已提交
197
    FLAGS.restart_round = 1
198

Y
yejianwu 已提交
199
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
Y
yejianwu 已提交
200
  os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
Y
yejianwu 已提交
201
  os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(FLAGS.config))[0]
202

Y
yejianwu 已提交
203 204 205 206 207 208 209 210
  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))
      os.makedirs(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))

Y
yejianwu 已提交
211
  generate_opencl_and_version_code()
W
wuchenghui 已提交
212
  option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')])
Y
yejianwu 已提交
213

214
  for target_abi in configs["target_abis"]:
Y
yejianwu 已提交
215 216
    for target_soc in configs["target_socs"]:
      global_runtime = get_global_runtime(configs)
217
      # Transfer params by environment
Y
yejianwu 已提交
218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264
      os.environ["TARGET_ABI"] = target_abi
      model_output_dirs = []
      for model_name in configs["models"]:
        # Transfer params by environment
        os.environ["MODEL_TAG"] = model_name
        print '=======================', model_name, '======================='
        model_config = configs["models"][model_name]
        for key in model_config:
          if key in ['input_nodes', 'output_nodes'] and isinstance(model_config[key], list):
              os.environ[key.upper()] = ",".join(model_config[key])
          elif key in ['input_shapes', 'output_shapes'] and isinstance(model_config[key], list):
              os.environ[key.upper()] = ":".join(model_config[key])
          else:
            os.environ[key.upper()] = str(model_config[key])

        md5 = hashlib.md5()
        md5.update(model_config["model_file_path"])
        model_path_digest = md5.hexdigest()
        model_output_dir = "%s/%s/%s/%s/%s" % (FLAGS.output_dir, model_name, model_path_digest, target_soc, target_abi)
        model_output_dirs.append(model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          if os.path.exists(model_output_dir):
            shutil.rmtree(model_output_dir)
          os.makedirs(model_output_dir)
          clear_env(target_soc)

        # Support http:// and https://
        if model_config["model_file_path"].startswith(
            "http://") or model_config["model_file_path"].startswith("https://"):
          os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
          urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"])

        if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith(
            "http://") or model_config["weight_file_path"].startswith("https://")):
          os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
          urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"])

        if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
            or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
          generate_random_input(target_soc, model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          generate_model_code()
          build_mace_run_prod(target_soc, model_output_dir, FLAGS.tuning, global_runtime)

        if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
Y
fix run  
yejianwu 已提交
265
          run_model(target_soc, model_output_dir, FLAGS.round, FLAGS.restart_round, option_args)
Y
yejianwu 已提交
266 267

        if FLAGS.mode == "benchmark":
Y
fix run  
yejianwu 已提交
268
          benchmark_model(model_output_dir, option_args)
Y
yejianwu 已提交
269 270

        if FLAGS.mode == "validate" or FLAGS.mode == "all":
Y
fix run  
yejianwu 已提交
271
          validate_model(target_soc, model_output_dir)
Y
yejianwu 已提交
272 273 274 275 276 277 278 279 280 281 282 283 284

      if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
        merge_libs_and_tuning_results(target_soc, FLAGS.output_dir + "/" + os.environ["PROJECT_NAME"],
                                      model_output_dirs)

      if FLAGS.mode == "throughput_test":
        merged_lib_file = FLAGS.output_dir + "/%s/libmace/lib/libmace_%s.a" % \
            (configs["target_abis"][0], os.environ["PROJECT_NAME"])
        generate_random_input(target_soc, FLAGS.output_dir)
        for model_name in configs["models"]:
          runtime = configs["models"][model_name]["runtime"]
          os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
        build_run_throughput_test(FLAGS.run_seconds, merged_lib_file, FLAGS.output_dir)
285

286

Y
yejianwu 已提交
287
if __name__ == "__main__":
288 289
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)