mace_tools.py 13.3 KB
Newer Older
1 2
#!/usr/bin/env python

3
# Must run at root dir of libmace project.
4
# python tools/mace_tools.py \
Y
yejianwu 已提交
5
#     --config=tools/example.yaml \
6 7 8 9
#     --round=100 \
#     --mode=all

import argparse
10
import hashlib
11 12 13 14
import os
import shutil
import subprocess
import sys
15
import urllib
Y
yejianwu 已提交
16
import yaml
L
liuqi 已提交
17
import re
18

L
Liangliang He 已提交
19 20
import adb_tools

21 22
from ConfigParser import ConfigParser

Y
yejianwu 已提交
23

24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39
def run_command(command):
  print("Run command: {}".format(command))
  result = subprocess.Popen(
      command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
  out, err = result.communicate()

  if out:
    print("Stdout msg:\n{}".format(out))
  if err:
    print("Stderr msg:\n{}".format(err))

  if result.returncode != 0:
    raise Exception("Exit not 0 from bash with code: {}, command: {}".format(
        result.returncode, command))


Y
yejianwu 已提交
40
def get_global_runtime(configs):
41
  runtime_list = []
Y
yejianwu 已提交
42 43 44
  for model_name in configs["models"]:
    model_runtime = configs["models"][model_name]["runtime"]
    runtime_list.append(model_runtime.lower())
45

Y
yejianwu 已提交
46
  global_runtime = ""
47 48 49 50 51 52 53 54 55
  if "dsp" in runtime_list:
    global_runtime = "dsp"
  elif "gpu" in runtime_list:
    global_runtime = "gpu"
  elif "cpu" in runtime_list:
    global_runtime = "cpu"
  else:
    raise Exception("Not found available RUNTIME in config files!")

Y
yejianwu 已提交
56
  return global_runtime
57 58


Y
yejianwu 已提交
59 60 61
def generate_opencl_and_version_code():
  command = "bash tools/generate_opencl_and_version_code.sh"
  run_command(command)
62

63

Y
yejianwu 已提交
64 65
def clear_env(target_soc):
  command = "bash tools/clear_env.sh {}".format(target_soc)
66 67
  run_command(command)

L
liuqi 已提交
68 69 70
def input_file_name(input_name):
  return os.environ['INPUT_FILE_NAME'] + '_' + \
         re.sub('[^0-9a-zA-Z]+', '_', input_name)
71

L
liuqi 已提交
72 73
def generate_random_input(target_soc, model_output_dir,
                          input_names, input_files):
74
  generate_data_or_not = True
Y
yejianwu 已提交
75 76
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
77 78
  run_command(command)

L
liuqi 已提交
79 80 81 82 83
  input_file_list = []
  if isinstance(input_files, list):
    input_file_list.extend(input_files)
  else:
    input_file_list.append(input_files)
L
liuqi 已提交
84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
  if len(input_file_list) != 0:
    input_name_list = []
    if isinstance(input_names, list):
      input_name_list.extend(input_names)
    else:
      input_name_list.append(input_names)
    if len(input_file_list) != len(input_name_list):
      raise Exception('If input_files set, the input files should match the input names.')
    for i in range(len(input_file_list)):
      if input_file_list[i] is not None:
        dst_input_file = model_output_dir + '/' + input_file_name(input_name_list[i])
        if input_file_list[i].startswith("http://") or \
            input_file_list[i].startswith("https://"):
          urllib.urlretrieve(input_file_list[i], dst_input_file)
        else:
          print 'Copy input data:', dst_input_file
          shutil.copy(input_file_list[i], dst_input_file)
101 102 103

def generate_model_code():
  command = "bash tools/generate_model_code.sh"
Y
yejianwu 已提交
104
  run_command(command)
105 106


107 108 109
def build_mace_run(production_mode, model_output_dir, hexagon_mode):
  command = "bash tools/build_mace_run.sh {} {} {}".format(
      int(production_mode), model_output_dir, int(hexagon_mode))
110 111 112
  run_command(command)


Y
yejianwu 已提交
113 114 115 116 117 118 119
def tuning_run(target_soc,
               model_output_dir,
               running_round,
               tuning,
               production_mode,
               restart_round,
               option_args=''):
Y
yejianwu 已提交
120
  command = "bash tools/tuning_run.sh {} {} {} {} {} {} \"{}\"".format(
Y
yejianwu 已提交
121 122
      target_soc, model_output_dir, running_round, int(tuning),
      int(production_mode), restart_round, option_args)
123 124
  run_command(command)

Y
yejianwu 已提交
125

126 127 128
def benchmark_model(target_soc, model_output_dir, option_args=''):
  command = "bash tools/benchmark.sh {} {} \"{}\"".format(
      target_soc, model_output_dir, option_args)
129
  run_command(command)
130

Y
yejianwu 已提交
131

Y
yejianwu 已提交
132 133 134 135
def run_model(target_soc, model_output_dir, running_round, restart_round,
              option_args):
  tuning_run(target_soc, model_output_dir, running_round, False, False,
             restart_round, option_args)
136 137


Y
yejianwu 已提交
138
def generate_production_code(target_soc, model_output_dirs, pull_or_not):
139 140 141 142
  cl_bin_dirs = []
  for d in model_output_dirs:
    cl_bin_dirs.append(os.path.join(d, "opencl_bin"))
  cl_bin_dirs_str = ",".join(cl_bin_dirs)
Y
yejianwu 已提交
143 144
  command = "bash tools/generate_production_code.sh {} {} {}".format(
      target_soc, cl_bin_dirs_str, int(pull_or_not))
145 146 147
  run_command(command)


Y
yejianwu 已提交
148
def build_mace_run_prod(target_soc, model_output_dir, tuning, global_runtime):
Y
yejianwu 已提交
149
  if "dsp" == global_runtime:
150 151 152 153
    hexagon_mode = True
  else:
    hexagon_mode = False

154
  production_or_not = False
155
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
156
  tuning_run(
Y
yejianwu 已提交
157
      target_soc,
158 159 160
      model_output_dir,
      running_round=0,
      tuning=tuning,
李寅 已提交
161 162
      production_mode=production_or_not,
      restart_round=1)
163 164 165

  production_or_not = True
  pull_or_not = True
Y
yejianwu 已提交
166
  generate_production_code(target_soc, [model_output_dir], pull_or_not)
167
  build_mace_run(production_or_not, model_output_dir, hexagon_mode)
168 169


170 171 172 173
def build_run_throughput_test(target_soc, run_seconds, merged_lib_file,
                              model_input_dir):
  command = "bash tools/build_run_throughput_test.sh {} {} {} {}".format(
      target_soc, run_seconds, merged_lib_file, model_input_dir)
174 175 176
  run_command(command)


Y
fix run  
yejianwu 已提交
177
def validate_model(target_soc, model_output_dir):
178
  generate_data_or_not = False
Y
fix run  
yejianwu 已提交
179 180
  command = "bash tools/validate_tools.sh {} {} {}".format(
      target_soc, model_output_dir, int(generate_data_or_not))
181 182 183 184 185 186 187 188
  run_command(command)


def build_production_code():
  command = "bash tools/build_production_code.sh"
  run_command(command)


Y
yejianwu 已提交
189
def merge_libs_and_tuning_results(target_soc, output_dir, model_output_dirs):
190
  pull_or_not = False
Y
yejianwu 已提交
191
  generate_production_code(target_soc, model_output_dirs, pull_or_not)
192 193 194
  build_production_code()

  model_output_dirs_str = ",".join(model_output_dirs)
Y
yejianwu 已提交
195
  command = "bash tools/merge_libs.sh {} {} {}".format(target_soc, output_dir,
Y
yejianwu 已提交
196
                                                       model_output_dirs_str)
197 198
  run_command(command)

Y
yejianwu 已提交
199

Y
yejianwu 已提交
200 201 202 203
def packaging_lib_file(output_dir):
  command = "bash tools/packaging_lib.sh {}".format(output_dir)
  run_command(command)

204 205

def parse_model_configs():
Y
yejianwu 已提交
206 207 208
  with open(FLAGS.config) as f:
    configs = yaml.load(f)
    return configs
209 210 211 212 213 214 215


def parse_args():
  """Parses command line arguments."""
  parser = argparse.ArgumentParser()
  parser.register("type", "bool", lambda v: v.lower() == "true")
  parser.add_argument(
Y
yejianwu 已提交
216
      "--config",
217 218 219 220
      type=str,
      default="./tool/config",
      help="The global config file of models.")
  parser.add_argument(
221
      "--output_dir", type=str, default="build", help="The output dir.")
222 223
  parser.add_argument(
      "--round", type=int, default=1, help="The model running round.")
李寅 已提交
224
  parser.add_argument(
Y
yejianwu 已提交
225 226 227 228 229 230
      "--run_seconds",
      type=int,
      default=10,
      help="The model throughput test running seconds.")
  parser.add_argument(
      "--restart_round", type=int, default=1, help="The model restart round.")
231 232
  parser.add_argument(
      "--tuning", type="bool", default="true", help="Tune opencl params.")
Y
yejianwu 已提交
233 234 235 236 237
  parser.add_argument(
      "--mode",
      type=str,
      default="all",
      help="[build|run|validate|merge|all|throughput_test].")
L
Liangliang He 已提交
238 239 240 241 242
  parser.add_argument(
      "--socs",
      type=str,
      default="all",
      help="SoCs to build, comma seperated list (getprop ro.board.platform)")
243 244
  return parser.parse_known_args()

L
liuqi 已提交
245 246 247 248 249 250 251
def set_environment(configs):
  os.environ["EMBED_MODEL_DATA"] = str(configs["embed_model_data"])
  os.environ["VLOG_LEVEL"] = str(configs["vlog_level"])
  os.environ["PROJECT_NAME"] = os.path.splitext(os.path.basename(
    FLAGS.config))[0]
  os.environ['INPUT_FILE_NAME'] = "model_input"
  os.environ['OUTPUT_FILE_NAME'] = "model_out"
252 253

def main(unused_args):
Y
yejianwu 已提交
254
  configs = parse_model_configs()
255

256 257
  if FLAGS.mode == "validate":
    FLAGS.round = 1
李寅 已提交
258
    FLAGS.restart_round = 1
259

L
liuqi 已提交
260
  set_environment(configs)
261

Y
yejianwu 已提交
262 263 264 265 266 267 268 269
  if FLAGS.mode == "build" or FLAGS.mode == "all":
    # Remove previous output dirs
    if not os.path.exists(FLAGS.output_dir):
      os.makedirs(FLAGS.output_dir)
    elif os.path.exists(os.path.join(FLAGS.output_dir, "libmace")):
      shutil.rmtree(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))
      os.makedirs(os.path.join(FLAGS.output_dir, os.environ["PROJECT_NAME"]))

