# Copyright 2019 The MACE Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from __future__ import absolute_import from __future__ import division from __future__ import print_function import inspect import hashlib import filelock import errno import os import sys import shutil import traceback ################################ # log ################################ class CMDColors: PURPLE = '\033[95m' BLUE = '\033[94m' GREEN = '\033[92m' YELLOW = '\033[93m' RED = '\033[91m' ENDC = '\033[0m' BOLD = '\033[1m' UNDERLINE = '\033[4m' def get_frame_info(level=2): caller_frame = inspect.stack()[level] info = inspect.getframeinfo(caller_frame[0]) return info.filename + ':' + str(info.lineno) + ': ' class MaceLogger: @staticmethod def header(message): print(CMDColors.PURPLE + str(message) + CMDColors.ENDC) @staticmethod def summary(message): print(CMDColors.GREEN + str(message) + CMDColors.ENDC) @staticmethod def info(message): print(get_frame_info() + str(message)) @staticmethod def warning(message): print(CMDColors.YELLOW + 'WARNING: ' + get_frame_info() + str(message) + CMDColors.ENDC) @staticmethod def error(message, level=2): print(CMDColors.RED + 'ERROR: ' + get_frame_info(level) + str(message) + CMDColors.ENDC) exit(1) def mace_check(condition, message): if not condition: for line in traceback.format_stack(): print(line.strip()) MaceLogger.error(message, level=3) ################################ # String Formatter ################################ class StringFormatter: @staticmethod def table(header, data, title, align="R"): data_size = len(data) column_size = len(header) column_length = [len(str(ele)) + 1 for ele in header] for row_idx in range(data_size): data_tuple = data[row_idx] ele_size = len(data_tuple) assert (ele_size == column_size) for i in range(ele_size): column_length[i] = max(column_length[i], len(str(data_tuple[i])) + 1) table_column_length = sum(column_length) + column_size + 1 dash_line = '-' * table_column_length + '\n' header_line = '=' * table_column_length + '\n' output = "" output += dash_line output += str(title).center(table_column_length) + '\n' output += dash_line output += '|' + '|'.join([str(header[i]).center(column_length[i]) for i in range(column_size)]) + '|\n' output += header_line for data_tuple in data: ele_size = len(data_tuple) row_list = [] for i in range(ele_size): if align == "R": row_list.append(str(data_tuple[i]).rjust(column_length[i])) elif align == "L": row_list.append(str(data_tuple[i]).ljust(column_length[i])) elif align == "C": row_list.append(str(data_tuple[i]) .center(column_length[i])) output += '|' + '|'.join(row_list) + "|\n" + dash_line return output @staticmethod def block(message): line_length = 10 + len(str(message)) + 10 star_line = '*' * line_length + '\n' return star_line + str(message).center(line_length) + '\n' + star_line def formatted_file_name(input_file_name, input_name): res = input_file_name + '_' for c in input_name: res += c if c.isalnum() else '_' return res ################################ # file ################################ def file_checksum(fname): hash_func = hashlib.sha256() with open(fname, "rb") as f: for chunk in iter(lambda: f.read(4096), b""): hash_func.update(chunk) return hash_func.hexdigest() def download_or_get_file(file, sha256_checksum, output_file): if file.startswith("http://") or file.startswith("https://"): if not os.path.exists(output_file) or file_checksum( output_file) != sha256_checksum: MaceLogger.info("Downloading file %s to %s, please wait ..." % (file, output_file)) if sys.version_info >= (3, 0): import urllib.request data = urllib.request.urlopen(file) out_handle = open(output_file, "wb") out_handle.write(data.read()) out_handle.close() else: import urllib urllib.urlretrieve(file, output_file) MaceLogger.info("Model downloaded successfully.") else: shutil.copyfile(file, output_file) if sha256_checksum: mace_check(file_checksum(output_file) == sha256_checksum, "checksum validate failed") return output_file def download_or_get_model(file, sha256_checksum, output_dir): filename = os.path.basename(file) output_file = "%s/%s-%s.pb" % (output_dir, filename, sha256_checksum) download_or_get_file(file, sha256_checksum, output_file) return output_file ################################ # bazel commands ################################ class ABIType(object): armeabi_v7a = 'armeabi-v7a' arm64_v8a = 'arm64-v8a' arm64 = 'arm64' aarch64 = 'aarch64' armhf = 'armhf' host = 'host' def abi_to_internal(abi): if abi in [ABIType.armeabi_v7a, ABIType.arm64_v8a]: return abi if abi == ABIType.arm64: return ABIType.aarch64 if abi == ABIType.armhf: return ABIType.armeabi_v7a ################################ # lock ################################ def device_lock(device_id, timeout=7200): return filelock.FileLock( "/tmp/device-lock-%s" % device_id.replace("/", ""), timeout=timeout) def is_device_locked(device_id): try: with device_lock(device_id, timeout=0.000001): return False except filelock.Timeout: return True ################################ # os ################################ def mkdir_p(path): try: os.makedirs(path) except OSError as exc: # Python >2.5 if exc.errno == errno.EEXIST and os.path.isdir(path): pass else: raise