mace_tools.py 10.9 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=tools/example.yaml \
6 7 8 9
#     --round=100 \
#     --mode=all

import argparse
10
import hashlib
11 12 13 14
import os
import shutil
import subprocess
import sys
15
import urllib
Y
yejianwu 已提交
16
import yaml
17 18 19

from ConfigParser import ConfigParser

Y
yejianwu 已提交
20

21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
37
def get_global_runtime(configs):
38
  runtime_list = []
Y
yejianwu 已提交
39 40 41
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
42

Y
yejianwu 已提交
43
  global_runtime = ""
44 45 46 47 48 49 50 51 52
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
53
  return global_runtime
54 55


Y
yejianwu 已提交
56 57 58
def generate_opencl_and_version_code():
  command = "bash tools/generate_opencl_and_version_code.sh"
  run_command(command)
59

60

Y
yejianwu 已提交
61 62
def clear_env(target_soc):
  command = "bash tools/clear_env.sh {}".format(target_soc)
63 64 65
  run_command(command)


Y
yejianwu 已提交
66
def generate_random_input(target_soc, model_output_dir):
67
  generate_data_or_not = True
Y
yejianwu 已提交
68 69
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
70 71 72 73 74
  run_command(command)


def generate_model_code():
  command = "bash tools/generate_model_code.sh"
Y
yejianwu 已提交
75
  run_command(command)
76 77


78 79 80
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
81 82 83
  run_command(command)


Y
yejianwu 已提交
84 85 86 87 88 89 90
def tuning_run(target_soc,
               model_output_dir,
               running_round,
               tuning,
               production_mode,
               restart_round,
               option_args=''):
Y
yejianwu 已提交
91
  command = "bash tools/tuning_run.sh {} {} {} {} {} {} \"{}\"".format(
Y
yejianwu 已提交
92 93
      target_soc, model_output_dir, running_round, int(tuning),
      int(production_mode), restart_round, option_args)
94 95
  run_command(command)

Y
yejianwu 已提交
96

97 98 99
def benchmark_model(target_soc, model_output_dir, option_args=''):
  command = "bash tools/benchmark.sh {} {} \"{}\"".format(
      target_soc, model_output_dir, option_args)
100
  run_command(command)
101

Y
yejianwu 已提交
102

Y
yejianwu 已提交
103 104 105 106
def run_model(target_soc, model_output_dir, running_round, restart_round,
              option_args):
  tuning_run(target_soc, model_output_dir, running_round, False, False,
             restart_round, option_args)
107 108


Y
yejianwu 已提交
109
def generate_production_code(target_soc, model_output_dirs, pull_or_not):
110 111 112 113
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
Y
yejianwu 已提交
114 115
  command = "bash tools/generate_production_code.sh {} {} {}".format(
      target_soc, cl_bin_dirs_str, int(pull_or_not))
116 117 118
  run_command(command)


Y
yejianwu 已提交
119
def build_mace_run_prod(target_soc, model_output_dir, tuning, global_runtime):
Y
yejianwu 已提交
120
  if "dsp" == global_runtime:
121 122 123 124
    hexagon_mode = True
  else:
    hexagon_mode = False

125
  production_or_not = False
126
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
127
  tuning_run(
Y
yejianwu 已提交
128
      target_soc,
129 130 131
      model_output_dir,
      running_round=0,
      tuning=tuning,
李寅 已提交
132 133
      production_mode=production_or_not,
      restart_round=1)
134 135 136

  production_or_not = True
  pull_or_not = True
Y
yejianwu 已提交
137
  generate_production_code(target_soc, [model_output_dir], pull_or_not)
138
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
139 140


141 142 143 144
def build_run_throughput_test(target_soc, run_seconds, merged_lib_file,
                              model_input_dir):
  command = "bash tools/build_run_throughput_test.sh {} {} {} {}".format(
      target_soc, run_seconds, merged_lib_file, model_input_dir)
145 146 147
  run_command(command)


Y
fix run  
yejianwu 已提交
148
def validate_model(target_soc, model_output_dir):
149
  generate_data_or_not = False
Y
fix run  
yejianwu 已提交
150 151
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
152 153 154 155 156 157 158 159
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


Y
yejianwu 已提交
160
def merge_libs_and_tuning_results(target_soc, output_dir, model_output_dirs):
161
  pull_or_not = False
Y
yejianwu 已提交
162
  generate_production_code(target_soc, model_output_dirs, pull_or_not)
163 164 165
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
Y
yejianwu 已提交
166
  command = "bash tools/merge_libs.sh {} {} {}".format(target_soc, output_dir,
Y
yejianwu 已提交
167
                                                       model_output_dirs_str)
168 169 170 171
  run_command(command)


def parse_model_configs():
Y
yejianwu 已提交
172 173 174
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
175 176 177 178 179 180 181


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
182
      "--config",
183 184 185 186
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
187
      "--output_dir", type=str, default="build", help="The output dir.")
188 189
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
李寅 已提交
190
  parser.add_argument(
Y
yejianwu 已提交
191 192 193 194 195 196
      "--run_seconds",
      type=int,
      default=10,
      help="The model throughput test running seconds.")
  parser.add_argument(
      "--restart_round", type=int, default=1, help="The model restart round.")
197 198
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
Y
yejianwu 已提交
199 200 201 202 203
  parser.add_argument(
      "--mode",
      type=str,
      default="all",
      help="[build|run|validate|merge|all|throughput_test].")