Y
yejianwu 已提交
270
  generate_opencl_and_version_code()
W
wuchenghui 已提交
271
  option_args = ' '.join([arg for arg in unused_args if arg.startswith('--')])
Y
yejianwu 已提交
272

L
Liangliang He 已提交
273 274 275 276 277 278 279 280 281 282 283 284 285 286 287
  available_socs = adb_tools.adb_get_all_socs()
  target_socs = available_socs
  if hasattr(configs, "target_socs"):
    target_socs = set(configs["target_socs"])
    target_socs = target_socs & available_socs

  if FLAGS.socs != "all":
    socs = set(FLAGS.socs.split(','))
    target_socs = target_socs & socs
    missing_socs = socs.difference(target_socs)
    if len(missing_socs) > 0:
      print("Error: devices with SoCs are not connected %s" % missing_socs)
      exit(1)

  for target_soc in target_socs:
288
    for target_abi in configs["target_abis"]:
Y
yejianwu 已提交
289
      global_runtime = get_global_runtime(configs)
290
      # Transfer params by environment
Y
yejianwu 已提交
291 292 293 294 295 296
      os.environ["TARGET_ABI"] = target_abi
      model_output_dirs = []
      for model_name in configs["models"]:
        # Transfer params by environment
        os.environ["MODEL_TAG"] = model_name
        print '=======================', model_name, '======================='
