Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
magicwindyyd
mindspore
提交
3a61d646
M
mindspore
项目概览
magicwindyyd
/
mindspore
与 Fork 源项目一致
Fork自
MindSpore / mindspore
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
M
mindspore
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
3a61d646
编写于
8月 03, 2020
作者:
L
lvliang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
decoupling-the-interface-of-mallocing-mem
上级
72d2fc74
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
24 addition
and
40 deletion
+24
-40
mindspore/ccsrc/runtime/device/kernel_runtime.cc
mindspore/ccsrc/runtime/device/kernel_runtime.cc
+8
-10
mindspore/ccsrc/runtime/device/memory_manager.cc
mindspore/ccsrc/runtime/device/memory_manager.cc
+13
-26
mindspore/ccsrc/runtime/device/memory_manager.h
mindspore/ccsrc/runtime/device/memory_manager.h
+3
-4
未找到文件。
mindspore/ccsrc/runtime/device/kernel_runtime.cc
浏览文件 @
3a61d646
...
...
@@ -411,7 +411,7 @@ void KernelRuntime::AssignStaticMemoryInput(const session::KernelGraph *graph) {
}
auto
tensor_size
=
CountNodeDeviceMemorySize
(
item
,
index
);
auto
address
=
CreateDeviceAddress
(
nullptr
,
tensor_size
,
AnfAlgo
::
GetOutputFormat
(
item
,
index
),
output_type_id
);
if
(
mem_manager_
->
MallocMem
(
address
,
kStaticMem
,
tensor_size
)
==
nullptr
)
{
if
(
mem_manager_
->
MallocMem
(
kStaticMem
,
tensor_size
,
address
)
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
tensor_size
;
}
AnfAlgo
::
SetOutputAddr
(
address
,
index
,
item
.
get
());
...
...
@@ -517,7 +517,7 @@ void KernelRuntime::AssignCommunicationNodeOutputMem(MemType type, const AnfNode
auto
address
=
CreateDeviceAddress
(
nullptr
,
output_sizes
[
j
],
output_format
,
output_type
);
MS_EXCEPTION_IF_NULL
(
address
);
if
(
output_ptr
==
nullptr
)
{
output_ptr
=
mem_manager_
->
Malloc
Mem
(
address
,
type
,
total_size
,
std
::
pair
<
AnfNodePtr
,
size_t
>
(
node
,
0
)
);
output_ptr
=
mem_manager_
->
Malloc
OutputMem
(
node
,
0
,
type
,
total_size
,
address
);
MS_EXCEPTION_IF_NULL
(
output_ptr
);
}
else
{
address
->
set_ptr
(
output_ptr
);
...
...
@@ -565,8 +565,7 @@ void KernelRuntime::AssignCommunicationNodeInputMem(MemType type, const AnfNodeP
if
(
addr_size
.
empty
())
{
return
;
}
uint8_t
*
input_ptr
=
mem_manager_
->
MallocMem
(
addr_size
[
0
].
first
,
type
,
total_size
,
std
::
pair
<
AnfNodePtr
,
size_t
>
(
node
,
0
));
uint8_t
*
input_ptr
=
mem_manager_
->
MallocOutputMem
(
node
,
0
,
type
,
total_size
,
addr_size
[
0
].
first
);
for
(
const
auto
&
iter
:
addr_size
)
{
MS_EXCEPTION_IF_NULL
(
iter
.
first
);
iter
.
first
->
set_ptr
(
input_ptr
);
...
...
@@ -600,8 +599,7 @@ void KernelRuntime::AssignNodeOutputMem(MemType type, const AnfNodePtr &node, in
auto
output_type
=
AnfAlgo
::
GetOutputDeviceDataType
(
node
,
i
);
auto
device_address
=
CreateDeviceAddress
(
nullptr
,
output_sizes
[
i
],
output_format
,
output_type
);
MS_EXCEPTION_IF_NULL
(
device_address
);
uint8_t
*
ptr
=
mem_manager_
->
MallocMem
(
device_address
,
type
,
output_sizes
[
i
],
std
::
pair
<
AnfNodePtr
,
size_t
>
(
node
,
i
));
uint8_t
*
ptr
=
mem_manager_
->
MallocOutputMem
(
node
,
i
,
type
,
output_sizes
[
i
],
device_address
);
MS_EXCEPTION_IF_NULL
(
ptr
);
device_address
->
set_host_shape
(
trans
::
GetRuntimePaddingShape
(
node
,
i
));
AnfAlgo
::
SetOutputAddr
(
device_address
,
i
,
node
.
get
());
...
...
@@ -639,7 +637,7 @@ void KernelRuntime::AssignValueNodeTensor(const ValueNodePtr &value_node, const
MS_EXCEPTION_IF_NULL
(
address
);
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
node_size
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
node_size
;
}
else
if
(
mem_manager_
->
MallocMem
(
address
,
kStaticMem
,
node_size
)
==
nullptr
)
{
}
else
if
(
mem_manager_
->
MallocMem
(
kStaticMem
,
node_size
,
address
)
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
node_size
;
}
AnfAlgo
::
SetOutputAddr
(
address
,
output_idx
,
value_node
.
get
());
...
...
@@ -675,7 +673,7 @@ void KernelRuntime::AssignStaticMemoryValueNode(session::KernelGraph *graph) {
MS_EXCEPTION_IF_NULL
(
address
);
if
(
ms_context
->
enable_pynative_infer
()
&&
!
mem_manager_
->
MallocMemFromMemPool
(
address
,
tensor_size
))
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address from memory pool when tensor size is: "
<<
tensor_size
;
}
else
if
(
mem_manager_
->
MallocMem
(
address
,
kStaticMem
,
tensor_size
)
==
nullptr
)
{
}
else
if
(
mem_manager_
->
MallocMem
(
kStaticMem
,
tensor_size
,
address
)
==
nullptr
)
{
MS_LOG
(
EXCEPTION
)
<<
"Cannot alloc address when flag is: "
<<
kStaticMem
<<
", tensor size is: "
<<
tensor_size
;
}
AnfAlgo
::
SetOutputAddr
(
address
,
0
,
value_node
.
get
());
...
...
@@ -859,8 +857,8 @@ DeviceAddressPtr KernelRuntime::AssignSingleOpLaunchMemory(size_t size, const st
auto
device_address
=
CreateDeviceAddress
(
nullptr
,
size
,
format
,
type
);
MS_EXCEPTION_IF_NULL
(
device_address
);
MS_EXCEPTION_IF_NULL
(
mem_manager_
);
auto
base_ptr
=
mem_manager_
->
MallocMem
(
kDynamicMem
,
size
);
device_address
->
set_ptr
(
base_ptr
);
auto
base_ptr
=
mem_manager_
->
MallocMem
(
kDynamicMem
,
size
,
device_address
);
MS_EXCEPTION_IF_NULL
(
base_ptr
);
return
device_address
;
}
...
...
mindspore/ccsrc/runtime/device/memory_manager.cc
浏览文件 @
3a61d646
...
...
@@ -45,8 +45,10 @@ void MemoryManager::MallocReusedDynamicMem(const session::KernelGraph *graph) {
mem_reuse_util_ptr_
->
set_mem_base
(
base_ptr
);
}
uint8_t
*
MemoryManager
::
MallocOutputMem
(
const
AnfNodePtr
&
node
,
size_t
index
,
MemType
type
,
size_t
size
)
{
uint8_t
*
MemoryManager
::
MallocOutputMem
(
const
AnfNodePtr
&
node
,
size_t
index
,
MemType
type
,
size_t
size
,
const
DeviceAddressPtr
&
address
)
{
MS_EXCEPTION_IF_NULL
(
node
);
MS_EXCEPTION_IF_NULL
(
address
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
uint8_t
*
ptr
=
nullptr
;
...
...
@@ -57,23 +59,30 @@ uint8_t *MemoryManager::MallocOutputMem(const AnfNodePtr &node, size_t index, Me
}
if
(
type
==
kStaticMem
)
{
ptr
=
MallocStaticMem
(
size
,
communication_mem
);
address
->
from_mem_pool_
=
true
;
if
(
communication_mem
)
{
address
->
communication_ptr_
=
ptr
-
kMemAlignSize
;
}
}
else
if
(
type
==
kReuseDynamicCommMem
)
{
MS_EXCEPTION_IF_NULL
(
mem_reuse_util_ptr_
);
ptr
=
mem_reuse_util_ptr_
->
GetNodeOutputPtr
(
node
,
index
);
}
else
{
ptr
=
MallocDynamicMem
(
size
,
communication_mem
);
}
address
->
ptr_
=
ptr
;
return
ptr
;
}
if
(
type
==
kStaticMem
)
{
ptr
=
MallocStaticMem
(
size
,
false
);
address
->
from_mem_pool_
=
true
;
}
else
if
(
type
==
kDynamicMem
)
{
ptr
=
MallocDynamicMem
(
size
,
false
);
}
else
if
(
type
==
kReuseDynamicMem
)
{
MS_EXCEPTION_IF_NULL
(
mem_reuse_util_ptr_
);
ptr
=
mem_reuse_util_ptr_
->
GetNodeOutputPtr
(
node
,
index
);
}
address
->
ptr_
=
ptr
;
return
ptr
;
}
...
...
@@ -85,38 +94,16 @@ uint8_t *MemoryManager::MallocWorkSpaceMem(const AnfNodePtr &node, size_t index,
return
MallocDynamicMem
(
size
,
false
);
}
uint8_t
*
MemoryManager
::
MallocMem
(
MemType
type
,
size_t
size
)
{
uint8_t
*
MemoryManager
::
MallocMem
(
MemType
type
,
size_t
size
,
const
DeviceAddressPtr
&
address
)
{
MS_EXCEPTION_IF_NULL
(
address
);
uint8_t
*
ptr
=
nullptr
;
if
(
type
==
kStaticMem
)
{
ptr
=
MallocStaticMem
(
size
,
false
);
address
->
from_mem_pool_
=
true
;
}
else
if
(
type
==
kDynamicMem
)
{
ptr
=
MallocDynamicMem
(
size
,
false
);
}
return
ptr
;
}
uint8_t
*
MemoryManager
::
MallocMem
(
const
DeviceAddressPtr
&
address
,
MemType
flag
,
size_t
size
,
const
session
::
KernelWithIndex
&
node_with_index
)
{
MS_EXCEPTION_IF_NULL
(
address
);
auto
context_ptr
=
MsContext
::
GetInstance
();
MS_EXCEPTION_IF_NULL
(
context_ptr
);
uint8_t
*
ptr
=
nullptr
;
if
(
node_with_index
.
first
!=
nullptr
)
{
ptr
=
MallocOutputMem
(
node_with_index
.
first
,
node_with_index
.
second
,
flag
,
size
);
MS_EXCEPTION_IF_NULL
(
ptr
);
if
(
AnfAlgo
::
IsCommunicationOp
(
node_with_index
.
first
)
&&
context_ptr
->
enable_hccl
())
{
address
->
communication_ptr_
=
ptr
-
kMemAlignSize
;
}
}
else
{
ptr
=
MallocMem
(
flag
,
size
);
MS_EXCEPTION_IF_NULL
(
ptr
);
}
address
->
ptr_
=
ptr
;
if
(
flag
==
kStaticMem
)
{
address
->
from_mem_pool_
=
true
;
}
return
ptr
;
}
...
...
mindspore/ccsrc/runtime/device/memory_manager.h
浏览文件 @
3a61d646
...
...
@@ -41,11 +41,10 @@ class MemoryManager {
}
void
MallocReusedDynamicMem
(
const
session
::
KernelGraph
*
graph
);
uint8_t
*
MallocOutputMem
(
const
AnfNodePtr
&
node
,
size_t
index
,
MemType
type
,
size_t
size
);
uint8_t
*
MallocOutputMem
(
const
AnfNodePtr
&
node
,
size_t
index
,
MemType
type
,
size_t
size
,
const
DeviceAddressPtr
&
address
);
uint8_t
*
MallocWorkSpaceMem
(
const
AnfNodePtr
&
node
,
size_t
index
,
MemType
type
,
size_t
size
);
uint8_t
*
MallocMem
(
const
DeviceAddressPtr
&
address
,
MemType
flag
,
size_t
size
,
const
session
::
KernelWithIndex
&
node_with_index
=
std
::
pair
<
AnfNodePtr
,
size_t
>
(
nullptr
,
0
));
virtual
uint8_t
*
MallocMem
(
MemType
type
,
size_t
size
);
virtual
uint8_t
*
MallocMem
(
MemType
type
,
size_t
size
,
const
DeviceAddressPtr
&
address
);
virtual
bool
MallocMemFromMemPool
(
const
DeviceAddressPtr
address
,
size_t
size
);
virtual
void
*
MallocMemFromMemPool
(
size_t
size
);
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录