mace_tools.py 9.4 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=tools/example.yaml \
6 7 8 9
#     --round=100 \
#     --mode=all

import argparse
10
import hashlib
11 12 13 14
import os
import shutil
import subprocess
import sys
15
import urllib
Y
yejianwu 已提交
16
import yaml
17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35

from ConfigParser import ConfigParser

def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
36
def get_global_runtime(configs):
37
  runtime_list = []
Y
yejianwu 已提交
38 39 40
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
41

Y
yejianwu 已提交
42
  global_runtime = ""
43 44 45 46 47 48 49 50 51
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
52
  return global_runtime
53 54


Y
yejianwu 已提交
55 56 57
def generate_opencl_and_version_code():
  command = "bash tools/generate_opencl_and_version_code.sh"
  run_command(command)
58

59 60 61 62 63 64 65 66 67 68 69 70 71 72 73

def clear_env():
  command = "bash tools/clear_env.sh"
  run_command(command)


def generate_random_input(model_output_dir):
  generate_data_or_not = True
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def generate_model_code():
  command = "bash tools/generate_model_code.sh"
Y
yejianwu 已提交
74
  run_command(command)
75 76


77 78 79
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
80 81 82
  run_command(command)


李寅 已提交
83 84 85
def tuning_run(model_output_dir, running_round, tuning, production_mode, restart_round):
  command = "bash tools/tuning_run.sh {} {} {} {} {}".format(
      model_output_dir, running_round, int(tuning), int(production_mode), restart_round)
86 87
  run_command(command)

Y
yejianwu 已提交
88

89 90 91
def benchmark_model(model_output_dir):
  command = "bash tools/benchmark.sh {}".format(model_output_dir)
  run_command(command)
92

Y
yejianwu 已提交
93

李寅 已提交
94 95
def run_model(model_output_dir, running_round, restart_round):
  tuning_run(model_output_dir, running_round, False, False, restart_round)
96 97 98 99 100 101 102 103 104 105 106 107


def generate_production_code(model_output_dirs, pull_or_not):
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
  command = "bash tools/generate_production_code.sh {} {}".format(
      cl_bin_dirs_str, int(pull_or_not))
  run_command(command)


Y
yejianwu 已提交
108 109
def build_mace_run_prod(model_output_dir, tuning, global_runtime):
  if "dsp" == global_runtime:
110 111 112 113
    hexagon_mode = True
  else:
    hexagon_mode = False

114
  production_or_not = False
115
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
116 117 118 119
  tuning_run(
      model_output_dir,
      running_round=0,
      tuning=tuning,
李寅 已提交
120 121
      production_mode=production_or_not,
      restart_round=1)
122 123 124 125

  production_or_not = True
  pull_or_not = True
  generate_production_code([model_output_dir], pull_or_not)
126
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
127 128


129 130 131 132 133 134
def build_run_throughput_test(run_seconds, merged_lib_file, model_input_dir):
  command = "bash tools/build_run_throughput_test.sh {} {} {}".format(
    run_seconds, merged_lib_file, model_input_dir)
  run_command(command)


135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158
def validate_model(model_output_dir):
  generate_data_or_not = False
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


def merge_libs_and_tuning_results(output_dir, model_output_dirs):
  pull_or_not = False
  generate_production_code(model_output_dirs, pull_or_not)
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
  command = "bash tools/merge_libs.sh {} {}".format(output_dir,
                                                    model_output_dirs_str)
  run_command(command)


def parse_model_configs():
Y
yejianwu 已提交
159 160 161
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
162 163 164 165 166 167 168


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
169
      "--config",
170 171 172 173
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
174
      "--output_dir", type=str, default="build", help="The output dir.")
175 176
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
177 178
  parser.add_argument("--run_seconds", type=int, default=10,
                      help="The model throughput test running seconds.")
李寅 已提交
179 180
  parser.add_argument(
    "--restart_round", type=int, default=1, help="The model restart round.")
181 182
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
183 184
  parser.add_argument("--mode", type=str, default="all",
                      help="[build|run|validate|merge|all|throughput_test].")
185 186 187 188
  return parser.parse_known_args()