Y
yejianwu 已提交
297 298
        skip_validation = configs["models"][model_name].get(
            "skip_validation", 0)
Y
yejianwu 已提交
299
        model_config = configs["models"][model_name]
L
liuqi 已提交
300
        input_file_list = model_config.get("input_files", [])
Y
yejianwu 已提交
301
        for key in model_config:
Y
yejianwu 已提交
302 303 304 305 306 307
          if key in ['input_nodes', 'output_nodes'] and isinstance(
              model_config[key], list):
            os.environ[key.upper()] = ",".join(model_config[key])
          elif key in ['input_shapes', 'output_shapes'] and isinstance(
              model_config[key], list):
            os.environ[key.upper()] = ":".join(model_config[key])
Y
yejianwu 已提交
308 309 310 311 312 313
          else:
            os.environ[key.upper()] = str(model_config[key])

        md5 = hashlib.md5()
        md5.update(model_config["model_file_path"])
        model_path_digest = md5.hexdigest()
Y
yejianwu 已提交
314 315 316 317 318
        model_output_dir = "%s/%s/%s/%s/%s/%s/%s" % (FLAGS.output_dir,
                                                     os.environ["PROJECT_NAME"],
                                                     "build", model_name,
                                                     model_path_digest,
                                                     target_soc, target_abi)
