提交 0154bdeb 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!3935 Decouple ME and AKG for Ascend

Merge pull request !3935 from ZhangQinghua/master
......@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
......@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""Providing multi process compile with json"""
"""akg process"""
import os
import subprocess
import sys
from multiprocessing import Pool, cpu_count
def _compile_akg_task(*json_strs):
"""
compile func called in single process
......@@ -34,38 +33,56 @@ def _compile_akg_task(*json_strs):
if res.returncode != 0:
raise ValueError("Failed, args: {}!".format(json_str))
def compile_akg_kernel_parallel(json_infos, process, waitime):
def create_akg_parallel_process(process_num, wait_time):
"""
compile kernel use multi processes
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
create AkgParallelCompiler object
Returns:
True for all compile success, False for some failed.
AkgParallelCompiler
"""
if not isinstance(json_infos, list):
raise ValueError("json_infos must be a list")
if not isinstance(process, int):
raise ValueError("process must be a num")
if not isinstance(waitime, int):
raise ValueError("waittime must be a num")
return AkgProcess(process_num, wait_time)
if process == 0 and json_infos:
process = 1
class AkgProcess:
"""akg kernel parallel process"""
cpu_proc_num = cpu_count()
max_proc_num = 16
process = min([cpu_proc_num, max_proc_num, process])
def __init__(self, process_num, wait_time):
"""
Args:
process_num: int. processes number
waittime: int. max time the function blocked
"""
if not isinstance(process_num, int):
raise ValueError("process number must be a num")
if not isinstance(wait_time, int):
raise ValueError("wait time must be a num")
if process_num == 0:
process_num = 1
max_proc_num = 16
self.process_num = min([cpu_count(), max_proc_num, process_num])
self.args = [[] for _ in range(self.process_num)]
self.wait_time = wait_time
self.argc = 0
args = [[] for _ in range(process)]
for p, info in enumerate(json_infos):
args[p % process].append(info)
def compile(self):
"""
compile kernel by multi processes
Return:
True for all compile success, False for some failed.
"""
if self.argc == 0:
raise ValueError("json must be not null")
with Pool(processes=self.process_num) as pool:
res = pool.starmap_async(_compile_akg_task, self.args)
res.get(timeout=self.wait_time)
return True
with Pool(processes=process) as pool:
res = pool.starmap_async(_compile_akg_task, args)
res.get(timeout=waitime)
return True
def accept_json(self, json):
"""
accept json data before compile
Args:
json: str. kernel info.
"""
if not isinstance(json, str):
raise ValueError("json must be a str")
self.args[self.argc % self.process_num].append(json)
self.argc += 1
......@@ -22,14 +22,14 @@ import json
from .common import check_kernel_info, TBEException
from .helper import _op_select_format, _check_supported
def create_tbe_parallel_compiler():
def create_tbe_parallel_process():
"""
create TBEParallelCompiler object
Returns:
TBEParallelCompiler
"""
return compile_pool
return tbe_process
def op_select_format(op_json: str):
"""
......@@ -98,8 +98,8 @@ def run_compiler(op_json):
except subprocess.CalledProcessError as e:
return "TBEException", "PreCompileProcessFailed:\n" + e.stdout + "\n" + e.stderr + "\ninput_args: " + op_json
class CompilerPool:
"""compiler pool"""
class TbeProcess:
"""tbe process"""
def __init__(self):
self.__processe_num = multiprocessing.cpu_count()
......@@ -168,5 +168,4 @@ class CompilerPool:
if self.__running_tasks:
self.__running_tasks.clear()
compile_pool = CompilerPool()
tbe_process = TbeProcess()
......@@ -16,13 +16,14 @@
import os
import sys
import time
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_compiler, op_select_format, check_supported
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class TbeBuilder:
"""Tbe building wrapper"""
def __init__(self):
self.tbe_builder = create_tbe_parallel_compiler()
self.tbe_builder = create_tbe_parallel_process()
def start(self, json):
return self.tbe_builder.start_compile_op(json)
......@@ -36,6 +37,21 @@ class TbeBuilder:
def exit(self):
self.tbe_builder.exit()
class AkgBuilder:
"""Akg building wrapper"""
def __init__(self):
pass
def create(self, process_num, waitime):
self.akg_builder = create_akg_parallel_process(process_num, waitime)
def accept_json(self, json):
return self.akg_builder.accept_json(json)
def compile(self):
return self.akg_builder.compile()
class Messager:
'''Messager'''
......@@ -43,6 +59,7 @@ class Messager:
logger.info('[TRACE]', 'Messager init...')
self.message = ''
self.tbe_builder = TbeBuilder()
self.akg_builder = AkgBuilder()
def get_message(self):
"""
......@@ -111,12 +128,12 @@ class Messager:
Communicate with remote
"""
arg = self.get_message()
if arg == 'START':
if arg == 'TBE/START':
self.send_ack()
json = self.get_message()
res = self.tbe_builder.start(json)
self.send_res(res)
elif arg == 'WAIT':
elif arg == 'TBE/WAIT':
self.send_ack()
task_id, res, pre = self.tbe_builder.wait()
logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre))
......@@ -132,9 +149,30 @@ class Messager:
self.send_ack(False)
self.exit()
self.send_res(pre)
elif arg == 'RESET':
elif arg == 'TBE/RESET':
self.tbe_builder.reset()
self.send_ack()
elif arg == 'AKG/START':
self.send_ack()
process_num_str = self.get_message()
self.send_ack()
wait_time_str = self.get_message()
self.akg_builder.create(int(process_num_str), int(wait_time_str))
self.send_ack()
elif arg == 'AKG/DATA':
self.send_ack()
while True:
req = self.get_message()
if req.startswith('{'):
self.akg_builder.accept_json(req)
self.send_ack()
elif req == 'AKG/WAIT':
res = self.akg_builder.compile()
self.send_res(res)
break
else:
self.send_ack(False)
break
elif arg == 'FORMAT':
self.send_ack()
json = self.get_message()
......@@ -180,7 +218,7 @@ class Messager:
class Logger:
"""
Replace dummy 'logger' to output log as below:
logger = Logger("remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
logger = Logger(0, True, "remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
"""
def __init__(self, level=1, dumpfile=False, filename='Logger.log'):
"""
......@@ -225,7 +263,7 @@ class DummyLogger:
def info(self, tag, msg):
pass
logger = Logger()
logger = DummyLogger()
if __name__ == '__main__':
if len(sys.argv) != 3:
......
......@@ -23,7 +23,6 @@
#include <unordered_set>
#include <utility>
#include <vector>
#include <Python.h>
#include "ir/dtype.h"
#include "ir/func_graph.h"
#include "backend/kernel_compiler/kernel.h"
......@@ -32,10 +31,10 @@
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore {
namespace kernel {
constexpr int32_t PARALLEL_ARGS_SIZE = 3;
constexpr int32_t PROCESS_NUM = 16;
constexpr int32_t TIME_OUT = 300;
......@@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type";
constexpr auto kInputDesc = "input_desc";
constexpr auto kOutputDesc = "output_desc";
constexpr auto kTensorName = "tensor_name";
constexpr auto kCompileAkgKernelParallelFunc = "compile_akg_kernel_parallel";
constexpr auto kMultiProcModule = "mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler";
namespace {
void UpdateTensorNameInJson(const std::vector<AnfNodePtr> &anf_nodes,
std::map<AnfNodePtr, nlohmann::json> *node_json_map) {
......@@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf
return true;
}
void GenParallelCompileFuncArgs(const std::vector<std::string> &kernel_jsons, PyObject **p_args) {
MS_EXCEPTION_IF_NULL(p_args);
*p_args = PyTuple_New(PARALLEL_ARGS_SIZE);
PyObject *arg1 = PyList_New(kernel_jsons.size());
for (int i = 0; i < PyList_Size(arg1); ++i) {
PyList_SetItem(arg1, i, Py_BuildValue("s", kernel_jsons[i].c_str()));
}
PyObject *arg2 = Py_BuildValue("i", PROCESS_NUM);
PyObject *arg3 = Py_BuildValue("i", TIME_OUT);
(void)PyTuple_SetItem(*p_args, 0, arg1);
(void)PyTuple_SetItem(*p_args, 1, arg2);
(void)PyTuple_SetItem(*p_args, 2, arg3);
}
bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfNodePtr>> &build_args) {
auto [jsons, repeat_nodes] = PreProcessJsonForBuild(build_args);
if (jsons.empty()) {
return true;
}
// Try to call python method to compile nodes parallely.
PyObject *p_module = nullptr;
PyObject *p_func = nullptr;
PyObject *p_arg = nullptr;
PyObject *p_res = nullptr;
p_module = PyImport_ImportModule(kMultiProcModule);
if (p_module == nullptr) {
MS_LOG(ERROR) << "Failed to import [" << kMultiProcModule << "].";
// Start building in AKG
if (!KernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) {
MS_LOG(ERROR) << "Akg start failed.";
return false;
}
p_func = PyObject_GetAttrString(p_module, kCompileAkgKernelParallelFunc);
GenParallelCompileFuncArgs(jsons, &p_arg);
MS_LOG(DEBUG) << "Call function [" << kCompileAkgKernelParallelFunc << "], try to compile " << jsons.size()
<< " Akg kernels parallelly.";
p_res = PyEval_CallObject(p_func, p_arg);
if (p_res == nullptr) {
PyErr_Print();
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
if (!KernelBuildClient::Instance().AkgSendData(jsons)) {
MS_LOG(ERROR) << "Akg send data failed.";
return false;
}
if (PyObject_IsTrue(p_res) != 1) {
PyErr_Print();
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileAkgKernelParallelFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(p_arg) << ").";
if (!KernelBuildClient::Instance().AkgWait()) {
MS_LOG(ERROR) << "Akg compile failed.";
return false;
}
......
......@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
}
int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) {
return KernelBuildClient::Instance().Start(kernel_json.dump());
return KernelBuildClient::Instance().TbeStart(kernel_json.dump());
}
bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) {
MS_EXCEPTION_IF_NULL(task_id);
return KernelBuildClient::Instance().Wait(task_id, task_result, pre_build_result);
return KernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result);
}
void ParallelBuildManager::ResetTaskInfo() {
......@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() {
}
task_map_.clear();
same_op_list_.clear();
KernelBuildClient::Instance().Reset();
KernelBuildClient::Instance().TbeReset();
}
} // namespace kernel
} // namespace mindspore
......@@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
}
}
int KernelBuildClient::Start(const std::string &json) {
int KernelBuildClient::TbeStart(const std::string &json) {
// Start compiling..
std::string res = SendRequest(kSTART);
if (res != kACK) {
auto res = SendRequest(kTbeStart);
if (res != kAck) {
MS_LOG(ERROR) << "START failed, res: " << res;
return -1;
}
// Send the json data.
res = SendRequest(json);
if (res == kFAILED) {
MS_LOG(ERROR) << "START send data failed, res: " << res;
if (res == kFailed) {
MS_LOG(ERROR) << "TBE/START responds failed, res: " << res;
return -1;
}
// Return task id.
return std::stoi(res);
}
bool KernelBuildClient::Wait(int *task_id, std::string *task_result, std::string *pre_build_result) {
bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
// Start waiting..
std::string res = SendRequest(kWAIT);
if (res != kACK) {
MS_LOG(ERROR) << "WAIT failed, res: " << res;
auto res = SendRequest(kTbeWait);
if (res != kAck) {
MS_LOG(ERROR) << "TBE/WAIT failed, res: " << res;
return false;
}
// Request task id.
*task_id = std::stoi(SendRequest(kCONT));
*task_id = std::stoi(SendRequest(kCont));
// Requst task result.
*task_result = SendRequest(kCONT);
*task_result = SendRequest(kCont);
// Request prebuild result.
*pre_build_result = SendRequest(kCONT);
*pre_build_result = SendRequest(kCont);
return true;
}
void KernelBuildClient::Reset() {
void KernelBuildClient::TbeReset() {
// Start compiling..
std::string res = SendRequest(kRESET);
if (res != kACK) {
MS_LOG(EXCEPTION) << "RESET response is: " << res;
auto res = SendRequest(kTbeReset);
if (res != kAck) {
MS_LOG(EXCEPTION) << "TBE/RESET response is: " << res;
}
}
bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
// Start compiling..
auto res = SendRequest(kAkgStart);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START failed, res: " << res;
return false;
}
std::string process_num_str = std::to_string(process_num);
res = SendRequest(process_num_str);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START(process_num) responds failed, res: " << res;
return false;
}
std::string wait_time_str = std::to_string(wait_time);
res = SendRequest(wait_time_str);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/START(wait_time) responds failed, res: " << res;
return false;
}
return true;
}
bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
auto res = SendRequest(kAkgData);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/DATA failed, res: " << res;
return false;
}
for (auto &json : jsons) {
res = SendRequest(json);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/DATA.. responds failed, res: " << res << ", when sending [" << json << "]";
return false;
}
}
return true;
}
// Fetch the result of AKG compiling.
bool KernelBuildClient::AkgWait() {
auto res = SendRequest(kAkgWait);
if (res != kTrue) {
MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res;
return false;
}
return true;
}
std::string KernelBuildClient::SelectFormat(const std::string &json) {
// Start compiling..
std::string res = SendRequest(kFORMAT);
if (res != kACK) {
auto res = SendRequest(kFormat);
if (res != kAck) {
MS_LOG(ERROR) << "FORMAT failed, res: " << res;
return "";
}
// Send the json data.
res = SendRequest(json);
if (res == kERR) {
MS_LOG(ERROR) << "FORMAT send data failed, res: " << res;
if (res == kErr) {
MS_LOG(ERROR) << "FORMAT responds failed, res: " << res;
return "";
}
return res;
......@@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) {
bool KernelBuildClient::CheckSupported(const std::string &json) {
// Checking support..
std::string res = SendRequest(kSUPPORT);
if (res != kACK) {
auto res = SendRequest(kSupport);
if (res != kAck) {
MS_LOG(ERROR) << "SUPPORT failed, res: " << res;
return false;
}
// Send the json data.
res = SendRequest(json);
if (res != kTRUE) {
MS_LOG(ERROR) << "SUPPORT send data failed, res: " << res;
if (res != kTrue) {
MS_LOG(INFO) << "SUPPORT responds failed, res: " << res;
return false;
}
return true;
......
......@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#include <vector>
#include <string>
#include <cstring>
#include <cstdlib>
......@@ -43,23 +44,26 @@ class KernelBuildClient {
"\"";
// Receive the response from server
constexpr inline static auto kACK = "ACK";
constexpr inline static auto kERR = "ERR";
constexpr inline static auto kFAILED = "-1";
constexpr inline static auto kAck = "ACK";
constexpr inline static auto kErr = "ERR";
constexpr inline static auto kFailed = "-1";
// Send Finish request to server
constexpr inline static auto kFIN = "FIN";
constexpr inline static auto kFin = "FIN";
// Send building request to server
constexpr inline static auto kSTART = "START";
constexpr inline static auto kWAIT = "WAIT";
constexpr inline static auto kCONT = "CONT";
constexpr inline static auto kSUCCESS = "Success";
constexpr inline static auto kRESET = "RESET";
constexpr inline static auto kTbeStart = "TBE/START";
constexpr inline static auto kTbeWait = "TBE/WAIT";
constexpr inline static auto kCont = "CONT";
constexpr inline static auto kSuccess = "Success";
constexpr inline static auto kTbeReset = "TBE/RESET";
constexpr inline static auto kAkgStart = "AKG/START";
constexpr inline static auto kAkgData = "AKG/DATA";
constexpr inline static auto kAkgWait = "AKG/WAIT";
// Send server info. query to server
constexpr inline static auto kFORMAT = "FORMAT";
constexpr inline static auto kSUPPORT = "SUPPORT";
constexpr inline static auto kTRUE = "True";
constexpr inline static auto kFormat = "FORMAT";
constexpr inline static auto kSupport = "SUPPORT";
constexpr inline static auto kTrue = "True";
// Revert \n, \r, [space].
constexpr inline static auto kLF = "[LF]";
......@@ -67,7 +71,7 @@ class KernelBuildClient {
constexpr inline static auto kSP = "[SP]";
// The TAG as prefix of real command from remote.
constexpr inline static auto kTAG = "[~]";
constexpr inline static auto kTag = "[~]";
constexpr inline static int kBufferSize = 4096;
constexpr inline static unsigned int kTimeOutSeconds = 20;
......@@ -87,7 +91,7 @@ class KernelBuildClient {
std::string result;
char buf[kBufferSize];
while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
if (std::strncmp(buf, kTAG, std::strlen(kTAG)) == 0) {
if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) {
start = true;
}
// Filter with 'kTAG' and '\n'
......@@ -105,7 +109,7 @@ class KernelBuildClient {
if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
}
result = result.substr(strlen(kTAG));
result = result.substr(strlen(kTag));
MS_LOG(DEBUG) << "result: " << result;
return result;
}
......@@ -115,7 +119,7 @@ class KernelBuildClient {
// Exception's thrown if open failed
if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) {
dp_->SetTimeOutSeconds(kTimeOutSeconds);
dp_->SetTimeOutCallback([this]() { SendRequest(kFIN); });
dp_->SetTimeOutCallback([this]() { SendRequest(kFin); });
init_ = true;
}
}
......@@ -146,13 +150,13 @@ class KernelBuildClient {
std::string res;
*dp_ >> res;
// Filter out the interference
auto start = res.find(kTAG);
auto start = res.find(kTag);
if (start == std::string::npos) {
MS_LOG(EXCEPTION) << "Response seems incorrect, res: " << res;
}
res = res.substr(start + std::strlen(kTAG), res.size() - start);
res = res.substr(start + std::strlen(kTag), res.size() - start);
// Revert the line feed and space
if (res != kSUCCESS && res != kACK && res != kERR && res != kTRUE) {
if (res != kSuccess && res != kAck && res != kErr && res != kTrue) {
ReplaceStr(&res, kLF, '\n');
ReplaceStr(&res, kSP, ' ');
}
......@@ -164,10 +168,15 @@ class KernelBuildClient {
std::string SelectFormat(const std::string &json);
bool CheckSupported(const std::string &json);
// Run building.
int Start(const std::string &json);
bool Wait(int *task_id, std::string *task_result, std::string *pre_build_result);
void Reset();
// Run TBE building.
int TbeStart(const std::string &json);
bool TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result);
void TbeReset();
// Run AKG building.
bool AkgStart(int process_num, int wait_time);
bool AkgSendData(const std::vector<std::string> &jsons);
bool AkgWait();
KernelBuildClient(const KernelBuildClient &) = delete;
KernelBuildClient &operator=(const KernelBuildClient &) = delete;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册