def main(unused_args):
Y
yejianwu 已提交
189
  configs = parse_model_configs()
190 191 192 193 194 195 196 197

  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, "libmace"))

198 199
  if FLAGS.mode == "validate":
    FLAGS.round = 1
李寅 已提交
200
    FLAGS.restart_round = 1
201

Y
yejianwu 已提交
202
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
Y
yejianwu 已提交
203
  os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
Y
yejianwu 已提交
204
  os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(FLAGS.config))[0]
205

Y
yejianwu 已提交
206 207
  generate_opencl_and_version_code()

208
  for target_abi in configs["target_abis"]:
Y
yejianwu 已提交
209
    global_runtime = get_global_runtime(configs)
210
    # Transfer params by environment
211 212 213 214 215
    os.environ["TARGET_ABI"] = target_abi
    model_output_dirs = []
    for model_name in configs["models"]:
      # Transfer params by environment
      os.environ["MODEL_TAG"] = model_name
L
liuqi 已提交
216
      print '=======================', model_name, '======================='
217 218
      model_config = configs["models"][model_name]
      for key in model_config:
L
liuqi 已提交
219
        if key in ['input_nodes', 'output_nodes'] and isinstance(model_config[key], list):
220
            os.environ[key.upper()] = ",".join(model_config[key])
L
liuqi 已提交
221
        elif key in ['input_shapes', 'output_shapes'] and isinstance(model_config[key], list):
222 223 224
            os.environ[key.upper()] = ":".join(model_config[key])
        else:
          os.environ[key.upper()] = str(model_config[key])
225

226 227 228 229
      md5 = hashlib.md5()
      md5.update(model_config["model_file_path"])
      model_path_digest = md5.hexdigest()
      model_output_dir = "%s/%s/%s/%s" % (FLAGS.output_dir, model_name, model_path_digest, target_abi)
230 231 232 233 234 235 236 237
      model_output_dirs.append(model_output_dir)

      if FLAGS.mode == "build" or FLAGS.mode == "all":
        if os.path.exists(model_output_dir):
          shutil.rmtree(model_output_dir)
        os.makedirs(model_output_dir)
        clear_env()

238 239 240 241 242 243
      # Support http:// and https://
      if model_config["model_file_path"].startswith(
          "http://") or model_config["model_file_path"].startswith("https://"):
        os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
        urllib.urlretrieve(model_config["model_file_path"], os.environ["MODEL_FILE_PATH"])

244 245 246 247 248
      if model_config["platform"] == "caffe" and (model_config["weight_file_path"].startswith(
          "http://") or model_config["weight_file_path"].startswith("https://")):
        os.environ["WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
        urllib.urlretrieve(model_config["weight_file_path"], os.environ["WEIGHT_FILE_PATH"])

249 250
      if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
          or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
251 252 253 254
        generate_random_input(model_output_dir)

      if FLAGS.mode == "build" or FLAGS.mode == "all":
        generate_model_code()
Y
yejianwu 已提交
255
        build_mace_run_prod(model_output_dir, FLAGS.tuning, global_runtime)
256 257

      if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
李寅 已提交
258
        run_model(model_output_dir, FLAGS.round, FLAGS.restart_round)
259 260 261 262 263 264 265 266 267 268

      if FLAGS.mode == "benchmark":
        benchmark_model(model_output_dir)

      if FLAGS.mode == "validate" or FLAGS.mode == "all":
        validate_model(model_output_dir)

    if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
      merge_libs_and_tuning_results(FLAGS.output_dir + "/" + target_abi,
                                    model_output_dirs)
269

270 271 272 273 274 275 276 277 278
  if FLAGS.mode == "throughput_test":
    merged_lib_file = FLAGS.output_dir + "/%s/libmace/lib/libmace_%s.a" % \
        (configs["target_abis"][0], os.environ["PROJECT_NAME"])
    generate_random_input(FLAGS.output_dir)
    for model_name in configs["models"]:
      runtime = configs["models"][model_name]["runtime"]
      os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
    build_run_throughput_test(FLAGS.run_seconds, merged_lib_file, FLAGS.output_dir)

279

Y
yejianwu 已提交
280
if __name__ == "__main__":
281 282
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)