mace_tools.py 10.2 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=tools/example.yaml \
6 7 8 9
#     --round=100 \
#     --mode=all

import argparse
10
import hashlib
11 12 13 14
import os
import shutil
import subprocess
import sys
15
import urllib
Y
yejianwu 已提交
16
import yaml
17 18 19

from ConfigParser import ConfigParser

L
liuqi 已提交
20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42
def run_command_real_time(command):
  print("Run command: {}".format(command))
  process = subprocess.Popen(
    command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)

  while True:
    std_err = process.stderr.readline()
    if std_err == '' and process.poll() is not None:
      break
    if std_err:
      print std_err.strip()
  while True:
    std_out = process.stdout.readline()
    if std_out == '' and process.poll() is not None:
      break
    if std_out:
      print std_out.strip()
  ret_code = process.poll()

  if ret_code != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
      ret_code, command))

43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58
def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
59
def get_libs(target_abi, configs):
60
  runtime_list = []
Y
yejianwu 已提交
61 62 63
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
64

Y
yejianwu 已提交
65
  global_runtime = ""
66 67 68 69 70 71 72 73 74
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
75
  libmace_name = "libmace-{}-{}".format(target_abi, global_runtime)
76 77 78 79

  command = "bash tools/download_and_link_lib.sh " + libmace_name
  run_command(command)

80 81
  return libmace_name

82 83 84 85 86 87 88 89 90 91 92 93 94 95 96

def clear_env():
  command = "bash tools/clear_env.sh"
  run_command(command)


def generate_random_input(model_output_dir):
  generate_data_or_not = True
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def generate_model_code():
  command = "bash tools/generate_model_code.sh"
L
liuqi 已提交
97
  run_command_real_time(command)
98 99


100 101 102
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
103 104 105
  run_command(command)


李寅 已提交
106 107 108
def tuning_run(model_output_dir, running_round, tuning, production_mode, restart_round):
  command = "bash tools/tuning_run.sh {} {} {} {} {}".format(
      model_output_dir, running_round, int(tuning), int(production_mode), restart_round)
109 110
  run_command(command)

Y
yejianwu 已提交
111

112 113 114
def benchmark_model(model_output_dir):
  command = "bash tools/benchmark.sh {}".format(model_output_dir)
  run_command(command)
115

Y
yejianwu 已提交
116

李寅 已提交
117 118
def run_model(model_output_dir, running_round, restart_round):
  tuning_run(model_output_dir, running_round, False, False, restart_round)
119 120 121 122 123 124 125 126 127 128 129 130


def generate_production_code(model_output_dirs, pull_or_not):
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
  command = "bash tools/generate_production_code.sh {} {}".format(
      cl_bin_dirs_str, int(pull_or_not))
  run_command(command)


131 132 133 134 135 136
def build_mace_run_prod(model_output_dir, tuning, libmace_name):
  if "dsp" in libmace_name:
    hexagon_mode = True
  else:
    hexagon_mode = False

137
  production_or_not = False
138
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
139 140 141 142
  tuning_run(
      model_output_dir,
      running_round=0,
      tuning=tuning,
李寅 已提交
143 144
      production_mode=production_or_not,
      restart_round=1)
145 146 147 148

  production_or_not = True
  pull_or_not = True
  generate_production_code([model_output_dir], pull_or_not)
149
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
150 151


152 153 154 155 156 157
def build_run_throughput_test(run_seconds, merged_lib_file, model_input_dir):
  command = "bash tools/build_run_throughput_test.sh {} {} {}".format(
    run_seconds, merged_lib_file, model_input_dir)
  run_command(command)


158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181
def validate_model(model_output_dir):
  generate_data_or_not = False
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


def merge_libs_and_tuning_results(output_dir, model_output_dirs):
  pull_or_not = False
  generate_production_code(model_output_dirs, pull_or_not)
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
  command = "bash tools/merge_libs.sh {} {}".format(output_dir,
                                                    model_output_dirs_str)
  run_command(command)


def parse_model_configs():
Y
yejianwu 已提交
182 183 184
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
185 186 187 188 189 190 191


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
192
      "--config",
193 194 195 196
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
197
      "--output_dir", type=str, default="build", help="The output dir.")
198 199
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
200 201
  parser.add_argument("--run_seconds", type=int, default=10,
                      help="The model throughput test running seconds.")
