Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
MindSpore
mindspore
提交
ee1510da
M
mindspore
项目概览
MindSpore
/
mindspore
通知
35
Star
15
Fork
15
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ee1510da
编写于
7月 14, 2020
作者:
H
He Wei
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Eliminate circular dependency between 'ir' and 'device/kernel'
上级
c99cc0df
变更
22
隐藏空白更改
内联
并排
Showing
22 changed file
with
154 addition
and
92 deletion
+154
-92
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc
...e/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc
+1
-1
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc
...e/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc
+1
-1
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc
...ptimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc
+1
-1
mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc
...ackend/optimizer/pass/common_subexpression_elimination.cc
+2
-2
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
+27
-27
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
+2
-0
mindspore/ccsrc/backend/session/gpu_session.cc
mindspore/ccsrc/backend/session/gpu_session.cc
+1
-1
mindspore/ccsrc/backend/session/session_basic.cc
mindspore/ccsrc/backend/session/session_basic.cc
+7
-6
mindspore/ccsrc/debug/anf_ir_dump.cc
mindspore/ccsrc/debug/anf_ir_dump.cc
+2
-2
mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc
...ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc
+1
-2
mindspore/ccsrc/runtime/device/device_address.h
mindspore/ccsrc/runtime/device/device_address.h
+2
-6
mindspore/ccsrc/runtime/device/kernel_info.h
mindspore/ccsrc/runtime/device/kernel_info.h
+3
-1
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+4
-2
mindspore/core/ir/anf.h
mindspore/core/ir/anf.h
+2
-7
mindspore/core/ir/device_sync.h
mindspore/core/ir/device_sync.h
+38
-0
mindspore/core/ir/kernel_info_dev.h
mindspore/core/ir/kernel_info_dev.h
+32
-0
mindspore/core/ir/tensor.cc
mindspore/core/ir/tensor.cc
+7
-7
mindspore/core/ir/tensor.h
mindspore/core/ir/tensor.h
+4
-6
mindspore/core/ir/tensor_py.cc
mindspore/core/ir/tensor_py.cc
+0
-1
mindspore/core/ir/tensor_py.h
mindspore/core/ir/tensor_py.h
+0
-2
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
+16
-16
tests/ut/cpp/session/kernel_graph_test.cc
tests/ut/cpp/session/kernel_graph_test.cc
+1
-1
未找到文件。
mindspore/ccsrc/backend/kernel_compiler/cpu/cpu_kernel_factory.cc
浏览文件 @
ee1510da
...
...
@@ -38,7 +38,7 @@ void CPUKernelFactory::Register(const std::string &kernel_name, const KernelAttr
}
std
::
shared_ptr
<
CPUKernel
>
CPUKernelFactory
::
Create
(
const
std
::
string
&
kernel_name
,
const
CNodePtr
&
apply_kernel
)
{
auto
kernel_info
=
apply_kernel
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
apply_kernel
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
const
KernelBuildInfo
*
kernel_build_Info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
kernel_build_Info
);
...
...
mindspore/ccsrc/backend/kernel_compiler/gpu/gpu_kernel_factory.cc
浏览文件 @
ee1510da
...
...
@@ -137,7 +137,7 @@ std::pair<bool, size_t> GpuKernelFactory::GpuKernelAttrCheck(const std::string &
}
GpuKernel
*
GpuKernelFactory
::
Create
(
const
std
::
string
&
kernel_name
,
const
CNodePtr
&
apply_kernel
)
{
auto
kernel_info
=
apply_kernel
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
apply_kernel
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
const
KernelBuildInfo
*
kernel_build_Info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
kernel_build_Info
);
...
...
mindspore/ccsrc/backend/optimizer/ascend/ir_fusion/parameter_and_transop_fusion.cc
浏览文件 @
ee1510da
...
...
@@ -63,7 +63,7 @@ const AnfNodePtr ParamTransRoad(const FuncGraphPtr &func_graph, const AnfNodePtr
kernel
::
KernelBuildInfoPtr
GetKernelBuildInfo
(
const
CNodePtr
&
cast
,
const
string
&
format
,
TypeId
input_type
,
TypeId
output_type
)
{
MS_EXCEPTION_IF_NULL
(
cast
);
auto
kernel_info
=
cast
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
cast
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
cast_build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
cast_build_info
);
...
...
mindspore/ccsrc/backend/optimizer/pass/common_subexpression_elimination.cc
浏览文件 @
ee1510da
...
...
@@ -23,8 +23,8 @@ namespace {
bool
CheckEqualKernelBuildInfo
(
const
AnfNodePtr
&
main
,
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
main
);
MS_EXCEPTION_IF_NULL
(
node
);
auto
main_kernel_info
=
main
->
kernel_info
(
);
auto
node_kernel_info
=
node
->
kernel_info
(
);
auto
main_kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
main
->
kernel_info
()
);
auto
node_kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
if
(
main_kernel_info
==
nullptr
&&
node_kernel_info
==
nullptr
)
{
return
true
;
}
...
...
mindspore/ccsrc/backend/session/anf_runtime_algorithm.cc
浏览文件 @
ee1510da
...
...
@@ -338,7 +338,7 @@ std::string AnfRuntimeAlgorithm::GetOutputFormat(const AnfNodePtr &node, size_t
if
(
!
AnfAlgo
::
IsRealKernel
(
node
))
{
return
AnfAlgo
::
GetPrevNodeOutputFormat
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -360,7 +360,7 @@ std::string AnfRuntimeAlgorithm::GetInputFormat(const AnfNodePtr &node, size_t i
if
(
!
IsRealKernel
(
node
))
{
GetPrevNodeOutputFormat
(
node
,
input_idx
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -467,7 +467,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetInputReshapeType(const AnfNode
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputReshapeType
(
node
,
input_idx
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -486,7 +486,7 @@ std::vector<kernel::Axis> AnfRuntimeAlgorithm::GetOutputReshapeType(const AnfNod
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputReshapeType
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -546,7 +546,7 @@ TypeId AnfRuntimeAlgorithm::GetOutputDeviceDataType(const AnfNodePtr &node, size
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputDeviceDataType
(
node
,
output_idx
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -567,7 +567,7 @@ TypeId AnfRuntimeAlgorithm::GetInputDeviceDataType(const AnfNodePtr &node, size_
if
(
!
IsRealKernel
(
node
))
{
return
GetPrevNodeOutputDeviceDataType
(
node
,
0
);
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -597,7 +597,7 @@ const DeviceAddress *AnfRuntimeAlgorithm::GetOutputAddr(const AnfNodePtr &node,
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node"
;
}
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
addr
=
kernel_info
->
GetOutputAddr
(
output_idx
);
if
(
addr
==
nullptr
)
{
...
...
@@ -619,7 +619,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetMutableOutputAddr(const AnfNodePtr &nod
MS_LOG
(
EXCEPTION
)
<<
node
->
DebugString
()
<<
"Invalid nop node."
;
}
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
addr
=
kernel_info
->
GetMutableOutputAddr
(
output_idx
);
if
(
addr
==
nullptr
)
{
...
...
@@ -636,7 +636,7 @@ bool AnfRuntimeAlgorithm::OutputAddrExist(const AnfNodePtr &node, size_t output_
MS_LOG
(
EXCEPTION
)
<<
"The index ["
<<
output_idx
<<
"] is out of range of the node's output size [ "
<<
GetOutputTensorNum
(
node
)
<<
"#node:[ "
<<
node
->
DebugString
()
<<
"]"
;
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
OutputAddrExist
(
output_idx
);
}
...
...
@@ -656,7 +656,7 @@ DeviceAddressPtr AnfRuntimeAlgorithm::GetPrevNodeMutableOutputAddr(const AnfNode
// set output device addr of anf_node
void
AnfRuntimeAlgorithm
::
SetOutputAddr
(
const
DeviceAddressPtr
&
addr
,
size_t
output_idx
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
if
(
!
kernel_info
->
SetOutputAddr
(
addr
,
output_idx
))
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
"set adr"
<<
output_idx
<<
" fail"
;
...
...
@@ -666,7 +666,7 @@ void AnfRuntimeAlgorithm::SetOutputAddr(const DeviceAddressPtr &addr, size_t out
// set workspace device addr of anf_node
void
AnfRuntimeAlgorithm
::
SetWorkspaceAddr
(
const
DeviceAddressPtr
&
addr
,
size_t
output_idx
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
if
(
!
kernel_info
->
SetWorkspaceAddr
(
addr
,
output_idx
))
{
MS_LOG
(
EXCEPTION
)
<<
"Node "
<<
node
->
DebugString
()
<<
"set adr"
<<
output_idx
<<
" fail"
;
...
...
@@ -676,7 +676,7 @@ void AnfRuntimeAlgorithm::SetWorkspaceAddr(const DeviceAddressPtr &addr, size_t
// get workspace device addr of anf_node
DeviceAddress
*
AnfRuntimeAlgorithm
::
GetWorkspaceAddr
(
const
AnfNodePtr
&
node
,
size_t
output_idx
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
addr
=
kernel_info
->
GetWorkspaceAddr
(
output_idx
);
if
(
addr
==
nullptr
)
{
...
...
@@ -720,7 +720,7 @@ void AnfRuntimeAlgorithm::CopyAbstract(const AnfNodePtr &from_node, AnfNode *to_
kernel
::
OpPattern
AnfRuntimeAlgorithm
::
GetOpPattern
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
// select_kernel_build_info() has checked whether return pointer is null
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -731,7 +731,7 @@ kernel::OpPattern AnfRuntimeAlgorithm::GetOpPattern(const AnfNodePtr &node) {
// get KernelBuildType of node, such as ATT,RT,FWK and so on
KernelType
AnfRuntimeAlgorithm
::
GetKernelType
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
// select_kernel_build_info() has checked whether return pointer is null
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
...
...
@@ -741,7 +741,7 @@ KernelType AnfRuntimeAlgorithm::GetKernelType(const AnfNodePtr &node) {
kernel
::
Processor
AnfRuntimeAlgorithm
::
GetProcessor
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -750,7 +750,7 @@ kernel::Processor AnfRuntimeAlgorithm::GetProcessor(const AnfNodePtr &node) {
kernel
::
FusionType
AnfRuntimeAlgorithm
::
GetFusionType
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
auto
build_info
=
kernel_info
->
select_kernel_build_info
();
MS_EXCEPTION_IF_NULL
(
build_info
);
...
...
@@ -760,7 +760,7 @@ kernel::FusionType AnfRuntimeAlgorithm::GetFusionType(const AnfNodePtr &node) {
// set select kernel_build_info
void
AnfRuntimeAlgorithm
::
SetSelectKernelBuildInfo
(
const
KernelBuildInfoPtr
&
select_kernel_build_info
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
set_select_kernel_build_info
(
select_kernel_build_info
);
}
...
...
@@ -768,7 +768,7 @@ void AnfRuntimeAlgorithm::SetSelectKernelBuildInfo(const KernelBuildInfoPtr &sel
// get select kernel_build_info
KernelBuildInfoPtr
AnfRuntimeAlgorithm
::
GetSelectKernelBuildInfo
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
GetMutableSelectKernelBuildInfo
();
}
...
...
@@ -776,7 +776,7 @@ KernelBuildInfoPtr AnfRuntimeAlgorithm::GetSelectKernelBuildInfo(const AnfNodePt
// get kernelMode
KernelMod
*
AnfRuntimeAlgorithm
::
GetKernelMod
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
MutableKernelMod
();
}
...
...
@@ -784,7 +784,7 @@ KernelMod *AnfRuntimeAlgorithm::GetKernelMod(const AnfNodePtr &node) {
// set kernel mod
void
AnfRuntimeAlgorithm
::
SetKernelMod
(
const
KernelModPtr
&
kernel_mod
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
kernel_info
->
set_kernel_mod
(
kernel_mod
);
}
...
...
@@ -850,42 +850,42 @@ bool AnfRuntimeAlgorithm::IsParameterWeight(const ParameterPtr &node) {
void
AnfRuntimeAlgorithm
::
SetStreamId
(
uint32_t
stream_id
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
kernel_info
->
set_stream_id
(
stream_id
);
}
uint32_t
AnfRuntimeAlgorithm
::
GetStreamId
(
const
AnfNodePtr
&
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
stream_id
();
}
void
AnfRuntimeAlgorithm
::
SetStreamDistinctionLabel
(
uint32_t
stream_label
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
kernel_info
->
set_stream_distinction_label
(
stream_label
);
}
uint32_t
AnfRuntimeAlgorithm
::
GetStreamDistinctionLabel
(
const
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
const
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
stream_distinction_label
();
}
void
AnfRuntimeAlgorithm
::
SetGraphId
(
uint32_t
graph_id
,
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
kernel_info
->
set_graph_id
(
graph_id
);
}
uint32_t
AnfRuntimeAlgorithm
::
GetGraphId
(
const
AnfNode
*
node
)
{
MS_EXCEPTION_IF_NULL
(
node
);
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
const
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
graph_id
();
}
...
...
@@ -913,7 +913,7 @@ bool AnfRuntimeAlgorithm::IsFeatureMapOutput(const AnfNodePtr &node) {
if
(
node
->
isa
<
ValueNode
>
())
{
return
false
;
}
auto
kernel_info
=
node
->
kernel_info
(
);
auto
kernel_info
=
dynamic_cast
<
const
device
::
KernelInfo
*>
(
node
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
kernel_info
);
return
kernel_info
->
is_feature_map
();
}
...
...
mindspore/ccsrc/backend/session/anf_runtime_algorithm.h
浏览文件 @
ee1510da
...
...
@@ -38,6 +38,8 @@ namespace mindspore {
namespace
session
{
using
AnfVisitFuncion
=
std
::
function
<
Any
(
const
AnfNodePtr
&
node
,
int
index
)
>
;
using
KernelWithIndex
=
std
::
pair
<
AnfNodePtr
,
size_t
>
;
using
DeviceAddress
=
device
::
DeviceAddress
;
using
DeviceAddressPtr
=
device
::
DeviceAddressPtr
;
class
AnfRuntimeAlgorithm
{
public:
// get input_anf_node's real kernel by recurse
...
...
mindspore/ccsrc/backend/session/gpu_session.cc
浏览文件 @
ee1510da
...
...
@@ -121,7 +121,7 @@ void GPUSession::LoadInputData(const std::shared_ptr<KernelGraph> &kernel_graph,
if
(
input_node
->
isa
<
Parameter
>
()
&&
AnfAlgo
::
OutputAddrExist
(
input_node
,
0
))
{
auto
pk_node
=
input_node
->
cast
<
ParameterPtr
>
();
auto
device_address
=
AnfAlgo
::
GetMutableOutputAddr
(
pk_node
,
0
);
auto
tensor_address
=
tensor
->
device_address
(
);
auto
tensor_address
=
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
tensor
->
device_address
()
);
bool
need_sync
=
false
;
if
(
ms_context
->
enable_pynative_infer
())
{
if
(
tensor_address
==
nullptr
||
tensor_address
!=
device_address
)
{
...
...
mindspore/ccsrc/backend/session/session_basic.cc
浏览文件 @
ee1510da
...
...
@@ -230,13 +230,14 @@ ParameterPtr ConstructRunOpParameter(const std::shared_ptr<KernelGraph> &graph,
// set the kernel info of parameter
auto
kernel_build_info_builder
=
std
::
make_shared
<
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
>
();
MS_EXCEPTION_IF_NULL
(
input_tensor
);
if
(
input_tensor
->
device_address
().
get
()
==
nullptr
)
{
auto
device_address
=
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
input_tensor
->
device_address
());
if
(
device_address
==
nullptr
)
{
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
kOpFormat_DEFAULT
});
TypeId
param_init_data_type
=
AnfAlgo
::
IsParameterWeight
(
param
)
?
kTypeUnknown
:
input_tensor
->
data_type
();
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
{
param_init_data_type
});
}
else
{
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
input_tensor
->
device_address
()
->
format
()});
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
{
input_tensor
->
device_address
()
->
type_id
()});
kernel_build_info_builder
->
SetOutputsFormat
(
std
::
vector
<
std
::
string
>
{
device_address
->
format
()});
kernel_build_info_builder
->
SetOutputsDeviceType
(
std
::
vector
<
TypeId
>
{
device_address
->
type_id
()});
}
AnfAlgo
::
SetSelectKernelBuildInfo
(
kernel_build_info_builder
->
Build
(),
param
.
get
());
// construct abstract of parameter
...
...
@@ -319,7 +320,7 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
if
(
ref_real_node
->
isa
<
CNode
>
()
&&
node_graph
->
IsInternalOutput
(
ref_real_node
)
&&
node_graph
->
IsFinalOutputKernel
(
ref_real_node
))
{
auto
kernel_info
=
ref_real_node
->
kernel_info
();
if
(
kernel_info
==
nullptr
||
kernel_info
->
select_kernel_build_info
()
==
nullptr
)
{
if
(
kernel_info
==
nullptr
||
!
kernel_info
->
has_build_info
()
)
{
MS_LOG
(
INFO
)
<<
"No kernel info"
;
return
;
}
...
...
@@ -330,9 +331,9 @@ void SessionBasic::InitInternalOutputParameter(const AnfNodePtr &out_node, const
}
auto
format
=
AnfAlgo
::
GetOutputFormat
(
ref_real_node
,
ref_real_node_index
);
auto
type
=
AnfAlgo
::
GetOutputDeviceDataType
(
ref_real_node
,
ref_real_node_index
);
parameter
->
set_kernel_info
(
std
::
make_shared
<
device
::
KernelInfo
>
());
auto
d_kernel_info
=
parameter
->
kernel_info
();
auto
d_kernel_info
=
std
::
make_shared
<
device
::
KernelInfo
>
();
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
parameter
->
set_kernel_info
(
d_kernel_info
);
kernel
::
KernelBuildInfo
::
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
type
});
builder
.
SetOutputsFormat
({
format
});
...
...
mindspore/ccsrc/debug/anf_ir_dump.cc
浏览文件 @
ee1510da
...
...
@@ -128,7 +128,7 @@ void DumpKernelInfo(const CNodePtr &node, const std::shared_ptr<SubGraphIRInfo>
return
;
}
auto
kernel_info
=
node
->
kernel_info
();
if
(
kernel_info
==
nullptr
||
kernel_info
->
select_kernel_build_info
()
==
nullptr
)
{
if
(
kernel_info
==
nullptr
||
!
kernel_info
->
has_build_info
()
)
{
return
;
}
...
...
@@ -179,7 +179,7 @@ void DumpParams(const FuncGraphPtr &graph, std::ostringstream &buffer, OrderedMa
// print parameters' type and shape
PrintNodeOutputType
(
buffer
,
p
);
auto
kernel_info
=
p
->
kernel_info
();
if
(
kernel_info
!=
nullptr
&&
kernel_info
->
select_kernel_build_info
()
!=
nullptr
)
{
if
(
kernel_info
!=
nullptr
&&
kernel_info
->
has_build_info
()
)
{
buffer
<<
" : "
;
auto
type
=
AnfAlgo
::
GetOutputDeviceDataType
(
p
,
0
);
auto
format
=
AnfAlgo
::
GetOutputFormat
(
p
,
0
);
...
...
mindspore/ccsrc/runtime/device/ascend/kernel_select_graph_kernel.cc
浏览文件 @
ee1510da
...
...
@@ -362,8 +362,7 @@ void CheckFormatsAndDtypes(const CNodePtr &kernel_node, const std::vector<AnfNod
continue
;
}
for
(
auto
&
node_user
:
iter
->
second
)
{
if
(
node_user
.
first
->
kernel_info
()
==
nullptr
||
node_user
.
first
->
kernel_info
()
->
select_kernel_build_info
()
==
nullptr
)
{
if
(
node_user
.
first
->
kernel_info
()
==
nullptr
||
!
node_user
.
first
->
kernel_info
()
->
has_build_info
())
{
// maybe not a real kernel.
continue
;
}
...
...
mindspore/ccsrc/runtime/device/device_address.h
浏览文件 @
ee1510da
...
...
@@ -21,8 +21,7 @@
#include <vector>
#include <memory>
#include "ir/dtype.h"
using
std
::
string
;
#include "ir/device_sync.h"
namespace
mindspore
{
namespace
device
{
...
...
@@ -51,15 +50,12 @@ namespace device {
enum
class
DeviceAddressStatus
{
kInDevice
,
kInHost
,
kInDeviceToHost
,
kInHostToDevice
};
enum
class
DeviceAddressType
{
kUnknown
,
kAscend
,
kCPU
,
kGPU
};
class
DeviceAddress
{
class
DeviceAddress
:
public
mindspore
::
DeviceSync
{
public:
explicit
DeviceAddress
(
void
*
ptr
,
size_t
size
)
:
ptr_
(
ptr
),
size_
(
size
)
{}
explicit
DeviceAddress
(
void
*
ptr
,
size_t
size
,
const
string
&
format
,
TypeId
type_id
)
:
ptr_
(
ptr
),
size_
(
size
),
format_
(
format
),
type_id_
(
type_id
)
{}
virtual
~
DeviceAddress
()
{
ptr_
=
nullptr
;
}
virtual
bool
SyncDeviceToHost
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
void
*
host_ptr
)
const
=
0
;
virtual
bool
SyncHostToDevice
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
const
void
*
host_ptr
)
const
=
0
;
const
void
*
GetPtr
()
const
{
return
ptr_
;
}
size_t
GetSize
()
const
{
return
size_
;
}
std
::
string
format
()
const
{
return
format_
;
}
...
...
mindspore/ccsrc/runtime/device/kernel_info.h
浏览文件 @
ee1510da
...
...
@@ -19,6 +19,7 @@
#include <vector>
#include <memory>
#include "ir/kernel_info_dev.h"
#include "backend/kernel_compiler/kernel_build_info.h"
#include "runtime/device/ascend/ascend_device_address.h"
#include "backend/kernel_compiler/kernel.h"
...
...
@@ -27,7 +28,7 @@ namespace mindspore {
const
uint32_t
kInvalidGraphId
=
UINT32_MAX
;
const
uint32_t
kInvalidDistincLabel
=
UINT32_MAX
;
namespace
device
{
class
KernelInfo
{
class
KernelInfo
:
public
KernelInfoDevice
{
public:
KernelInfo
()
{
kernel_mod_
=
nullptr
;
...
...
@@ -41,6 +42,7 @@ class KernelInfo {
}
virtual
~
KernelInfo
()
=
default
;
bool
has_build_info
()
const
override
{
return
select_kernel_build_info
()
!=
nullptr
;
}
const
kernel
::
KernelBuildInfo
*
select_kernel_build_info
()
const
;
kernel
::
KernelBuildInfoPtr
GetMutableSelectKernelBuildInfo
()
const
;
void
set_select_kernel_build_info
(
const
kernel
::
KernelBuildInfoPtr
&
select_kernel_build_info
)
{
...
...
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
ee1510da
...
...
@@ -214,8 +214,10 @@ void KernelRuntime::RunOpAssignInputMemory(const std::vector<tensor::TensorPtr>
auto
output_size
=
AnfAlgo
::
GetOutputTensorNum
(
item
);
for
(
size_t
index
=
0
;
index
<
output_size
;
index
++
)
{
MS_EXCEPTION_IF_NULL
(
input_tensors
[
input_index
]);
if
(
input_tensors
[
input_index
]
->
device_address
().
get
()
!=
nullptr
)
{
AnfAlgo
::
SetOutputAddr
(
input_tensors
[
input_index
]
->
device_address
(),
index
,
item
.
get
());
auto
output_address
=
std
::
dynamic_pointer_cast
<
device
::
DeviceAddress
>
(
input_tensors
[
input_index
]
->
device_address
());
if
(
output_address
!=
nullptr
)
{
AnfAlgo
::
SetOutputAddr
(
output_address
,
index
,
item
.
get
());
continue
;
}
TypeId
output_type_id
=
AnfAlgo
::
GetOutputDeviceDataType
(
item
,
index
);
...
...
mindspore/core/ir/anf.h
浏览文件 @
ee1510da
...
...
@@ -27,8 +27,9 @@
#include <utility>
#include "base/base.h"
#include "
debug/info
.h"
#include "
ir/kernel_info_dev
.h"
#include "ir/scope.h"
#include "debug/info.h"
// A MindSpore ANF IR defined here.
// with BNF followed:
...
...
@@ -71,12 +72,6 @@ class BaseRef;
class
Var
;
using
VarPtr
=
std
::
shared_ptr
<
Var
>
;
namespace
device
{
class
KernelInfo
;
}
// namespace device
using
KernelInfoDevice
=
device
::
KernelInfo
;
using
KernelInfoDevicePtr
=
std
::
shared_ptr
<
KernelInfoDevice
>
;
class
AnfVisitor
;
class
ParamValue
;
...
...
mindspore/core/ir/device_sync.h
0 → 100644
浏览文件 @
ee1510da
/**
* Copyright 2020 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
#define MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
#include <vector>
#include <memory>
#include <string>
#include "ir/dtype/type.h"
using
std
::
string
;
namespace
mindspore
{
// Interface for data synchornize between device and host.
class
DeviceSync
{
public:
virtual
bool
SyncDeviceToHost
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
void
*
host_ptr
)
const
=
0
;
virtual
bool
SyncHostToDevice
(
const
std
::
vector
<
int
>
&
shape
,
size_t
size
,
TypeId
type
,
const
void
*
host_ptr
)
const
=
0
;
};
using
DeviceSyncPtr
=
std
::
shared_ptr
<
DeviceSync
>
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_DEVICE_SYNC_H_
mindspore/core/ir/kernel_info_dev.h
0 → 100644
浏览文件 @
ee1510da
/**
* Copyright 2019 Huawei Technologies Co., Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
#define MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
#include <memory>
namespace
mindspore
{
// Interface for device kernel program information.
class
KernelInfoDevice
{
public:
// If kernel program was built and build info is set.
virtual
bool
has_build_info
()
const
=
0
;
};
using
KernelInfoDevicePtr
=
std
::
shared_ptr
<
KernelInfoDevice
>
;
}
// namespace mindspore
#endif // MINDSPORE_CCSRC_IR_KERNEL_INFO_DEV_H_
mindspore/core/ir/tensor.cc
浏览文件 @
ee1510da
...
...
@@ -326,7 +326,7 @@ Tensor::Tensor(const Tensor &tensor)
data_
(
tensor
.
data_
),
dirty_
(
tensor
.
dirty_
),
id_
(
tensor
.
id_
),
device_
address_
(
tensor
.
device_address
_
)
{}
device_
sync_
(
tensor
.
device_sync
_
)
{}
Tensor
::
Tensor
(
const
Tensor
&
tensor
,
TypeId
data_type
)
:
MetaTensor
(
data_type
,
tensor
.
shape_
),
...
...
@@ -334,7 +334,7 @@ Tensor::Tensor(const Tensor &tensor, TypeId data_type)
data_
(
MakeTensorData
(
data_type
,
tensor
.
shape_
,
tensor
.
data_
->
data
(),
tensor
.
data_type_
)),
dirty_
(
tensor
.
dirty_
),
id_
(
tensor
.
id_
),
device_
address_
(
tensor
.
device_address
_
)
{}
device_
sync_
(
tensor
.
device_sync
_
)
{}
Tensor
::
Tensor
(
TypeId
data_type
,
const
std
::
vector
<
int
>
&
shape
,
TensorDataPtr
data
)
:
MetaTensor
(
data_type
,
shape
),
data_
(
std
::
move
(
data
)),
id_
(
MakeId
())
{}
...
...
@@ -379,10 +379,10 @@ bool Tensor::ValueEqual(const Tensor &tensor) const {
Tensor
&
Tensor
::
AssignValue
(
const
Tensor
&
tensor
)
{
if
(
this
!=
&
tensor
)
{
MetaTensor
::
operator
=
(
tensor
);
dirty_
=
tensor
.
is_dirty
()
;
device_
address_
=
tensor
.
device_address
()
;
dirty_
=
tensor
.
dirty_
;
device_
sync_
=
tensor
.
device_sync_
;
data_
=
tensor
.
data_
;
id_
=
tensor
.
id
()
;
id_
=
tensor
.
id
_
;
}
return
*
this
;
}
...
...
@@ -425,8 +425,8 @@ std::string Tensor::ToStringRepr() const {
}
void
Tensor
::
data_sync
()
const
{
if
(
device_
address
_
!=
nullptr
)
{
if
(
!
device_
address
_
->
SyncDeviceToHost
(
shape
(),
static_cast
<
size_t
>
(
data
().
nbytes
()),
data_type
(),
data_c
()))
{
if
(
device_
sync
_
!=
nullptr
)
{
if
(
!
device_
sync
_
->
SyncDeviceToHost
(
shape
(),
static_cast
<
size_t
>
(
data
().
nbytes
()),
data_type
(),
data_c
()))
{
MS_LOG
(
EXCEPTION
)
<<
"SyncDeviceToHost when asnumpy."
;
}
}
...
...
mindspore/core/ir/tensor.h
浏览文件 @
ee1510da
...
...
@@ -23,15 +23,13 @@
#include <numeric>
#include "Eigen/Core"
#include "
runtime/device/device_address
.h"
#include "
ir/device_sync
.h"
#include "ir/meta_tensor.h"
#include "include/ms_tensor.h"
#include "utils/log_adapter.h"
using
float16
=
Eigen
::
half
;
using
mindspore
::
device
::
DeviceAddress
;
using
DeviceAddressPtr
=
std
::
shared_ptr
<
mindspore
::
device
::
DeviceAddress
>
;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of MindSpore project.
...
...
@@ -222,8 +220,8 @@ class Tensor : public MetaTensor {
bool
is_dirty
()
const
{
return
dirty_
;
}
void
set_dirty
(
const
bool
dirty
)
{
dirty_
=
dirty
;
}
Device
AddressPtr
device_address
()
const
{
return
device_address
_
;
}
void
set_device_address
(
const
Device
AddressPtr
&
device_address
)
{
device_address_
=
device_address
;
}
Device
SyncPtr
device_address
()
const
{
return
device_sync
_
;
}
void
set_device_address
(
const
Device
SyncPtr
&
device_sync
)
{
device_sync_
=
device_sync
;
}
std
::
string
id
()
const
{
return
id_
;
}
...
...
@@ -234,7 +232,7 @@ class Tensor : public MetaTensor {
TensorDataPtr
data_
{
nullptr
};
bool
dirty_
{
true
};
std
::
string
id_
{
""
};
Device
AddressPtr
device_address
_
{
nullptr
};
Device
SyncPtr
device_sync
_
{
nullptr
};
};
using
TensorPtr
=
std
::
shared_ptr
<
Tensor
>
;
using
TensorPtrList
=
std
::
vector
<
std
::
shared_ptr
<
Tensor
>>
;
...
...
mindspore/core/ir/tensor_py.cc
浏览文件 @
ee1510da
...
...
@@ -22,7 +22,6 @@
#include <sstream>
#include <string>
#include "runtime/device/device_address.h"
#include "pybind_api/api_register.h"
#include "pybind_api/export_flags.h"
#include "abstract/abstract_value.h"
...
...
mindspore/core/ir/tensor_py.h
浏览文件 @
ee1510da
...
...
@@ -81,8 +81,6 @@ struct type_caster<float16> : public npy_scalar_caster<float16> {
}
// namespace detail
}
// namespace pybind11
using
mindspore
::
device
::
DeviceAddress
;
using
DeviceAddressPtr
=
std
::
shared_ptr
<
mindspore
::
device
::
DeviceAddress
>
;
// brief mindspore namespace.
//
// mindspore namespace is the top level namespace of Mindsporeession project.
...
...
tests/ut/cpp/session/anf_runtime_algorithm_test.cc
浏览文件 @
ee1510da
...
...
@@ -255,7 +255,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputFormat) {
AnfAlgo
::
SetOutputInferTypeAndShape
({
kNumberTypeFloat32
,
kNumberTypeFloat32
},
{
shape
,
shape
},
add
.
get
());
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
(),
kFloat16
->
type_id
()});
...
...
@@ -274,7 +274,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputFormat) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
(),
kFloat16
->
type_id
()});
...
...
@@ -293,7 +293,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputFormat) {
auto
pre_add
=
kernel_graph
->
NewCNode
(
pre_node_inputs
);
MS_EXCEPTION_IF_NULL
(
pre_add
);
pre_add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
pre_add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
pre_add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
...
...
@@ -373,7 +373,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceShape) {
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_abstract
(
tuple_abstract
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsFormat
({
kOpFormat_NCHW
,
kOpFormat_NCHW
,
kOpFormat_NHWC
,
kOpFormat_FRAC_NZ
});
...
...
@@ -404,7 +404,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceShape) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsFormat
({
kOpFormat_NCHW
,
kOpFormat_NCHW
,
kOpFormat_NHWC
});
...
...
@@ -457,7 +457,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputDeviceDataTypeTest) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
...
...
@@ -474,7 +474,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetInputDeviceDataTypeTest) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetInputsDeviceType
({
kFloat32
->
type_id
(),
kFloat16
->
type_id
()});
...
...
@@ -492,7 +492,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputDeviceDataType) {
auto
pre_add
=
kernel_graph
->
NewCNode
(
pre_add_inputs
);
MS_EXCEPTION_IF_NULL
(
pre_add
);
pre_add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
pre_add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
pre_add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetOutputsDeviceType
({
kFloat32
->
type_id
()});
...
...
@@ -513,7 +513,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetOutputAddr) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
int
*
addr
=
nullptr
;
auto
device_address
=
std
::
make_shared
<
AscendDeviceAddress
>
(
addr
,
1
);
...
...
@@ -528,7 +528,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetPrevNodeOutputAddr) {
auto
pre_add
=
kernel_graph
->
NewCNode
(
pre_add_inputs
);
MS_EXCEPTION_IF_NULL
(
pre_add
);
pre_add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
pre_add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
pre_add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
int
*
addr
=
nullptr
;
auto
device_address
=
std
::
make_shared
<
AscendDeviceAddress
>
(
addr
,
1
);
...
...
@@ -561,7 +561,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetWorkspaceAddr) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
int
*
addr
=
nullptr
;
auto
device_address
=
std
::
make_shared
<
AscendDeviceAddress
>
(
addr
,
1
);
...
...
@@ -643,7 +643,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelType) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetKernelType
(
AKG_KERNEL
);
...
...
@@ -659,7 +659,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetProcessor) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetProcessor
(
kernel
::
AICORE
);
...
...
@@ -675,7 +675,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetFusionType) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
KernelBuildInfoBuilder
builder
;
builder
.
SetFusionType
(
kernel
::
CONVLUTION
);
...
...
@@ -703,7 +703,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetKernelMod) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
d_kernel_info
->
set_kernel_mod
(
nullptr
);
EXPECT_EQ
(
AnfAlgo
::
GetKernelMod
(
add
),
nullptr
);
...
...
@@ -779,7 +779,7 @@ TEST_F(AnfRuntimeAlgorithmTest, GetStreamId) {
auto
add
=
kernel_graph
->
NewCNode
(
inputs
);
MS_EXCEPTION_IF_NULL
(
add
);
add
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
d_kernel_info
=
add
->
kernel_info
(
);
auto
d_kernel_info
=
dynamic_cast
<
KernelInfo
*>
(
add
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
d_kernel_info
);
d_kernel_info
->
set_stream_id
(
0
);
EXPECT_EQ
(
AnfAlgo
::
GetStreamId
(
add
),
0
);
...
...
tests/ut/cpp/session/kernel_graph_test.cc
浏览文件 @
ee1510da
...
...
@@ -42,7 +42,7 @@ TEST_F(KernelGraphTest, NewValueNode) {
auto
x_abstract
=
std
::
make_shared
<
abstract
::
AbstractTensor
>
(
kFloat32
,
shape
);
add_value
->
set_abstract
(
x_abstract
);
add_value
->
set_kernel_info
(
std
::
make_shared
<
KernelInfo
>
());
auto
mutable_kernel_info
=
add_value
->
kernel_info
(
);
auto
mutable_kernel_info
=
dynamic_cast
<
device
::
KernelInfo
*>
(
add_value
->
kernel_info
()
);
MS_EXCEPTION_IF_NULL
(
mutable_kernel_info
);
std
::
shared_ptr
<
KernelBuildInfoBuilder
>
builder
=
std
::
make_shared
<
KernelBuildInfoBuilder
>
();
builder
->
SetOutputsFormat
({
kOpFormat_FRAC_Z
});
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录