diff --git a/mindspore/_extends/parallel_compile/akg_compiler/__init__.py b/mindspore/_extends/parallel_compile/akg_compiler/__init__.py index e30774307ca2107b3a81c071ad33c042ef924790..c336f0dafc7f92c96e45a9d357e70a88f44767ad 100644 --- a/mindspore/_extends/parallel_compile/akg_compiler/__init__.py +++ b/mindspore/_extends/parallel_compile/akg_compiler/__init__.py @@ -12,3 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================ +""" +Extension functions. + +Python functions that will be called in the c++ parts of MindSpore. +""" diff --git a/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py new file mode 100644 index 0000000000000000000000000000000000000000..757008a022b4d7293ff8d21e3fba228bf5129a36 --- /dev/null +++ b/mindspore/_extends/parallel_compile/akg_compiler/akg_process.py @@ -0,0 +1,88 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ +"""akg process""" +import os +import subprocess +import sys +from multiprocessing import Pool, cpu_count + +def _compile_akg_task(*json_strs): + """ + compile func called in single process + + Parameters: + json_strs: list. List contains multiple kernel infos, suitable for json compile api. + """ + akg_compiler = os.path.join(os.path.split( + os.path.realpath(__file__))[0], "compiler.py") + for json_str in json_strs: + res = subprocess.run( + [sys.executable, akg_compiler, json_str], text=True) + if res.returncode != 0: + raise ValueError("Failed, args: {}!".format(json_str)) + +def create_akg_parallel_process(process_num, wait_time): + """ + create AkgParallelCompiler object + + Returns: + AkgParallelCompiler + """ + return AkgProcess(process_num, wait_time) + +class AkgProcess: + """akg kernel parallel process""" + + def __init__(self, process_num, wait_time): + """ + Args: + process_num: int. processes number + waittime: int. max time the function blocked + """ + if not isinstance(process_num, int): + raise ValueError("process number must be a num") + if not isinstance(wait_time, int): + raise ValueError("wait time must be a num") + if process_num == 0: + process_num = 1 + max_proc_num = 16 + self.process_num = min([cpu_count(), max_proc_num, process_num]) + self.args = [[] for _ in range(self.process_num)] + self.wait_time = wait_time + self.argc = 0 + + def compile(self): + """ + compile kernel by multi processes + Return: + True for all compile success, False for some failed. + """ + if self.argc == 0: + raise ValueError("json must be not null") + with Pool(processes=self.process_num) as pool: + res = pool.starmap_async(_compile_akg_task, self.args) + res.get(timeout=self.wait_time) + return True + + def accept_json(self, json): + """ + accept json data before compile + Args: + json: str. kernel info. + """ + if not isinstance(json, str): + raise ValueError("json must be a str") + self.args[self.argc % self.process_num].append(json) + self.argc += 1 diff --git a/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py b/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py deleted file mode 100644 index ffe9c85dc39fea82a69fff5961189eba74b79c65..0000000000000000000000000000000000000000 --- a/mindspore/_extends/parallel_compile/akg_compiler/multi_process_compiler.py +++ /dev/null @@ -1,71 +0,0 @@ -# Copyright 2020 Huawei Technologies Co., Ltd -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================ -"""Providing multi process compile with json""" -import os -import subprocess -import sys -from multiprocessing import Pool, cpu_count - - -def _compile_akg_task(*json_strs): - """ - compile func called in single process - - Parameters: - json_strs: list. List contains multiple kernel infos, suitable for json compile api. - """ - akg_compiler = os.path.join(os.path.split( - os.path.realpath(__file__))[0], "compiler.py") - for json_str in json_strs: - res = subprocess.run( - [sys.executable, akg_compiler, json_str], text=True) - if res.returncode != 0: - raise ValueError("Failed, args: {}!".format(json_str)) - - -def compile_akg_kernel_parallel(json_infos, process, waitime): - """ - compile kernel use multi processes - - Parameters: - json_infos: list. list contain kernel info(task id and json str) - process: int. processes num - waittime: int. max time the function blocked - - Returns: - True for all compile success, False for some failed. - """ - if not isinstance(json_infos, list): - raise ValueError("json_infos must be a list") - if not isinstance(process, int): - raise ValueError("process must be a num") - if not isinstance(waitime, int): - raise ValueError("waittime must be a num") - - if process == 0 and json_infos: - process = 1 - - cpu_proc_num = cpu_count() - max_proc_num = 16 - process = min([cpu_proc_num, max_proc_num, process]) - - args = [[] for _ in range(process)] - for p, info in enumerate(json_infos): - args[p % process].append(info) - - with Pool(processes=process) as pool: - res = pool.starmap_async(_compile_akg_task, args) - res.get(timeout=waitime) - return True diff --git a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py index 80b50c45a9272006724cdd7cac6f852b4a903e41..12bdd8ea38e0690f32a56317d4655c1f417c2ee8 100644 --- a/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py +++ b/mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py @@ -22,14 +22,14 @@ import json from .common import check_kernel_info, TBEException from .helper import _op_select_format, _check_supported -def create_tbe_parallel_compiler(): +def create_tbe_parallel_process(): """ create TBEParallelCompiler object Returns: TBEParallelCompiler """ - return compile_pool + return tbe_process def op_select_format(op_json: str): """ @@ -98,8 +98,8 @@ def run_compiler(op_json): except subprocess.CalledProcessError as e: return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json -class CompilerPool: - """compiler pool""" +class TbeProcess: + """tbe process""" def __init__(self): self.__processe_num = multiprocessing.cpu_count() @@ -168,5 +168,4 @@ class CompilerPool: if self.__running_tasks: self.__running_tasks.clear() - -compile_pool = CompilerPool() +tbe_process = TbeProcess() diff --git a/mindspore/_extends/remote/kernel_build_server.py b/mindspore/_extends/remote/kernel_build_server.py index 8167e672cbd83689949e9e199aa3bf92e853eed8..c3c07beb4fefb922ab73e60a5902d727f35d6932 100644 --- a/mindspore/_extends/remote/kernel_build_server.py +++ b/mindspore/_extends/remote/kernel_build_server.py @@ -16,13 +16,14 @@ import os import sys import time -from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_compiler, op_select_format, check_supported +from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported +from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process class TbeBuilder: """Tbe building wrapper""" def __init__(self): - self.tbe_builder = create_tbe_parallel_compiler() + self.tbe_builder = create_tbe_parallel_process() def start(self, json): return self.tbe_builder.start_compile_op(json) @@ -36,6 +37,21 @@ class TbeBuilder: def exit(self): self.tbe_builder.exit() +class AkgBuilder: + """Akg building wrapper""" + + def __init__(self): + pass + + def create(self, process_num, waitime): + self.akg_builder = create_akg_parallel_process(process_num, waitime) + + def accept_json(self, json): + return self.akg_builder.accept_json(json) + + def compile(self): + return self.akg_builder.compile() + class Messager: '''Messager''' @@ -43,6 +59,7 @@ class Messager: logger.info('[TRACE]', 'Messager init...') self.message = '' self.tbe_builder = TbeBuilder() + self.akg_builder = AkgBuilder() def get_message(self): """ @@ -111,12 +128,12 @@ class Messager: Communicate with remote """ arg = self.get_message() - if arg == 'START': + if arg == 'TBE/START': self.send_ack() json = self.get_message() res = self.tbe_builder.start(json) self.send_res(res) - elif arg == 'WAIT': + elif arg == 'TBE/WAIT': self.send_ack() task_id, res, pre = self.tbe_builder.wait() logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre)) @@ -132,9 +149,30 @@ class Messager: self.send_ack(False) self.exit() self.send_res(pre) - elif arg == 'RESET': + elif arg == 'TBE/RESET': self.tbe_builder.reset() self.send_ack() + elif arg == 'AKG/START': + self.send_ack() + process_num_str = self.get_message() + self.send_ack() + wait_time_str = self.get_message() + self.akg_builder.create(int(process_num_str), int(wait_time_str)) + self.send_ack() + elif arg == 'AKG/DATA': + self.send_ack() + while True: + req = self.get_message() + if req.startswith('{'): + self.akg_builder.accept_json(req) + self.send_ack() + elif req == 'AKG/WAIT': + res = self.akg_builder.compile() + self.send_res(res) + break + else: + self.send_ack(False) + break elif arg == 'FORMAT': self.send_ack() json = self.get_message() @@ -180,7 +218,7 @@ class Messager: class Logger: """ Replace dummy 'logger' to output log as below: - logger = Logger("remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log") + logger = Logger(0, True, "remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log") """ def __init__(self, level=1, dumpfile=False, filename='Logger.log'): """ @@ -225,7 +263,7 @@ class DummyLogger: def info(self, tag, msg): pass -logger = Logger() +logger = DummyLogger() if __name__ == '__main__': if len(sys.argv) != 3: diff --git a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc index d698c89bc941f2dc121037845bb525dc3fdfe1be..8a7c02e79028418626bc95f19d1e2ed31204e5fb 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc @@ -23,7 +23,6 @@ #include #include #include -#include #include "ir/dtype.h" #include "ir/func_graph.h" #include "backend/kernel_compiler/kernel.h" @@ -32,10 +31,10 @@ #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" #include "backend/session/anf_runtime_algorithm.h" +#include "backend/session/kernel_build_client.h" namespace mindspore { namespace kernel { -constexpr int32_t PARALLEL_ARGS_SIZE = 3; constexpr int32_t PROCESS_NUM = 16; constexpr int32_t TIME_OUT = 300; @@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type"; constexpr auto kInputDesc = "input_desc"; constexpr auto kOutputDesc = "output_desc"; constexpr auto kTensorName = "tensor_name"; -constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel"; -constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"; + namespace { void UpdateTensorNameInJson(const std::vector &anf_nodes, std::map *node_json_map) { @@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector &anf return true; } -void GenParallelCompileFuncArgs(const std::vector &kernel_jsons, PyObject **p_args) { - MS_EXCEPTION_IF_NULL(p_args); - *p_args = PyTuple_New(PARALLEL_ARGS_SIZE); - - PyObject *arg1 = PyList_New(kernel_jsons.size()); - for (int i = 0; i < PyList_Size(arg1); ++i) { - PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str())); - } - PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM); - PyObject *arg3 = Py_BuildValue("i", TIME_OUT); - - (void)PyTuple_SetItem(*p_args, 0, arg1); - (void)PyTuple_SetItem(*p_args, 1, arg2); - (void)PyTuple_SetItem(*p_args, 2, arg3); -} - bool AkgOpParallelBuild(const std::vector> &build_args) { auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); if (jsons.empty()) { return true; } - // Try to call python method to compile nodes parallely. - PyObject *p_module = nullptr; - PyObject *p_func = nullptr; - PyObject *p_arg = nullptr; - PyObject *p_res = nullptr; - - p_module = PyImport_ImportModule(kMultiProcModule); - if (p_module == nullptr) { - MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "]."; + // Start building in AKG + if (!KernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) { + MS_LOG(ERROR) << "Akg start failed."; return false; } - - p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); - GenParallelCompileFuncArgs(jsons, &p_arg); - MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size() - << " Akg kernels parallelly."; - p_res = PyEval_CallObject(p_func, p_arg); - if (p_res == nullptr) { - PyErr_Print(); - MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + if (!KernelBuildClient::Instance().AkgSendData(jsons)) { + MS_LOG(ERROR) << "Akg send data failed."; return false; } - if (PyObject_IsTrue(p_res) != 1) { - PyErr_Print(); - MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n(" - << AkgKernelBuild::PyObjectToStr(p_arg) << ")."; + if (!KernelBuildClient::Instance().AkgWait()) { + MS_LOG(ERROR) << "Akg compile failed."; return false; } diff --git a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc index f694142763c264481ec177b0fb373fafadb53ba5..793eeabae178873948c1498ac766b7ddfdc6df7f 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc @@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s } int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) { - return KernelBuildClient::Instance().Start(kernel_json.dump()); + return KernelBuildClient::Instance().TbeStart(kernel_json.dump()); } bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) { MS_EXCEPTION_IF_NULL(task_id); - return KernelBuildClient::Instance().Wait(task_id, task_result, pre_build_result); + return KernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result); } void ParallelBuildManager::ResetTaskInfo() { @@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() { } task_map_.clear(); same_op_list_.clear(); - KernelBuildClient::Instance().Reset(); + KernelBuildClient::Instance().TbeReset(); } } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.cc b/mindspore/ccsrc/backend/session/kernel_build_client.cc index 97f55cb1713bae51929684ed32d2a5a5f00806f3..847d24096a3d6fbac23f09e6fb167bae2e4e5395 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.cc +++ b/mindspore/ccsrc/backend/session/kernel_build_client.cc @@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { } } -int KernelBuildClient::Start(const std::string &json) { +int KernelBuildClient::TbeStart(const std::string &json) { // Start compiling.. - std::string res = SendRequest(kSTART); - if (res != kACK) { + auto res = SendRequest(kTbeStart); + if (res != kAck) { MS_LOG(ERROR) << "START failed, res: " << res; return -1; } // Send the json data. res = SendRequest(json); - if (res == kFAILED) { - MS_LOG(ERROR) << "START send data failed, res: " << res; + if (res == kFailed) { + MS_LOG(ERROR) << "TBE/START responds failed, res: " << res; return -1; } // Return task id. return std::stoi(res); } -bool KernelBuildClient::Wait(int *task_id, std::string *task_result, std::string *pre_build_result) { +bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) { // Start waiting.. - std::string res = SendRequest(kWAIT); - if (res != kACK) { - MS_LOG(ERROR) << "WAIT failed, res: " << res; + auto res = SendRequest(kTbeWait); + if (res != kAck) { + MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res; return false; } // Request task id. - *task_id = std::stoi(SendRequest(kCONT)); + *task_id = std::stoi(SendRequest(kCont)); // Requst task result. - *task_result = SendRequest(kCONT); + *task_result = SendRequest(kCont); // Request prebuild result. - *pre_build_result = SendRequest(kCONT); + *pre_build_result = SendRequest(kCont); return true; } -void KernelBuildClient::Reset() { +void KernelBuildClient::TbeReset() { // Start compiling.. - std::string res = SendRequest(kRESET); - if (res != kACK) { - MS_LOG(EXCEPTION) << "RESET response is: " << res; + auto res = SendRequest(kTbeReset); + if (res != kAck) { + MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res; } } +bool KernelBuildClient::AkgStart(int process_num, int wait_time) { + // Start compiling.. + auto res = SendRequest(kAkgStart); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START failed, res: " << res; + return false; + } + std::string process_num_str = std::to_string(process_num); + res = SendRequest(process_num_str); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res; + return false; + } + std::string wait_time_str = std::to_string(wait_time); + res = SendRequest(wait_time_str); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res; + return false; + } + return true; +} + +bool KernelBuildClient::AkgSendData(const std::vector &jsons) { + auto res = SendRequest(kAkgData); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/DATA failed, res: " << res; + return false; + } + for (auto &json : jsons) { + res = SendRequest(json); + if (res != kAck) { + MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]"; + return false; + } + } + return true; +} + +// Fetch the result of AKG compiling. +bool KernelBuildClient::AkgWait() { + auto res = SendRequest(kAkgWait); + if (res != kTrue) { + MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res; + return false; + } + return true; +} + std::string KernelBuildClient::SelectFormat(const std::string &json) { // Start compiling.. - std::string res = SendRequest(kFORMAT); - if (res != kACK) { + auto res = SendRequest(kFormat); + if (res != kAck) { MS_LOG(ERROR) << "FORMAT failed, res: " << res; return ""; } // Send the json data. res = SendRequest(json); - if (res == kERR) { - MS_LOG(ERROR) << "FORMAT send data failed, res: " << res; + if (res == kErr) { + MS_LOG(ERROR) << "FORMAT responds failed, res: " << res; return ""; } return res; @@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) { bool KernelBuildClient::CheckSupported(const std::string &json) { // Checking support.. - std::string res = SendRequest(kSUPPORT); - if (res != kACK) { + auto res = SendRequest(kSupport); + if (res != kAck) { MS_LOG(ERROR) << "SUPPORT failed, res: " << res; return false; } // Send the json data. res = SendRequest(json); - if (res != kTRUE) { - MS_LOG(ERROR) << "SUPPORT send data failed, res: " << res; + if (res != kTrue) { + MS_LOG(INFO) << "SUPPORT responds failed, res: " << res; return false; } return true; diff --git a/mindspore/ccsrc/backend/session/kernel_build_client.h b/mindspore/ccsrc/backend/session/kernel_build_client.h index 8b38d53e62d36abacc64b24c4f4435f23e238516..21a3a7cfaaac42885c0a77e3411704e692329d81 100644 --- a/mindspore/ccsrc/backend/session/kernel_build_client.h +++ b/mindspore/ccsrc/backend/session/kernel_build_client.h @@ -17,6 +17,7 @@ #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ +#include #include #include #include @@ -43,23 +44,26 @@ class KernelBuildClient { "\""; // Receive the response from server - constexpr inline static auto kACK = "ACK"; - constexpr inline static auto kERR = "ERR"; - constexpr inline static auto kFAILED = "-1"; + constexpr inline static auto kAck = "ACK"; + constexpr inline static auto kErr = "ERR"; + constexpr inline static auto kFailed = "-1"; // Send Finish request to server - constexpr inline static auto kFIN = "FIN"; + constexpr inline static auto kFin = "FIN"; // Send building request to server - constexpr inline static auto kSTART = "START"; - constexpr inline static auto kWAIT = "WAIT"; - constexpr inline static auto kCONT = "CONT"; - constexpr inline static auto kSUCCESS = "Success"; - constexpr inline static auto kRESET = "RESET"; + constexpr inline static auto kTbeStart = "TBE/START"; + constexpr inline static auto kTbeWait = "TBE/WAIT"; + constexpr inline static auto kCont = "CONT"; + constexpr inline static auto kSuccess = "Success"; + constexpr inline static auto kTbeReset = "TBE/RESET"; + constexpr inline static auto kAkgStart = "AKG/START"; + constexpr inline static auto kAkgData = "AKG/DATA"; + constexpr inline static auto kAkgWait = "AKG/WAIT"; // Send server info. query to server - constexpr inline static auto kFORMAT = "FORMAT"; - constexpr inline static auto kSUPPORT = "SUPPORT"; - constexpr inline static auto kTRUE = "True"; + constexpr inline static auto kFormat = "FORMAT"; + constexpr inline static auto kSupport = "SUPPORT"; + constexpr inline static auto kTrue = "True"; // Revert \n, \r, [space]. constexpr inline static auto kLF = "[LF]"; @@ -67,7 +71,7 @@ class KernelBuildClient { constexpr inline static auto kSP = "[SP]"; // The TAG as prefix of real command from remote. - constexpr inline static auto kTAG = "[~]"; + constexpr inline static auto kTag = "[~]"; constexpr inline static int kBufferSize = 4096; constexpr inline static unsigned int kTimeOutSeconds = 20; @@ -87,7 +91,7 @@ class KernelBuildClient { std::string result; char buf[kBufferSize]; while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) { - if (std::strncmp(buf, kTAG, std::strlen(kTAG)) == 0) { + if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) { start = true; } // Filter with 'kTAG' and '\n' @@ -105,7 +109,7 @@ class KernelBuildClient { if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) { MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}"; } - result = result.substr(strlen(kTAG)); + result = result.substr(strlen(kTag)); MS_LOG(DEBUG) << "result: " << result; return result; } @@ -115,7 +119,7 @@ class KernelBuildClient { // Exception's thrown if open failed if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) { dp_->SetTimeOutSeconds(kTimeOutSeconds); - dp_->SetTimeOutCallback([this]() { SendRequest(kFIN); }); + dp_->SetTimeOutCallback([this]() { SendRequest(kFin); }); init_ = true; } } @@ -146,13 +150,13 @@ class KernelBuildClient { std::string res; *dp_ >> res; // Filter out the interference - auto start = res.find(kTAG); + auto start = res.find(kTag); if (start == std::string::npos) { MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res; } - res = res.substr(start + std::strlen(kTAG), res.size() - start); + res = res.substr(start + std::strlen(kTag), res.size() - start); // Revert the line feed and space - if (res != kSUCCESS && res != kACK && res != kERR && res != kTRUE) { + if (res != kSuccess && res != kAck && res != kErr && res != kTrue) { ReplaceStr(&res, kLF, '\n'); ReplaceStr(&res, kSP, ' '); } @@ -164,10 +168,15 @@ class KernelBuildClient { std::string SelectFormat(const std::string &json); bool CheckSupported(const std::string &json); - // Run building. - int Start(const std::string &json); - bool Wait(int *task_id, std::string *task_result, std::string *pre_build_result); - void Reset(); + // Run TBE building. + int TbeStart(const std::string &json); + bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result); + void TbeReset(); + + // Run AKG building. + bool AkgStart(int process_num, int wait_time); + bool AkgSendData(const std::vector &jsons); + bool AkgWait(); KernelBuildClient(const KernelBuildClient &) = delete; KernelBuildClient &operator=(const KernelBuildClient &) = delete;