Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
90dfbab3
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
90dfbab3
编写于
4月 26, 2020
作者:
M
mindspore-ci-bot
提交者:
Gitee
4月 26, 2020
浏览文件
操作
浏览文件
下载
差异文件
!501 Refactor PyNative Excution Get Input Tensors
Merge pull request !501 from chujinjin/abstract_input_tensor
上级
348b0ef5
13bf42ba
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
135 addition
and
133 deletion
+135
-133
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+116
-2
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+3
-3
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+1
-1
mindspore/ccsrc/session/gpu_session.cc
mindspore/ccsrc/session/gpu_session.cc
+2
-3
mindspore/ccsrc/session/gpu_session.h
mindspore/ccsrc/session/gpu_session.h
+1
-1
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+8
-121
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+4
-2
未找到文件。
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
90dfbab3
...
...
@@ -30,7 +30,8 @@
#include "pipeline/parse/data_converter.h"
#include "pipeline/static_analysis/prim.h"
#include "session/session_factory.h"
#include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h"
#include "pynative/base.h"
#ifdef ENABLE_GE
...
...
@@ -188,6 +189,117 @@ py::object RunOpInVM(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
return
std
::
move
(
result
);
}
bool
RunOpConvertConstInputToAttr
(
const
py
::
object
&
input_object
,
size_t
input_index
,
const
PrimitivePtr
&
op_prim
,
const
std
::
unordered_set
<
size_t
>
&
input_attrs
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
auto
input_names_value
=
op_prim
->
GetAttr
(
kAttrInputNames
);
if
(
input_names_value
==
nullptr
)
{
return
false
;
}
auto
input_names_vec
=
GetValue
<
std
::
vector
<
std
::
string
>>
(
input_names_value
);
if
(
input_index
>=
input_names_vec
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input index: "
<<
input_index
<<
" is large than the input names vector size!"
;
}
if
(
input_attrs
.
find
(
input_index
)
!=
input_attrs
.
end
())
{
ValuePtr
value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
input_name
=
input_names_vec
[
input_index
];
op_prim
->
set_attr
(
input_name
,
value
);
return
true
;
}
return
false
;
}
void
PlantTensorTupleToVector
(
const
py
::
tuple
&
tuple_inputs
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
for
(
const
auto
&
input_object
:
tuple_inputs
)
{
if
(
!
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a tensor!"
;
}
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
MS_EXCEPTION_IF_NULL
(
tensor
);
input_tensor
->
push_back
(
tensor
);
}
op_prim
->
set_attr
(
kAttrDynInputSizes
,
MakeValue
(
std
::
vector
<
int
>
{
SizeToInt
(
tuple_inputs
.
size
())}));
}
void
ConvertValueTupleToTensor
(
const
py
::
object
&
input_object
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
input_tensor
);
ValuePtr
input_value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
input_value
);
if
(
!
input_value
->
isa
<
ValueTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a value tuple!"
;
}
auto
value_tuple
=
input_value
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
tensor
::
TensorPtr
tensor_ptr
=
opt
::
CreateTupleTensor
(
value_tuple
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConvertPyObjectToTensor
(
const
py
::
object
&
input_object
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
tensor
::
TensorPtr
tensor_ptr
=
nullptr
;
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
}
else
if
(
py
::
isinstance
<
py
::
float_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
float_
>
(
input_object
),
kFloat32
);
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
int_
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
list
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
list
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
array
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
array
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
tuple
>
(
input_object
))
{
auto
tuple_inputs
=
py
::
cast
<
py
::
tuple
>
(
input_object
);
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
tuple_inputs
[
0
]))
{
PlantTensorTupleToVector
(
tuple_inputs
,
op_prim
,
input_tensor
);
}
else
{
ConvertValueTupleToTensor
(
input_object
,
input_tensor
);
}
return
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Run op inputs type is invalid!"
;
}
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConstructInputTensor
(
const
OpExecInfoPtr
&
op_run_info
,
std
::
vector
<
bool
>
*
tensors_mask
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
tensors_mask
);
MS_EXCEPTION_IF_NULL
(
input_tensors
);
PrimitivePtr
op_prim
=
op_run_info
->
py_primitive
;
MS_EXCEPTION_IF_NULL
(
op_prim
);
if
(
op_run_info
->
op_inputs
.
size
()
!=
op_run_info
->
inputs_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Op input size "
<<
op_run_info
->
op_inputs
.
size
()
<<
" should be equal to op input mask size "
<<
op_run_info
->
inputs_mask
.
size
();
}
opt
::
ConstInputToAttrInfoRegister
reg
;
bool
reg_exist
=
opt
::
ConstInputToAttrInfoRegistry
::
Instance
().
GetRegisterByOpName
(
op_run_info
->
op_name
,
&
reg
);
size_t
input_num
=
op_run_info
->
op_inputs
.
size
();
MS_LOG
(
INFO
)
<<
"py input size: "
<<
input_num
;
for
(
size_t
index
=
0
;
index
<
input_num
;
++
index
)
{
// convert const input to attr
if
(
reg_exist
&&
RunOpConvertConstInputToAttr
(
op_run_info
->
op_inputs
[
index
],
index
,
op_prim
,
reg
.
GetConstInputAttrInfo
()))
{
continue
;
}
// convert const and tuple input to tensor
ConvertPyObjectToTensor
(
op_run_info
->
op_inputs
[
index
],
op_prim
,
input_tensors
);
// make tensors, weight : 1, data : 0
std
::
vector
<
bool
>
new_mask
(
input_tensors
->
size
()
-
tensors_mask
->
size
(),
py
::
cast
<
bool
>
(
op_run_info
->
inputs_mask
[
index
]));
tensors_mask
->
insert
(
tensors_mask
->
end
(),
new_mask
.
begin
(),
new_mask
.
end
());
}
}
py
::
object
RunOpInMs
(
const
OpExecInfoPtr
&
op_exec_info
,
PynativeStatusCode
*
status
)
{
MS_EXCEPTION_IF_NULL
(
op_exec_info
);
MS_LOG
(
INFO
)
<<
"Start run op["
<<
op_exec_info
->
op_name
<<
"] with backend policy ms"
;
...
...
@@ -204,7 +316,9 @@ py::object RunOpInMs(const OpExecInfoPtr &op_exec_info, PynativeStatusCode *stat
std
::
string
graph_info
=
GetSingleOpGraphInfo
(
op_exec_info
);
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
session
->
BuildOp
(
*
op_exec_info
,
graph_info
,
&
input_tensors
);
std
::
vector
<
bool
>
tensors_mask
;
ConstructInputTensor
(
op_exec_info
,
&
tensors_mask
,
&
input_tensors
);
session
->
BuildOp
(
*
op_exec_info
,
graph_info
,
input_tensors
,
tensors_mask
);
py
::
tuple
result
=
session
->
RunOp
(
*
op_exec_info
,
graph_info
,
input_tensors
);
ms_context
->
set_enable_pynative_infer
(
false
);
*
status
=
PYNATIVE_SUCCESS
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
90dfbab3
...
...
@@ -250,11 +250,11 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
}
void
AscendSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
input_tensors
);
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
{
MS_LOG
(
INFO
)
<<
"Build op "
<<
op_run_info
.
op_name
<<
" start !"
;
// construct graph include one op
auto
graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
);
auto
graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
,
tensors_mask
);
MS_EXCEPTION_IF_NULL
(
graph
);
opt
::
RunOpAscendBackendIRFusionOptimization
(
graph
);
// kernel select
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
90dfbab3
...
...
@@ -42,7 +42,7 @@ class AscendSession : public SessionBasic {
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
override
;
void
BuildGraph
(
GraphId
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
override
;
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
override
;
...
...
mindspore/ccsrc/session/gpu_session.cc
浏览文件 @
90dfbab3
...
...
@@ -133,10 +133,9 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
}
void
GPUSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
{
// Prepare the graph
MS_EXCEPTION_IF_NULL
(
input_tensors
);
auto
kernel_graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
);
auto
kernel_graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
,
tensors_mask
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
SelectKernel
(
kernel_graph
);
StartKernelRT
();
...
...
mindspore/ccsrc/session/gpu_session.h
浏览文件 @
90dfbab3
...
...
@@ -40,7 +40,7 @@ class GPUSession : public SessionBasic {
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
override
;
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
override
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
90dfbab3
...
...
@@ -180,115 +180,6 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return
ret
;
}
bool
RunOpConvertConstInputToAttr
(
const
py
::
object
&
input_object
,
size_t
input_index
,
const
PrimitivePtr
&
op_prim
,
const
std
::
unordered_set
<
size_t
>
&
input_attrs
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
auto
input_names_value
=
op_prim
->
GetAttr
(
kAttrInputNames
);
if
(
input_names_value
==
nullptr
)
{
return
false
;
}
auto
input_names_vec
=
GetValue
<
std
::
vector
<
std
::
string
>>
(
input_names_value
);
if
(
input_index
>=
input_names_vec
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input index: "
<<
input_index
<<
" is large than the input names vector size!"
;
}
if
(
input_attrs
.
find
(
input_index
)
!=
input_attrs
.
end
())
{
ValuePtr
value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
input_name
=
input_names_vec
[
input_index
];
op_prim
->
set_attr
(
input_name
,
value
);
return
true
;
}
return
false
;
}
void
PlantTensorTupleToVector
(
const
py
::
tuple
&
tuple_inputs
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
for
(
const
auto
&
input_object
:
tuple_inputs
)
{
if
(
!
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a tensor!"
;
}
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
MS_EXCEPTION_IF_NULL
(
tensor
);
input_tensor
->
push_back
(
tensor
);
}
op_prim
->
set_attr
(
kAttrDynInputSizes
,
MakeValue
(
std
::
vector
<
int
>
{
SizeToInt
(
tuple_inputs
.
size
())}));
}
void
ConvertValueTupleToTensor
(
const
py
::
object
&
input_object
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
input_tensor
);
ValuePtr
input_value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
input_value
);
if
(
!
input_value
->
isa
<
ValueTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a value tuple!"
;
}
auto
value_tuple
=
input_value
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
tensor
::
TensorPtr
tensor_ptr
=
opt
::
CreateTupleTensor
(
value_tuple
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConvertPyObjectToTensor
(
const
py
::
object
&
input_object
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
tensor
::
TensorPtr
tensor_ptr
=
nullptr
;
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
}
else
if
(
py
::
isinstance
<
py
::
float_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
float_
>
(
input_object
),
kFloat32
);
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
int_
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
list
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
list
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
array
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
array
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
tuple
>
(
input_object
))
{
auto
tuple_inputs
=
py
::
cast
<
py
::
tuple
>
(
input_object
);
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
tuple_inputs
[
0
]))
{
PlantTensorTupleToVector
(
tuple_inputs
,
op_prim
,
input_tensor
);
}
else
{
ConvertValueTupleToTensor
(
input_object
,
input_tensor
);
}
return
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Run op inputs type is invalid!"
;
}
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConvertInputPyobject
(
const
OpRunInfo
&
op_run_info
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
,
std
::
vector
<
bool
>
*
tensors_mask
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensors
);
MS_EXCEPTION_IF_NULL
(
tensors_mask
);
if
(
op_run_info
.
op_inputs
.
size
()
!=
op_run_info
.
inputs_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Op input size "
<<
op_run_info
.
op_inputs
.
size
()
<<
" should be equal to op input mask size "
<<
op_run_info
.
inputs_mask
.
size
();
}
opt
::
ConstInputToAttrInfoRegister
reg
;
bool
reg_exist
=
opt
::
ConstInputToAttrInfoRegistry
::
Instance
().
GetRegisterByOpName
(
op_run_info
.
op_name
,
&
reg
);
size_t
input_num
=
op_run_info
.
op_inputs
.
size
();
MS_LOG
(
INFO
)
<<
"py input size: "
<<
input_num
;
for
(
size_t
index
=
0
;
index
<
input_num
;
++
index
)
{
// convert const input to attr
if
(
reg_exist
&&
RunOpConvertConstInputToAttr
(
op_run_info
.
op_inputs
[
index
],
index
,
op_prim
,
reg
.
GetConstInputAttrInfo
()))
{
continue
;
}
// convert const and tuple input to tensor
ConvertPyObjectToTensor
(
op_run_info
.
op_inputs
[
index
],
op_prim
,
input_tensors
);
// make tensors, weight : 1, data : 0
std
::
vector
<
bool
>
new_mask
(
input_tensors
->
size
()
-
tensors_mask
->
size
(),
py
::
cast
<
bool
>
(
op_run_info
.
inputs_mask
[
index
]));
tensors_mask
->
insert
(
tensors_mask
->
end
(),
new_mask
.
begin
(),
new_mask
.
end
());
}
}
ValueNodePtr
CreateNewValueNode
(
const
AnfNodePtr
&
anf
,
KernelGraph
*
graph
)
{
auto
value_node
=
anf
->
cast
<
ValueNodePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_node
);
...
...
@@ -747,26 +638,22 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
}
std
::
shared_ptr
<
KernelGraph
>
SessionBasic
::
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
input_tensors
);
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
{
auto
graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
AnfNodePtr
>
inputs
;
// set input[0]
PrimitivePtr
op_prim
=
op_run_info
.
py_primitive
;
if
(
op_prim
==
nullptr
)
{
op_prim
=
std
::
make_shared
<
Primitive
>
(
op_run_info
.
op_name
);
}
MS_EXCEPTION_IF_NULL
(
op_prim
);
inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
op_prim
));
// set input parameter
std
::
vector
<
bool
>
tensors_mask
;
ConvertInputPyobject
(
op_run_info
,
op_prim
,
input_tensors
,
&
tensors_mask
);
MS_LOG
(
INFO
)
<<
"Input tensor size: "
<<
input_tensors
->
size
();
if
(
input_tensors
->
size
()
!=
tensors_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Input tensors size "
<<
input_tensors
->
size
()
<<
" should be equal to tensors mask size "
MS_LOG
(
INFO
)
<<
"Input tensor size: "
<<
input_tensors
.
size
();
if
(
input_tensors
.
size
()
!=
tensors_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Input tensors size "
<<
input_tensors
.
size
()
<<
" should be equal to tensors mask size "
<<
tensors_mask
.
size
();
}
for
(
size_t
i
=
0
;
i
<
input_tensors
->
size
();
++
i
)
{
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
->
at
(
i
),
tensors_mask
[
i
]);
for
(
size_t
i
=
0
;
i
<
input_tensors
.
size
();
++
i
)
{
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
.
at
(
i
),
tensors_mask
[
i
]);
inputs
.
push_back
(
parameter
);
graph
->
MutableInputs
()
->
push_back
(
parameter
);
}
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
90dfbab3
...
...
@@ -61,7 +61,8 @@ class SessionBasic {
virtual
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
=
0
;
virtual
void
BuildOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{}
virtual
void
BuildOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
)
{}
virtual
py
::
tuple
RunOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
return
py
::
tuple
();
...
...
@@ -99,7 +100,8 @@ class SessionBasic {
CNodePtr
ConstructOutput
(
const
AnfNodePtrList
&
outputs
,
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
);
// create a single run op graph
std
::
shared_ptr
<
KernelGraph
>
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
);
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
,
const
std
::
vector
<
bool
>
&
tensors_mask
);
// trans BaseRef list to py::tuple
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录