提交 0154bdeb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3935 Decouple ME and AKG for Ascend

Merge pull request !3935 from ZhangQinghua/master
...@@ -12,3 +12,8 @@ ...@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
...@@ -12,13 +12,12 @@ ...@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================ # ============================================================================
"""Providing multi process compile with json""" """akg process"""
import os import os
import subprocess import subprocess
import sys import sys
from multiprocessing import Pool, cpu_count from multiprocessing import Pool, cpu_count
def _compile_akg_task(*json_strs): def _compile_akg_task(*json_strs):
""" """
compile func called in single process compile func called in single process
...@@ -34,38 +33,56 @@ def _compile_akg_task(*json_strs): ...@@ -34,38 +33,56 @@ def _compile_akg_task(*json_strs):
if res.returncode != 0: if res.returncode != 0:
raise ValueError("Failed, args: {}!".format(json_str)) raise ValueError("Failed, args: {}!".format(json_str))
def create_akg_parallel_process(process_num, wait_time):
def compile_akg_kernel_parallel(json_infos, process, waitime):
""" """
compile kernel use multi processes create AkgParallelCompiler object
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
Returns: Returns:
True for all compile success, False for some failed. AkgParallelCompiler
""" """
if not isinstance(json_infos, list): return AkgProcess(process_num, wait_time)
raise ValueError("json_infos must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
if process == 0 and json_infos: class AkgProcess:
process = 1 """akg kernel parallel process"""
cpu_proc_num = cpu_count() def __init__(self, process_num, wait_time):
"""
Args:
process_num: int. processes number
waittime: int. max time the function blocked
"""
if not isinstance(process_num, int):
raise ValueError("process number must be a num")
if not isinstance(wait_time, int):
raise ValueError("wait time must be a num")
if process_num == 0:
process_num = 1
max_proc_num = 16 max_proc_num = 16
process = min([cpu_proc_num, max_proc_num, process]) self.process_num = min([cpu_count(), max_proc_num, process_num])
self.args = [[] for _ in range(self.process_num)]
args = [[] for _ in range(process)] self.wait_time = wait_time
for p, info in enumerate(json_infos): self.argc = 0
args[p % process].append(info)
with Pool(processes=process) as pool: def compile(self):
res = pool.starmap_async(_compile_akg_task, args) """
res.get(timeout=waitime) compile kernel by multi processes
Return:
True for all compile success, False for some failed.
"""
if self.argc == 0:
raise ValueError("json must be not null")
with Pool(processes=self.process_num) as pool:
res = pool.starmap_async(_compile_akg_task, self.args)
res.get(timeout=self.wait_time)
return True return True
def accept_json(self, json):
"""
accept json data before compile
Args:
json: str. kernel info.
"""
if not isinstance(json, str):
raise ValueError("json must be a str")
self.args[self.argc % self.process_num].append(json)
self.argc += 1
...@@ -22,14 +22,14 @@ import json ...@@ -22,14 +22,14 @@ import json
from .common import check_kernel_info, TBEException from .common import check_kernel_info, TBEException
from .helper import _op_select_format, _check_supported from .helper import _op_select_format, _check_supported
def create_tbe_parallel_compiler(): def create_tbe_parallel_process():
""" """
create TBEParallelCompiler object create TBEParallelCompiler object
Returns: Returns:
TBEParallelCompiler TBEParallelCompiler
""" """
return compile_pool return tbe_process
def op_select_format(op_json: str): def op_select_format(op_json: str):
""" """
...@@ -98,8 +98,8 @@ def run_compiler(op_json): ...@@ -98,8 +98,8 @@ def run_compiler(op_json):
except subprocess.CalledProcessError as e: except subprocess.CalledProcessError as e:
return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
class CompilerPool: class TbeProcess:
"""compiler pool""" """tbe process"""
def __init__(self): def __init__(self):
self.__processe_num = multiprocessing.cpu_count() self.__processe_num = multiprocessing.cpu_count()
...@@ -168,5 +168,4 @@ class CompilerPool: ...@@ -168,5 +168,4 @@ class CompilerPool:
if self.__running_tasks: if self.__running_tasks:
self.__running_tasks.clear() self.__running_tasks.clear()
tbe_process = TbeProcess()
compile_pool = CompilerPool()
...@@ -16,13 +16,14 @@ ...@@ -16,13 +16,14 @@
import os import os
import sys import sys
import time import time
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_compiler, op_select_format, check_supported from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class TbeBuilder: class TbeBuilder:
"""Tbe building wrapper""" """Tbe building wrapper"""
def __init__(self): def __init__(self):
self.tbe_builder = create_tbe_parallel_compiler() self.tbe_builder = create_tbe_parallel_process()
def start(self, json): def start(self, json):
return self.tbe_builder.start_compile_op(json) return self.tbe_builder.start_compile_op(json)
...@@ -36,6 +37,21 @@ class TbeBuilder: ...@@ -36,6 +37,21 @@ class TbeBuilder:
def exit(self): def exit(self):
self.tbe_builder.exit() self.tbe_builder.exit()
class AkgBuilder:
"""Akg building wrapper"""
def __init__(self):
pass
def create(self, process_num, waitime):
self.akg_builder = create_akg_parallel_process(process_num, waitime)
def accept_json(self, json):
return self.akg_builder.accept_json(json)
def compile(self):
return self.akg_builder.compile()
class Messager: class Messager:
'''Messager''' '''Messager'''
...@@ -43,6 +59,7 @@ class Messager: ...@@ -43,6 +59,7 @@ class Messager:
logger.info('[TRACE]', 'Messager init...') logger.info('[TRACE]', 'Messager init...')
self.message = '' self.message = ''
self.tbe_builder = TbeBuilder() self.tbe_builder = TbeBuilder()
self.akg_builder = AkgBuilder()
def get_message(self): def get_message(self):
""" """
...@@ -111,12 +128,12 @@ class Messager: ...@@ -111,12 +128,12 @@ class Messager:
Communicate with remote Communicate with remote
""" """
arg = self.get_message() arg = self.get_message()
if arg == 'START': if arg == 'TBE/START':
self.send_ack() self.send_ack()
json = self.get_message() json = self.get_message()
res = self.tbe_builder.start(json) res = self.tbe_builder.start(json)
self.send_res(res) self.send_res(res)
elif arg == 'WAIT': elif arg == 'TBE/WAIT':
self.send_ack() self.send_ack()
task_id, res, pre = self.tbe_builder.wait() task_id, res, pre = self.tbe_builder.wait()
logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre)) logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre))
...@@ -132,9 +149,30 @@ class Messager: ...@@ -132,9 +149,30 @@ class Messager:
self.send_ack(False) self.send_ack(False)
self.exit() self.exit()
self.send_res(pre) self.send_res(pre)
elif arg == 'RESET': elif arg == 'TBE/RESET':
self.tbe_builder.reset() self.tbe_builder.reset()
self.send_ack() self.send_ack()
elif arg == 'AKG/START':
self.send_ack()
process_num_str = self.get_message()
self.send_ack()
wait_time_str = self.get_message()
self.akg_builder.create(int(process_num_str), int(wait_time_str))
self.send_ack()
elif arg == 'AKG/DATA':
self.send_ack()
while True:
req = self.get_message()
if req.startswith('{'):
self.akg_builder.accept_json(req)
self.send_ack()
elif req == 'AKG/WAIT':
res = self.akg_builder.compile()
self.send_res(res)
break
else:
self.send_ack(False)
break
elif arg == 'FORMAT': elif arg == 'FORMAT':
self.send_ack() self.send_ack()
json = self.get_message() json = self.get_message()
...@@ -180,7 +218,7 @@ class Messager: ...@@ -180,7 +218,7 @@ class Messager:
class Logger: class Logger:
""" """
Replace dummy 'logger' to output log as below: Replace dummy 'logger' to output log as below:
logger = Logger("remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log") logger = Logger(0, True, "remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
""" """
def __init__(self, level=1, dumpfile=False, filename='Logger.log'): def __init__(self, level=1, dumpfile=False, filename='Logger.log'):
""" """
...@@ -225,7 +263,7 @@ class DummyLogger: ...@@ -225,7 +263,7 @@ class DummyLogger:
def info(self, tag, msg): def info(self, tag, msg):
pass pass
logger = Logger() logger = DummyLogger()
if __name__ == '__main__': if __name__ == '__main__':
if len(sys.argv) != 3: if len(sys.argv) != 3:
......
...@@ -23,7 +23,6 @@ ...@@ -23,7 +23,6 @@
#include <unordered_set> #include <unordered_set>
#include <utility> #include <utility>
#include <vector> #include <vector>
#include <Python.h>
#include "ir/dtype.h" #include "ir/dtype.h"
#include "ir/func_graph.h" #include "ir/func_graph.h"
#include "backend/kernel_compiler/kernel.h" #include "backend/kernel_compiler/kernel.h"
...@@ -32,10 +31,10 @@ ...@@ -32,10 +31,10 @@
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h" #include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h" #include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/anf_runtime_algorithm.h" #include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore { namespace mindspore {
namespace kernel { namespace kernel {
constexpr int32_t PARALLEL_ARGS_SIZE = 3;
constexpr int32_t PROCESS_NUM = 16; constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300; constexpr int32_t TIME_OUT = 300;
...@@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type"; ...@@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type";
constexpr auto kInputDesc = "input_desc"; constexpr auto kInputDesc = "input_desc";
constexpr auto kOutputDesc = "output_desc"; constexpr auto kOutputDesc = "output_desc";
constexpr auto kTensorName = "tensor_name"; constexpr auto kTensorName = "tensor_name";
constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel";
constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler";
namespace { namespace {
void UpdateTensorNameInJson(const std::vector<AnfNodePtr> &anf_nodes, void UpdateTensorNameInJson(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) { std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
...@@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf ...@@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf
return true; return true;
} }
void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) {
MS_EXCEPTION_IF_NULL(p_args);
*p_args = PyTuple_New(PARALLEL_ARGS_SIZE);
PyObject *arg1 = PyList_New(kernel_jsons.size());
for (int i = 0; i < PyList_Size(arg1); ++i) {
PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str()));
}
PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM);
PyObject *arg3 = Py_BuildValue("i", TIME_OUT);
(void)PyTuple_SetItem(*p_args, 0, arg1);
(void)PyTuple_SetItem(*p_args, 1, arg2);
(void)PyTuple_SetItem(*p_args, 2, arg3);
}
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) { bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args); auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args);
if (jsons.empty()) { if (jsons.empty()) {
return true; return true;
} }
// Try to call python method to compile nodes parallely. // Start building in AKG
PyObject *p_module = nullptr; if (!KernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) {
PyObject *p_func = nullptr; MS_LOG(ERROR) << "Akg start failed.";
PyObject *p_arg = nullptr;
PyObject *p_res = nullptr;
p_module = PyImport_ImportModule(kMultiProcModule);
if (p_module == nullptr) {
MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "].";
return false; return false;
} }
if (!KernelBuildClient::Instance().AkgSendData(jsons)) {
p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc); MS_LOG(ERROR) << "Akg send data failed.";
GenParallelCompileFuncArgs(jsons, &p_arg);
MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size()
<< " Akg kernels parallelly.";
p_res = PyEval_CallObject(p_func, p_arg);
if (p_res == nullptr) {
PyErr_Print();
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false; return false;
} }
if (PyObject_IsTrue(p_res) != 1) { if (!KernelBuildClient::Instance().AkgWait()) {
PyErr_Print(); MS_LOG(ERROR) << "Akg compile failed.";
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
return false; return false;
} }
......
...@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s ...@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
} }
int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) { int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) {
return KernelBuildClient::Instance().Start(kernel_json.dump()); return KernelBuildClient::Instance().TbeStart(kernel_json.dump());
} }
bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) { bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) {
MS_EXCEPTION_IF_NULL(task_id); MS_EXCEPTION_IF_NULL(task_id);
return KernelBuildClient::Instance().Wait(task_id, task_result, pre_build_result); return KernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result);
} }
void ParallelBuildManager::ResetTaskInfo() { void ParallelBuildManager::ResetTaskInfo() {
...@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() { ...@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() {
} }
task_map_.clear(); task_map_.clear();
same_op_list_.clear(); same_op_list_.clear();
KernelBuildClient::Instance().Reset(); KernelBuildClient::Instance().TbeReset();
} }
} // namespace kernel } // namespace kernel
} // namespace mindspore } // namespace mindspore
...@@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) { ...@@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
} }
} }
int KernelBuildClient::Start(const std::string &json) { int KernelBuildClient::TbeStart(const std::string &json) {
// Start compiling.. // Start compiling..
std::string res = SendRequest(kSTART); auto res = SendRequest(kTbeStart);
if (res != kACK) { if (res != kAck) {
MS_LOG(ERROR) << "START failed, res: " << res; MS_LOG(ERROR) << "START failed, res: " << res;
return -1; return -1;
} }
// Send the json data. // Send the json data.
res = SendRequest(json); res = SendRequest(json);
if (res == kFAILED) { if (res == kFailed) {
MS_LOG(ERROR) << "START send data failed, res: " << res; MS_LOG(ERROR) << "TBE/START responds failed, res: " << res;
return -1; return -1;
} }
// Return task id. // Return task id.
return std::stoi(res); return std::stoi(res);
} }
bool KernelBuildClient::Wait(int *task_id, std::string *task_result, std::string *pre_build_result) { bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
// Start waiting.. // Start waiting..
std::string res = SendRequest(kWAIT); auto res = SendRequest(kTbeWait);
if (res != kACK) { if (res != kAck) {
MS_LOG(ERROR) << "WAIT failed, res: " << res; MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res;
return false; return false;
} }
// Request task id. // Request task id.
*task_id = std::stoi(SendRequest(kCONT)); *task_id = std::stoi(SendRequest(kCont));
// Requst task result. // Requst task result.
*task_result = SendRequest(kCONT); *task_result = SendRequest(kCont);
// Request prebuild result. // Request prebuild result.
*pre_build_result = SendRequest(kCONT); *pre_build_result = SendRequest(kCont);
return true; return true;
} }
void KernelBuildClient::Reset() { void KernelBuildClient::TbeReset() {
// Start compiling.. // Start compiling..
std::string res = SendRequest(kRESET); auto res = SendRequest(kTbeReset);
if (res != kACK) { if (res != kAck) {
MS_LOG(EXCEPTION) << "RESET response is: " << res; MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res;
} }
} }
bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
// Start compiling..
auto res = SendRequest(kAkgStart);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START failed, res: " << res;
return false;
}
std::string process_num_str = std::to_string(process_num);
res = SendRequest(process_num_str);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res;
return false;
}
std::string wait_time_str = std::to_string(wait_time);
res = SendRequest(wait_time_str);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res;
return false;
}
return true;
}
bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
auto res = SendRequest(kAkgData);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/DATA failed, res: " << res;
return false;
}
for (auto &json : jsons) {
res = SendRequest(json);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]";
return false;
}
}
return true;
}
// Fetch the result of AKG compiling.
bool KernelBuildClient::AkgWait() {
auto res = SendRequest(kAkgWait);
if (res != kTrue) {
MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res;
return false;
}
return true;
}
std::string KernelBuildClient::SelectFormat(const std::string &json) { std::string KernelBuildClient::SelectFormat(const std::string &json) {
// Start compiling.. // Start compiling..
std::string res = SendRequest(kFORMAT); auto res = SendRequest(kFormat);
if (res != kACK) { if (res != kAck) {
MS_LOG(ERROR) << "FORMAT failed, res: " << res; MS_LOG(ERROR) << "FORMAT failed, res: " << res;
return ""; return "";
} }
// Send the json data. // Send the json data.
res = SendRequest(json); res = SendRequest(json);
if (res == kERR) { if (res == kErr) {
MS_LOG(ERROR) << "FORMAT send data failed, res: " << res; MS_LOG(ERROR) << "FORMAT responds failed, res: " << res;
return ""; return "";
} }
return res; return res;
...@@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) { ...@@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) {
bool KernelBuildClient::CheckSupported(const std::string &json) { bool KernelBuildClient::CheckSupported(const std::string &json) {
// Checking support.. // Checking support..
std::string res = SendRequest(kSUPPORT); auto res = SendRequest(kSupport);
if (res != kACK) { if (res != kAck) {
MS_LOG(ERROR) << "SUPPORT failed, res: " << res; MS_LOG(ERROR) << "SUPPORT failed, res: " << res;
return false; return false;
} }
// Send the json data. // Send the json data.
res = SendRequest(json); res = SendRequest(json);
if (res != kTRUE) { if (res != kTrue) {
MS_LOG(ERROR) << "SUPPORT send data failed, res: " << res; MS_LOG(INFO) << "SUPPORT responds failed, res: " << res;
return false; return false;
} }
return true; return true;
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ #ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_ #define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#include <vector>
#include <string> #include <string>
#include <cstring> #include <cstring>
#include <cstdlib> #include <cstdlib>
...@@ -43,23 +44,26 @@ class KernelBuildClient { ...@@ -43,23 +44,26 @@ class KernelBuildClient {
"\""; "\"";
// Receive the response from server // Receive the response from server
constexpr inline static auto kACK = "ACK"; constexpr inline static auto kAck = "ACK";
constexpr inline static auto kERR = "ERR"; constexpr inline static auto kErr = "ERR";
constexpr inline static auto kFAILED = "-1"; constexpr inline static auto kFailed = "-1";
// Send Finish request to server // Send Finish request to server
constexpr inline static auto kFIN = "FIN"; constexpr inline static auto kFin = "FIN";
// Send building request to server // Send building request to server
constexpr inline static auto kSTART = "START"; constexpr inline static auto kTbeStart = "TBE/START";
constexpr inline static auto kWAIT = "WAIT"; constexpr inline static auto kTbeWait = "TBE/WAIT";
constexpr inline static auto kCONT = "CONT"; constexpr inline static auto kCont = "CONT";
constexpr inline static auto kSUCCESS = "Success"; constexpr inline static auto kSuccess = "Success";
constexpr inline static auto kRESET = "RESET"; constexpr inline static auto kTbeReset = "TBE/RESET";
constexpr inline static auto kAkgStart = "AKG/START";
constexpr inline static auto kAkgData = "AKG/DATA";
constexpr inline static auto kAkgWait = "AKG/WAIT";
// Send server info. query to server // Send server info. query to server
constexpr inline static auto kFORMAT = "FORMAT"; constexpr inline static auto kFormat = "FORMAT";
constexpr inline static auto kSUPPORT = "SUPPORT"; constexpr inline static auto kSupport = "SUPPORT";
constexpr inline static auto kTRUE = "True"; constexpr inline static auto kTrue = "True";
// Revert \n, \r, [space]. // Revert \n, \r, [space].
constexpr inline static auto kLF = "[LF]"; constexpr inline static auto kLF = "[LF]";
...@@ -67,7 +71,7 @@ class KernelBuildClient { ...@@ -67,7 +71,7 @@ class KernelBuildClient {
constexpr inline static auto kSP = "[SP]"; constexpr inline static auto kSP = "[SP]";
// The TAG as prefix of real command from remote. // The TAG as prefix of real command from remote.
constexpr inline static auto kTAG = "[~]"; constexpr inline static auto kTag = "[~]";
constexpr inline static int kBufferSize = 4096; constexpr inline static int kBufferSize = 4096;
constexpr inline static unsigned int kTimeOutSeconds = 20; constexpr inline static unsigned int kTimeOutSeconds = 20;
...@@ -87,7 +91,7 @@ class KernelBuildClient { ...@@ -87,7 +91,7 @@ class KernelBuildClient {
std::string result; std::string result;
char buf[kBufferSize]; char buf[kBufferSize];
while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) { while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
if (std::strncmp(buf, kTAG, std::strlen(kTAG)) == 0) { if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) {
start = true; start = true;
} }
// Filter with 'kTAG' and '\n' // Filter with 'kTAG' and '\n'
...@@ -105,7 +109,7 @@ class KernelBuildClient { ...@@ -105,7 +109,7 @@ class KernelBuildClient {
if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) { if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}"; MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
} }
result = result.substr(strlen(kTAG)); result = result.substr(strlen(kTag));
MS_LOG(DEBUG) << "result: " << result; MS_LOG(DEBUG) << "result: " << result;
return result; return result;
} }
...@@ -115,7 +119,7 @@ class KernelBuildClient { ...@@ -115,7 +119,7 @@ class KernelBuildClient {
// Exception's thrown if open failed // Exception's thrown if open failed
if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) { if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) {
dp_->SetTimeOutSeconds(kTimeOutSeconds); dp_->SetTimeOutSeconds(kTimeOutSeconds);
dp_->SetTimeOutCallback([this]() { SendRequest(kFIN); }); dp_->SetTimeOutCallback([this]() { SendRequest(kFin); });
init_ = true; init_ = true;
} }
} }
...@@ -146,13 +150,13 @@ class KernelBuildClient { ...@@ -146,13 +150,13 @@ class KernelBuildClient {
std::string res; std::string res;
*dp_ >> res; *dp_ >> res;
// Filter out the interference // Filter out the interference
auto start = res.find(kTAG); auto start = res.find(kTag);
if (start == std::string::npos) { if (start == std::string::npos) {
MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res; MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res;
} }
res = res.substr(start + std::strlen(kTAG), res.size() - start); res = res.substr(start + std::strlen(kTag), res.size() - start);
// Revert the line feed and space // Revert the line feed and space
if (res != kSUCCESS && res != kACK && res != kERR && res != kTRUE) { if (res != kSuccess && res != kAck && res != kErr && res != kTrue) {
ReplaceStr(&res, kLF, '\n'); ReplaceStr(&res, kLF, '\n');
ReplaceStr(&res, kSP, ' '); ReplaceStr(&res, kSP, ' ');
} }
...@@ -164,10 +168,15 @@ class KernelBuildClient { ...@@ -164,10 +168,15 @@ class KernelBuildClient {
std::string SelectFormat(const std::string &json); std::string SelectFormat(const std::string &json);
bool CheckSupported(const std::string &json); bool CheckSupported(const std::string &json);
// Run building. // Run TBE building.
int Start(const std::string &json); int TbeStart(const std::string &json);
bool Wait(int *task_id, std::string *task_result, std::string *pre_build_result); bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result);
void Reset(); void TbeReset();
// Run AKG building.
bool AkgStart(int process_num, int wait_time);
bool AkgSendData(const std::vector<std::string> &jsons);
bool AkgWait();
KernelBuildClient(const KernelBuildClient &) = delete; KernelBuildClient(const KernelBuildClient &) = delete;
KernelBuildClient &operator=(const KernelBuildClient &) = delete; KernelBuildClient &operator=(const KernelBuildClient &) = delete;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册