李寅 已提交
202 203
  parser.add_argument(
    "--restart_round", type=int, default=1, help="The model restart round.")
204 205
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
206 207
  parser.add_argument("--mode", type=str, default="all",
                      help="[build|run|validate|merge|all|throughput_test].")
208 209 210 211
  return parser.parse_known_args()


def main(unused_args):
Y
yejianwu 已提交
212
  configs = parse_model_configs()
213 214 215 216 217 218 219 220

  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, "libmace"))

221 222
  if FLAGS.mode == "validate":
    FLAGS.round = 1
李寅 已提交
223
    FLAGS.restart_round = 1
224

225 226
  # target_abi = configs["target_abi"]
  # libmace_name = get_libs(target_abi, configs)
Y
yejianwu 已提交
227
  # Transfer params by environment
228
  # os.environ["TARGET_ABI"] = target_abi
Y
yejianwu 已提交
229
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
Y
yejianwu 已提交
230
  os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
Y
yejianwu 已提交
231
  os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(FLAGS.config))[0]
232

233 234
  for target_abi in configs["target_abis"]:
    libmace_name = get_libs(target_abi, configs)
235
    # Transfer params by environment
236 237 238 239 240
    os.environ["TARGET_ABI"] = target_abi
    model_output_dirs = []
    for model_name in configs["models"]:
      # Transfer params by environment
      os.environ["MODEL_TAG"] = model_name
L
liuqi 已提交
241
      print '=======================', model_name, '======================='
242 243
      model_config = configs["models"][model_name]
      for key in model_config:
L
liuqi 已提交
244
        if key in ['input_nodes', 'output_nodes'] and isinstance(model_config[key], list):
245
            os.environ[key.upper()] = ",".join(model_config[key])
L
liuqi 已提交
246
        elif key in ['input_shapes', 'output_shapes'] and isinstance(model_config[key], list):
247 248 249
            os.environ[key.upper()] = ":".join(model_config[key])
        else:
          os.environ[key.upper()] = str(model_config[key])
250

251 252 253 254
      md5 = hashlib.md5()
      md5.update(model_config["model_file_path"])
      model_path_digest = md5.hexdigest()
      model_output_dir = "%s/%s/%s/%s" % (FLAGS.output_dir, model_name, model_path_digest, target_abi)
255 256 257 258 259 260 261 262
      model_output_dirs.append(model_output_dir)

      if FLAGS.mode == "build" or FLAGS.mode == "all":
        if os.path.exists(model_output_dir):
          shutil.rmtree(model_output_dir)
        os.makedirs(model_output_dir)
        clear_env()

263 264 265 266 267 268
      # Support http:// and https://
      if model_config["model_file_path"].startswith(
          "http://") or model_config["model_file_path"].startswith("https://"):
        os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
        urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"])

269 270 271 272 273
      if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith(
          "http://") or model_config["weight_file_path"].startswith("https://")):
        os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
        urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"])

274 275
      if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
          or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
276 277 278 279 280 281 282
        generate_random_input(model_output_dir)

      if FLAGS.mode == "build" or FLAGS.mode == "all":
        generate_model_code()
        build_mace_run_prod(model_output_dir, FLAGS.tuning, libmace_name)

      if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
李寅 已提交
283
        run_model(model_output_dir, FLAGS.round, FLAGS.restart_round)
284 285 286 287 288 289 290 291 292 293

      if FLAGS.mode == "benchmark":
        benchmark_model(model_output_dir)

      if FLAGS.mode == "validate" or FLAGS.mode == "all":
        validate_model(model_output_dir)

    if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
      merge_libs_and_tuning_results(FLAGS.output_dir + "/" + target_abi,
                                    model_output_dirs)
294

295 296 297 298 299 300 301 302 303
  if FLAGS.mode == "throughput_test":
    merged_lib_file = FLAGS.output_dir + "/%s/libmace/lib/libmace_%s.a" % \
        (configs["target_abis"][0], os.environ["PROJECT_NAME"])
    generate_random_input(FLAGS.output_dir)
    for model_name in configs["models"]:
      runtime = configs["models"][model_name]["runtime"]
      os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
    build_run_throughput_test(FLAGS.run_seconds, merged_lib_file, FLAGS.output_dir)

304

Y
yejianwu 已提交
305
if __name__ == "__main__":
306 307
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)