204 205 206 207
  return parser.parse_known_args()


def main(unused_args):
Y
yejianwu 已提交
208
  configs = parse_model_configs()
209

210 211
  if FLAGS.mode == "validate":
    FLAGS.round = 1
李寅 已提交
212
    FLAGS.restart_round = 1
213

Y
yejianwu 已提交
214
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
Y
yejianwu 已提交
215
  os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
Y
yejianwu 已提交
216 217
  os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(
      FLAGS.config))[0]
218

Y
yejianwu 已提交
219 220 221 222 223 224 225 226
  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))
      os.makedirs(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))

Y
yejianwu 已提交
227
  generate_opencl_and_version_code()
W
wuchenghui 已提交
228
  option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')])
Y
yejianwu 已提交
229

230 231
  for target_soc in configs["target_socs"]:
    for target_abi in configs["target_abis"]:
Y
yejianwu 已提交
232
      global_runtime = get_global_runtime(configs)
233
      # Transfer params by environment
Y
yejianwu 已提交
234 235 236 237 238 239
      os.environ["TARGET_ABI"] = target_abi
      model_output_dirs = []
      for model_name in configs["models"]:
        # Transfer params by environment
        os.environ["MODEL_TAG"] = model_name
        print '=======================', model_name, '======================='
Y
yejianwu 已提交
240
        skip_validation = configs["models"][model_name]["skip_validation"]
Y
yejianwu 已提交
241 242
        model_config = configs["models"][model_name]
        for key in model_config:
Y
yejianwu 已提交
243 244 245 246 247 248
          if key in ['input_nodes', 'output_nodes'] and isinstance(
              model_config[key], list):
            os.environ[key.upper()] = ",".join(model_config[key])
          elif key in ['input_shapes', 'output_shapes'] and isinstance(
              model_config[key], list):
            os.environ[key.upper()] = ":".join(model_config[key])
Y
yejianwu 已提交
249 250 251 252 253 254
          else:
            os.environ[key.upper()] = str(model_config[key])

        md5 = hashlib.md5()
        md5.update(model_config["model_file_path"])
        model_path_digest = md5.hexdigest()
Y
yejianwu 已提交
255 256 257
        model_output_dir = "%s/%s/%s/%s/%s" % (FLAGS.output_dir, model_name,
                                               model_path_digest, target_soc,
                                               target_abi)
Y
yejianwu 已提交
258 259 260 261 262 263 264 265 266 267
        model_output_dirs.append(model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          if os.path.exists(model_output_dir):
            shutil.rmtree(model_output_dir)
          os.makedirs(model_output_dir)
          clear_env(target_soc)

        # Support http:// and https://
        if model_config["model_file_path"].startswith(
Y
yejianwu 已提交
268 269
            "http://") or model_config["model_file_path"].startswith(
                "https://"):
Y
yejianwu 已提交
270
          os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
Y
yejianwu 已提交
271 272
          urllib.urlretrieve(model_config["model_file_path"],
                             os.environ["MODEL_FILE_PATH"])
Y
yejianwu 已提交
273

Y
yejianwu 已提交
274 275 276 277 278 279 280
        if model_config["platform"] == "caffe" and (
            model_config["weight_file_path"].startswith("http://") or
            model_config["weight_file_path"].startswith("https://")):
          os.environ[
              "WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
          urllib.urlretrieve(model_config["weight_file_path"],
                             os.environ["WEIGHT_FILE_PATH"])
Y
yejianwu 已提交
281 282 283 284 285 286 287

        if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
            or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
          generate_random_input(target_soc, model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          generate_model_code()
Y
yejianwu 已提交
288 289
          build_mace_run_prod(target_soc, model_output_dir, FLAGS.tuning,
                              global_runtime)
Y
yejianwu 已提交
290 291

        if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
Y
yejianwu 已提交
292 293
          run_model(target_soc, model_output_dir, FLAGS.round,
                    FLAGS.restart_round, option_args)
Y
yejianwu 已提交
294 295

        if FLAGS.mode == "benchmark":
296
          benchmark_model(target_soc, model_output_dir, option_args)
Y
yejianwu 已提交
297

Y
yejianwu 已提交
298
        if FLAGS.mode == "validate" or (FLAGS.mode == "all" and skip_validation == 0):
Y
fix run  
yejianwu 已提交
299
          validate_model(target_soc, model_output_dir)
Y
yejianwu 已提交
300 301

      if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
Y
yejianwu 已提交
302 303 304
        merge_libs_and_tuning_results(
            target_soc, FLAGS.output_dir + "/" + os.environ["PROJECT_NAME"],
            model_output_dirs)
Y
yejianwu 已提交
305 306

      if FLAGS.mode == "throughput_test":
307 308
        merged_lib_file = FLAGS.output_dir + "/%s/%s/libmace_%s.%s.a" % \
            (os.environ["PROJECT_NAME"], target_abi, os.environ["PROJECT_NAME"], target_soc)
Y
yejianwu 已提交
309 310 311 312
        generate_random_input(target_soc, FLAGS.output_dir)
        for model_name in configs["models"]:
          runtime = configs["models"][model_name]["runtime"]
          os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
313 314
        build_run_throughput_test(target_soc, FLAGS.run_seconds,
                                  merged_lib_file, FLAGS.output_dir)
315

316

Y
yejianwu 已提交
317
if __name__ == "__main__":
318 319
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)