提交 9de9dd82 编写于 作者: C chenqiyou

feat(tools): add two model evaluation tools

上级 d9a46ea4
#!/usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
purpose: use to test whether a model have good parallelism, if a model have good
parallelism it will get high performance improvement.
"""
import argparse
import logging
import os
import re
import subprocess
# test device
device = {
"name": "hwmt40p",
"login_name": "hwmt40p-K9000-maliG78",
"ip": "box86.br.megvii-inc.com",
"port": 2200,
"thread_number": 3,
}
class SshConnector:
"""imp ssh control master connector"""
ip = None
port = None
login_name = None
def setup(self, login_name, ip, port):
self.ip = ip
self.login_name = login_name
self.port = port
def copy(self, src_list, dst_dir):
assert isinstance(src_list, list), "code issue happened!!"
assert isinstance(dst_dir, str), "code issue happened!!"
for src in src_list:
cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
self.port, src, self.login_name, self.ip, dst_dir
)
logging.debug("ssh run cmd: {}".format(cmd))
subprocess.check_call(cmd, shell=True)
def cmd(self, cmd):
output = ""
assert isinstance(cmd, list), "code issue happened!!"
for sub_cmd in cmd:
p_cmd = 'ssh -p {} {}@{} "{}" '.format(
self.port, self.login_name, self.ip, sub_cmd
)
logging.debug("ssh run cmd: {}".format(p_cmd))
output = output + subprocess.check_output(p_cmd, shell=True).decode("utf-8")
return output
def get_finally_bench_resulut_from_log(raw_log) -> float:
# raw_log --> avg_time=23.331ms -->23.331ms
h = re.findall(r"avg_time=.*ms ", raw_log)[-1][9:]
# to 23.331
h = h[: h.find("ms")]
# to float
h = float(h)
return h
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model_file", help="model file", required=True)
parser.add_argument(
"--load_and_run_file", help="path for load_and_run", required=True
)
args = parser.parse_args()
# init device
ssh = SshConnector()
ssh.setup(device["login_name"], device["ip"], device["port"])
# create test dir
workspace = "model_parallelism_test"
ssh.cmd(["mkdir -p {}".format(workspace)])
# copy load_and_run_file
ssh.copy([args.load_and_run_file], workspace)
# call test
model_file = args.model_file
# copy model file
ssh.copy([args.model_file], workspace)
m = model_file.split('\\')[-1]
# run single thread
result = []
thread_number = [1, 2, 4]
for b in thread_number :
cmd = []
cmd1 = "cd {} && ./load_and_run {} -multithread {} --fast-run --fast_run_algo_policy fastrun.cache --iter 1 --warmup-iter 1 --no-sanity-check --weight-preprocess".format(
workspace, m, b
)
cmd2 = "cd {} && ./load_and_run {} -multithread {} --fast_run_algo_policy fastrun.cache --iter 20 --warmup-iter 5 --no-sanity-check --weight-preprocess ".format(
workspace, m, b
)
cmd.append(cmd1)
cmd.append(cmd2)
raw_log = ssh.cmd(cmd)
# logging.debug(raw_log)
ret = get_finally_bench_resulut_from_log(raw_log)
logging.debug("model: {} with backend: {} result is: {}".format(m, b, ret))
result.append(ret)
thread_2 = result[0]/result[1]
thread_4 = result[0]/result[2]
if thread_2 > 1.6 or thread_4 > 3.0:
print("model: {} can has good parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
else:
print("model: {} can has bad parallelism. 2 thread is {}, 4 thread is {}".format(m, thread_2, thread_4))
if __name__ == "__main__":
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
main()
#!/usr/bin/env python3
# MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
#
# Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
"""
purpose: use to test whether a model contain dynamic operator, if no dynamic
operator the model is static, other wise the model is dynamic.
"""
import argparse
import logging
import os
import re
import subprocess
# test device
device = {
"name": "hwmt40p",
"login_name": "hwmt40p-K9000-maliG78",
"ip": "box86.br.megvii-inc.com",
"port": 2200,
"thread_number": 3,
}
class SshConnector:
"""imp ssh control master connector"""
ip = None
port = None
login_name = None
def setup(self, login_name, ip, port):
self.ip = ip
self.login_name = login_name
self.port = port
def copy(self, src_list, dst_dir):
assert isinstance(src_list, list), "code issue happened!!"
assert isinstance(dst_dir, str), "code issue happened!!"
for src in src_list:
cmd = 'rsync --progress -a -e "ssh -p {}" {} {}@{}:{}'.format(
self.port, src, self.login_name, self.ip, dst_dir
)
logging.debug("ssh run cmd: {}".format(cmd))
subprocess.check_call(cmd, shell=True)
def cmd(self, cmd):
assert isinstance(cmd, list), "code issue happened!!"
try:
for sub_cmd in cmd:
p_cmd = 'ssh -p {} {}@{} "{}" '.format(
self.port, self.login_name, self.ip, sub_cmd
)
logging.debug("ssh run cmd: {}".format(p_cmd))
subprocess.check_call(p_cmd, shell=True)
except:
raise
def main():
parser = argparse.ArgumentParser(formatter_class=argparse.RawTextHelpFormatter)
parser.add_argument("--model_file", help="megengine model", required=True)
parser.add_argument(
"--load_and_run_file", help="path for load_and_run", required=True
)
args = parser.parse_args()
assert os.path.isfile(
args.model_file
), "invalid args for models_file, need a file for model"
assert os.path.isfile(args.load_and_run_file), "invalid args for load_and_run_file"
# init device
ssh = SshConnector()
ssh.setup(device["login_name"], device["ip"], device["port"])
# create test dir
workspace = "model_static_evaluation_workspace"
ssh.cmd(["mkdir -p {}".format(workspace)])
# copy load_and_run_file
ssh.copy([args.load_and_run_file], workspace)
model_file = args.model_file
# copy model file
ssh.copy([model_file], workspace)
m = model_file.split('\\')[-1]
# run single thread
cmd = "cd {} && ./load_and_run {} --fast-run --record-comp-seq --iter 1 --warmup-iter 1".format(
workspace, m
)
try:
raw_log = ssh.cmd([cmd])
except:
print("model: {} is not static model, it has dynamic operator.".format(m))
raise
print("model: {} is static model.".format(m))
if __name__ == "__main__":
LOG_FORMAT = "%(asctime)s - %(levelname)s - %(message)s"
DATE_FORMAT = "%Y/%m/%d %H:%M:%S"
logging.basicConfig(level=logging.DEBUG, format=LOG_FORMAT, datefmt=DATE_FORMAT)
main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册