Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
0154bdeb
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
0154bdeb
编写于
8月 05, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
8月 05, 2020
浏览文件
操作
浏览文件
下载
差异文件
!3935 Decouple ME and AKG for Ascend
Merge pull request !3935 from ZhangQinghua/master
上级
61d24efb
89f34017
变更
8
显示空白变更内容
内联
并排
Showing
8 changed file
with
259 addition
and
106 deletion
+259
-106
mindspore/_extends/parallel_compile/akg_compiler/__init__.py
mindspore/_extends/parallel_compile/akg_compiler/__init__.py
+5
-0
mindspore/_extends/parallel_compile/akg_compiler/akg_process.py
...ore/_extends/parallel_compile/akg_compiler/akg_process.py
+88
-0
mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py
...ore/_extends/parallel_compile/tbe_compiler/tbe_process.py
+5
-6
mindspore/_extends/remote/kernel_build_server.py
mindspore/_extends/remote/kernel_build_server.py
+45
-7
mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc
...end/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc
+9
-43
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc
.../backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc
+3
-3
mindspore/ccsrc/backend/session/kernel_build_client.cc
mindspore/ccsrc/backend/session/kernel_build_client.cc
+72
-24
mindspore/ccsrc/backend/session/kernel_build_client.h
mindspore/ccsrc/backend/session/kernel_build_client.h
+32
-23
未找到文件。
mindspore/_extends/parallel_compile/akg_compiler/__init__.py
浏览文件 @
0154bdeb
...
@@ -12,3 +12,8 @@
...
@@ -12,3 +12,8 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
Extension functions.
Python functions that will be called in the c++ parts of MindSpore.
"""
mindspore/_extends/parallel_compile/akg_compiler/
multi_process_compiler
.py
→
mindspore/_extends/parallel_compile/akg_compiler/
akg_process
.py
浏览文件 @
0154bdeb
...
@@ -12,13 +12,12 @@
...
@@ -12,13 +12,12 @@
# See the License for the specific language governing permissions and
# See the License for the specific language governing permissions and
# limitations under the License.
# limitations under the License.
# ============================================================================
# ============================================================================
"""
Providing multi process compile with json
"""
"""
akg process
"""
import
os
import
os
import
subprocess
import
subprocess
import
sys
import
sys
from
multiprocessing
import
Pool
,
cpu_count
from
multiprocessing
import
Pool
,
cpu_count
def
_compile_akg_task
(
*
json_strs
):
def
_compile_akg_task
(
*
json_strs
):
"""
"""
compile func called in single process
compile func called in single process
...
@@ -34,38 +33,56 @@ def _compile_akg_task(*json_strs):
...
@@ -34,38 +33,56 @@ def _compile_akg_task(*json_strs):
if
res
.
returncode
!=
0
:
if
res
.
returncode
!=
0
:
raise
ValueError
(
"Failed, args: {}!"
.
format
(
json_str
))
raise
ValueError
(
"Failed, args: {}!"
.
format
(
json_str
))
def
create_akg_parallel_process
(
process_num
,
wait_time
):
def
compile_akg_kernel_parallel
(
json_infos
,
process
,
waitime
):
"""
"""
compile kernel use multi processes
create AkgParallelCompiler object
Parameters:
json_infos: list. list contain kernel info(task id and json str)
process: int. processes num
waittime: int. max time the function blocked
Returns:
Returns:
True for all compile success, False for some failed.
AkgParallelCompiler
"""
"""
if
not
isinstance
(
json_infos
,
list
):
return
AkgProcess
(
process_num
,
wait_time
)
raise
ValueError
(
"json_infos must be a list"
)
if
not
isinstance
(
process
,
int
):
raise
ValueError
(
"process must be a num"
)
if
not
isinstance
(
waitime
,
int
):
raise
ValueError
(
"waittime must be a num"
)
if
process
==
0
and
json_info
s
:
class
AkgProces
s
:
process
=
1
"""akg kernel parallel process"""
cpu_proc_num
=
cpu_count
()
def
__init__
(
self
,
process_num
,
wait_time
):
"""
Args:
process_num: int. processes number
waittime: int. max time the function blocked
"""
if
not
isinstance
(
process_num
,
int
):
raise
ValueError
(
"process number must be a num"
)
if
not
isinstance
(
wait_time
,
int
):
raise
ValueError
(
"wait time must be a num"
)
if
process_num
==
0
:
process_num
=
1
max_proc_num
=
16
max_proc_num
=
16
process
=
min
([
cpu_proc_num
,
max_proc_num
,
process
])
self
.
process_num
=
min
([
cpu_count
(),
max_proc_num
,
process_num
])
self
.
args
=
[[]
for
_
in
range
(
self
.
process_num
)]
args
=
[[]
for
_
in
range
(
process
)]
self
.
wait_time
=
wait_time
for
p
,
info
in
enumerate
(
json_infos
):
self
.
argc
=
0
args
[
p
%
process
].
append
(
info
)
with
Pool
(
processes
=
process
)
as
pool
:
def
compile
(
self
):
res
=
pool
.
starmap_async
(
_compile_akg_task
,
args
)
"""
res
.
get
(
timeout
=
waitime
)
compile kernel by multi processes
Return:
True for all compile success, False for some failed.
"""
if
self
.
argc
==
0
:
raise
ValueError
(
"json must be not null"
)
with
Pool
(
processes
=
self
.
process_num
)
as
pool
:
res
=
pool
.
starmap_async
(
_compile_akg_task
,
self
.
args
)
res
.
get
(
timeout
=
self
.
wait_time
)
return
True
return
True
def
accept_json
(
self
,
json
):
"""
accept json data before compile
Args:
json: str. kernel info.
"""
if
not
isinstance
(
json
,
str
):
raise
ValueError
(
"json must be a str"
)
self
.
args
[
self
.
argc
%
self
.
process_num
].
append
(
json
)
self
.
argc
+=
1
mindspore/_extends/parallel_compile/tbe_compiler/tbe_process.py
浏览文件 @
0154bdeb
...
@@ -22,14 +22,14 @@ import json
...
@@ -22,14 +22,14 @@ import json
from
.common
import
check_kernel_info
,
TBEException
from
.common
import
check_kernel_info
,
TBEException
from
.helper
import
_op_select_format
,
_check_supported
from
.helper
import
_op_select_format
,
_check_supported
def
create_tbe_parallel_
compiler
():
def
create_tbe_parallel_
process
():
"""
"""
create TBEParallelCompiler object
create TBEParallelCompiler object
Returns:
Returns:
TBEParallelCompiler
TBEParallelCompiler
"""
"""
return
compile_pool
return
tbe_process
def
op_select_format
(
op_json
:
str
):
def
op_select_format
(
op_json
:
str
):
"""
"""
...
@@ -98,8 +98,8 @@ def run_compiler(op_json):
...
@@ -98,8 +98,8 @@ def run_compiler(op_json):
except
subprocess
.
CalledProcessError
as
e
:
except
subprocess
.
CalledProcessError
as
e
:
return
"TBEException"
,
"PreCompileProcessFailed:
\n
"
+
e
.
stdout
+
"
\n
"
+
e
.
stderr
+
"
\n
input_args: "
+
op_json
return
"TBEException"
,
"PreCompileProcessFailed:
\n
"
+
e
.
stdout
+
"
\n
"
+
e
.
stderr
+
"
\n
input_args: "
+
op_json
class
CompilerPool
:
class
TbeProcess
:
"""
compiler pool
"""
"""
tbe process
"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
__processe_num
=
multiprocessing
.
cpu_count
()
self
.
__processe_num
=
multiprocessing
.
cpu_count
()
...
@@ -168,5 +168,4 @@ class CompilerPool:
...
@@ -168,5 +168,4 @@ class CompilerPool:
if
self
.
__running_tasks
:
if
self
.
__running_tasks
:
self
.
__running_tasks
.
clear
()
self
.
__running_tasks
.
clear
()
tbe_process
=
TbeProcess
()
compile_pool
=
CompilerPool
()
mindspore/_extends/remote/kernel_build_server.py
浏览文件 @
0154bdeb
...
@@ -16,13 +16,14 @@
...
@@ -16,13 +16,14 @@
import
os
import
os
import
sys
import
sys
import
time
import
time
from
mindspore._extends.parallel_compile.tbe_compiler.tbe_process
import
create_tbe_parallel_compiler
,
op_select_format
,
check_supported
from
mindspore._extends.parallel_compile.tbe_compiler.tbe_process
import
create_tbe_parallel_process
,
op_select_format
,
check_supported
from
mindspore._extends.parallel_compile.akg_compiler.akg_process
import
create_akg_parallel_process
class
TbeBuilder
:
class
TbeBuilder
:
"""Tbe building wrapper"""
"""Tbe building wrapper"""
def
__init__
(
self
):
def
__init__
(
self
):
self
.
tbe_builder
=
create_tbe_parallel_
compiler
()
self
.
tbe_builder
=
create_tbe_parallel_
process
()
def
start
(
self
,
json
):
def
start
(
self
,
json
):
return
self
.
tbe_builder
.
start_compile_op
(
json
)
return
self
.
tbe_builder
.
start_compile_op
(
json
)
...
@@ -36,6 +37,21 @@ class TbeBuilder:
...
@@ -36,6 +37,21 @@ class TbeBuilder:
def
exit
(
self
):
def
exit
(
self
):
self
.
tbe_builder
.
exit
()
self
.
tbe_builder
.
exit
()
class
AkgBuilder
:
"""Akg building wrapper"""
def
__init__
(
self
):
pass
def
create
(
self
,
process_num
,
waitime
):
self
.
akg_builder
=
create_akg_parallel_process
(
process_num
,
waitime
)
def
accept_json
(
self
,
json
):
return
self
.
akg_builder
.
accept_json
(
json
)
def
compile
(
self
):
return
self
.
akg_builder
.
compile
()
class
Messager
:
class
Messager
:
'''Messager'''
'''Messager'''
...
@@ -43,6 +59,7 @@ class Messager:
...
@@ -43,6 +59,7 @@ class Messager:
logger
.
info
(
'[TRACE]'
,
'Messager init...'
)
logger
.
info
(
'[TRACE]'
,
'Messager init...'
)
self
.
message
=
''
self
.
message
=
''
self
.
tbe_builder
=
TbeBuilder
()
self
.
tbe_builder
=
TbeBuilder
()
self
.
akg_builder
=
AkgBuilder
()
def
get_message
(
self
):
def
get_message
(
self
):
"""
"""
...
@@ -111,12 +128,12 @@ class Messager:
...
@@ -111,12 +128,12 @@ class Messager:
Communicate with remote
Communicate with remote
"""
"""
arg
=
self
.
get_message
()
arg
=
self
.
get_message
()
if
arg
==
'START'
:
if
arg
==
'
TBE/
START'
:
self
.
send_ack
()
self
.
send_ack
()
json
=
self
.
get_message
()
json
=
self
.
get_message
()
res
=
self
.
tbe_builder
.
start
(
json
)
res
=
self
.
tbe_builder
.
start
(
json
)
self
.
send_res
(
res
)
self
.
send_res
(
res
)
elif
arg
==
'WAIT'
:
elif
arg
==
'
TBE/
WAIT'
:
self
.
send_ack
()
self
.
send_ack
()
task_id
,
res
,
pre
=
self
.
tbe_builder
.
wait
()
task_id
,
res
,
pre
=
self
.
tbe_builder
.
wait
()
logger
.
debug
(
'[TRACE]'
,
str
(
task_id
)
+
'/'
+
str
(
res
)
+
'/'
+
str
(
pre
))
logger
.
debug
(
'[TRACE]'
,
str
(
task_id
)
+
'/'
+
str
(
res
)
+
'/'
+
str
(
pre
))
...
@@ -132,9 +149,30 @@ class Messager:
...
@@ -132,9 +149,30 @@ class Messager:
self
.
send_ack
(
False
)
self
.
send_ack
(
False
)
self
.
exit
()
self
.
exit
()
self
.
send_res
(
pre
)
self
.
send_res
(
pre
)
elif
arg
==
'RESET'
:
elif
arg
==
'
TBE/
RESET'
:
self
.
tbe_builder
.
reset
()
self
.
tbe_builder
.
reset
()
self
.
send_ack
()
self
.
send_ack
()
elif
arg
==
'AKG/START'
:
self
.
send_ack
()
process_num_str
=
self
.
get_message
()
self
.
send_ack
()
wait_time_str
=
self
.
get_message
()
self
.
akg_builder
.
create
(
int
(
process_num_str
),
int
(
wait_time_str
))
self
.
send_ack
()
elif
arg
==
'AKG/DATA'
:
self
.
send_ack
()
while
True
:
req
=
self
.
get_message
()
if
req
.
startswith
(
'{'
):
self
.
akg_builder
.
accept_json
(
req
)
self
.
send_ack
()
elif
req
==
'AKG/WAIT'
:
res
=
self
.
akg_builder
.
compile
()
self
.
send_res
(
res
)
break
else
:
self
.
send_ack
(
False
)
break
elif
arg
==
'FORMAT'
:
elif
arg
==
'FORMAT'
:
self
.
send_ack
()
self
.
send_ack
()
json
=
self
.
get_message
()
json
=
self
.
get_message
()
...
@@ -180,7 +218,7 @@ class Messager:
...
@@ -180,7 +218,7 @@ class Messager:
class
Logger
:
class
Logger
:
"""
"""
Replace dummy 'logger' to output log as below:
Replace dummy 'logger' to output log as below:
logger = Logger("remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
logger = Logger(
0, True,
"remote_kernel_build_" + time.strftime("%Y_%m_%d_%H_%M_%S", time.localtime()) + ".log")
"""
"""
def
__init__
(
self
,
level
=
1
,
dumpfile
=
False
,
filename
=
'Logger.log'
):
def
__init__
(
self
,
level
=
1
,
dumpfile
=
False
,
filename
=
'Logger.log'
):
"""
"""
...
@@ -225,7 +263,7 @@ class DummyLogger:
...
@@ -225,7 +263,7 @@ class DummyLogger:
def
info
(
self
,
tag
,
msg
):
def
info
(
self
,
tag
,
msg
):
pass
pass
logger
=
Logger
()
logger
=
Dummy
Logger
()
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
if
len
(
sys
.
argv
)
!=
3
:
if
len
(
sys
.
argv
)
!=
3
:
...
...
mindspore/ccsrc/backend/kernel_compiler/akg/ascend/akg_ascend_kernel_build.cc
浏览文件 @
0154bdeb
...
@@ -23,7 +23,6 @@
...
@@ -23,7 +23,6 @@
#include <unordered_set>
#include <unordered_set>
#include <utility>
#include <utility>
#include <vector>
#include <vector>
#include <Python.h>
#include "ir/dtype.h"
#include "ir/dtype.h"
#include "ir/func_graph.h"
#include "ir/func_graph.h"
#include "backend/kernel_compiler/kernel.h"
#include "backend/kernel_compiler/kernel.h"
...
@@ -32,10 +31,10 @@
...
@@ -32,10 +31,10 @@
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
#include "backend/kernel_compiler/akg/ascend/akg_ascend_kernel_mod.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/kernel_compiler/akg/akg_kernel_attrs_process.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/anf_runtime_algorithm.h"
#include "backend/session/kernel_build_client.h"
namespace
mindspore
{
namespace
mindspore
{
namespace
kernel
{
namespace
kernel
{
constexpr
int32_t
PARALLEL_ARGS_SIZE
=
3
;
constexpr
int32_t
PROCESS_NUM
=
16
;
constexpr
int32_t
PROCESS_NUM
=
16
;
constexpr
int32_t
TIME_OUT
=
300
;
constexpr
int32_t
TIME_OUT
=
300
;
...
@@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type";
...
@@ -45,8 +44,7 @@ constexpr auto kDataType = "data_type";
constexpr
auto
kInputDesc
=
"input_desc"
;
constexpr
auto
kInputDesc
=
"input_desc"
;
constexpr
auto
kOutputDesc
=
"output_desc"
;
constexpr
auto
kOutputDesc
=
"output_desc"
;
constexpr
auto
kTensorName
=
"tensor_name"
;
constexpr
auto
kTensorName
=
"tensor_name"
;
constexpr
auto
kCompileAkgKernelParallelFunc
=
"compile_akg_kernel_parallel"
;
constexpr
auto
kMultiProcModule
=
"mindspore._extends.parallel_compile.akg_compiler.multi_process_compiler"
;
namespace
{
namespace
{
void
UpdateTensorNameInJson
(
const
std
::
vector
<
AnfNodePtr
>
&
anf_nodes
,
void
UpdateTensorNameInJson
(
const
std
::
vector
<
AnfNodePtr
>
&
anf_nodes
,
std
::
map
<
AnfNodePtr
,
nlohmann
::
json
>
*
node_json_map
)
{
std
::
map
<
AnfNodePtr
,
nlohmann
::
json
>
*
node_json_map
)
{
...
@@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf
...
@@ -319,55 +317,23 @@ bool AkgAscendKernelBuilder::CollectFusedJson(const std::vector<AnfNodePtr> &anf
return
true
;
return
true
;
}
}
void
GenParallelCompileFuncArgs
(
const
std
::
vector
<
std
::
string
>
&
kernel_jsons
,
PyObject
**
p_args
)
{
MS_EXCEPTION_IF_NULL
(
p_args
);
*
p_args
=
PyTuple_New
(
PARALLEL_ARGS_SIZE
);
PyObject
*
arg1
=
PyList_New
(
kernel_jsons
.
size
());
for
(
int
i
=
0
;
i
<
PyList_Size
(
arg1
);
++
i
)
{
PyList_SetItem
(
arg1
,
i
,
Py_BuildValue
(
"s"
,
kernel_jsons
[
i
].
c_str
()));
}
PyObject
*
arg2
=
Py_BuildValue
(
"i"
,
PROCESS_NUM
);
PyObject
*
arg3
=
Py_BuildValue
(
"i"
,
TIME_OUT
);
(
void
)
PyTuple_SetItem
(
*
p_args
,
0
,
arg1
);
(
void
)
PyTuple_SetItem
(
*
p_args
,
1
,
arg2
);
(
void
)
PyTuple_SetItem
(
*
p_args
,
2
,
arg3
);
}
bool
AkgOpParallelBuild
(
const
std
::
vector
<
std
::
pair
<
AkgAscendKernelBuilder
,
AnfNodePtr
>>
&
build_args
)
{
bool
AkgOpParallelBuild
(
const
std
::
vector
<
std
::
pair
<
AkgAscendKernelBuilder
,
AnfNodePtr
>>
&
build_args
)
{
auto
[
jsons
,
repeat_nodes
]
=
PreProcessJsonForBuild
(
build_args
);
auto
[
jsons
,
repeat_nodes
]
=
PreProcessJsonForBuild
(
build_args
);
if
(
jsons
.
empty
())
{
if
(
jsons
.
empty
())
{
return
true
;
return
true
;
}
}
// Try to call python method to compile nodes parallely.
// Start building in AKG
PyObject
*
p_module
=
nullptr
;
if
(
!
KernelBuildClient
::
Instance
().
AkgStart
(
PROCESS_NUM
,
TIME_OUT
))
{
PyObject
*
p_func
=
nullptr
;
MS_LOG
(
ERROR
)
<<
"Akg start failed."
;
PyObject
*
p_arg
=
nullptr
;
PyObject
*
p_res
=
nullptr
;
p_module
=
PyImport_ImportModule
(
kMultiProcModule
);
if
(
p_module
==
nullptr
)
{
MS_LOG
(
ERROR
)
<<
"Failed to import ["
<<
kMultiProcModule
<<
"]."
;
return
false
;
return
false
;
}
}
if
(
!
KernelBuildClient
::
Instance
().
AkgSendData
(
jsons
))
{
p_func
=
PyObject_GetAttrString
(
p_module
,
kCompileAkgKernelParallelFunc
);
MS_LOG
(
ERROR
)
<<
"Akg send data failed."
;
GenParallelCompileFuncArgs
(
jsons
,
&
p_arg
);
MS_LOG
(
DEBUG
)
<<
"Call function ["
<<
kCompileAkgKernelParallelFunc
<<
"], try to compile "
<<
jsons
.
size
()
<<
" Akg kernels parallelly."
;
p_res
=
PyEval_CallObject
(
p_func
,
p_arg
);
if
(
p_res
==
nullptr
)
{
PyErr_Print
();
MS_LOG
(
ERROR
)
<<
"No ret got, failed to call function ["
<<
kCompileAkgKernelParallelFunc
<<
"], args:
\n
("
<<
AkgKernelBuild
::
PyObjectToStr
(
p_arg
)
<<
")."
;
return
false
;
return
false
;
}
}
if
(
PyObject_IsTrue
(
p_res
)
!=
1
)
{
if
(
!
KernelBuildClient
::
Instance
().
AkgWait
())
{
PyErr_Print
();
MS_LOG
(
ERROR
)
<<
"Akg compile failed."
;
MS_LOG
(
ERROR
)
<<
"Illegal ret, failed to call function ["
<<
kCompileAkgKernelParallelFunc
<<
"], args:
\n
("
<<
AkgKernelBuild
::
PyObjectToStr
(
p_arg
)
<<
")."
;
return
false
;
return
false
;
}
}
...
...
mindspore/ccsrc/backend/kernel_compiler/tbe/tbe_kernel_parallel_build.cc
浏览文件 @
0154bdeb
...
@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
...
@@ -272,12 +272,12 @@ KernelModPtr ParallelBuildManager::GenKernelMod(const string &json_name, const s
}
}
int
ParallelBuildManager
::
StartCompileOp
(
const
nlohmann
::
json
&
kernel_json
)
{
int
ParallelBuildManager
::
StartCompileOp
(
const
nlohmann
::
json
&
kernel_json
)
{
return
KernelBuildClient
::
Instance
().
Start
(
kernel_json
.
dump
());
return
KernelBuildClient
::
Instance
().
Tbe
Start
(
kernel_json
.
dump
());
}
}
bool
ParallelBuildManager
::
WaitOne
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
)
{
bool
ParallelBuildManager
::
WaitOne
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
)
{
MS_EXCEPTION_IF_NULL
(
task_id
);
MS_EXCEPTION_IF_NULL
(
task_id
);
return
KernelBuildClient
::
Instance
().
Wait
(
task_id
,
task_result
,
pre_build_result
);
return
KernelBuildClient
::
Instance
().
Tbe
Wait
(
task_id
,
task_result
,
pre_build_result
);
}
}
void
ParallelBuildManager
::
ResetTaskInfo
()
{
void
ParallelBuildManager
::
ResetTaskInfo
()
{
...
@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() {
...
@@ -287,7 +287,7 @@ void ParallelBuildManager::ResetTaskInfo() {
}
}
task_map_
.
clear
();
task_map_
.
clear
();
same_op_list_
.
clear
();
same_op_list_
.
clear
();
KernelBuildClient
::
Instance
().
Reset
();
KernelBuildClient
::
Instance
().
Tbe
Reset
();
}
}
}
// namespace kernel
}
// namespace kernel
}
// namespace mindspore
}
// namespace mindspore
mindspore/ccsrc/backend/session/kernel_build_client.cc
浏览文件 @
0154bdeb
...
@@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
...
@@ -29,58 +29,106 @@ void ReplaceStr(std::string *dest, const std::string &replace, char new_char) {
}
}
}
}
int
KernelBuildClient
::
Start
(
const
std
::
string
&
json
)
{
int
KernelBuildClient
::
Tbe
Start
(
const
std
::
string
&
json
)
{
// Start compiling..
// Start compiling..
std
::
string
res
=
SendRequest
(
kSTART
);
auto
res
=
SendRequest
(
kTbeStart
);
if
(
res
!=
kA
CK
)
{
if
(
res
!=
kA
ck
)
{
MS_LOG
(
ERROR
)
<<
"START failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"START failed, res: "
<<
res
;
return
-
1
;
return
-
1
;
}
}
// Send the json data.
// Send the json data.
res
=
SendRequest
(
json
);
res
=
SendRequest
(
json
);
if
(
res
==
kF
AILED
)
{
if
(
res
==
kF
ailed
)
{
MS_LOG
(
ERROR
)
<<
"
START send data
failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"
TBE/START responds
failed, res: "
<<
res
;
return
-
1
;
return
-
1
;
}
}
// Return task id.
// Return task id.
return
std
::
stoi
(
res
);
return
std
::
stoi
(
res
);
}
}
bool
KernelBuildClient
::
Wait
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
)
{
bool
KernelBuildClient
::
Tbe
Wait
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
)
{
// Start waiting..
// Start waiting..
std
::
string
res
=
SendRequest
(
kWAIT
);
auto
res
=
SendRequest
(
kTbeWait
);
if
(
res
!=
kA
CK
)
{
if
(
res
!=
kA
ck
)
{
MS_LOG
(
ERROR
)
<<
"WAIT failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"
TBE/
WAIT failed, res: "
<<
res
;
return
false
;
return
false
;
}
}
// Request task id.
// Request task id.
*
task_id
=
std
::
stoi
(
SendRequest
(
kC
ONT
));
*
task_id
=
std
::
stoi
(
SendRequest
(
kC
ont
));
// Requst task result.
// Requst task result.
*
task_result
=
SendRequest
(
kC
ONT
);
*
task_result
=
SendRequest
(
kC
ont
);
// Request prebuild result.
// Request prebuild result.
*
pre_build_result
=
SendRequest
(
kC
ONT
);
*
pre_build_result
=
SendRequest
(
kC
ont
);
return
true
;
return
true
;
}
}
void
KernelBuildClient
::
Reset
()
{
void
KernelBuildClient
::
Tbe
Reset
()
{
// Start compiling..
// Start compiling..
std
::
string
res
=
SendRequest
(
kRESET
);
auto
res
=
SendRequest
(
kTbeReset
);
if
(
res
!=
kA
CK
)
{
if
(
res
!=
kA
ck
)
{
MS_LOG
(
EXCEPTION
)
<<
"RESET response is: "
<<
res
;
MS_LOG
(
EXCEPTION
)
<<
"
TBE/
RESET response is: "
<<
res
;
}
}
}
}
bool
KernelBuildClient
::
AkgStart
(
int
process_num
,
int
wait_time
)
{
// Start compiling..
auto
res
=
SendRequest
(
kAkgStart
);
if
(
res
!=
kAck
)
{
MS_LOG
(
ERROR
)
<<
"AKG/START failed, res: "
<<
res
;
return
false
;
}
std
::
string
process_num_str
=
std
::
to_string
(
process_num
);
res
=
SendRequest
(
process_num_str
);
if
(
res
!=
kAck
)
{
MS_LOG
(
ERROR
)
<<
"AKG/START(process_num) responds failed, res: "
<<
res
;
return
false
;
}
std
::
string
wait_time_str
=
std
::
to_string
(
wait_time
);
res
=
SendRequest
(
wait_time_str
);
if
(
res
!=
kAck
)
{
MS_LOG
(
ERROR
)
<<
"AKG/START(wait_time) responds failed, res: "
<<
res
;
return
false
;
}
return
true
;
}
bool
KernelBuildClient
::
AkgSendData
(
const
std
::
vector
<
std
::
string
>
&
jsons
)
{
auto
res
=
SendRequest
(
kAkgData
);
if
(
res
!=
kAck
)
{
MS_LOG
(
ERROR
)
<<
"AKG/DATA failed, res: "
<<
res
;
return
false
;
}
for
(
auto
&
json
:
jsons
)
{
res
=
SendRequest
(
json
);
if
(
res
!=
kAck
)
{
MS_LOG
(
ERROR
)
<<
"AKG/DATA.. responds failed, res: "
<<
res
<<
", when sending ["
<<
json
<<
"]"
;
return
false
;
}
}
return
true
;
}
// Fetch the result of AKG compiling.
bool
KernelBuildClient
::
AkgWait
()
{
auto
res
=
SendRequest
(
kAkgWait
);
if
(
res
!=
kTrue
)
{
MS_LOG
(
ERROR
)
<<
"AKG/WAIT failed, res: "
<<
res
;
return
false
;
}
return
true
;
}
std
::
string
KernelBuildClient
::
SelectFormat
(
const
std
::
string
&
json
)
{
std
::
string
KernelBuildClient
::
SelectFormat
(
const
std
::
string
&
json
)
{
// Start compiling..
// Start compiling..
std
::
string
res
=
SendRequest
(
kFORMAT
);
auto
res
=
SendRequest
(
kFormat
);
if
(
res
!=
kA
CK
)
{
if
(
res
!=
kA
ck
)
{
MS_LOG
(
ERROR
)
<<
"FORMAT failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"FORMAT failed, res: "
<<
res
;
return
""
;
return
""
;
}
}
// Send the json data.
// Send the json data.
res
=
SendRequest
(
json
);
res
=
SendRequest
(
json
);
if
(
res
==
kE
RR
)
{
if
(
res
==
kE
rr
)
{
MS_LOG
(
ERROR
)
<<
"FORMAT
send data
failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"FORMAT
responds
failed, res: "
<<
res
;
return
""
;
return
""
;
}
}
return
res
;
return
res
;
...
@@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) {
...
@@ -88,15 +136,15 @@ std::string KernelBuildClient::SelectFormat(const std::string &json) {
bool
KernelBuildClient
::
CheckSupported
(
const
std
::
string
&
json
)
{
bool
KernelBuildClient
::
CheckSupported
(
const
std
::
string
&
json
)
{
// Checking support..
// Checking support..
std
::
string
res
=
SendRequest
(
kSUPPORT
);
auto
res
=
SendRequest
(
kSupport
);
if
(
res
!=
kA
CK
)
{
if
(
res
!=
kA
ck
)
{
MS_LOG
(
ERROR
)
<<
"SUPPORT failed, res: "
<<
res
;
MS_LOG
(
ERROR
)
<<
"SUPPORT failed, res: "
<<
res
;
return
false
;
return
false
;
}
}
// Send the json data.
// Send the json data.
res
=
SendRequest
(
json
);
res
=
SendRequest
(
json
);
if
(
res
!=
kT
RUE
)
{
if
(
res
!=
kT
rue
)
{
MS_LOG
(
ERROR
)
<<
"SUPPORT send data
failed, res: "
<<
res
;
MS_LOG
(
INFO
)
<<
"SUPPORT responds
failed, res: "
<<
res
;
return
false
;
return
false
;
}
}
return
true
;
return
true
;
...
...
mindspore/ccsrc/backend/session/kernel_build_client.h
浏览文件 @
0154bdeb
...
@@ -17,6 +17,7 @@
...
@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#ifndef MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#define MINDSPORE_CCSRC_BACKEND_SESSION_KERNEL_BUILD_CLIENT_H_
#include <vector>
#include <string>
#include <string>
#include <cstring>
#include <cstring>
#include <cstdlib>
#include <cstdlib>
...
@@ -43,23 +44,26 @@ class KernelBuildClient {
...
@@ -43,23 +44,26 @@ class KernelBuildClient {
"
\"
"
;
"
\"
"
;
// Receive the response from server
// Receive the response from server
constexpr
inline
static
auto
kA
CK
=
"ACK"
;
constexpr
inline
static
auto
kA
ck
=
"ACK"
;
constexpr
inline
static
auto
kE
RR
=
"ERR"
;
constexpr
inline
static
auto
kE
rr
=
"ERR"
;
constexpr
inline
static
auto
kF
AILED
=
"-1"
;
constexpr
inline
static
auto
kF
ailed
=
"-1"
;
// Send Finish request to server
// Send Finish request to server
constexpr
inline
static
auto
kF
IN
=
"FIN"
;
constexpr
inline
static
auto
kF
in
=
"FIN"
;
// Send building request to server
// Send building request to server
constexpr
inline
static
auto
kSTART
=
"START"
;
constexpr
inline
static
auto
kTbeStart
=
"TBE/START"
;
constexpr
inline
static
auto
kWAIT
=
"WAIT"
;
constexpr
inline
static
auto
kTbeWait
=
"TBE/WAIT"
;
constexpr
inline
static
auto
kCONT
=
"CONT"
;
constexpr
inline
static
auto
kCont
=
"CONT"
;
constexpr
inline
static
auto
kSUCCESS
=
"Success"
;
constexpr
inline
static
auto
kSuccess
=
"Success"
;
constexpr
inline
static
auto
kRESET
=
"RESET"
;
constexpr
inline
static
auto
kTbeReset
=
"TBE/RESET"
;
constexpr
inline
static
auto
kAkgStart
=
"AKG/START"
;
constexpr
inline
static
auto
kAkgData
=
"AKG/DATA"
;
constexpr
inline
static
auto
kAkgWait
=
"AKG/WAIT"
;
// Send server info. query to server
// Send server info. query to server
constexpr
inline
static
auto
kF
ORMAT
=
"FORMAT"
;
constexpr
inline
static
auto
kF
ormat
=
"FORMAT"
;
constexpr
inline
static
auto
kS
UPPORT
=
"SUPPORT"
;
constexpr
inline
static
auto
kS
upport
=
"SUPPORT"
;
constexpr
inline
static
auto
kT
RUE
=
"True"
;
constexpr
inline
static
auto
kT
rue
=
"True"
;
// Revert \n, \r, [space].
// Revert \n, \r, [space].
constexpr
inline
static
auto
kLF
=
"[LF]"
;
constexpr
inline
static
auto
kLF
=
"[LF]"
;
...
@@ -67,7 +71,7 @@ class KernelBuildClient {
...
@@ -67,7 +71,7 @@ class KernelBuildClient {
constexpr
inline
static
auto
kSP
=
"[SP]"
;
constexpr
inline
static
auto
kSP
=
"[SP]"
;
// The TAG as prefix of real command from remote.
// The TAG as prefix of real command from remote.
constexpr
inline
static
auto
kT
AG
=
"[~]"
;
constexpr
inline
static
auto
kT
ag
=
"[~]"
;
constexpr
inline
static
int
kBufferSize
=
4096
;
constexpr
inline
static
int
kBufferSize
=
4096
;
constexpr
inline
static
unsigned
int
kTimeOutSeconds
=
20
;
constexpr
inline
static
unsigned
int
kTimeOutSeconds
=
20
;
...
@@ -87,7 +91,7 @@ class KernelBuildClient {
...
@@ -87,7 +91,7 @@ class KernelBuildClient {
std
::
string
result
;
std
::
string
result
;
char
buf
[
kBufferSize
];
char
buf
[
kBufferSize
];
while
(
std
::
fgets
(
buf
,
sizeof
(
buf
),
fpipe
)
!=
nullptr
)
{
while
(
std
::
fgets
(
buf
,
sizeof
(
buf
),
fpipe
)
!=
nullptr
)
{
if
(
std
::
strncmp
(
buf
,
kT
AG
,
std
::
strlen
(
kTAG
))
==
0
)
{
if
(
std
::
strncmp
(
buf
,
kT
ag
,
std
::
strlen
(
kTag
))
==
0
)
{
start
=
true
;
start
=
true
;
}
}
// Filter with 'kTAG' and '\n'
// Filter with 'kTAG' and '\n'
...
@@ -105,7 +109,7 @@ class KernelBuildClient {
...
@@ -105,7 +109,7 @@ class KernelBuildClient {
if
(
result
.
empty
()
||
result
.
rfind
(
py_suffix
)
!=
(
result
.
length
()
-
py_suffix
.
length
()))
{
if
(
result
.
empty
()
||
result
.
rfind
(
py_suffix
)
!=
(
result
.
length
()
-
py_suffix
.
length
()))
{
MS_LOG
(
EXCEPTION
)
<<
"py file seems incorrect, result: {"
<<
result
<<
"}"
;
MS_LOG
(
EXCEPTION
)
<<
"py file seems incorrect, result: {"
<<
result
<<
"}"
;
}
}
result
=
result
.
substr
(
strlen
(
kT
AG
));
result
=
result
.
substr
(
strlen
(
kT
ag
));
MS_LOG
(
DEBUG
)
<<
"result: "
<<
result
;
MS_LOG
(
DEBUG
)
<<
"result: "
<<
result
;
return
result
;
return
result
;
}
}
...
@@ -115,7 +119,7 @@ class KernelBuildClient {
...
@@ -115,7 +119,7 @@ class KernelBuildClient {
// Exception's thrown if open failed
// Exception's thrown if open failed
if
(
dp_
->
Open
({
kEnv
,
GetScriptPath
()},
true
)
!=
-
1
)
{
if
(
dp_
->
Open
({
kEnv
,
GetScriptPath
()},
true
)
!=
-
1
)
{
dp_
->
SetTimeOutSeconds
(
kTimeOutSeconds
);
dp_
->
SetTimeOutSeconds
(
kTimeOutSeconds
);
dp_
->
SetTimeOutCallback
([
this
]()
{
SendRequest
(
kF
IN
);
});
dp_
->
SetTimeOutCallback
([
this
]()
{
SendRequest
(
kF
in
);
});
init_
=
true
;
init_
=
true
;
}
}
}
}
...
@@ -146,13 +150,13 @@ class KernelBuildClient {
...
@@ -146,13 +150,13 @@ class KernelBuildClient {
std
::
string
res
;
std
::
string
res
;
*
dp_
>>
res
;
*
dp_
>>
res
;
// Filter out the interference
// Filter out the interference
auto
start
=
res
.
find
(
kT
AG
);
auto
start
=
res
.
find
(
kT
ag
);
if
(
start
==
std
::
string
::
npos
)
{
if
(
start
==
std
::
string
::
npos
)
{
MS_LOG
(
EXCEPTION
)
<<
"Response seems incorrect, res: "
<<
res
;
MS_LOG
(
EXCEPTION
)
<<
"Response seems incorrect, res: "
<<
res
;
}
}
res
=
res
.
substr
(
start
+
std
::
strlen
(
kT
AG
),
res
.
size
()
-
start
);
res
=
res
.
substr
(
start
+
std
::
strlen
(
kT
ag
),
res
.
size
()
-
start
);
// Revert the line feed and space
// Revert the line feed and space
if
(
res
!=
kS
UCCESS
&&
res
!=
kACK
&&
res
!=
kERR
&&
res
!=
kTRUE
)
{
if
(
res
!=
kS
uccess
&&
res
!=
kAck
&&
res
!=
kErr
&&
res
!=
kTrue
)
{
ReplaceStr
(
&
res
,
kLF
,
'\n'
);
ReplaceStr
(
&
res
,
kLF
,
'\n'
);
ReplaceStr
(
&
res
,
kSP
,
' '
);
ReplaceStr
(
&
res
,
kSP
,
' '
);
}
}
...
@@ -164,10 +168,15 @@ class KernelBuildClient {
...
@@ -164,10 +168,15 @@ class KernelBuildClient {
std
::
string
SelectFormat
(
const
std
::
string
&
json
);
std
::
string
SelectFormat
(
const
std
::
string
&
json
);
bool
CheckSupported
(
const
std
::
string
&
json
);
bool
CheckSupported
(
const
std
::
string
&
json
);
// Run building.
// Run TBE building.
int
Start
(
const
std
::
string
&
json
);
int
TbeStart
(
const
std
::
string
&
json
);
bool
Wait
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
);
bool
TbeWait
(
int
*
task_id
,
std
::
string
*
task_result
,
std
::
string
*
pre_build_result
);
void
Reset
();
void
TbeReset
();
// Run AKG building.
bool
AkgStart
(
int
process_num
,
int
wait_time
);
bool
AkgSendData
(
const
std
::
vector
<
std
::
string
>
&
jsons
);
bool
AkgWait
();
KernelBuildClient
(
const
KernelBuildClient
&
)
=
delete
;
KernelBuildClient
(
const
KernelBuildClient
&
)
=
delete
;
KernelBuildClient
&
operator
=
(
const
KernelBuildClient
&
)
=
delete
;
KernelBuildClient
&
operator
=
(
const
KernelBuildClient
&
)
=
delete
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录