Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ebff566a
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ebff566a
编写于
9月 04, 2020
作者:
K
kswang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add group operation for executor
上级
bc4c5afc
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
148 addition
and
82 deletion
+148
-82
mindspore/ccsrc/backend/session/executor.cc
mindspore/ccsrc/backend/session/executor.cc
+35
-58
mindspore/ccsrc/backend/session/executor.h
mindspore/ccsrc/backend/session/executor.h
+37
-8
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+3
-3
mindspore/ccsrc/backend/session/session_basic.h
mindspore/ccsrc/backend/session/session_basic.h
+2
-1
mindspore/ccsrc/frontend/parallel/group_manager.cc
mindspore/ccsrc/frontend/parallel/group_manager.cc
+27
-11
mindspore/ccsrc/frontend/parallel/group_manager.h
mindspore/ccsrc/frontend/parallel/group_manager.h
+1
-0
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
+43
-1
未找到文件。
mindspore/ccsrc/backend/session/executor.cc
浏览文件 @
ebff566a
...
...
@@ -16,6 +16,7 @@
#include "backend/session/executor.h"
#include "runtime/device/kernel_runtime_manager.h"
#include "backend/session/executor_manager.h"
#include "utils/comm_manager.h"
namespace
mindspore
{
namespace
session
{
...
...
@@ -45,32 +46,6 @@ void UpdateOutputTensors(VectorRef *outputs,
}
}
}
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
)
{
if
(
utils
::
isa
<
VectorRef
>
(
base_ref
))
{
auto
ref_list
=
utils
::
cast
<
VectorRef
>
(
base_ref
);
py
::
tuple
output_tensors
(
ref_list
.
size
());
for
(
size_t
i
=
0
;
i
<
ref_list
.
size
();
++
i
)
{
auto
output
=
TransformBaseRefListToTuple
(
ref_list
[
i
]);
// use pyObjectRef
if
(
utils
::
isa
<
tensor
::
TensorPtr
>
(
output
))
{
auto
tensor_ptr
=
utils
::
cast
<
tensor
::
TensorPtr
>
(
output
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
output_tensors
[
i
]
=
tensor_ptr
;
}
else
if
(
utils
::
isa
<
PyObjectRef
>
(
output
))
{
py
::
object
obj
=
utils
::
cast
<
PyObjectRef
>
(
output
).
object_
;
py
::
tuple
tensor_tuple
=
py
::
cast
<
py
::
tuple
>
(
obj
);
output_tensors
[
i
]
=
tensor_tuple
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a base ref list or a tensor!"
;
}
}
return
output_tensors
;
// turn tuple to py::object and store in PyObjectRef
}
else
if
(
utils
::
isa
<
tensor
::
TensorPtr
>
(
base_ref
))
{
return
base_ref
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a base ref list or a tensor!"
;
}
}
}
// namespace
void
CompileNodesTask
::
Run
()
{
MS_EXCEPTION_IF_NULL
(
session_
);
...
...
@@ -104,6 +79,10 @@ void RunOpTask::Run() {
session_
->
RunOp
(
*
op_run_info_
,
graph_info_
,
input_tensors_
,
&
outputs_
);
}
void
CreateCommGroupTask
::
Run
()
{
result_
=
CommManager
::
GetInstance
().
CreateGroupSync
(
group_name_
,
ranks_
);
}
void
DestroyCommGroupTask
::
Run
()
{
result_
=
CommManager
::
GetInstance
().
DestroyGroup
(
group_name_
);
}
Executor
::
Executor
(
const
std
::
string
&
device_name
,
uint32_t
device_id
)
{
device_name_
=
device_name
;
device_id_
=
device_id
;
...
...
@@ -141,22 +120,8 @@ void Executor::WorkerLoop() {
}
catch
(
const
std
::
exception
&
e
)
{
exception_ptr_
=
std
::
current_exception
();
}
auto
task_type
=
task
->
type_
;
task
=
nullptr
;
if
(
task_type
==
kCompileNodes
)
{
compile_cond_var_
.
notify_all
();
}
else
if
(
task_type
==
kCompileGraph
)
{
compile_cond_var_
.
notify_all
();
}
else
if
(
task_type
==
kBuildGraph
)
{
build_cond_var_
.
notify_all
();
}
else
if
(
task_type
==
kRunGraph
)
{
run_cond_var_
.
notify_all
();
}
else
if
(
task_type
==
kBuildOp
)
{
build_op_cond_var_
.
notify_all
();
}
else
if
(
task_type
==
kRunOp
)
{
run_op_cond_var_
.
notify_all
();
}
sync_cond_var_
.
notify_all
();
}
}
...
...
@@ -206,7 +171,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, const AnfNodePtrL
task
->
output_nodes_
=
outputs
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
compile
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
return
task
->
graph_id_
;
}
...
...
@@ -219,7 +184,7 @@ GraphId Executor::CompileGraphAsync(const SessionPtr &session, NotNull<FuncGraph
task
->
func_graph_
=
func_graph
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
compile
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
return
task
->
graph_id_
;
}
...
...
@@ -232,7 +197,7 @@ void Executor::BuildGraphAsync(const SessionPtr &session, GraphId graphId) {
task
->
graph_id_
=
graphId
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
build
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
}
...
...
@@ -258,7 +223,7 @@ void Executor::RunGraphAsync(const SessionPtr &session, const GraphId &graph_id,
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
py
::
gil_scoped_release
release
;
run
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
}
...
...
@@ -274,12 +239,12 @@ void Executor::BuildOpAsync(const SessionPtr &session, OpRunInfo *op_run_info, c
task
->
tensors_mask_
=
tensors_mask
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
build_op
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
}
py
::
tuple
Executor
::
RunOpAsync
(
const
SessionPtr
&
session
,
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensor
s
)
{
void
Executor
::
RunOpAsync
(
const
SessionPtr
&
session
,
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
VectorRef
*
output
s
)
{
CheckException
();
std
::
unique_lock
<
std
::
mutex
>
lock
(
task_mutex_
);
auto
task
=
std
::
make_shared
<
RunOpTask
>
();
...
...
@@ -289,18 +254,30 @@ py::tuple Executor::RunOpAsync(const SessionPtr &session, OpRunInfo *op_run_info
task
->
input_tensors_
=
input_tensors
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
run_op
_cond_var_
.
wait
(
lock
);
sync
_cond_var_
.
wait
(
lock
);
CheckException
();
*
outputs
=
task
->
outputs_
;
}
// Trans output to tuple
auto
output_tensors
=
TransformBaseRefListToTuple
(
task
->
outputs_
);
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
!
py
::
isinstance
<
py
::
tuple
>
(
utils
::
cast
<
PyObjectRef
>
(
output_tensors
).
object_
))
{
MS_EXCEPTION
(
NotSupportError
)
<<
"The output tensors should be a tuple !"
;
}
py
::
object
tuple_obj
=
utils
::
cast
<
PyObjectRef
>
(
output_tensors
).
object_
;
py
::
tuple
tuple_tensors
=
py
::
cast
<
py
::
tuple
>
(
tuple_obj
);
return
tuple_tensors
;
bool
Executor
::
CreateCommGroup
(
const
std
::
string
&
group_name
,
std
::
vector
<
uint32_t
>
ranks
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
task_mutex_
);
auto
task
=
std
::
make_shared
<
CreateCommGroupTask
>
();
task
->
group_name_
=
group_name
;
task
->
ranks_
=
ranks
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
sync_cond_var_
.
wait
(
lock
);
return
task
->
result_
;
}
bool
Executor
::
DestroyCommGroup
(
const
std
::
string
&
group_name
)
{
std
::
unique_lock
<
std
::
mutex
>
lock
(
task_mutex_
);
auto
task
=
std
::
make_shared
<
DestroyCommGroupTask
>
();
task
->
group_name_
=
group_name
;
ready_tasks_
.
push
(
task
);
task_cond_var_
.
notify_all
();
sync_cond_var_
.
wait
(
lock
);
return
task
->
result_
;
}
void
Executor
::
StopWorker
()
{
...
...
mindspore/ccsrc/backend/session/executor.h
浏览文件 @
ebff566a
...
...
@@ -32,10 +32,22 @@
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/contract.h"
#include "utils/comm_manager.h"
namespace
mindspore
{
namespace
session
{
enum
TaskType
{
kUnKnown
,
kExit
,
kCompileNodes
,
kCompileGraph
,
kBuildGraph
,
kBuildOp
,
kRunGraph
,
kRunOp
};
enum
TaskType
{
kUnKnown
,
kExit
,
kCompileNodes
,
kCompileGraph
,
kBuildGraph
,
kBuildOp
,
kRunGraph
,
kRunOp
,
kCreateCommGroup
,
kDestroyCommGroup
};
class
Task
{
public:
...
...
@@ -106,6 +118,25 @@ class RunOpTask : public Task {
VectorRef
outputs_
;
};
class
CreateCommGroupTask
:
public
Task
{
public:
CreateCommGroupTask
()
{
type_
=
kCreateCommGroup
;
}
~
CreateCommGroupTask
()
override
=
default
;
void
Run
()
override
;
std
::
string
group_name_
;
std
::
vector
<
uint32_t
>
ranks_
;
bool
result_
;
};
class
DestroyCommGroupTask
:
public
Task
{
public:
DestroyCommGroupTask
()
{
type_
=
kDestroyCommGroup
;
}
~
DestroyCommGroupTask
()
override
=
default
;
void
Run
()
override
;
std
::
string
group_name_
;
bool
result_
;
};
class
ExitTask
:
public
Task
{
public:
ExitTask
()
{
type_
=
kExit
;
}
...
...
@@ -125,9 +156,11 @@ class Executor {
VectorRef
*
outputs
);
void
BuildOpAsync
(
const
SessionPtr
&
session
,
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
int
>
&
tensors_mask
);
py
::
tuple
RunOpAsync
(
const
SessionPtr
&
session
,
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensor
s
);
void
RunOpAsync
(
const
SessionPtr
&
session
,
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
VectorRef
*
output
s
);
void
OnRunGraphFinished
();
bool
CreateCommGroup
(
const
std
::
string
&
group_name
,
std
::
vector
<
uint32_t
>
ranks
);
bool
DestroyCommGroup
(
const
std
::
string
&
group_name
);
private:
void
UpdateOutputTensors
(
VectorRef
*
outputs
,
...
...
@@ -143,11 +176,7 @@ class Executor {
std
::
mutex
task_mutex_
;
std
::
mutex
pending_task_mutex_
;
std
::
condition_variable
task_cond_var_
;
std
::
condition_variable
compile_cond_var_
;
std
::
condition_variable
build_cond_var_
;
std
::
condition_variable
run_cond_var_
;
std
::
condition_variable
build_op_cond_var_
;
std
::
condition_variable
run_op_cond_var_
;
std
::
condition_variable
sync_cond_var_
;
std
::
queue
<
std
::
shared_ptr
<
Task
>>
ready_tasks_
;
std
::
list
<
std
::
shared_ptr
<
RunGraphTask
>>
pending_tasks_
;
std
::
shared_ptr
<
std
::
thread
>
worker_
;
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
ebff566a
...
...
@@ -1344,10 +1344,10 @@ void SessionBasic::BuildOpAsync(OpRunInfo *op_run_info, const GraphInfo &graph_i
executor_
->
BuildOpAsync
(
shared_from_this
(),
op_run_info
,
graph_info
,
input_tensors
,
tensors_mask
);
}
py
::
tuple
SessionBasic
::
RunOpAsync
(
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensor
s
)
{
void
SessionBasic
::
RunOpAsync
(
OpRunInfo
*
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
VectorRef
*
output
s
)
{
MS_EXCEPTION_IF_NULL
(
executor_
);
return
executor_
->
RunOpAsync
(
shared_from_this
(),
op_run_info
,
graph_info
,
input_tensor
s
);
executor_
->
RunOpAsync
(
shared_from_this
(),
op_run_info
,
graph_info
,
input_tensors
,
output
s
);
}
void
SessionBasic
::
RunGraphAsync
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
...
...
mindspore/ccsrc/backend/session/session_basic.h
浏览文件 @
ebff566a
...
...
@@ -90,7 +90,8 @@ class SessionBasic : public std::enable_shared_from_this<SessionBasic> {
void
RunGraphAsync
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
);
void
BuildOpAsync
(
OpRunInfo
*
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
int
>
&
tensors_mask
);
py
::
tuple
RunOpAsync
(
OpRunInfo
*
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
);
void
RunOpAsync
(
OpRunInfo
*
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
VectorRef
*
outputs
);
virtual
void
RegisterSummaryCallBackFunc
(
const
CallBackFunc
&
callback
);
...
...
mindspore/ccsrc/frontend/parallel/group_manager.cc
浏览文件 @
ebff566a
...
...
@@ -15,12 +15,12 @@
*/
#include "frontend/parallel/group_manager.h"
#include <algorithm>
#include <vector>
#include "frontend/parallel/device_manager.h"
#include "backend/session/executor_manager.h"
#include "utils/comm_manager.h"
#include "utils/ms_context.h"
namespace
mindspore
{
namespace
parallel
{
...
...
@@ -96,8 +96,14 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
vector
<
uint32_t
>
ranks
;
(
void
)
std
::
transform
(
std
::
begin
(
devices
),
std
::
end
(
devices
),
std
::
back_inserter
(
ranks
),
[](
const
Device
dev
)
{
return
(
uint32_t
)
dev
.
rank
();
});
// Create group through the CommManager interface
bool
ret
=
CommManager
::
GetInstance
().
CreateGroupSync
(
group_name
,
ranks
);
// Create group through the executor
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
device_name
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
uint32_t
device_id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
executor
=
session
::
ExecutorManager
::
Instance
().
GetExecutor
(
device_name
,
device_id
);
MS_EXCEPTION_IF_NULL
(
executor
);
bool
ret
=
executor
->
CreateCommGroup
(
group_name
,
ranks
);
if
(
!
ret
)
{
MS_LOG
(
ERROR
)
<<
"Create group failed, group name is "
<<
group_name
;
return
Status
::
FAILED
;
...
...
@@ -108,6 +114,20 @@ Status GroupManager::CreateGroup(const std::string &group_name, const std::vecto
}
}
Status
GroupManager
::
DestroyGroup
(
const
std
::
string
&
group_name
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
std
::
string
device_name
=
context_ptr
->
get_param
<
std
::
string
>
(
MS_CTX_DEVICE_TARGET
);
uint32_t
device_id
=
context_ptr
->
get_param
<
uint32_t
>
(
MS_CTX_DEVICE_ID
);
auto
executor
=
session
::
ExecutorManager
::
Instance
().
GetExecutor
(
device_name
,
device_id
);
MS_EXCEPTION_IF_NULL
(
executor
);
bool
ret
=
executor
->
DestroyCommGroup
(
group_name
);
if
(
!
ret
)
{
return
Status
::
FAILED
;
}
return
Status
::
SUCCESS
;
}
Status
GroupManager
::
DestroyGroup
(
mindspore
::
parallel
::
Group
*
const
group
)
{
std
::
string
name
=
(
*
group
).
name
();
auto
it
=
groups_
.
find
(
name
);
...
...
@@ -116,18 +136,14 @@ Status GroupManager::DestroyGroup(mindspore::parallel::Group *const group) {
return
Status
::
FAILED
;
}
(
void
)
groups_
.
erase
(
it
);
bool
ret
=
CommManager
::
GetInstance
().
DestroyGroup
(
name
);
if
(
!
ret
)
{
return
Status
::
FAILED
;
}
return
Status
::
SUCCESS
;
return
DestroyGroup
(
name
);
}
Status
GroupManager
::
DestroyAllGroups
()
{
for
(
auto
&
it
:
groups_
)
{
std
::
string
name
=
it
.
first
;
bool
ret
=
CommManager
::
GetInstance
().
DestroyGroup
(
name
);
if
(
!
ret
)
{
auto
ret
=
DestroyGroup
(
name
);
if
(
ret
!=
Status
::
SUCCESS
)
{
return
Status
::
FAILED
;
}
}
...
...
mindspore/ccsrc/frontend/parallel/group_manager.h
浏览文件 @
ebff566a
...
...
@@ -65,6 +65,7 @@ class GroupManager {
void
Clear
();
private:
Status
DestroyGroup
(
const
std
::
string
&
group_name
);
// the key is group name (name_)
std
::
map
<
std
::
string
,
Group
>
groups_
;
std
::
string
world_group_
;
...
...
mindspore/ccsrc/pipeline/pynative/pynative_execute.cc
浏览文件 @
ebff566a
...
...
@@ -19,18 +19,22 @@
#include <typeinfo>
#include <map>
#include <set>
#include <memory>
#include <unordered_set>
#include <algorithm>
#include "debug/trace.h"
#include "pybind_api/ir/tensor_py.h"
#include "ir/param_info.h"
#include "ir/anf.h"
#include "ir/tensor.h"
#include "utils/any.h"
#include "utils/utils.h"
#include "utils/ms_context.h"
#include "utils/context/context_extends.h"
#include "utils/config_manager.h"
#include "utils/convert_utils_py.h"
#include "utils/base_ref_extends.h"
#include "frontend/operator/ops.h"
#include "frontend/operator/composite/composite.h"
#include "frontend/operator/composite/do_signature.h"
...
...
@@ -554,6 +558,32 @@ void EraseValueNodeTensor(const std::vector<int> &tensors_mask, std::vector<tens
*
input_tensors
=
new_input_tensors
;
}
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
)
{
if
(
utils
::
isa
<
VectorRef
>
(
base_ref
))
{
auto
ref_list
=
utils
::
cast
<
VectorRef
>
(
base_ref
);
py
::
tuple
output_tensors
(
ref_list
.
size
());
for
(
size_t
i
=
0
;
i
<
ref_list
.
size
();
++
i
)
{
auto
output
=
TransformBaseRefListToTuple
(
ref_list
[
i
]);
if
(
utils
::
isa
<
tensor
::
TensorPtr
>
(
output
))
{
auto
tensor_ptr
=
utils
::
cast
<
tensor
::
TensorPtr
>
(
output
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
output_tensors
[
i
]
=
tensor_ptr
;
}
else
if
(
utils
::
isa
<
PyObjectRef
>
(
output
))
{
py
::
object
obj
=
utils
::
cast
<
PyObjectRef
>
(
output
).
object_
;
py
::
tuple
tensor_tuple
=
py
::
cast
<
py
::
tuple
>
(
obj
);
output_tensors
[
i
]
=
tensor_tuple
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a base ref list or a tensor!"
;
}
}
return
std
::
make_shared
<
PyObjectRef
>
(
output_tensors
);
}
else
if
(
utils
::
isa
<
tensor
::
TensorPtr
>
(
base_ref
))
{
return
base_ref
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"The output is not a base ref list or a tensor!"
;
}
}
py
::
object
RunOpInMs
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
)
{
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
MS_LOG
(
INFO
)
<<
"Start run op["
<<
op_exec_info
->
op_name
<<
"] with backend policy ms"
;
...
...
@@ -577,7 +607,19 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
std
::
string
graph_info
=
GetSingleOpGraphInfo
(
op_exec_info
,
input_tensors
);
session
->
BuildOpAsync
(
op_exec_info
.
get
(),
graph_info
,
input_tensors
,
tensors_mask
);
EraseValueNodeTensor
(
tensors_mask
,
&
input_tensors
);
py
::
tuple
result
=
session
->
RunOpAsync
(
op_exec_info
.
get
(),
graph_info
,
input_tensors
);
VectorRef
outputs
;
session
->
RunOpAsync
(
op_exec_info
.
get
(),
graph_info
,
input_tensors
,
&
outputs
);
// Trans output to tuple
auto
output_tensors
=
TransformBaseRefListToTuple
(
outputs
);
if
(
!
utils
::
isa
<
PyObjectRef
>
(
output_tensors
)
||
!
py
::
isinstance
<
py
::
tuple
>
(
utils
::
cast
<
PyObjectRef
>
(
output_tensors
).
object_
))
{
MS_EXCEPTION
(
NotSupportError
)
<<
"The output tensors should be a tuple !"
;
}
py
::
object
tuple_obj
=
utils
::
cast
<
PyObjectRef
>
(
output_tensors
).
object_
;
py
::
tuple
result
=
py
::
cast
<
py
::
tuple
>
(
tuple_obj
);
ms_context
->
set_param
<
bool
>
(
MS_CTX_ENABLE_PYNATIVE_INFER
,
false
);
*
status
=
PYNATIVE_SUCCESS
;
MS_LOG
(
INFO
)
<<
"End run op["
<<
op_exec_info
->
op_name
<<
"] with backend policy ms"
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录