提交 22e0a0ba 编写于 作者: Z Zhang Qinghua

Decouple ME and AKG for GPU.

上级 dd68fb16
......@@ -14,52 +14,21 @@
# ============================================================================
"""kernel build server"""
import os
import sys
import time
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class TbeBuilder:
"""Tbe building wrapper"""
def __init__(self):
self.tbe_builder = create_tbe_parallel_process()
def start(self, json):
return self.tbe_builder.start_compile_op(json)
def wait(self):
return self.tbe_builder.wait_one()
def reset(self):
self.tbe_builder.reset_task_info()
def exit(self):
self.tbe_builder.exit()
class AkgBuilder:
"""Akg building wrapper"""
def __init__(self):
pass
def create(self, process_num, waitime):
self.akg_builder = create_akg_parallel_process(process_num, waitime)
def accept_json(self, json):
return self.akg_builder.accept_json(json)
def compile(self):
return self.akg_builder.compile()
class Messager:
'''Messager'''
def __init__(self):
logger.info('[TRACE]', 'Messager init...')
def __init__(self, fdin, fdout):
self.fdin = fdin
self.fdout = fdout
self.fin = os.fdopen(fdin, "r")
self.fout = os.fdopen(fdout, "w")
self.message = ''
self.tbe_builder = TbeBuilder()
self.akg_builder = AkgBuilder()
def __del__(self):
os.close(self.fdin)
os.close(self.fdout)
def get_message(self):
"""
......@@ -72,7 +41,7 @@ class Messager:
# Not read by input() anymore
res = self.fin.readline()
if not res:
logger.info('[TRACE]', "read <empty>")
logger.debug('[TRACE]', "read nothing...")
self.exit()
if res[len(res) - 1] == '\n':
res = res[0:len(res)-1]
......@@ -82,7 +51,7 @@ class Messager:
self.exit()
finally:
pass
if self.message == '' or self.message == 'FIN':
if self.message == '' or self.message == 'FINISH':
self.send_ack()
self.exit()
return self.message
......@@ -123,76 +92,6 @@ class Messager:
else:
self.send_res('ERR')
def handle(self):
"""
Communicate with remote
"""
arg = self.get_message()
if arg == 'TBE/START':
self.send_ack()
json = self.get_message()
res = self.tbe_builder.start(json)
self.send_res(res)
elif arg == 'TBE/WAIT':
self.send_ack()
task_id, res, pre = self.tbe_builder.wait()
logger.debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre))
if self.get_message() != 'CONT':
self.send_ack(False)
self.exit()
self.send_res(task_id)
if self.get_message() != 'CONT':
self.send_ack(False)
self.exit()
self.send_res(res)
if self.get_message() != 'CONT':
self.send_ack(False)
self.exit()
self.send_res(pre)
elif arg == 'TBE/RESET':
self.tbe_builder.reset()
self.send_ack()
elif arg == 'AKG/START':
self.send_ack()
process_num_str = self.get_message()
self.send_ack()
wait_time_str = self.get_message()
self.akg_builder.create(int(process_num_str), int(wait_time_str))
self.send_ack()
elif arg == 'AKG/DATA':
self.send_ack()
while True:
req = self.get_message()
if req.startswith('{'):
self.akg_builder.accept_json(req)
self.send_ack()
elif req == 'AKG/WAIT':
res = self.akg_builder.compile()
self.send_res(res)
break
else:
self.send_ack(False)
break
elif arg == 'FORMAT':
self.send_ack()
json = self.get_message()
self.send_res(op_select_format(json))
elif arg == 'SUPPORT':
self.send_ack()
json = self.get_message()
logger.debug('[SUPPORT]', json)
try:
res = check_supported(json)
except json.decoder.JSONDecodeError:
self.send_ack(False)
self.exit()
finally:
pass
self.send_res(res)
else:
self.send_ack(False)
self.exit()
def loop(self):
"""
Messaging loop
......@@ -200,20 +99,26 @@ class Messager:
while True:
self.handle()
def run(self):
self.loop()
def handle(self):
"""
A interface communicates with remote.
Note:
All subclasses should override this interface.
"""
raise NotImplementedError
def exit(self):
os.close(self.fdin)
os.close(self.fdout)
self.tbe_builder.reset()
self.tbe_builder.exit()
logger.info('[TRACE]', 'Messager Exit...')
exit()
"""
A interface handles the procedure before exit.
def run(self, fdin, fdout):
self.fdin = fdin
self.fdout = fdout
self.fin = os.fdopen(fdin, "r")
self.fout = os.fdopen(fdout, "w")
self.loop()
Note:
All subclasses should override this interface.
"""
raise NotImplementedError
class Logger:
"""
......@@ -265,9 +170,5 @@ class DummyLogger:
logger = DummyLogger()
if __name__ == '__main__':
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))
logger.debug('[TRACE]', 'argv: ' + str(sys.argv))
messager = Messager()
messager.run(int(sys.argv[1]), int(sys.argv[2]))
def get_logger():
return logger
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""kernel build server for ascend"""
import sys
from mindspore._extends.remote.kernel_build_server import Messager, get_logger
from mindspore._extends.parallel_compile.tbe_compiler.tbe_process import create_tbe_parallel_process, op_select_format, check_supported
from mindspore._extends.parallel_compile.akg_compiler.akg_process import create_akg_parallel_process
class TbeBuilder:
"""Tbe building wrapper"""
def __init__(self):
self.tbe_builder = create_tbe_parallel_process()
def start(self, json):
return self.tbe_builder.start_compile_op(json)
def wait(self):
return self.tbe_builder.wait_one()
def reset(self):
self.tbe_builder.reset_task_info()
def exit(self):
self.tbe_builder.exit()
class AkgBuilder:
"""Akg building wrapper"""
def __init__(self):
pass
def create(self, process_num, waitime):
self.akg_builder = create_akg_parallel_process(process_num, waitime)
def accept_json(self, json):
return self.akg_builder.accept_json(json)
def compile(self):
return self.akg_builder.compile()
class AscendMessager(Messager):
'''
Ascend Messager
It works as a server, communicating with c++ client.
'''
def __init__(self, fdin, fdout):
super().__init__(fdin, fdout)
get_logger().info('[TRACE]', 'Ascend Messager init...')
self.tbe_builder = TbeBuilder()
self.akg_builder = AkgBuilder()
def handle(self):
"""
Communicate with remote client.
Reference protocol between them at PR#3821 and PR#3935
"""
arg = self.get_message()
if arg == 'TBE/START':
self.send_ack()
json = self.get_message()
res = self.tbe_builder.start(json)
self.send_res(res)
elif arg == 'TBE/WAIT':
self.send_ack()
task_id, res, pre = self.tbe_builder.wait()
get_logger().debug('[TRACE]', str(task_id) + '/' + str(res) + '/' + str(pre))
if self.get_message() != 'CONTINUE':
self.send_ack(False)
self.exit()
self.send_res(task_id)
if self.get_message() != 'CONTINUE':
self.send_ack(False)
self.exit()
self.send_res(res)
if self.get_message() != 'CONTINUE':
self.send_ack(False)
self.exit()
self.send_res(pre)
elif arg == 'TBE/RESET':
self.tbe_builder.reset()
self.send_ack()
elif arg == 'AKG/START':
self.send_ack()
process_num_str = self.get_message()
self.send_ack()
wait_time_str = self.get_message()
self.akg_builder.create(int(process_num_str), int(wait_time_str))
self.send_ack()
elif arg == 'AKG/DATA':
self.send_ack()
while True:
req = self.get_message()
if req.startswith('{'):
self.akg_builder.accept_json(req)
self.send_ack()
elif req == 'AKG/WAIT':
res = self.akg_builder.compile()
self.send_res(res)
break
else:
self.send_ack(False)
break
elif arg == 'FORMAT':
self.send_ack()
json = self.get_message()
self.send_res(op_select_format(json))
elif arg == 'SUPPORT':
self.send_ack()
json = self.get_message()
get_logger().debug('[SUPPORT]', json)
try:
res = check_supported(json)
except json.decoder.JSONDecodeError:
self.send_ack(False)
self.exit()
finally:
pass
self.send_res(res)
else:
self.send_ack(False)
self.exit()
def exit(self):
self.tbe_builder.reset()
self.tbe_builder.exit()
get_logger().info('[TRACE]', 'Ascend Messager Exit...')
exit()
if __name__ == '__main__':
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))
get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv))
messager = AscendMessager(int(sys.argv[1]), int(sys.argv[2]))
messager.run()
# Copyright 2020 Huawei Technologies Co., Ltd
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ============================================================================
"""kernel build server for gpu"""
import os
import sys
from mindspore._extends.remote.kernel_build_server import Messager, get_logger
from mindspore._extends.parallel_compile.akg_compiler.compiler import run_compiler as akg_compile_single
class GpuMessager(Messager):
'''
GPU Messager
It works as a server, communicating with c++ client.
'''
def __init__(self, fdin, fdout):
super().__init__(fdin, fdout)
get_logger().info('[TRACE]', 'GPU Messager init...')
def handle(self):
"""
Communicate with remote client.
Reference protocol between them at PR#4063
"""
arg = self.get_message()
if arg == 'AKG/PID':
self.send_res(os.getpid())
elif arg == 'AKG/COMPILE':
self.send_ack()
json = self.get_message()
try:
akg_compile_single(json)
except ValueError:
self.send_ack(False)
self.exit()
finally:
pass
self.send_ack()
else:
self.send_ack(False)
self.exit()
def exit(self):
get_logger().info('[TRACE]', 'GPU Messager Exit...')
exit()
if __name__ == '__main__':
if len(sys.argv) != 3:
raise Exception('Incorrect argv: {}'.format(sys.argv))
get_logger().debug('[TRACE]', 'argv: ' + str(sys.argv))
messager = GpuMessager(int(sys.argv[1]), int(sys.argv[2]))
messager.run()
......@@ -15,7 +15,6 @@
*/
#include "backend/kernel_compiler/akg/akg_kernel_build.h"
#include <Python.h>
#include <sys/types.h>
#include <signal.h>
#include <unistd.h>
......@@ -37,13 +36,10 @@
#include "utils/utils.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore {
namespace kernel {
constexpr int ME_MAX_KERNEL_NAME_LENGTH = 200;
constexpr int32_t ARGS_SIZE = 1;
constexpr auto kCompileWithJsonFunc = "compilewithjson";
// json key
constexpr auto kOpDesc = "op_desc";
constexpr auto kInputDesc = "input_desc";
......@@ -70,25 +66,6 @@ std::string Vector2Str(const std::vector<T> &inputs) {
}
} // namespace
std::string AkgKernelBuild::PyObjectToStr(PyObject *const PyObj) {
char *pChar = nullptr;
std::string str_res;
if (PyObj == nullptr) {
MS_LOG(ERROR) << "Input parameter is nullptr.";
return str_res;
}
PyObject *strArgs = PyObject_Str(PyObj);
if (strArgs != nullptr) {
(void)PyArg_Parse(strArgs, "s", &pChar);
}
if (pChar == nullptr) {
MS_LOG(ERROR) << "pChar is nullptr.";
return str_res;
}
str_res = pChar;
return str_res;
}
std::string GetTensorName(const nlohmann::json &node_json, const std::string &tag,
const std::pair<size_t, size_t> &position) {
if (node_json.count(tag) == 0) {
......@@ -528,32 +505,11 @@ KernelPackPtr AkgKernelBuild::OpBuild(const std::string &node_json, const AnfNod
return cached_kernel_pack;
}
PyObject *pModule = nullptr;
PyObject *pFunc = nullptr;
PyObject *pArg = nullptr;
PyObject *pRes = nullptr;
pModule = PyImport_ImportModule(kAkgModule);
if (pModule == nullptr) {
MS_LOG(ERROR) << "Failed to import [" << kAkgModule << "].";
return nullptr;
}
pFunc = PyObject_GetAttrString(pModule, kCompileWithJsonFunc);
pArg = PyTuple_New(ARGS_SIZE);
(void)PyTuple_SetItem(pArg, 0, Py_BuildValue("s", node_json.c_str()));
(void)alarm(AUTODIFF_COMPILE_OVERTIME);
pRes = PyEval_CallObject(pFunc, pArg);
auto res = GpuKernelBuildClient::Instance().AkgCompileSingle(node_json);
(void)alarm(0);
if (pRes == nullptr) {
MS_LOG(ERROR) << "No ret got, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(pArg) << ").";
return nullptr;
}
if (PyObject_IsTrue(pRes) != 1) {
MS_LOG(ERROR) << "Illegal ret, failed to call function [" << kCompileWithJsonFunc << "], args:\n("
<< AkgKernelBuild::PyObjectToStr(pArg) << ").";
if (!res) {
MS_LOG(ERROR) << "Akg compile failed, json: " << node_json;
return nullptr;
}
......
......@@ -324,15 +324,15 @@ bool AkgOpParallelBuild(const std::vector<std::pair<AkgAscendKernelBuilder, AnfN
}
// Start building in AKG
if (!KernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) {
if (!AscendKernelBuildClient::Instance().AkgStart(PROCESS_NUM, TIME_OUT)) {
MS_LOG(ERROR) << "Akg start failed.";
return false;
}
if (!KernelBuildClient::Instance().AkgSendData(jsons)) {
if (!AscendKernelBuildClient::Instance().AkgSendData(jsons)) {
MS_LOG(ERROR) << "Akg send data failed.";
return false;
}
if (!KernelBuildClient::Instance().AkgWait()) {
if (!AscendKernelBuildClient::Instance().AkgWait()) {
MS_LOG(ERROR) << "Akg compile failed.";
return false;
}
......
......@@ -74,8 +74,12 @@ const std::unordered_map<std::string, FusionType> fusion_type_maps = {
{"SEGMENT", FusionType::SEGMENT}, {"OPAQUE", FusionType::OPAQUE},
};
void KernelMeta::Initialize() {
kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
void KernelMeta::Initialize(int pid) {
if (pid == -1) {
kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(getpid()) + "/";
} else {
kernel_meta_path_ = std::string(kGpuKernelMeta) + "_" + std::to_string(pid) + "/";
}
// remove old kernel cache
RemoveKernelCache();
......
......@@ -40,7 +40,6 @@ constexpr auto kProcessorCuda = "cuda";
constexpr auto kJsonSuffix = ".json";
constexpr auto kInfoSuffix = ".info";
constexpr unsigned int AUTODIFF_COMPILE_OVERTIME = 600;
constexpr auto kAkgModule = "akg.ms";
constexpr auto kArgDataformat = "data_format";
const std::vector<std::string> support_devices = {"aicore", "aicpu", "cuda"};
......@@ -54,7 +53,7 @@ using KernelMetaPtr = std::shared_ptr<KernelMetaInfo>;
class KernelMeta {
public:
KernelMeta() = default;
void Initialize();
void Initialize(int pid);
void RemoveKernelCache();
std::string Search(const std::string &kernel_name) const;
bool Insert(const std::string &kernel_name, const std::string &kernel_json);
......
......@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
}
int ParallelBuildManager::StartCompileOp(const nlohmann::json &kernel_json) {
return KernelBuildClient::Instance().TbeStart(kernel_json.dump());
return AscendKernelBuildClient::Instance().TbeStart(kernel_json.dump());
}
bool ParallelBuildManager::WaitOne(int *task_id, std::string *task_result, std::string *pre_build_result) {
MS_EXCEPTION_IF_NULL(task_id);
return KernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result);
return AscendKernelBuildClient::Instance().TbeWait(task_id, task_result, pre_build_result);
}
void ParallelBuildManager::ResetTaskInfo() {
......@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() {
}
task_map_.clear();
same_op_list_.clear();
KernelBuildClient::Instance().TbeReset();
AscendKernelBuildClient::Instance().TbeReset();
}
} // namespace kernel
} // namespace mindspore
......@@ -312,7 +312,7 @@ bool TbeKernelSelect::TbeCheckSupported(
if (!ret) {
MS_LOG(EXCEPTION) << "Gen tbe single kernel json for check support failed.";
}
ret = KernelBuildClient::Instance().CheckSupported(kernel_json.dump());
ret = AscendKernelBuildClient::Instance().CheckSupported(kernel_json.dump());
AnfAlgo::SetSelectKernelBuildInfo(kernel_build_info_tmp, cnode_ptr_.get());
return ret;
}
......@@ -486,7 +486,7 @@ std::string TbeKernelSelect::OpSelectFormat() {
if (!ret) {
MS_LOG(EXCEPTION) << "GenTbeSingleKernelJson failed.";
}
res_json_str = KernelBuildClient::Instance().SelectFormat(kernel_json.dump());
res_json_str = AscendKernelBuildClient::Instance().SelectFormat(kernel_json.dump());
if (res_json_str.empty()) {
MS_LOG(EXCEPTION) << "op select format error.";
}
......
......@@ -29,7 +29,7 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
}
}
int KernelBuildClient::TbeStart(const std::string &json) {
int AscendKernelBuildClient::TbeStart(const std::string &json) {
// Start compiling..
auto res = SendRequest(kTbeStart);
if (res != kAck) {
......@@ -46,7 +46,7 @@ int KernelBuildClient::TbeStart(const std::string &json) {
return std::stoi(res);
}
bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
bool AscendKernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::string *pre_build_result) {
// Start waiting..
auto res = SendRequest(kTbeWait);
if (res != kAck) {
......@@ -54,15 +54,15 @@ bool KernelBuildClient::TbeWait(int *task_id, std::string *task_result, std::str
return false;
}
// Request task id.
*task_id = std::stoi(SendRequest(kCont));
*task_id = std::stoi(SendRequest(kContinue));
// Requst task result.
*task_result = SendRequest(kCont);
*task_result = SendRequest(kContinue);
// Request prebuild result.
*pre_build_result = SendRequest(kCont);
*pre_build_result = SendRequest(kContinue);
return true;
}
void KernelBuildClient::TbeReset() {
void AscendKernelBuildClient::TbeReset() {
// Start compiling..
auto res = SendRequest(kTbeReset);
if (res != kAck) {
......@@ -70,7 +70,7 @@ void KernelBuildClient::TbeReset() {
}
}
bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
bool AscendKernelBuildClient::AkgStart(int process_num, int wait_time) {
// Start compiling..
auto res = SendRequest(kAkgStart);
if (res != kAck) {
......@@ -92,7 +92,7 @@ bool KernelBuildClient::AkgStart(int process_num, int wait_time) {
return true;
}
bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
bool AscendKernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
auto res = SendRequest(kAkgData);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/DATA failed, res: " << res;
......@@ -109,7 +109,7 @@ bool KernelBuildClient::AkgSendData(const std::vector<std::string> &jsons) {
}
// Fetch the result of AKG compiling.
bool KernelBuildClient::AkgWait() {
bool AscendKernelBuildClient::AkgWait() {
auto res = SendRequest(kAkgWait);
if (res != kTrue) {
MS_LOG(ERROR) << "AKG/WAIT failed, res: " << res;
......@@ -118,7 +118,7 @@ bool KernelBuildClient::AkgWait() {
return true;
}
std::string KernelBuildClient::SelectFormat(const std::string &json) {
std::string AscendKernelBuildClient::SelectFormat(const std::string &json) {
// Start compiling..
auto res = SendRequest(kFormat);
if (res != kAck) {
......@@ -134,7 +134,7 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) {
return res;
}
bool KernelBuildClient::CheckSupported(const std::string &json) {
bool AscendKernelBuildClient::CheckSupported(const std::string &json) {
// Checking support..
auto res = SendRequest(kSupport);
if (res != kAck) {
......@@ -149,5 +149,29 @@ bool KernelBuildClient::CheckSupported(const std::string &json) {
}
return true;
}
int GpuKernelBuildClient::AkgGetPid() {
auto res = SendRequest(kAkgPid);
if (res == kErr) {
MS_LOG(ERROR) << "AKG/PID failed, res: " << res;
return -1;
}
return std::stoi(res);
}
bool GpuKernelBuildClient::AkgCompileSingle(const std::string json) {
auto res = SendRequest(kAkgCompileOp);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/COMPILE failed, res: " << res;
return false;
}
// Send single json data.
res = SendRequest(json);
if (res != kAck) {
MS_LOG(ERROR) << "AKG/COMPILE responds failed, res: " << res;
return false;
}
return true;
}
} // namespace kernel
} // namespace mindspore
......@@ -29,97 +29,37 @@
namespace mindspore {
namespace kernel {
void ReplaceStr(std::string *dest, const std::string &replace, char new_char);
constexpr inline static int kBufferSize = 4096;
// The TAG as prefix of real command from remote.
constexpr inline static auto kTag = "[~]";
class KernelBuildClient {
public:
// Server configure
constexpr inline static auto kEnv = "python";
constexpr inline static auto kGetPathScript =
"-c "
"\""
"import pkgutil;"
"path = pkgutil"
".get_loader(\\\"mindspore._extends.remote.kernel_build_server\\\")" // Server module name
".get_filename();"
"print('[~]' + path)"
"\"";
// Send Finish request to server
constexpr inline static auto kFinish = "FINISH";
// Receive the response from server
constexpr inline static auto kAck = "ACK";
constexpr inline static auto kErr = "ERR";
constexpr inline static auto kFailed = "-1";
// Send Finish request to server
constexpr inline static auto kFin = "FIN";
// Send building request to server
constexpr inline static auto kTbeStart = "TBE/START";
constexpr inline static auto kTbeWait = "TBE/WAIT";
constexpr inline static auto kCont = "CONT";
constexpr inline static auto kSuccess = "Success";
constexpr inline static auto kTbeReset = "TBE/RESET";
constexpr inline static auto kAkgStart = "AKG/START";
constexpr inline static auto kAkgData = "AKG/DATA";
constexpr inline static auto kAkgWait = "AKG/WAIT";
// Send server info. query to server
constexpr inline static auto kFormat = "FORMAT";
constexpr inline static auto kSupport = "SUPPORT";
constexpr inline static auto kTrue = "True";
constexpr inline static auto kSuccess = "Success";
// Revert \n, \r, [space].
constexpr inline static auto kLF = "[LF]";
constexpr inline static auto kCR = "[CR]";
constexpr inline static auto kSP = "[SP]";
// The TAG as prefix of real command from remote.
constexpr inline static auto kTag = "[~]";
constexpr inline static int kBufferSize = 4096;
constexpr inline static unsigned int kTimeOutSeconds = 350;
static KernelBuildClient &Instance() {
static KernelBuildClient instance;
return instance;
}
std::string GetScriptPath() {
std::string cmd = kEnv;
(void)cmd.append(1, ' ').append(kGetPathScript);
FILE *fpipe = popen(cmd.c_str(), "r");
if (fpipe == nullptr) {
MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")";
}
bool start = false;
std::string result;
char buf[kBufferSize];
while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) {
start = true;
}
// Filter with 'kTAG' and '\n'
if (start) {
auto size = std::strlen(buf);
bool line_end = buf[size - 1] == '\n';
result.append(buf, line_end ? size - 1 : size);
if (line_end) {
break;
}
}
}
pclose(fpipe);
const std::string py_suffix = ".py";
if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
}
result = result.substr(strlen(kTag));
MS_LOG(DEBUG) << "result: " << result;
return result;
}
virtual std::string GetEnv() = 0;
virtual std::string GetScript() = 0;
void Open() {
if (!init_) {
// Exception's thrown if open failed
if (dp_->Open({kEnv, GetScriptPath()}, true) != -1) {
if (dp_->Open({GetEnv(), GetScript()}, true) != -1) {
dp_->SetTimeOutSeconds(kTimeOutSeconds);
dp_->SetTimeOutCallback([this]() { SendRequest(kFin); });
dp_->SetTimeOutCallback([this]() { SendRequest(kFinish); });
init_ = true;
}
}
......@@ -164,6 +104,88 @@ class KernelBuildClient {
return res;
}
protected:
KernelBuildClient() : init_(false), dp_(std::make_shared<DuplexPipe>()) {}
virtual ~KernelBuildClient() = default;
private:
bool init_;
std::shared_ptr<DuplexPipe> dp_;
};
static inline std::string GetScriptFilePath(const std::string cmd_env, const std::string &cmd_script) {
std::string cmd = cmd_env;
(void)cmd.append(1, ' ').append(cmd_script);
FILE *fpipe = popen(cmd.c_str(), "r");
if (fpipe == nullptr) {
MS_LOG(EXCEPTION) << "popen failed, " << strerror(errno) << "(" << errno << ")";
}
bool start = false;
std::string result;
char buf[kBufferSize];
while (std::fgets(buf, sizeof(buf), fpipe) != nullptr) {
if (std::strncmp(buf, kTag, std::strlen(kTag)) == 0) {
start = true;
}
// Filter with 'kTAG' and '\n'
if (start) {
auto size = std::strlen(buf);
bool line_end = buf[size - 1] == '\n';
result.append(buf, line_end ? size - 1 : size);
if (line_end) {
break;
}
}
}
pclose(fpipe);
const std::string py_suffix = ".py";
if (result.empty() || result.rfind(py_suffix) != (result.length() - py_suffix.length())) {
MS_LOG(EXCEPTION) << "py file seems incorrect, result: {" << result << "}";
}
result = result.substr(strlen(kTag));
MS_LOG(DEBUG) << "result: " << result;
return result;
}
class AscendKernelBuildClient : public KernelBuildClient {
public:
// Server configure
constexpr inline static auto kEnv = "python";
constexpr inline static auto kGetPathScript =
"-c "
"\""
"import pkgutil;"
"path = pkgutil"
".get_loader(\\\"mindspore._extends.remote.kernel_build_server_ascend\\\")" // Server module name
".get_filename();"
"print('[~]' + path)"
"\"";
// Receive the response from server
constexpr inline static auto kFailed = "-1";
// Send building request to server
constexpr inline static auto kContinue = "CONTINUE"; // More transactions to be continued
constexpr inline static auto kTbeStart = "TBE/START";
constexpr inline static auto kTbeWait = "TBE/WAIT";
constexpr inline static auto kTbeReset = "TBE/RESET";
constexpr inline static auto kAkgStart = "AKG/START";
constexpr inline static auto kAkgData = "AKG/DATA";
constexpr inline static auto kAkgWait = "AKG/WAIT";
// Send server info. query to server
constexpr inline static auto kFormat = "FORMAT";
constexpr inline static auto kSupport = "SUPPORT";
static AscendKernelBuildClient &Instance() {
static AscendKernelBuildClient instance;
return instance;
}
std::string GetEnv() override { return kEnv; }
std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); }
// Before building.
std::string SelectFormat(const std::string &json);
bool CheckSupported(const std::string &json);
......@@ -177,19 +199,60 @@ class KernelBuildClient {
bool AkgStart(int process_num, int wait_time);
bool AkgSendData(const std::vector<std::string> &jsons);
bool AkgWait();
bool AkgCompileSingle(const std::string json);
KernelBuildClient(const KernelBuildClient &) = delete;
KernelBuildClient &operator=(const KernelBuildClient &) = delete;
AscendKernelBuildClient(const AscendKernelBuildClient &) = delete;
AscendKernelBuildClient &operator=(const AscendKernelBuildClient &) = delete;
KernelBuildClient(KernelBuildClient &&) = delete;
KernelBuildClient &operator=(KernelBuildClient &&) = delete;
AscendKernelBuildClient(AscendKernelBuildClient &&) = delete;
AscendKernelBuildClient &operator=(AscendKernelBuildClient &&) = delete;
private:
KernelBuildClient() : init_(false), dp_(std::make_shared<DuplexPipe>()) { Open(); }
~KernelBuildClient() { Close(); }
AscendKernelBuildClient() { Open(); }
~AscendKernelBuildClient() override { Close(); }
};
bool init_;
std::shared_ptr<DuplexPipe> dp_;
class GpuKernelBuildClient : public KernelBuildClient {
public:
// Server configure
constexpr inline static auto kEnv = "python";
constexpr inline static auto kGetPathScript =
"-c "
"\""
"import pkgutil;"
"path = pkgutil"
".get_loader(\\\"mindspore._extends.remote.kernel_build_server_gpu\\\")" // Server module name
".get_filename();"
"print('[~]' + path)"
"\"";
// Send building request to server
constexpr inline static auto kAkgPid = "AKG/PID";
constexpr inline static auto kAkgCompileOp = "AKG/COMPILE"; // Compile a single op
static GpuKernelBuildClient &Instance() {
static GpuKernelBuildClient instance;
return instance;
}
std::string GetEnv() override { return kEnv; }
std::string GetScript() override { return GetScriptFilePath(kEnv, kGetPathScript); }
// Fetch pid(pid_t) from remote.
int AkgGetPid();
// Run AKG building.
bool AkgCompileSingle(const std::string json);
GpuKernelBuildClient(const GpuKernelBuildClient &) = delete;
GpuKernelBuildClient &operator=(const GpuKernelBuildClient &) = delete;
GpuKernelBuildClient(GpuKernelBuildClient &&) = delete;
GpuKernelBuildClient &operator=(GpuKernelBuildClient &&) = delete;
private:
GpuKernelBuildClient() { Open(); }
~GpuKernelBuildClient() override { Close(); }
};
} // namespace kernel
} // namespace mindspore
......
......@@ -21,13 +21,16 @@
#include "backend/kernel_compiler/gpu/gpu_kernel_factory.h"
#include "frontend/operator/ops.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
namespace mindspore {
namespace device {
namespace gpu {
void GpuBuild(const KernelGraphPtr &kernel_graph) {
kernel::KernelMeta *bin_map = kernel::KernelMeta::GetInstance();
MS_EXCEPTION_IF_NULL(bin_map);
bin_map->Initialize();
auto pid = mindspore::kernel::GpuKernelBuildClient::Instance().AkgGetPid();
bin_map->Initialize(pid);
MS_EXCEPTION_IF_NULL(kernel_graph);
auto kernels = kernel_graph->execution_order();
for (const auto &kernel : kernels) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册