Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
ffe8b5d3
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
ffe8b5d3
编写于
4月 13, 2020
作者:
L
lvliang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
pynative-add-op-supported
上级
936bae7b
变更
14
显示空白变更内容
内联
并排
Showing
14 changed file
with
240 addition
and
210 deletion
+240
-210
mindspore/ccsrc/pre_activate/common/helper.cc
mindspore/ccsrc/pre_activate/common/helper.cc
+57
-0
mindspore/ccsrc/pre_activate/common/helper.h
mindspore/ccsrc/pre_activate/common/helper.h
+5
-0
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
...e/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
+34
-0
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h
...re/ccsrc/pre_activate/pass/const_input_to_attr_registry.h
+1
-1
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc
...re/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc
+0
-32
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h
...ore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h
+1
-4
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc
.../pre_activate/pass/convert_const_input_to_tensor_input.cc
+1
-57
mindspore/ccsrc/pynative/pynative_execute.cc
mindspore/ccsrc/pynative/pynative_execute.cc
+3
-2
mindspore/ccsrc/session/ascend_session.cc
mindspore/ccsrc/session/ascend_session.cc
+6
-6
mindspore/ccsrc/session/ascend_session.h
mindspore/ccsrc/session/ascend_session.h
+4
-2
mindspore/ccsrc/session/gpu_session.cc
mindspore/ccsrc/session/gpu_session.cc
+6
-6
mindspore/ccsrc/session/gpu_session.h
mindspore/ccsrc/session/gpu_session.h
+4
-2
mindspore/ccsrc/session/session_basic.cc
mindspore/ccsrc/session/session_basic.cc
+112
-92
mindspore/ccsrc/session/session_basic.h
mindspore/ccsrc/session/session_basic.h
+6
-6
未找到文件。
mindspore/ccsrc/pre_activate/common/helper.cc
浏览文件 @
ffe8b5d3
...
...
@@ -28,6 +28,7 @@
namespace
mindspore
{
namespace
opt
{
constexpr
size_t
kType32Len
=
4
;
std
::
vector
<
int
>
Convert2Int
(
const
std
::
vector
<
size_t
>
&
v
)
{
std
::
vector
<
int
>
result
;
(
void
)
std
::
transform
(
v
.
begin
(),
v
.
end
(),
std
::
back_inserter
(
result
),
SizeToInt
);
...
...
@@ -264,6 +265,62 @@ void CreateMultipleOutputsOfAnfNode(const FuncGraphPtr &func_graph, const AnfNod
}
}
template
<
typename
T
>
tensor
::
TensorPtr
CreateTensorWithValueTuple
(
const
ValueTuplePtr
&
value_tuple_ptr
,
const
TypePtr
&
type_ptr
,
size_t
data_length
)
{
MS_EXCEPTION_IF_NULL
(
value_tuple_ptr
);
MS_EXCEPTION_IF_NULL
(
type_ptr
);
std
::
vector
<
T
>
values
;
for
(
const
auto
&
v
:
value_tuple_ptr
->
value
())
{
MS_EXCEPTION_IF_NULL
(
v
);
if
(
v
->
isa
<
Scalar
>
())
{
ScalarPtr
scalar
=
v
->
cast
<
ScalarPtr
>
();
values
.
push_back
(
GetValue
<
T
>
(
scalar
));
}
else
{
MS_LOG
(
WARNING
)
<<
"The value "
<<
v
<<
"of tuple is not a scalar"
;
return
nullptr
;
}
}
std
::
vector
<
int
>
tensor_shape
=
{
SizeToInt
(
values
.
size
())};
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_ptr
->
type_id
(),
tensor_shape
);
MS_EXCEPTION_IF_NULL
(
tensor
);
tensor
::
DeviceInfo
device_info
{
kOpFormat_DEFAULT
,
type_ptr
};
tensor
->
set_device_info
(
device_info
);
auto
data_ptr
=
tensor
->
data_c
(
true
);
MS_EXCEPTION_IF_NULL
(
data_ptr
);
auto
elem_num
=
values
.
size
()
*
data_length
;
auto
ret_code
=
memcpy_s
(
data_ptr
,
static_cast
<
size_t
>
(
tensor
->
data
().
nbytes
()),
values
.
data
(),
elem_num
);
if
(
ret_code
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to copy data into Tensor."
;
}
return
tensor
;
}
tensor
::
TensorPtr
CreateTupleTensor
(
const
ValueTuplePtr
&
value_tuple
)
{
MS_EXCEPTION_IF_NULL
(
value_tuple
);
tensor
::
TensorPtr
tensor
=
nullptr
;
ValuePtr
v
=
*
(
value_tuple
->
value
().
begin
());
MS_EXCEPTION_IF_NULL
(
v
);
// Currently we only deal with the scalar tuple
if
(
!
v
->
isa
<
Scalar
>
())
{
MS_LOG
(
WARNING
)
<<
"The value "
<<
v
<<
"of tuple is not a scalar"
;
return
nullptr
;
}
ScalarPtr
scalar
=
v
->
cast
<
ScalarPtr
>
();
MS_EXCEPTION_IF_NULL
(
scalar
);
if
(
scalar
->
isa
<
IntergerImm
>
())
{
tensor
=
CreateTensorWithValueTuple
<
int
>
(
value_tuple
,
kInt32
,
kType32Len
);
}
else
if
(
scalar
->
isa
<
FloatImm
>
())
{
tensor
=
CreateTensorWithValueTuple
<
float
>
(
value_tuple
,
kFloat32
,
kType32Len
);
}
else
{
auto
type
=
scalar
->
type
();
auto
type_str
=
(
type
==
nullptr
)
?
"nullptr"
:
type
->
ToString
();
MS_LOG
(
ERROR
)
<<
"Invalid scalar type: "
<<
type_str
;
return
nullptr
;
}
return
tensor
;
}
bool
IsNopNode
(
const
AnfNodePtr
&
node
)
{
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
...
...
mindspore/ccsrc/pre_activate/common/helper.h
浏览文件 @
ffe8b5d3
...
...
@@ -135,6 +135,11 @@ void CreateOutputsOfFusedBn3(const FuncGraphPtr &graph, const AnfNodePtr &data_i
void
CreateMultipleOutputsOfAnfNode
(
const
FuncGraphPtr
&
kernel_graph
,
const
AnfNodePtr
&
anf_node_ptr
,
size_t
output_num
,
std
::
vector
<
AnfNodePtr
>
*
outputs
);
tensor
::
TensorPtr
CreateTensorWithValueTuple
(
const
ValueTuplePtr
&
value_tuple_ptr
,
const
TypePtr
&
type_ptr
,
size_t
data_length
);
tensor
::
TensorPtr
CreateTupleTensor
(
const
ValueTuplePtr
&
value_tuple
);
bool
IsNopNode
(
const
AnfNodePtr
&
node
);
void
HideNopNode
(
session
::
KernelGraph
*
const
graph
);
...
...
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.cc
浏览文件 @
ffe8b5d3
...
...
@@ -17,10 +17,44 @@
#include <utility>
#include "utils/utils.h"
#include "utils/log_adapter.h"
#include "operator/ops.h"
namespace
mindspore
{
namespace
opt
{
ConstInputToAttrInfoRegistry
::
ConstInputToAttrInfoRegistry
()
{
Register
(
prim
::
kPrimCast
->
name
(),
{
1
});
Register
(
prim
::
kPrimConv2DBackpropInput
->
name
(),
{
2
});
Register
(
prim
::
kPrimConv2DBackpropFilter
->
name
(),
{
2
});
Register
(
prim
::
kPrimReshape
->
name
(),
{
1
});
Register
(
prim
::
kPrimReduceMax
->
name
(),
{
1
});
Register
(
prim
::
kPrimReduceMin
->
name
(),
{
1
});
Register
(
prim
::
kPrimReduceSum
->
name
(),
{
1
});
Register
(
prim
::
kPrimReduceMean
->
name
(),
{
1
});
Register
(
prim
::
kPrimGatherV2
->
name
(),
{
2
});
Register
(
prim
::
kPrimTranspose
->
name
(),
{
1
});
Register
(
prim
::
kPrimUnsortedSegmentSum
->
name
(),
{
2
});
Register
(
prim
::
kPrimOneHot
->
name
(),
{
1
});
Register
(
kUnsortedSegmentProdOpName
,
{
2
});
Register
(
kUnsortedSegmentMinOpName
,
{
2
});
Register
(
kSimpleMeanGradOpName
,
{
1
});
Register
(
kMeanGradOpName
,
{
1
});
Register
(
kSliceOpName
,
{
1
,
2
});
Register
(
kSliceGradOpName
,
{
2
,
3
});
Register
(
kTileOpName
,
{
1
});
Register
(
kScatterNdOpName
,
{
2
});
Register
(
kStridedSliceAssignOpName
,
{
1
,
2
,
3
});
Register
(
kStridedSliceOpName
,
{
1
,
2
,
3
});
Register
(
kStridedSliceGradOpName
,
{
1
,
2
,
3
,
4
});
Register
(
kFlattenGradOpName
,
{
1
});
Register
(
kExpandDimsOpName
,
{
1
});
Register
(
kSplitOpName
,
{
0
});
Register
(
kTopKOpName
,
{
1
});
Register
(
kSparseApplyAdagradOpName
,
{
2
});
Register
(
kResizeNearestNeighborGrad
,
{
1
});
}
ConstInputToAttrInfoRegistry
&
ConstInputToAttrInfoRegistry
::
Instance
()
{
static
ConstInputToAttrInfoRegistry
instance
;
return
instance
;
...
...
mindspore/ccsrc/pre_activate/pass/const_input_to_attr_registry.h
浏览文件 @
ffe8b5d3
...
...
@@ -54,7 +54,7 @@ class ConstInputToAttrInfoRegistry {
bool
GetRegisterByOpName
(
const
std
::
string
&
op_name
,
ConstInputToAttrInfoRegister
*
reg
)
const
;
private:
ConstInputToAttrInfoRegistry
()
=
default
;
ConstInputToAttrInfoRegistry
();
~
ConstInputToAttrInfoRegistry
()
=
default
;
DISABLE_COPY_AND_ASSIGN
(
ConstInputToAttrInfoRegistry
)
std
::
unordered_map
<
std
::
string
,
ConstInputToAttrInfoRegister
>
op_input_to_attr_map_
;
...
...
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.cc
浏览文件 @
ffe8b5d3
...
...
@@ -87,37 +87,5 @@ const AnfNodePtr ConvertConstInputToAttr::Process(const FuncGraphPtr &, const An
ConstInputToAttr
(
cnode
,
reg
.
GetConstInputAttrInfo
());
return
cnode
;
}
void
ConvertConstInputToAttr
::
Init
()
{
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimCast
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimConv2DBackpropInput
->
name
(),
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimConv2DBackpropFilter
->
name
(),
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimReshape
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimReduceMax
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimReduceMin
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimReduceSum
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimReduceMean
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimGatherV2
->
name
(),
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimTranspose
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimUnsortedSegmentSum
->
name
(),
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
prim
::
kPrimOneHot
->
name
(),
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kUnsortedSegmentProdOpName
,
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kUnsortedSegmentMinOpName
,
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kSimpleMeanGradOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kMeanGradOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kSliceOpName
,
{
1
,
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kSliceGradOpName
,
{
2
,
3
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kTileOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kScatterNdOpName
,
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kStridedSliceAssignOpName
,
{
1
,
2
,
3
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kStridedSliceOpName
,
{
1
,
2
,
3
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kStridedSliceGradOpName
,
{
1
,
2
,
3
,
4
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kFlattenGradOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kExpandDimsOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kSplitOpName
,
{
0
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kTopKOpName
,
{
1
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kSparseApplyAdagradOpName
,
{
2
});
ConstInputToAttrInfoRegistry
::
Instance
().
Register
(
kResizeNearestNeighborGrad
,
{
1
});
}
}
// namespace opt
}
// namespace mindspore
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_attr.h
浏览文件 @
ffe8b5d3
...
...
@@ -27,14 +27,11 @@ namespace opt {
class
ConvertConstInputToAttr
:
public
PatternProcessPass
{
public:
explicit
ConvertConstInputToAttr
(
bool
multigraph
=
true
)
:
PatternProcessPass
(
"convert_const_input_to_attr"
,
multigraph
)
{
Init
();
}
:
PatternProcessPass
(
"convert_const_input_to_attr"
,
multigraph
)
{}
~
ConvertConstInputToAttr
()
override
=
default
;
const
AnfNodePtr
Process
(
const
FuncGraphPtr
&
,
const
AnfNodePtr
&
,
const
EquivPtr
&
)
const
override
;
private:
void
Init
();
std
::
unordered_map
<
std
::
string
,
std
::
unordered_set
<
size_t
>>
op_input_attr_map_
;
};
}
// namespace opt
...
...
mindspore/ccsrc/pre_activate/pass/convert_const_input_to_tensor_input.cc
浏览文件 @
ffe8b5d3
...
...
@@ -19,69 +19,13 @@
#include <memory>
#include "utils/graph_utils.h"
#include "pre_activate/common/helper.h"
#include "session/anf_runtime_algorithm.h"
#include "session/kernel_graph.h"
namespace
mindspore
{
namespace
opt
{
namespace
{
constexpr
size_t
kType32Len
=
4
;
template
<
typename
T
>
tensor
::
TensorPtr
CreateTensorWithValueTuple
(
const
ValueTuplePtr
&
value_tuple_ptr
,
const
TypePtr
&
type_ptr
,
size_t
data_length
)
{
MS_EXCEPTION_IF_NULL
(
value_tuple_ptr
);
MS_EXCEPTION_IF_NULL
(
type_ptr
);
std
::
vector
<
T
>
values
;
for
(
const
auto
&
v
:
value_tuple_ptr
->
value
())
{
MS_EXCEPTION_IF_NULL
(
v
);
if
(
v
->
isa
<
Scalar
>
())
{
ScalarPtr
scalar
=
v
->
cast
<
ScalarPtr
>
();
values
.
push_back
(
GetValue
<
T
>
(
scalar
));
}
else
{
MS_LOG
(
WARNING
)
<<
"The value "
<<
v
<<
"of tuple is not a scalar"
;
return
nullptr
;
}
}
std
::
vector
<
int
>
tensor_shape
=
{
SizeToInt
(
values
.
size
())};
tensor
::
TensorPtr
tensor
=
std
::
make_shared
<
tensor
::
Tensor
>
(
type_ptr
->
type_id
(),
tensor_shape
);
MS_EXCEPTION_IF_NULL
(
tensor
);
tensor
::
DeviceInfo
device_info
{
kOpFormat_DEFAULT
,
type_ptr
};
tensor
->
set_device_info
(
device_info
);
auto
data_ptr
=
tensor
->
data_c
(
true
);
MS_EXCEPTION_IF_NULL
(
data_ptr
);
auto
elem_num
=
values
.
size
()
*
data_length
;
auto
ret_code
=
memcpy_s
(
data_ptr
,
static_cast
<
size_t
>
(
tensor
->
data
().
nbytes
()),
values
.
data
(),
elem_num
);
if
(
ret_code
!=
0
)
{
MS_LOG
(
EXCEPTION
)
<<
"Failed to copy data into Tensor."
;
}
return
tensor
;
}
tensor
::
TensorPtr
CreateTupleTensor
(
const
ValueTuplePtr
&
value_tuple
)
{
MS_EXCEPTION_IF_NULL
(
value_tuple
);
tensor
::
TensorPtr
tensor
=
nullptr
;
ValuePtr
v
=
*
(
value_tuple
->
value
().
begin
());
MS_EXCEPTION_IF_NULL
(
v
);
// Currently we only deal with the scalar tuple
if
(
!
v
->
isa
<
Scalar
>
())
{
MS_LOG
(
WARNING
)
<<
"The value "
<<
v
<<
"of tuple is not a scalar"
;
return
nullptr
;
}
ScalarPtr
scalar
=
v
->
cast
<
ScalarPtr
>
();
MS_EXCEPTION_IF_NULL
(
scalar
);
if
(
scalar
->
isa
<
IntergerImm
>
())
{
tensor
=
CreateTensorWithValueTuple
<
int
>
(
value_tuple
,
kInt32
,
kType32Len
);
}
else
if
(
scalar
->
isa
<
FloatImm
>
())
{
tensor
=
CreateTensorWithValueTuple
<
float
>
(
value_tuple
,
kFloat32
,
kType32Len
);
}
else
{
auto
type
=
scalar
->
type
();
auto
type_str
=
(
type
==
nullptr
)
?
"nullptr"
:
type
->
ToString
();
MS_LOG
(
ERROR
)
<<
"Invalid scalar type: "
<<
type_str
;
return
nullptr
;
}
return
tensor
;
}
AnfNodePtr
CreateTensorInput
(
const
KernelGraphPtr
&
kernel_graph
,
const
AnfNodePtr
&
input_node
)
{
MS_EXCEPTION_IF_NULL
(
input_node
);
auto
value_node
=
input_node
->
cast
<
ValueNodePtr
>
();
...
...
mindspore/ccsrc/pynative/pynative_execute.cc
浏览文件 @
ffe8b5d3
...
...
@@ -158,8 +158,9 @@ py::object RunOpInMs(const OpExecInfoPtr& op_exec_info, PynativeStatusCode* stat
session
->
Init
(
ms_context
->
device_id
());
std
::
string
graph_info
=
GetSingleOpGraphInfo
(
op_exec_info
);
session
->
BuildOp
(
*
op_exec_info
,
graph_info
);
py
::
tuple
result
=
session
->
RunOp
(
*
op_exec_info
,
graph_info
);
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
session
->
BuildOp
(
*
op_exec_info
,
graph_info
,
&
input_tensors
);
py
::
tuple
result
=
session
->
RunOp
(
*
op_exec_info
,
graph_info
,
input_tensors
);
ms_context
->
set_enable_pynative_infer
(
false
);
*
status
=
PYNATIVE_SUCCESS
;
return
result
;
...
...
mindspore/ccsrc/session/ascend_session.cc
浏览文件 @
ffe8b5d3
...
...
@@ -204,10 +204,12 @@ void AscendSession::RunOpExecTask(const std::shared_ptr<KernelGraph> &kernel_gra
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
void
AscendSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
{
void
AscendSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
input_tensors
);
MS_LOG
(
INFO
)
<<
"Build op "
<<
op_run_info
.
op_name
<<
" start !"
;
// construct graph include one op
auto
graph
=
ConstructSingleOpGraph
(
op_run_info
);
auto
graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
);
MS_EXCEPTION_IF_NULL
(
graph
);
opt
::
RunOpAscendBackendIRFusionOptimization
(
graph
);
// kernel select
...
...
@@ -222,14 +224,12 @@ void AscendSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph
run_op_graphs_
[
graph_info
]
=
graph
;
}
py
::
tuple
AscendSession
::
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
{
py
::
tuple
AscendSession
::
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
auto
graph
=
run_op_graphs_
[
graph_info
];
MS_EXCEPTION_IF_NULL
(
graph
);
MS_LOG
(
INFO
)
<<
"Run op "
<<
op_run_info
.
op_name
<<
" start!"
;
// malloc mem
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
=
{};
std
::
vector
<
bool
>
tensors_mask
=
{};
ToTensorPtr
(
op_run_info
,
&
input_tensors
,
&
tensors_mask
);
RunOpMemoryAlloc
(
input_tensors
,
graph
.
get
());
// load input data to device
LoadInputData
(
graph
,
input_tensors
);
...
...
mindspore/ccsrc/session/ascend_session.h
浏览文件 @
ffe8b5d3
...
...
@@ -41,8 +41,10 @@ class AscendSession : public SessionBasic {
GraphId
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
override
;
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
override
;
void
BuildGraph
(
GraphId
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
override
;
// set parameters of final graph
GraphId
SetFinalGraphInput
(
const
std
::
vector
<
AnfNodePtr
>
&
args
)
override
;
...
...
mindspore/ccsrc/session/gpu_session.cc
浏览文件 @
ffe8b5d3
...
...
@@ -132,9 +132,11 @@ void GPUSession::RunGraph(const GraphId &graph_id, const std::vector<tensor::Ten
}
}
void
GPUSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
{
void
GPUSession
::
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
// Prepare the graph
auto
kernel_graph
=
ConstructSingleOpGraph
(
op_run_info
);
MS_EXCEPTION_IF_NULL
(
input_tensors
);
auto
kernel_graph
=
ConstructSingleOpGraph
(
op_run_info
,
input_tensors
);
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
SelectKernel
(
kernel_graph
);
StartKernelRT
();
...
...
@@ -142,12 +144,10 @@ void GPUSession::BuildOp(const OpRunInfo &op_run_info, const GraphInfo &graph_in
run_op_graphs_
[
graph_info
]
=
kernel_graph
;
}
py
::
tuple
GPUSession
::
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
{
py
::
tuple
GPUSession
::
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
auto
kernel_graph
=
run_op_graphs_
[
graph_info
];
MS_EXCEPTION_IF_NULL
(
kernel_graph
);
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
=
{};
std
::
vector
<
bool
>
tensors_mask
=
{};
ToTensorPtr
(
op_run_info
,
&
input_tensors
,
&
tensors_mask
);
RunOpAllocateMemory
(
input_tensors
,
kernel_graph
.
get
());
// Execute the computation
LoadInputData
(
kernel_graph
,
input_tensors
);
...
...
mindspore/ccsrc/session/gpu_session.h
浏览文件 @
ffe8b5d3
...
...
@@ -39,8 +39,10 @@ class GPUSession : public SessionBasic {
GraphId
CompileGraph
(
const
AnfNodePtrList
&
lst
,
const
AnfNodePtrList
&
outputs
)
override
;
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
)
override
;
void
BuildOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
override
;
py
::
tuple
RunOp
(
const
OpRunInfo
&
op_run_info
,
const
GraphInfo
&
graph_info
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
override
;
private:
void
SelectKernel
(
const
std
::
shared_ptr
<
KernelGraph
>
&
kernel_graph
)
const
;
...
...
mindspore/ccsrc/session/session_basic.cc
浏览文件 @
ffe8b5d3
...
...
@@ -17,6 +17,7 @@
#include <utility>
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include "pipeline/parse/data_converter.h"
#include "ir/manager.h"
#include "operator/ops.h"
...
...
@@ -26,6 +27,7 @@
#include "session/anf_runtime_algorithm.h"
#include "kernel/oplib/oplib.h"
#include "pre_activate/common/common_backend_optimization.h"
#include "pre_activate/pass/const_input_to_attr_registry.h"
#include "pre_activate/common/helper.h"
#include "common/utils.h"
#include "ir/dtype.h"
...
...
@@ -178,56 +180,113 @@ BaseRef CreatTupleForOutput(const AnfNodePtr &anf, const KernelGraph &graph,
return
ret
;
}
std
::
string
FindOpInputParameterType
(
const
std
::
string
&
op_name
,
kernel
::
OpImplyType
implyType
,
size_t
index
)
{
std
::
string
para_type
;
auto
op_info
=
kernel
::
OpLib
::
FindOp
(
op_name
,
implyType
);
if
(
op_info
==
nullptr
)
{
return
para_type
;
bool
RunOpConvertConstInputToAttr
(
const
py
::
object
&
input_object
,
size_t
input_index
,
const
PrimitivePtr
&
op_prim
,
const
std
::
unordered_set
<
size_t
>
&
input_attrs
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
auto
input_names_value
=
op_prim
->
GetAttr
(
kAttrInputNames
);
if
(
input_names_value
==
nullptr
)
{
return
false
;
}
auto
input_names_vec
=
GetValue
<
std
::
vector
<
std
::
string
>>
(
input_names_value
);
if
(
input_index
>=
input_names_vec
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input index: "
<<
input_index
<<
" is large than the input names vector size!"
;
}
auto
op_inputs_info_vec
=
op_info
->
inputs_ptr
();
if
(
index
>=
op_inputs_info_vec
.
size
())
{
return
para_type
;
if
(
input_attrs
.
find
(
input_index
)
!=
input_attrs
.
end
())
{
ValuePtr
value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
value
);
auto
input_name
=
input_names_vec
[
input_index
];
op_prim
->
set_attr
(
input_name
,
value
);
return
true
;
}
auto
op_io_info
=
op_inputs_info_vec
[
index
];
MS_EXCEPTION_IF_NULL
(
op_io_info
);
para_type
=
op_io_info
->
param_type
();
return
para_type
;
return
false
;
}
void
RunOpConvertConstInputToAttr
(
const
OpRunInfo
&
op_run_info
,
const
std
::
shared_ptr
<
CNode
>
&
cnode
)
{
MS_EXCEPTION_IF_NULL
(
cnode
);
auto
op_inputs
=
op_run_info
.
op_inputs
;
// get input names vector from attrs
auto
primitive
=
AnfAlgo
::
GetCNodePrimitive
(
cnode
);
MS_EXCEPTION_IF_NULL
(
primitive
);
auto
input_names_value
=
primitive
->
GetAttr
(
kAttrInputNames
);
if
(
input_names_value
==
nullptr
)
{
void
PlantTensorTupleToVector
(
const
py
::
tuple
&
tuple_inputs
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
for
(
const
auto
&
input_object
:
tuple_inputs
)
{
if
(
!
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a tensor!"
;
}
auto
tensor
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
MS_EXCEPTION_IF_NULL
(
tensor
);
input_tensor
->
push_back
(
tensor
);
}
op_prim
->
set_attr
(
kAttrDynInputSizes
,
MakeValue
(
std
::
vector
<
int
>
{
SizeToInt
(
tuple_inputs
.
size
())}));
}
void
ConvertValueTupleToTensor
(
const
py
::
object
&
input_object
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
input_tensor
);
ValuePtr
input_value
=
parse
::
data_converter
::
PyDataToValue
(
input_object
);
MS_EXCEPTION_IF_NULL
(
input_value
);
if
(
!
input_value
->
isa
<
ValueTuple
>
())
{
MS_LOG
(
EXCEPTION
)
<<
"The input object is not a value tuple!"
;
}
auto
value_tuple
=
input_value
->
cast
<
ValueTuplePtr
>
();
MS_EXCEPTION_IF_NULL
(
value_tuple
);
tensor
::
TensorPtr
tensor_ptr
=
nullptr
;
tensor_ptr
=
opt
::
CreateTupleTensor
(
value_tuple
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConvertPyObjectToTensor
(
const
py
::
object
&
input_object
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensor
);
tensor
::
TensorPtr
tensor_ptr
=
nullptr
;
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
input_object
))
{
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
input_object
);
}
else
if
(
py
::
isinstance
<
py
::
float_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
float_
>
(
input_object
),
kFloat32
);
}
else
if
(
py
::
isinstance
<
py
::
int_
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
int_
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
list
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
list
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
array
>
(
input_object
))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
array
>
(
input_object
),
nullptr
);
}
else
if
(
py
::
isinstance
<
py
::
tuple
>
(
input_object
))
{
auto
tuple_inputs
=
py
::
cast
<
py
::
tuple
>
(
input_object
);
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
tuple_inputs
[
0
]))
{
PlantTensorTupleToVector
(
tuple_inputs
,
op_prim
,
input_tensor
);
}
else
{
ConvertValueTupleToTensor
(
input_object
,
input_tensor
);
}
return
;
}
else
{
MS_LOG
(
EXCEPTION
)
<<
"Run op inputs type is invalid!"
;
}
auto
input_names_vec
=
GetValue
<
std
::
vector
<
std
::
string
>>
(
input_names_value
);
// convert const input to attr
size_t
input_num
=
op_inputs
.
size
();
if
(
input_num
!=
input_names_vec
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"input name number "
<<
input_names_vec
.
size
()
<<
"is not equal to input value number "
<<
input_num
;
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
input_tensor
->
push_back
(
tensor_ptr
);
}
void
ConvertInputPyobject
(
const
OpRunInfo
&
op_run_info
,
const
PrimitivePtr
&
op_prim
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
,
std
::
vector
<
bool
>
*
tensors_mask
)
{
MS_EXCEPTION_IF_NULL
(
op_prim
);
MS_EXCEPTION_IF_NULL
(
input_tensors
);
MS_EXCEPTION_IF_NULL
(
tensors_mask
);
if
(
op_run_info
.
op_inputs
.
size
()
!=
op_run_info
.
inputs_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Op input size "
<<
op_run_info
.
op_inputs
.
size
()
<<
" should be equal to op input mask size "
<<
op_run_info
.
inputs_mask
.
size
();
}
opt
::
ConstInputToAttrInfoRegister
reg
;
bool
reg_exist
=
opt
::
ConstInputToAttrInfoRegistry
::
Instance
().
GetRegisterByOpName
(
op_run_info
.
op_name
,
&
reg
);
size_t
input_num
=
op_run_info
.
op_inputs
.
size
();
MS_LOG
(
INFO
)
<<
"py input size: "
<<
input_num
;
for
(
size_t
index
=
0
;
index
<
input_num
;
++
index
)
{
// skip tensor
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
op_inputs
[
index
]))
{
continue
;
}
// convert to attr
auto
para_type
=
FindOpInputParameterType
(
op_run_info
.
op_name
,
kernel
::
OpImplyType
::
kTBE
,
index
);
if
(
!
para_type
.
empty
()
&&
para_type
==
kAttrDynInput
)
{
auto
tuple_inputs
=
py
::
cast
<
py
::
tuple
>
(
op_inputs
[
index
]);
primitive
->
set_attr
(
kAttrDynInputSizes
,
MakeValue
(
std
::
vector
<
int
>
{
SizeToInt
(
tuple_inputs
.
size
())}));
// convert const input to attr
if
(
reg_exist
&&
RunOpConvertConstInputToAttr
(
op_run_info
.
op_inputs
[
index
],
index
,
op_prim
,
reg
.
GetConstInputAttrInfo
()))
{
continue
;
}
ValuePtr
value
=
parse
::
data_converter
::
PyDataToValue
(
op_inputs
[
index
]);
MS_EXCEPTION_IF_NULL
(
value
);
auto
input_name
=
input_names_vec
[
index
];
// set the input node as attr of the cnode, key is name of input node,value is input node's value
primitive
->
set_attr
(
input_name
,
value
);
// convert const and tuple input to tensor
ConvertPyObjectToTensor
(
op_run_info
.
op_inputs
[
index
],
op_prim
,
input_tensors
);
// make tensors, weight : 1, data : 0
std
::
vector
<
bool
>
new_mask
(
input_tensors
->
size
()
-
tensors_mask
->
size
(),
py
::
cast
<
bool
>
(
op_run_info
.
inputs_mask
[
index
]));
tensors_mask
->
insert
(
tensors_mask
->
end
(),
new_mask
.
begin
(),
new_mask
.
end
());
}
}
...
...
@@ -638,40 +697,6 @@ void SessionBasic::Summary(KernelGraph *graph) {
summary_callback_
(
0
,
params_list
);
}
void
SessionBasic
::
ToTensorPtr
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
inputs
,
std
::
vector
<
bool
>
*
tensor_mask
)
{
MS_EXCEPTION_IF_NULL
(
inputs
);
MS_EXCEPTION_IF_NULL
(
tensor_mask
);
if
(
op_run_info
.
op_inputs
.
size
()
!=
op_run_info
.
inputs_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Op input size "
<<
op_run_info
.
op_inputs
.
size
()
<<
" should be equal to op input mask size "
<<
op_run_info
.
inputs_mask
.
size
();
}
size_t
input_num
=
op_run_info
.
op_inputs
.
size
();
// get tensors from op_inputs
for
(
size_t
i
=
0
;
i
<
input_num
;
++
i
)
{
tensor
::
TensorPtr
tensor_ptr
=
nullptr
;
auto
param_type
=
FindOpInputParameterType
(
op_run_info
.
op_name
,
kernel
::
OpImplyType
::
kTBE
,
i
);
if
(
py
::
isinstance
<
tensor
::
Tensor
>
(
op_run_info
.
op_inputs
[
i
]))
{
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
op_run_info
.
op_inputs
[
i
]);
}
else
if
(
!
param_type
.
empty
()
&&
param_type
==
kAttrDynInput
)
{
auto
tuple_inputs
=
py
::
cast
<
py
::
tuple
>
(
op_run_info
.
op_inputs
[
i
]);
for
(
auto
&&
tuple_input
:
tuple_inputs
)
{
tensor_ptr
=
py
::
cast
<
tensor
::
TensorPtr
>
(
tuple_input
);
MS_EXCEPTION_IF_NULL
(
tensor_ptr
);
inputs
->
push_back
(
tensor_ptr
);
tensor_mask
->
push_back
(
py
::
cast
<
bool
>
(
op_run_info
.
inputs_mask
[
i
]));
}
continue
;
}
else
if
(
op_run_info
.
op_name
==
kApplyMomentumOpName
&&
py
::
isinstance
<
py
::
float_
>
(
op_run_info
.
op_inputs
[
i
]))
{
tensor_ptr
=
std
::
make_shared
<
tensor
::
Tensor
>
(
py
::
cast
<
py
::
float_
>
(
op_run_info
.
op_inputs
[
i
]),
kFloat32
);
}
if
(
tensor_ptr
!=
nullptr
)
{
inputs
->
push_back
(
tensor_ptr
);
tensor_mask
->
push_back
(
py
::
cast
<
bool
>
(
op_run_info
.
inputs_mask
[
i
]));
}
}
}
CNodePtr
SessionBasic
::
ConstructOutput
(
const
AnfNodePtrList
&
outputs
,
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
)
{
MS_EXCEPTION_IF_NULL
(
graph
);
std
::
vector
<
AnfNodePtr
>
output_args
;
...
...
@@ -724,30 +749,27 @@ void SessionBasic::CreateOutputNode(const CNodePtr &cnode, const std::shared_ptr
MS_LOG
(
INFO
)
<<
"Finish!"
;
}
std
::
shared_ptr
<
KernelGraph
>
SessionBasic
::
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
)
{
std
::
shared_ptr
<
KernelGraph
>
SessionBasic
::
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{
MS_EXCEPTION_IF_NULL
(
input_tensors
);
auto
graph
=
std
::
make_shared
<
KernelGraph
>
();
std
::
vector
<
AnfNodePtr
>
inputs
;
if
(
op_run_info
.
op_inputs
.
size
()
!=
op_run_info
.
inputs_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"op_run_info inputs.size"
<<
op_run_info
.
op_inputs
.
size
()
<<
" should be equal to parameter_mask.size "
<<
op_run_info
.
inputs_mask
.
size
();
}
// set input[0]
if
(
op_run_info
.
py_primitive
==
nullptr
)
{
inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
std
::
make_shared
<
Primitive
>
(
op_run_info
.
op_name
)));
}
else
{
inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
op_run_info
.
py_primitive
));
PrimitivePtr
op_prim
=
op_run_info
.
py_primitive
;
if
(
op_prim
==
nullptr
)
{
op_prim
=
std
::
make_shared
<
Primitive
>
(
op_run_info
.
op_name
);
}
inputs
.
push_back
(
std
::
make_shared
<
ValueNode
>
(
op_prim
));
// set input parameter
std
::
vector
<
tensor
::
TensorPtr
>
input_tensors
;
std
::
vector
<
bool
>
tensors_mask
;
ToTensorPtr
(
op_run_info
,
&
input_tensors
,
&
tensors_mask
);
MS_LOG
(
INFO
)
<<
"Input tensor size
"
<<
input_tensors
.
size
();
if
(
input_tensors
.
size
()
!=
tensors_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Input tensors size "
<<
input_tensors
.
size
()
<<
" should be equal to tensors mask size "
ConvertInputPyobject
(
op_run_info
,
op_prim
,
input_tensors
,
&
tensors_mask
);
MS_LOG
(
INFO
)
<<
"Input tensor size
: "
<<
input_tensors
->
size
();
if
(
input_tensors
->
size
()
!=
tensors_mask
.
size
())
{
MS_LOG
(
EXCEPTION
)
<<
"Input tensors size "
<<
input_tensors
->
size
()
<<
" should be equal to tensors mask size "
<<
tensors_mask
.
size
();
}
for
(
size_t
i
=
0
;
i
<
input_tensors
.
size
();
++
i
)
{
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
[
i
]
,
tensors_mask
[
i
]);
for
(
size_t
i
=
0
;
i
<
input_tensors
->
size
();
++
i
)
{
auto
parameter
=
ConstructRunOpParameter
(
graph
,
input_tensors
->
at
(
i
)
,
tensors_mask
[
i
]);
inputs
.
push_back
(
parameter
);
graph
->
MutableInputs
()
->
push_back
(
parameter
);
}
...
...
@@ -756,8 +778,6 @@ std::shared_ptr<KernelGraph> SessionBasic::ConstructSingleOpGraph(const OpRunInf
MS_EXCEPTION_IF_NULL
(
cnode
);
// set abstract,which include inferred shapes and types
cnode
->
set_abstract
(
op_run_info
.
abstract
);
// set const input to attr if value is not a tensor,such as scalar or tuple
RunOpConvertConstInputToAttr
(
op_run_info
,
cnode
);
// set execution order
std
::
vector
<
CNodePtr
>
exe_order
=
{
cnode
};
graph
->
set_execution_order
(
exe_order
);
...
...
mindspore/ccsrc/session/session_basic.h
浏览文件 @
ffe8b5d3
...
...
@@ -61,9 +61,11 @@ class SessionBasic {
virtual
void
RunGraph
(
const
GraphId
&
graph_id
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
inputs
,
VectorRef
*
outputs
)
=
0
;
virtual
void
BuildOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
)
{}
virtual
void
BuildOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensors
)
{}
virtual
py
::
tuple
RunOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
)
{
return
py
::
tuple
();
}
virtual
py
::
tuple
RunOp
(
const
OpRunInfo
&
,
const
GraphInfo
&
,
const
std
::
vector
<
tensor
::
TensorPtr
>
&
input_tensors
)
{
return
py
::
tuple
();
}
virtual
void
RegisterSummaryCallBackFunc
(
const
CallBackFunc
&
callback
);
...
...
@@ -96,10 +98,8 @@ class SessionBasic {
void
CreateOutputNode
(
const
CNodePtr
&
cnode
,
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
);
CNodePtr
ConstructOutput
(
const
AnfNodePtrList
&
outputs
,
const
std
::
shared_ptr
<
KernelGraph
>
&
graph
);
// create a single run op graph
std
::
shared_ptr
<
KernelGraph
>
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
);
// get tensors from op inputs
void
ToTensorPtr
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
inputs
,
std
::
vector
<
bool
>
*
tensor_mask
);
std
::
shared_ptr
<
KernelGraph
>
ConstructSingleOpGraph
(
const
OpRunInfo
&
op_run_info
,
std
::
vector
<
tensor
::
TensorPtr
>
*
input_tensor
);
// trans BaseRef list to py::tuple
BaseRef
TransformBaseRefListToTuple
(
const
BaseRef
&
base_ref
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录