mace_tools.py 7.0 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=models/config \
6 7 8 9 10 11 12 13
#     --round=100 \
#     --mode=all

import argparse
import os
import shutil
import subprocess
import sys
Y
yejianwu 已提交
14
import yaml
15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36

from ConfigParser import ConfigParser

tf_model_file_dir_key = "TF_MODEL_FILE_DIR"


def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
37
def get_libs(target_abi, configs):
38
  runtime_list = []
Y
yejianwu 已提交
39 40 41
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
42

Y
yejianwu 已提交
43
  global_runtime = ""
44 45 46 47 48 49 50 51 52
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
53
  libmace_name = "libmace-{}-{}".format(target_abi, global_runtime)
54 55 56 57

  command = "bash tools/download_and_link_lib.sh " + libmace_name
  run_command(command)

58 59
  return libmace_name

60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77

def clear_env():
  command = "bash tools/clear_env.sh"
  run_command(command)


def generate_random_input(model_output_dir):
  generate_data_or_not = True
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def generate_model_code():
  command = "bash tools/generate_model_code.sh"
  run_command(command)


78 79 80
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
81 82 83 84 85 86 87 88
  run_command(command)


def tuning_run(model_output_dir, running_round, tuning, production_mode):
  command = "bash tools/tuning_run.sh {} {} {} {}".format(
      model_output_dir, running_round, int(tuning), int(production_mode))
  run_command(command)

Y
yejianwu 已提交
89

90 91 92
def benchmark_model(model_output_dir):
  command = "bash tools/benchmark.sh {}".format(model_output_dir)
  run_command(command)
93

Y
yejianwu 已提交
94

95 96 97 98 99 100 101 102 103 104 105 106 107 108
def run_model(model_output_dir, running_round):
  tuning_run(model_output_dir, running_round, False, False)


def generate_production_code(model_output_dirs, pull_or_not):
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
  command = "bash tools/generate_production_code.sh {} {}".format(
      cl_bin_dirs_str, int(pull_or_not))
  run_command(command)


109 110 111 112 113 114
def build_mace_run_prod(model_output_dir, tuning, libmace_name):
  if "dsp" in libmace_name:
    hexagon_mode = True
  else:
    hexagon_mode = False

115
  production_or_not = False
116
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
117 118 119 120 121 122 123 124 125
  tuning_run(
      model_output_dir,
      running_round=0,
      tuning=tuning,
      production_mode=production_or_not)

  production_or_not = True
  pull_or_not = True
  generate_production_code([model_output_dir], pull_or_not)
126
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152


def validate_model(model_output_dir):
  generate_data_or_not = False
  command = "bash tools/validate_tools.sh {} {}".format(
      model_output_dir, int(generate_data_or_not))
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


def merge_libs_and_tuning_results(output_dir, model_output_dirs):
  pull_or_not = False
  generate_production_code(model_output_dirs, pull_or_not)
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
  command = "bash tools/merge_libs.sh {} {}".format(output_dir,
                                                    model_output_dirs_str)
  run_command(command)


def parse_model_configs():
Y
yejianwu 已提交
153 154 155
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
156 157 158 159 160 161 162


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
163
      "--config",
164 165 166 167 168 169 170 171 172 173
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
      "--output_dir", type=str, default="./build/", help="The output dir.")
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
  parser.add_argument(
174
      "--mode", type=str, default="all", help="[build|run|validate|merge|all].")
175 176 177 178
  return parser.parse_known_args()


def main(unused_args):
Y
yejianwu 已提交
179
  configs = parse_model_configs()
180 181 182 183 184 185 186 187

  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, "libmace"))

188 189 190
  if FLAGS.mode == "validate":
    FLAGS.round = 1

Y
yejianwu 已提交
191 192 193 194 195 196
  target_abi = configs["target_abi"]
  libmace_name = get_libs(target_abi, configs)
  # Transfer params by environment
  os.environ["TARGET_ABI"] = target_abi
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
  os.environ["PROJECT_NAME"] = os.path.splitext(FLAGS.config)[0]
197 198

  model_output_dirs = []
Y
yejianwu 已提交
199
  for model_name in configs["models"]:
200
    # Transfer params by environment
Y
yejianwu 已提交
201 202 203 204 205 206 207
    os.environ["MODEL_TAG"] = model_name
    model_config = configs["models"][model_name]
    for key in model_config:
      os.environ[key.upper()] = str(model_config[key])

    model_output_dir = FLAGS.output_dir + "/" + target_abi + "/" + os.path.splitext(
        model_config["tf_model_file_path"])[0]
208 209 210 211 212 213 214 215
    model_output_dirs.append(model_output_dir)

    if FLAGS.mode == "build" or FLAGS.mode == "all":
      if os.path.exists(model_output_dir):
        shutil.rmtree(model_output_dir)
      os.makedirs(model_output_dir)
      clear_env()

216
    if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
217 218 219 220
      generate_random_input(model_output_dir)

    if FLAGS.mode == "build" or FLAGS.mode == "all":
      generate_model_code()
221
      build_mace_run_prod(model_output_dir, FLAGS.tuning, libmace_name)
222

223
    if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
224 225
      run_model(model_output_dir, FLAGS.round)

226 227 228
    if FLAGS.mode == "benchmark":
      benchmark_model(model_output_dir)

229
    if FLAGS.mode == "validate" or FLAGS.mode == "all":
230 231 232
      validate_model(model_output_dir)

  if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
Y
yejianwu 已提交
233 234
    merge_libs_and_tuning_results(FLAGS.output_dir + "/" + target_abi,
                                  model_output_dirs)
235 236


Y
yejianwu 已提交
237
if __name__ == "__main__":
238 239
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)