提交 5a00d8cb 编写于 作者: W wenchunjiang

This fixes an issue about mindspore process cannot exit when calling python...

This fixes an issue about mindspore process cannot exit when calling python api op_select_format failed in select kernel steps.
Previously function op_select_format and check_supported raise an exception directly on the tbe_process python side, but we don't deal with the exception, and raise an exeception on c++ side to frontend ME, that will cause some conflict when recycle resource on ME and tbe_process python interpreter.
This changes adding try...catch in function op_select_format and check_supported on the python side, and return the Exception string to c++ side, so that we can raise an exception to frontend ME and ME will deal with resouce clearning and exit.
上级 c24252b2
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""tbe process"""
import sys
import os
from .common import get_args, get_build_in_impl_path, TBEException
build_in_impl_path = get_build_in_impl_path()
def _op_select_format(kernel_info):
"""
call op's op_select_format to get op supported format
Args:
kernel_info (dict): kernel info load by json string
Returns:
op supported format
"""
try:
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(op_impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
if impl_path not in sys.path:
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "op_select_format"):
return ""
op_func = getattr(op_module, "op_select_format", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
except Exception as e:
raise TBEException(str(e))
return ret
def _check_supported(kernel_info):
"""
call op's check_supported to check supported or not
Args:
kernel_info (dict): kernel info load by json string
Returns:
bool: check result, true or false
"""
try:
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(op_impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
if impl_path not in sys.path:
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "check_supported"):
return ""
op_func = getattr(op_module, "check_supported", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
except Exception as e:
raise TBEException(str(e))
return ret
......@@ -19,10 +19,8 @@ import subprocess
import sys
import os
import json
from .common import check_kernel_info, get_args, get_build_in_impl_path
build_in_impl_path = get_build_in_impl_path()
from .common import check_kernel_info, TBEException
from .helper import _op_select_format, _check_supported
def create_tbe_parallel_compiler():
"""
......@@ -41,40 +39,17 @@ def op_select_format(op_json: str):
op_json (str): json string of the op
Returns:
op supported format
op supported format or exception message
"""
ret = ""
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(op_impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "op_select_format"):
return ""
op_func = getattr(op_module, "op_select_format", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
try:
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
ret = _op_select_format(kernel_info)
except TBEException as e:
return "TBEException: " + str(e)
return ret
......@@ -86,40 +61,18 @@ def check_supported(op_json: str):
op_json (str): json string of the op
Returns:
true or false
bool: check result, true or false
str: exception message when catch an Exception
"""
ret = ""
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
# import module
op_name = kernel_info['op_info']['name']
impl_path = build_in_impl_path
custom_flag = False
if 'impl_path' in kernel_info and kernel_info['impl_path'] is not None:
op_impl_path = os.path.realpath(kernel_info['impl_path'])
if os.path.isfile(op_impl_path):
path, file_name = os.path.split(op_impl_path)
op_name, _ = os.path.splitext(file_name)
impl_path = path
custom_flag = True
sys.path.insert(0, impl_path)
if custom_flag:
op_module = __import__(op_name)
else:
op_module = __import__("impl." + op_name, globals(), locals(), [op_name], 0)
# get function
if not hasattr(op_module, "check_supported"):
return ""
op_func = getattr(op_module, "check_supported", None)
# call function
inputs_args = get_args(kernel_info['op_info'], 'inputs')
outputs_args = get_args(kernel_info['op_info'], 'outputs')
attrs_args = get_args(kernel_info['op_info'], 'attrs')
kernel_name = kernel_info['op_info']['kernel_name']
ret = op_func(*inputs_args, *outputs_args, *attrs_args, kernel_name=kernel_name)
try:
kernel_info = json.loads(op_json)
check_kernel_info(kernel_info)
ret = _check_supported(kernel_info)
except TBEException as e:
return "TBEException: " + str(e)
return ret
......@@ -149,12 +102,12 @@ class CompilerPool:
"""compiler pool"""
def __init__(self):
processes = multiprocessing.cpu_count()
self.__processe_num = multiprocessing.cpu_count()
# max_processes_num: Set the maximum number of concurrent processes for compiler
max_processes_num = 16
if processes > max_processes_num:
processes = max_processes_num
self.__pool = multiprocessing.Pool(processes=processes)
if self.__processe_num > max_processes_num:
self.__processe_num = max_processes_num
self.__pool = None
self.__next_task_id = 1
self.__running_tasks = []
......@@ -165,11 +118,10 @@ class CompilerPool:
del self.__pool
def exit(self):
return
# self.__pool.terminate()
# self.__pool.join()
# if self.__pool is not None:
# del self.__pool
if self.__pool is not None:
self.__pool.terminate()
self.__pool.join()
del self.__pool
def start_compile_op(self, op_json):
"""
......@@ -183,6 +135,8 @@ class CompilerPool:
"""
task_id = self.__next_task_id
self.__next_task_id = self.__next_task_id + 1
if self.__pool is None:
self.__pool = multiprocessing.Pool(processes=self.__processe_num)
task_future = self.__pool.apply_async(func=run_compiler, args=(op_json,))
self.__running_tasks.append((task_id, task_future))
return task_id
......
......@@ -98,7 +98,7 @@ void TbeAdapter::NormalizeFuncName(std::string *func_name) {
*func_name = name_tmp;
auto iter = tbe_func_adapter_map.find(*func_name);
if (iter != tbe_func_adapter_map.end()) {
MS_LOG(INFO) << "map actual op fron me " << func_name << "to tbe op" << iter->second;
MS_LOG(INFO) << "map actual op from me " << func_name << "to tbe op" << iter->second;
*func_name = iter->second;
}
}
......
......@@ -35,6 +35,8 @@ namespace kernel {
constexpr auto kName = "name";
constexpr auto kDtype = "dtype";
constexpr auto kFormat = "format";
constexpr auto kPrefixInput = "input";
constexpr auto kPrefixOutput = "output";
const std::map<std::string, std::string> DYNAMIC_FORMAT_MAP = {{"NCHW", "DefaultFormat"},
{"NHWC", "DefaultFormat"},
{"ND", "DefaultFormat"},
......@@ -146,13 +148,13 @@ bool ParseDynamicFormatJson(const std::string &jsonStr, std::vector<std::shared_
if (!CheckJsonItemValidity(json_obj, key_name, keys)) {
return false;
}
if (key_name.find("input", 0) != std::string::npos) {
if (key_name.compare(0, strlen(kPrefixInput), kPrefixInput) == 0) {
std::shared_ptr<OpIOInfo> input = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(input);
input->set_name(json_obj[key_name].at(kName));
ConvertFormatDtype(json_obj[key_name].at(kFormat), json_obj[key_name].at(kDtype), input);
inputs->emplace_back(input);
} else if (key_name.find("output", 0) != std::string::npos) {
} else if (key_name.compare(0, strlen(kPrefixOutput), kPrefixOutput) == 0) {
std::shared_ptr<OpIOInfo> output = std::make_shared<OpIOInfo>();
MS_EXCEPTION_IF_NULL(output);
output->set_name(json_obj[key_name].at(kName));
......
......@@ -26,6 +26,7 @@ constexpr auto kTbeProcessModule = "mindspore._extends.parallel_compile.tbe_comp
constexpr auto kCreateTbeParallelCompilerFunc = "create_tbe_parallel_compiler";
constexpr auto kOpSelectFormatFunc = "op_select_format";
constexpr auto kCheckSupportedFunc = "check_supported";
constexpr auto kTBEException = "TBEException";
PyObject *TbePythonFuncs::pCreateTbeParallelCompilerFunc_ = nullptr;
PyObject *TbePythonFuncs::pTbeCompiler_ = nullptr;
......@@ -133,6 +134,10 @@ std::string TbePythonFuncs::OpSelectFormat(const nlohmann::json &kernel_json) {
char *pstr = nullptr;
(void)PyArg_Parse(pRet, "s", &pstr);
res_json_str = pstr;
if (res_json_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kOpSelectFormatFunc << "], " << res_json_str
<< " ,function args:" << PyObjectToStr(pArg);
}
return res_json_str;
}
......@@ -167,7 +172,18 @@ bool TbePythonFuncs::CheckSupported(const nlohmann::json &kernel_json) {
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc
<< "], function args: " << PyObjectToStr(pArg);
}
ret = PyObject_IsTrue(pRes) != 0;
if (PyBool_Check(pRes)) {
ret = PyObject_IsTrue(pRes) != 0;
} else {
char *pstr = nullptr;
(void)PyArg_Parse(pRes, "s", &pstr);
std::string res_str = pstr;
if (res_str.compare(0, strlen(kTBEException), kTBEException) == 0) {
MS_EXCEPTION(ArgumentError) << "Failed to call function [" << kCheckSupportedFunc << "], " << res_str
<< ", function args: " << PyObjectToStr(pArg);
}
}
return ret;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册