Y
yejianwu 已提交
319 320 321 322 323 324 325 326 327 328
        model_output_dirs.append(model_output_dir)

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          if os.path.exists(model_output_dir):
            shutil.rmtree(model_output_dir)
          os.makedirs(model_output_dir)
          clear_env(target_soc)

        # Support http:// and https://
        if model_config["model_file_path"].startswith(
Y
yejianwu 已提交
329 330
            "http://") or model_config["model_file_path"].startswith(
                "https://"):
Y
yejianwu 已提交
331
          os.environ["MODEL_FILE_PATH"] = model_output_dir + "/model.pb"
Y
yejianwu 已提交
332 333
          urllib.urlretrieve(model_config["model_file_path"],
                             os.environ["MODEL_FILE_PATH"])
Y
yejianwu 已提交
334

Y
yejianwu 已提交
335 336 337 338 339 340 341
        if model_config["platform"] == "caffe" and (
            model_config["weight_file_path"].startswith("http://") or
            model_config["weight_file_path"].startswith("https://")):
          os.environ[
              "WEIGHT_FILE_PATH"] = model_output_dir + "/model.caffemodel"
          urllib.urlretrieve(model_config["weight_file_path"],
                             os.environ["WEIGHT_FILE_PATH"])
Y
yejianwu 已提交
342 343 344

        if FLAGS.mode == "build" or FLAGS.mode == "run" or FLAGS.mode == "validate"\
            or FLAGS.mode == "benchmark" or FLAGS.mode == "all":
L
liuqi 已提交
345 346
          generate_random_input(target_soc, model_output_dir,
            model_config['input_nodes'], input_file_list)
Y
yejianwu 已提交
347 348 349

        if FLAGS.mode == "build" or FLAGS.mode == "all":
          generate_model_code()
Y
yejianwu 已提交
350 351
          build_mace_run_prod(target_soc, model_output_dir, FLAGS.tuning,
                              global_runtime)
Y
yejianwu 已提交
352 353

        if FLAGS.mode == "run" or FLAGS.mode == "validate" or FLAGS.mode == "all":
Y
yejianwu 已提交
354 355
          run_model(target_soc, model_output_dir, FLAGS.round,
                    FLAGS.restart_round, option_args)
Y
yejianwu 已提交
356 357

        if FLAGS.mode == "benchmark":
358
          benchmark_model(target_soc, model_output_dir, option_args)
Y
yejianwu 已提交
359

Y
yejianwu 已提交
360 361
        if FLAGS.mode == "validate" or (FLAGS.mode == "all" and
                                        skip_validation == 0):
Y
fix run  
yejianwu 已提交
362
          validate_model(target_soc, model_output_dir)
Y
yejianwu 已提交
363 364

      if FLAGS.mode == "build" or FLAGS.mode == "merge" or FLAGS.mode == "all":
Y
yejianwu 已提交
365 366 367
        merge_libs_and_tuning_results(
            target_soc, FLAGS.output_dir + "/" + os.environ["PROJECT_NAME"],
            model_output_dirs)
Y
yejianwu 已提交
368 369

      if FLAGS.mode == "throughput_test":
370 371
        merged_lib_file = FLAGS.output_dir + "/%s/%s/libmace_%s.%s.a" % \
            (os.environ["PROJECT_NAME"], target_abi, os.environ["PROJECT_NAME"], target_soc)
L
liuqi 已提交
372
        generate_random_input(target_soc, FLAGS.output_dir, [], [])
Y
yejianwu 已提交
373 374 375
        for model_name in configs["models"]:
          runtime = configs["models"][model_name]["runtime"]
          os.environ["%s_MODEL_TAG" % runtime.upper()] = model_name
376 377
        build_run_throughput_test(target_soc, FLAGS.run_seconds,
                                  merged_lib_file, FLAGS.output_dir)
378

379 380
  if FLAGS.mode == "build" or FLAGS.mode == "all":
    packaging_lib_file(FLAGS.output_dir)
Y
yejianwu 已提交
381

382

Y
yejianwu 已提交
383
if __name__ == "__main__":
384 385
  FLAGS, unparsed = parse_args()
  main(unused_args=[sys.argv[0]] + unparsed)