提交 41107c86 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!73 fix matmul tuning and support all space tuning.

Merge pull request !73 from chenlei_autodiff/matmul_tiling
...@@ -15,22 +15,23 @@ ...@@ -15,22 +15,23 @@
"""AutoTuning job""" """AutoTuning job"""
import os import os
import json import json
import time
import datetime import datetime
import importlib import importlib
import logging import logging
import subprocess
import numpy as np import numpy as np
from collections import namedtuple from collections import namedtuple
from akg import composite from akg import composite
from akg.utils import kernel_exec as utils from akg.utils import kernel_exec as utils
from autotuning.runner import KernelRunner, error_time_list, error_time_string from autotuning.runner import KernelRunner, error_time_list, error_time_string
from autotuning.tuner import ModelBasedTuner from autotuning.tuner import ModelBasedTuner, Tuner
from autotuning.type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc from autotuning.type_definitions import ConvDesc, ConvBackpropDesc, MatmulCubeDesc
from autotuning.space_generators import get_space from autotuning.space_generators import get_space
from autotuning.space import ListConfigSpace from autotuning.space import ListConfigSpace
from autotuning.test_data_generators import gen_data from autotuning.test_data_generators import gen_data
logging.basicConfig(level=logging.DEBUG, logging.basicConfig(level=logging.DEBUG)
format='%(asctime)s - %(filename)s[line:%(lineno)d] - %(levelname)s: %(message)s')
logger = logging.getLogger('fuzz.tune.autotuning.job') logger = logging.getLogger('fuzz.tune.autotuning.job')
...@@ -92,11 +93,16 @@ def launch_json(debug_mode: bool = True, save_res: bool = False, json_input_dir= ...@@ -92,11 +93,16 @@ def launch_json(debug_mode: bool = True, save_res: bool = False, json_input_dir=
if save_res: if save_res:
save_tuning_result(key, "json", None, index_table, tuner) save_tuning_result(key, "json", None, index_table, tuner)
def jobs(op_type: str = 'add', desc=None, debug_mode: bool = True, def jobs(op_type: str = 'add', desc=None, debug_mode: bool = True, save_res: bool = False,
save_res: bool = False, insert_key='', conf_of_set_dim=""): all_space: bool = True, insert_key='', conf_of_set_dim=""):
"""AutoTuning jobs""" """AutoTuning jobs"""
iter_times = [3, 3, 3] if debug_mode else [80, 160, 320] iter_times = [3, 3, 3] if debug_mode else [80, 160, 320]
time_start_get_space = time.time()
index_table, space, key, expect, input_for_mod = get_space(op_type, desc) index_table, space, key, expect, input_for_mod = get_space(op_type, desc)
if all_space:
iter_times = [space.length, space.length, space.length]
time_end_get_space = time.time()
print("get space time: ", time_end_get_space - time_start_get_space)
print('space size:', space.length) print('space size:', space.length)
print('index table:', index_table) print('index table:', index_table)
key = key if insert_key == '' else insert_key key = key if insert_key == '' else insert_key
...@@ -121,12 +127,18 @@ def jobs(op_type: str = 'add', desc=None, debug_mode: bool = True, ...@@ -121,12 +127,18 @@ def jobs(op_type: str = 'add', desc=None, debug_mode: bool = True,
# available device numbers, normally is 8 or 1 # available device numbers, normally is 8 or 1
available_device_numbers = utils.get_available_devices_num() available_device_numbers = utils.get_available_devices_num()
tuner = ModelBasedTuner(runner, index_table, space, time_start_tuning = time.time()
n_parallel=available_device_numbers if is_truly_profiling else 1, if all_space:
plan_size=64, pre_model=None) tuner = Tuner(runner, index_table, space, n_parallel=available_device_numbers)
else:
tuner = ModelBasedTuner(runner, index_table, space,
n_parallel=available_device_numbers if is_truly_profiling else 1,
plan_size=64, pre_model=None)
least_try_times = iter_times[0 if space.length < 10 ** 4 else 1 if space.length < 10 ** 5 else 2] least_try_times = iter_times[0 if space.length < 10 ** 4 else 1 if space.length < 10 ** 5 else 2]
tuner.tune(least_try_times, output_file=op_type + ".log") tuner.tune(least_try_times, output_file=op_type + ".log")
time_end_tuning = time.time()
print("tuning time: ", time_end_tuning - time_start_tuning)
print_tuning_result(op_type, space, index_table, tuner, key) print_tuning_result(op_type, space, index_table, tuner, key)
if save_res: if save_res:
...@@ -231,46 +243,48 @@ def load_json_configs(op_type): ...@@ -231,46 +243,48 @@ def load_json_configs(op_type):
return {} return {}
return {} return {}
def read_shapes_from_file(debug_mode, save_res, conf_of_set_dim, op_type): def read_shapes_from_file(debug_mode, save_res, all_space, conf_of_set_dim, op_type):
"""read tuning shapes from file""" """read tuning shapes from file"""
file = importlib.import_module('autotuning.shapes.' + op_type) file = importlib.import_module('autotuning.shapes.' + op_type)
shapes = file.shapes shapes = file.shapes
for _, shp in enumerate(shapes): for _, shp in enumerate(shapes):
do_profiling(shp, debug_mode, save_res, op_type, conf_of_set_dim) do_profiling(shp, debug_mode, save_res, all_space, op_type, conf_of_set_dim)
def do_profiling(shp, debug_mode, save_res, op_type, conf_of_set_dim=None): def do_profiling(shp, debug_mode, save_res, all_space, op_type, conf_of_set_dim=None):
"""do profiling""" """do profiling"""
# remove undeleted JOB files for previous shapes
subprocess.run("rm -rf /var/log/npu/profiling/JOB*", shell=True)
if op_type == 'matmul': if op_type == 'matmul':
key = shp[2][0:-1] key = shp[2][0:-1]
logger.debug("start profiling: [%s]", str(key)) logger.debug("start profiling: [%s]", str(key))
desc = MatmulCubeDesc(*key) desc = MatmulCubeDesc(*key)
jobs(op_type, desc, debug_mode, save_res, key.__str__(), conf_of_set_dim) jobs(op_type, desc, debug_mode, save_res, all_space, key.__str__(), conf_of_set_dim)
logger.debug("end profiling: [%s]", str(key)) logger.debug("end profiling: [%s]", str(key))
elif op_type.startswith('conv_backprop'): elif op_type.startswith('conv_backprop'):
key = shp[2] key = shp[2]
logger.debug("start profiling: [%s]", str(key)) logger.debug("start profiling: [%s]", str(key))
desc = ConvBackpropDesc(*key) desc = ConvBackpropDesc(*key)
jobs(op_type, desc, debug_mode, save_res, key.__str__(), conf_of_set_dim) jobs(op_type, desc, debug_mode, save_res, all_space, key.__str__(), conf_of_set_dim)
logger.debug("end profiling: [%s]", str(key)) logger.debug("end profiling: [%s]", str(key))
elif op_type.startswith('conv'): elif op_type.startswith('conv'):
key = shp[2] key = shp[2]
logger.debug("start profiling: [%s]", str(key)) logger.debug("start profiling: [%s]", str(key))
desc = ConvDesc(*key) desc = ConvDesc(*key)
jobs(op_type, desc, debug_mode, save_res, key.__str__(), conf_of_set_dim) jobs(op_type, desc, debug_mode, save_res, all_space, key.__str__(), conf_of_set_dim)
logger.debug("end profiling: [%s]", str(key)) logger.debug("end profiling: [%s]", str(key))
else: else:
key = shp key = shp
logger.debug("start profiling: [%s]", str(key)) logger.debug("start profiling: [%s]", str(key))
desc = key desc = key
jobs(op_type, desc, debug_mode, save_res, conf_of_set_dim=conf_of_set_dim) jobs(op_type, desc, debug_mode, save_res, all_space, conf_of_set_dim=conf_of_set_dim)
logger.debug("end profiling: [%s]", str(key)) logger.debug("end profiling: [%s]", str(key))
def launch(op_type, debug_mode, save_res=False, desc=None): def launch(op_type, debug_mode, save_res=False, desc=None, all_space=False):
# get the existed tiling # get the existed tiling
conf_of_set_dim = load_json_configs(op_type) conf_of_set_dim = load_json_configs(op_type)
if desc is None: if desc is None:
read_shapes_from_file(debug_mode, save_res, conf_of_set_dim, op_type) read_shapes_from_file(debug_mode, save_res, all_space, conf_of_set_dim, op_type)
else: else:
shp = desc shp = desc
do_profiling(shp, debug_mode, save_res, op_type) do_profiling(shp, debug_mode, save_res, all_space, op_type)
...@@ -115,7 +115,8 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table, ...@@ -115,7 +115,8 @@ def gen_kernel_matmul_cube(op_desc: MatmulCubeDesc, _, index_table,
attrs = {'dim': dim_info, 'bypass': config.bypass} attrs = {'dim': dim_info, 'bypass': config.bypass}
return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format, return matmul_run.matmul_compile(op_desc.x_shape, op_desc.y_shape, op_desc.bias, op_desc.left_format,
op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y, op_desc.right_format, op_desc.out_format, op_desc.adj_x, op_desc.adj_y,
op_desc.dtype, op_desc.out_dtype, kernel_name, attrs, gen_tiling_spaces) op_desc.dtype, op_desc.bias_dtype, op_desc.out_dtype, kernel_name,
attrs, tuning=gen_tiling_spaces)
def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None, def gen_kernel_conv_backprop_input(op_desc: ConvBackpropDesc, _, index_table, config: ConvBackpropInputConfig = None,
......
...@@ -18,6 +18,7 @@ import multiprocessing ...@@ -18,6 +18,7 @@ import multiprocessing
import logging import logging
import os import os
import subprocess import subprocess
import time
from typing import NamedTuple from typing import NamedTuple
import numpy as np import numpy as np
from akg import composite from akg import composite
...@@ -86,8 +87,10 @@ class KernelRunner: ...@@ -86,8 +87,10 @@ class KernelRunner:
def run_one_kernel(self, run_times, idx, config, best_time=np.inf, is_auto=False): def run_one_kernel(self, run_times, idx, config, best_time=np.inf, is_auto=False):
"""Compile and execute a config of the operator on device""" """Compile and execute a config of the operator on device"""
time_one_kernel_start = time.time()
logger.debug('compile %dth kernel', idx) logger.debug('compile %dth kernel', idx)
try: try:
time_start_build = time.time()
if self.op_type == "json": if self.op_type == "json":
if is_auto: if is_auto:
mod = composite.build(self.op_desc) mod = composite.build(self.op_desc)
...@@ -105,6 +108,8 @@ class KernelRunner: ...@@ -105,6 +108,8 @@ class KernelRunner:
else: else:
mod = compile_kernel(self.op_type, self.op_desc, self.input_shape, self._index_table, mod = compile_kernel(self.op_type, self.op_desc, self.input_shape, self._index_table,
None if is_auto else config.input, idx) None if is_auto else config.input, idx)
time_end_build = time.time()
logger.debug("build module time: %f", time_end_build - time_start_build)
logger.debug('finished compile %dth kernel', idx) logger.debug('finished compile %dth kernel', idx)
except BaseException as e: except BaseException as e:
logger.debug("Compile Failed: [%s] : %s", "origin" if is_auto else str(config.input), str(e)) logger.debug("Compile Failed: [%s] : %s", "origin" if is_auto else str(config.input), str(e))
...@@ -127,6 +132,7 @@ class KernelRunner: ...@@ -127,6 +132,7 @@ class KernelRunner:
for _ in range(self.repeat_times): for _ in range(self.repeat_times):
stat_info = {} stat_info = {}
try: try:
time_start_launch = time.time()
if self.mod_output_param is not None: if self.mod_output_param is not None:
output, stat_info = utils.mod_launch(mod, list(self.input), self.mod_output_param, output, stat_info = utils.mod_launch(mod, list(self.input), self.mod_output_param,
tuning=True, device_id=device_id) tuning=True, device_id=device_id)
...@@ -144,18 +150,24 @@ class KernelRunner: ...@@ -144,18 +150,24 @@ class KernelRunner:
stat_info['run_time'] = precision_error_time stat_info['run_time'] = precision_error_time
logger.debug("Precision Error: [%s]", logger.debug("Precision Error: [%s]",
"origin" if config is None else str(config.input)) "origin" if config is None else str(config.input))
time_end_launch = time.time()
logger.debug("mod launch time: %f", time_end_launch - time_start_launch)
except BaseException as e: except BaseException as e:
logger.debug("Run Failed: [%s] : %s", str(config.input), str(e)) logger.debug("Run Failed: [%s] : %s", str(config.input), str(e))
stat_info['run_time'] = run_failed_time stat_info['run_time'] = run_failed_time
run_times[idx] = np.minimum(run_times[idx], stat_info['run_time']) run_times[idx] = np.minimum(run_times[idx], stat_info['run_time'])
finally: finally:
logger.debug('end of %dth kernel', idx) logger.debug('end of %dth kernel', idx)
time_one_kernel_end = time.time()
logger.debug('run one kernel time: %f', time_one_kernel_end - time_one_kernel_start)
return return
def run(self, configs, best_time=np.inf, is_auto_set_dim=False): def run(self, configs, best_time=np.inf, is_auto_set_dim=False, all_space=False):
"""Compile and execute a batch config of the operator on device""" """Compile and execute a batch config of the operator on device"""
start = time.time() start = time.time()
logger.setLevel(logging.DEBUG)
logger.debug("gen cce kernels batch: %d kernels", len(configs)) logger.debug("gen cce kernels batch: %d kernels", len(configs))
subprocess.run("rm -rf ./jobs/JOB*", shell=True)
process_jobs = [] process_jobs = []
run_times = multiprocessing.Manager().list(np.full((len(configs),), compile_fail_time)) run_times = multiprocessing.Manager().list(np.full((len(configs),), compile_fail_time))
for idx, config in enumerate(configs): for idx, config in enumerate(configs):
...@@ -173,6 +185,8 @@ class KernelRunner: ...@@ -173,6 +185,8 @@ class KernelRunner:
run_times[idx] = timeout_time run_times[idx] = timeout_time
p.terminate() p.terminate()
process_end = time.time()
logger.debug("process time: %f", process_end - start)
# clean the profiling directory # clean the profiling directory
tune_device = int(os.environ['DEVICE_ID']) tune_device = int(os.environ['DEVICE_ID'])
tune_num = int(os.environ['DEVICE_TOTAL_NUM']) tune_num = int(os.environ['DEVICE_TOTAL_NUM'])
...@@ -206,6 +220,7 @@ class KernelRunner: ...@@ -206,6 +220,7 @@ class KernelRunner:
job_file = p[0].decode('utf8').strip().split('/')[-2] job_file = p[0].decode('utf8').strip().split('/')[-2]
subprocess.run("rm -rf ./jobs/%s" % job_file, shell=True) subprocess.run("rm -rf ./jobs/%s" % job_file, shell=True)
end = time.time() end = time.time()
logger.debug("run kernels time: %f", end - start)
self.run_kernel_time += end - start self.run_kernel_time += end - start
for idx, config in enumerate(configs): for idx, config in enumerate(configs):
......
...@@ -161,6 +161,9 @@ class ListConfigSpace(ConfigSpace): ...@@ -161,6 +161,9 @@ class ListConfigSpace(ConfigSpace):
"""reset fetch state""" """reset fetch state"""
self.__fetch_pool = [i for i in range(len(self._configs))] self.__fetch_pool = [i for i in range(len(self._configs))]
def fetch_scope(self, start, end):
self.__fetch_pool = [i for i in range(start, end)]
def has_next(self) -> bool: def has_next(self) -> bool:
return len(self.__fetch_pool) > 0 return len(self.__fetch_pool) > 0
...@@ -172,6 +175,12 @@ class ListConfigSpace(ConfigSpace): ...@@ -172,6 +175,12 @@ class ListConfigSpace(ConfigSpace):
self.__fetch_pool.pop() self.__fetch_pool.pop()
return ret return ret
def fetch_next_index(self) -> int:
"""fetch next index of config"""
idx = len(self.__fetch_pool) - 1 + self.__fetch_pool[0]
self.__fetch_pool.pop()
return idx
def fetch_config(self) -> ConfigEntity: def fetch_config(self) -> ConfigEntity:
"""fetch a random config""" """fetch a random config"""
return self.get(self.fetch_index()) return self.get(self.fetch_index())
......
...@@ -107,10 +107,10 @@ def _gen_data_matmul_cube(op_desc: MatmulCubeDesc): ...@@ -107,10 +107,10 @@ def _gen_data_matmul_cube(op_desc: MatmulCubeDesc):
_, _, _, out_shape, k = matmul_run.get_converted_shapes(m, n, k, batch_tuple, op_desc.adj_x, op_desc.adj_y, _, _, _, out_shape, k = matmul_run.get_converted_shapes(m, n, k, batch_tuple, op_desc.adj_x, op_desc.adj_y,
op_desc.bias, op_desc.left_format, op_desc.right_format, op_desc.bias, op_desc.left_format, op_desc.right_format,
op_desc.out_format) op_desc.out_format)
m_x, m_y, bench_mark, bias_data = matmul_run.matmul_data(batch_tuple, m, k, n, op_desc.dtype, op_desc.out_dtype, m_x, m_y, bench_mark, bias_data = matmul_run.matmul_data(batch_tuple, m, k, n, op_desc.dtype, op_desc.bias_dtype,
op_desc.bias, op_desc.adj_x, op_desc.adj_y, op_desc.out_dtype, op_desc.bias, op_desc.adj_x,
op_desc.left_format, op_desc.right_format, op_desc.adj_y, op_desc.left_format,
op_desc.out_format) op_desc.right_format, op_desc.out_format)
out_data = np.full(out_shape, np.nan, op_desc.out_dtype) out_data = np.full(out_shape, np.nan, op_desc.out_dtype)
......
...@@ -93,7 +93,7 @@ class Tuner: ...@@ -93,7 +93,7 @@ class Tuner:
print('tuning time:', self._tuning_time, 'secs') print('tuning time:', self._tuning_time, 'secs')
def next_batch(self, batch_size: int, is_add_visited=True): def next_batch(self, batch_size: int, is_add_visited=True):
"""extract next batch""" """extract next batch with xgboost model"""
ret = [] ret = []
counter = 0 counter = 0
if not is_add_visited: if not is_add_visited:
...@@ -116,6 +116,17 @@ class Tuner: ...@@ -116,6 +116,17 @@ class Tuner:
counter += 1 counter += 1
return ret return ret
def next_config(self, batch_size: int):
"""extract next config orderly"""
ret = []
counter = 0
while counter < batch_size and self._space.has_next():
index = self._space.fetch_next_index()
ret.append(self._space.get(index))
self._visited.add(index)
counter += 1
return ret
def export_configs(self, configs: list, output_file: str, append: bool = True, desc=""): def export_configs(self, configs: list, output_file: str, append: bool = True, desc=""):
"""export configs""" """export configs"""
mode = "a" if append else "w" mode = "a" if append else "w"
...@@ -158,13 +169,13 @@ class Tuner: ...@@ -158,13 +169,13 @@ class Tuner:
while i < least_try_times: while i < least_try_times:
if not self._space.has_next(): if not self._space.has_next():
break break
configs = self.next_batch(min(self._n_parallel, least_try_times - i)) configs = self.next_config(min(self._n_parallel, least_try_times - i))
run_times = self._runner.run(configs, self._best_time) run_times = self._runner.run(configs, self._best_time)
results = [] results = []
for idx, conf in enumerate(configs): for idx, conf in enumerate(configs):
results.append((conf.input_id, run_times[idx])) results.append((conf.input_id, run_times[idx]))
# keep best config # keep best config
if self.best_time < run_times[idx]: if self.best_time > run_times[idx]:
self._best_time = run_times[idx] self._best_time = run_times[idx]
self._best_iter = i + idx self._best_iter = i + idx
self._best_config = conf self._best_config = conf
...@@ -224,6 +235,7 @@ class ModelBasedTuner(Tuner): ...@@ -224,6 +235,7 @@ class ModelBasedTuner(Tuner):
self.__least_try_times = least_try_times self.__least_try_times = least_try_times
self.__early_stopping = early_stopping self.__early_stopping = early_stopping
logger.setLevel(logging.DEBUG)
old_level = logger.level old_level = logger.level
i = 0 i = 0
error_ct = 0 error_ct = 0
......
...@@ -21,7 +21,7 @@ ConvDesc = namedtuple("ConvDesc", ['fmap_shape', 'filter_shape', 'pad', 'stride' ...@@ -21,7 +21,7 @@ ConvDesc = namedtuple("ConvDesc", ['fmap_shape', 'filter_shape', 'pad', 'stride'
ConvBackpropDesc = namedtuple("ConvBackpropDesc", ['fmap_shape', 'filter_shape', 'pad', 'stride', 'dilation']) ConvBackpropDesc = namedtuple("ConvBackpropDesc", ['fmap_shape', 'filter_shape', 'pad', 'stride', 'dilation'])
MatmulCubeDesc = namedtuple("MatmulCubeDesc", ["x_shape", "y_shape", "bias", "left_format", "right_format", MatmulCubeDesc = namedtuple("MatmulCubeDesc", ["x_shape", "y_shape", "bias", "left_format", "right_format",
"out_format", "adj_x", "adj_y", "dtype", "out_dtype"]) "out_format", "adj_x", "adj_y", "dtype", "bias_dtype", "out_dtype"])
# config param definitions # config param definitions
ConvConfig = namedtuple('ConvConfig', ['tile_h', 'tile_co', 'tile_m', 'tile_k', 'tile_n', 'tile_w', 'bypass']) ConvConfig = namedtuple('ConvConfig', ['tile_h', 'tile_co', 'tile_m', 'tile_k', 'tile_n', 'tile_w', 'bypass'])
......
...@@ -13,11 +13,16 @@ ...@@ -13,11 +13,16 @@
# limitations under the License. # limitations under the License.
"""test""" """test"""
import time
from autotuning.job import launch from autotuning.job import launch
from test_run.sub_run import sub_execute from test_run.sub_run import sub_execute
time_start = time.time()
op_type_ = 'sub' op_type_ = 'sub'
debug_mode_ = True debug_mode_ = True
save_res_ = True save_res_ = True
all_space_ = False
desc_ = ('024_sub_64_16_128_128_64_16_128_128_fp16', sub_execute, [(64, 16, 128, 128), (64, 16, 128, 1), 'float16']) desc_ = ('024_sub_64_16_128_128_64_16_128_128_fp16', sub_execute, [(64, 16, 128, 128), (64, 16, 128, 1), 'float16'])
launch(op_type=op_type_, debug_mode=debug_mode_, save_res=save_res_, desc=desc_) launch(op_type=op_type_, debug_mode=debug_mode_, save_res=save_res_, desc=desc_, all_space=all_space_)
time_end = time.time()
print("launch time: ", time_end - time_start)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册