diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/helper.py b/mindspore/_extends/parallel_compile/tbe_compiler/helper.py new file mode 100644 index 0000000000000000000000000000000000000000..b5f77501879190f3a5f79da4c8713fcb7d74779d --- /dev/null +++ b/mindspore/_extends/parallel_compile/tbe_compiler/helper.py @@ -0,0 +1,114 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""tbe process""" +import sys +import os +from .common import get_args, get_build_in_impl_path, TBEException + +build_in_impl_path = get_build_in_impl_path() + + +def _op_select_format(kernel_info): + """ + call op's op_select_format to get op supported format + + Args: + kernel_info (dict): kernel info load by json string + + Returns: + op supported format + """ + try: + # import module + op_name = kernel_info['op_info']['name'] + impl_path = build_in_impl_path + custom_flag = False + if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: + op_impl_path = os.path.realpath(kernel_info['impl_path']) + if os.path.isfile(op_impl_path): + path, file_name = os.path.split(op_impl_path) + op_name, _ = os.path.splitext(file_name) + impl_path = path + custom_flag = True + if impl_path not in sys.path: + sys.path.insert(0, impl_path) + + if custom_flag: + op_module = __import__(op_name) + else: + op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) + # get function + if not hasattr(op_module, "op_select_format"): + return "" + op_func = getattr(op_module, "op_select_format", None) + + # call function + inputs_args = get_args(kernel_info['op_info'], 'inputs') + outputs_args = get_args(kernel_info['op_info'], 'outputs') + attrs_args = get_args(kernel_info['op_info'], 'attrs') + kernel_name = kernel_info['op_info']['kernel_name'] + ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + + except Exception as e: + raise TBEException(str(e)) + + return ret + + +def _check_supported(kernel_info): + """ + call op's check_supported to check supported or not + + Args: + kernel_info (dict): kernel info load by json string + + Returns: + bool: check result, true or false + """ + try: + # import module + op_name = kernel_info['op_info']['name'] + impl_path = build_in_impl_path + custom_flag = False + if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: + op_impl_path = os.path.realpath(kernel_info['impl_path']) + if os.path.isfile(op_impl_path): + path, file_name = os.path.split(op_impl_path) + op_name, _ = os.path.splitext(file_name) + impl_path = path + custom_flag = True + if impl_path not in sys.path: + sys.path.insert(0, impl_path) + + if custom_flag: + op_module = __import__(op_name) + else: + op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) + # get function + if not hasattr(op_module, "check_supported"): + return "" + op_func = getattr(op_module, "check_supported", None) + + # call function + inputs_args = get_args(kernel_info['op_info'], 'inputs') + outputs_args = get_args(kernel_info['op_info'], 'outputs') + attrs_args = get_args(kernel_info['op_info'], 'attrs') + kernel_name = kernel_info['op_info']['kernel_name'] + ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + + except Exception as e: + raise TBEException(str(e)) + + return ret diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index 9b970b7242cead627ea67ec272c697238baecc68..2f73ced061ba9c7efafa7873289a6c76aa4601f1 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -19,10 +19,8 @@ import subprocess import sys import os import json -from .common import check_kernel_info, get_args, get_build_in_impl_path - -build_in_impl_path = get_build_in_impl_path() - +from .common import check_kernel_info, TBEException +from .helper import _op_select_format, _check_supported def create_tbe_parallel_compiler(): """ @@ -41,40 +39,17 @@ def op_select_format(op_json: str): op_json (str): json string of the op Returns: - op supported format + op supported format or exception message """ ret = "" - kernel_info = json.loads(op_json) - check_kernel_info(kernel_info) - - # import module - op_name = kernel_info['op_info']['name'] - impl_path = build_in_impl_path - custom_flag = False - if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: - op_impl_path = os.path.realpath(kernel_info['impl_path']) - if os.path.isfile(op_impl_path): - path, file_name = os.path.split(op_impl_path) - op_name, _ = os.path.splitext(file_name) - impl_path = path - custom_flag = True - sys.path.insert(0, impl_path) - - if custom_flag: - op_module = __import__(op_name) - else: - op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) - # get function - if not hasattr(op_module, "op_select_format"): - return "" - op_func = getattr(op_module, "op_select_format", None) - - # call function - inputs_args = get_args(kernel_info['op_info'], 'inputs') - outputs_args = get_args(kernel_info['op_info'], 'outputs') - attrs_args = get_args(kernel_info['op_info'], 'attrs') - kernel_name = kernel_info['op_info']['kernel_name'] - ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + try: + kernel_info = json.loads(op_json) + check_kernel_info(kernel_info) + ret = _op_select_format(kernel_info) + + except TBEException as e: + return "TBEException: " + str(e) + return ret @@ -86,40 +61,18 @@ def check_supported(op_json: str): op_json (str): json string of the op Returns: - true or false + bool: check result, true or false + str: exception message when catch an Exception """ ret = "" - kernel_info = json.loads(op_json) - check_kernel_info(kernel_info) - - # import module - op_name = kernel_info['op_info']['name'] - impl_path = build_in_impl_path - custom_flag = False - if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None: - op_impl_path = os.path.realpath(kernel_info['impl_path']) - if os.path.isfile(op_impl_path): - path, file_name = os.path.split(op_impl_path) - op_name, _ = os.path.splitext(file_name) - impl_path = path - custom_flag = True - sys.path.insert(0, impl_path) - - if custom_flag: - op_module = __import__(op_name) - else: - op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0) - # get function - if not hasattr(op_module, "check_supported"): - return "" - op_func = getattr(op_module, "check_supported", None) - - # call function - inputs_args = get_args(kernel_info['op_info'], 'inputs') - outputs_args = get_args(kernel_info['op_info'], 'outputs') - attrs_args = get_args(kernel_info['op_info'], 'attrs') - kernel_name = kernel_info['op_info']['kernel_name'] - ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name) + try: + kernel_info = json.loads(op_json) + check_kernel_info(kernel_info) + ret = _check_supported(kernel_info) + + except TBEException as e: + return "TBEException: " + str(e) + return ret @@ -149,12 +102,12 @@ class CompilerPool: """compiler pool""" def __init__(self): - processes = multiprocessing.cpu_count() + self.__processe_num = multiprocessing.cpu_count() # max_processes_num: Set the maximum number of concurrent processes for compiler max_processes_num = 16 - if processes > max_processes_num: - processes = max_processes_num - self.__pool = multiprocessing.Pool(processes=processes) + if self.__processe_num > max_processes_num: + self.__processe_num = max_processes_num + self.__pool = None self.__next_task_id = 1 self.__running_tasks = [] @@ -165,11 +118,10 @@ class CompilerPool: del self.__pool def exit(self): - return - # self.__pool.terminate() - # self.__pool.join() - # if self.__pool is not None: - # del self.__pool + if self.__pool is not None: + self.__pool.terminate() + self.__pool.join() + del self.__pool def start_compile_op(self, op_json): """ @@ -183,6 +135,8 @@ class CompilerPool: """ task_id = self.__next_task_id self.__next_task_id = self.__next_task_id + 1 + if self.__pool is None: + self.__pool = multiprocessing.Pool(processes=self.__processe_num) task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,)) self.__running_tasks.append((task_id, task_future)) return task_id diff --git a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc index ddb78a08df37c24afc20b1627cf9446ebe92ea7b..50fed77a9adb379df3c5d2ff6434c741a4cca0d0 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_adapter.cc @@ -98,7 +98,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) { *func_name = name_tmp; auto iter = tbe_func_adapter_map.find(*func_name); if (iter != tbe_func_adapter_map.end()) { - MS_LOG(INFO) << "map actual op fron me " << func_name << "to tbe op" << iter->second; + MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second; *func_name = iter->second; } } diff --git a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc index 25495d1e6871314850742a9eb6fee664b62f3fbc..1953fd0c72d64079c7be3c3883dffd974ee0711a 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_kernel_select.cc @@ -35,6 +35,8 @@ namespace kernel { constexpr auto kName = "name"; constexpr auto kDtype = "dtype"; constexpr auto kFormat = "format"; +constexpr auto kPrefixInput = "input"; +constexpr auto kPrefixOutput = "output"; const std::map DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"}, {"NHWC", "DefaultFormat"}, {"ND", "DefaultFormat"}, @@ -146,13 +148,13 @@ bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector input = std::make_shared(); MS_EXCEPTION_IF_NULL(input); input->set_name(json_obj[key_name].at(kName)); ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input); inputs->emplace_back(input); - } else if (key_name.find("output", 0) != std::string::npos) { + } else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) { std::shared_ptr output = std::make_shared(); MS_EXCEPTION_IF_NULL(output); output->set_name(json_obj[key_name].at(kName)); diff --git a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc b/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc index 1e60742fc4089badece3b66dd44890ca2111263e..7204fb7f960cb29c4caef6b19b6fb7efff878c2e 100644 --- a/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc +++ b/mindspore/ccsrc/kernel/tbe/tbe_python_funcs.cc @@ -26,6 +26,7 @@ constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_comp constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler"; constexpr auto kOpSelectFormatFunc = "op_select_format"; constexpr auto kCheckSupportedFunc = "check_supported"; +constexpr auto kTBEException = "TBEException"; PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr; PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr; @@ -133,6 +134,10 @@ std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) { char *pstr = nullptr; (void)PyArg_Parse(pRet, "s", &pstr); res_json_str = pstr; + if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str + << " ,function args:" << PyObjectToStr(pArg); + } return res_json_str; } @@ -167,7 +172,18 @@ bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) { MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], function args: " << PyObjectToStr(pArg); } - ret = PyObject_IsTrue(pRes) != 0; + if (PyBool_Check(pRes)) { + ret = PyObject_IsTrue(pRes) != 0; + } else { + char *pstr = nullptr; + (void)PyArg_Parse(pRes, "s", &pstr); + std::string res_str = pstr; + if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) { + MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str + << ", function args: " << PyObjectToStr(pArg); + } + } + return ret; }