Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
26cd0bb5
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 2 年 前同步成功
通知
2325
Star
20933
Fork
5424
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
26cd0bb5
编写于
6月 29, 2017
作者:
L
liaogang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
ENH: count allocated fallback size for performance
上级
464886bf
变更
3
显示空白变更内容
内联
并排
Showing
3 changed file
with
39 addition
and
20 deletion
+39
-20
paddle/memory/detail/system_allocator.cc
paddle/memory/detail/system_allocator.cc
+35
-17
paddle/memory/detail/system_allocator.h
paddle/memory/detail/system_allocator.h
+2
-1
python/paddle/trainer_config_helpers/networks.py
python/paddle/trainer_config_helpers/networks.py
+2
-2
未找到文件。
paddle/memory/detail/system_allocator.cc
浏览文件 @
26cd0bb5
...
@@ -39,22 +39,22 @@ void* CPUAllocator::Alloc(size_t& index, size_t size) {
...
@@ -39,22 +39,22 @@ void* CPUAllocator::Alloc(size_t& index, size_t size) {
// pointer shall not be dereferenced -- so we make it nullptr.
// pointer shall not be dereferenced -- so we make it nullptr.
if
(
size
<=
0
)
return
nullptr
;
if
(
size
<=
0
)
return
nullptr
;
if
(
FLAGS_use_pinned_memory
)
{
index
=
0
;
// unlock memory
void
*
p
=
malloc
(
size
);
void
*
p
=
malloc
(
size
);
if
(
p
!=
nullptr
)
{
if
(
p
!=
nullptr
)
{
mlock
(
p
,
size
);
if
(
FLAGS_use_pinned_memory
)
{
index
=
1
;
mlock
(
p
,
size
);
// lock memory
}
}
}
}
void
*
p
=
malloc
(
size
);
if
(
p
!=
nullptr
&&
FLAGS_use_pinned_memory
)
{
mlock
(
p
,
size
);
}
return
p
;
return
p
;
}
}
void
CPUAllocator
::
Free
(
void
*
p
,
size_t
size
,
size_t
index
)
{
void
CPUAllocator
::
Free
(
void
*
p
,
size_t
size
,
size_t
index
)
{
if
(
p
!=
nullptr
&&
FLAGS_use_pinned_memory
)
{
if
(
p
!=
nullptr
&&
index
==
1
)
{
munlock
(
p
,
size
);
munlock
(
p
,
size
);
}
}
free
(
p
);
free
(
p
);
...
@@ -73,26 +73,34 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
...
@@ -73,26 +73,34 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
// Reserve memory for page tables, etc.
// Reserve memory for page tables, etc.
size_t
reserving
=
capacity
-
paddle
::
platform
::
GpuMaxAllocSize
();
size_t
reserving
=
capacity
-
paddle
::
platform
::
GpuMaxAllocSize
();
size_t
remaining
=
available
>
reserving
?
available
-
reserving
:
0
;
size_t
usable
=
available
>
reserving
?
available
-
reserving
:
0
;
// If remaining size no less than expected size, using general
// If remaining size no less than expected size, using general
// cudaMalloc to allocate GPU memory.
// cudaMalloc to allocate GPU memory.
void
*
p
=
0
;
void
*
p
=
0
;
if
(
size
<=
remaining
)
{
if
(
size
<=
usable
)
{
cudaError_t
result
=
cudaMalloc
(
&
p
,
size
);
cudaError_t
result
=
cudaMalloc
(
&
p
,
size
);
if
(
result
==
cudaSuccess
)
{
if
(
result
==
cudaSuccess
)
{
index
=
0
;
index
=
0
;
total
_alloc_size_
+=
size
;
gpu
_alloc_size_
+=
size
;
return
p
;
return
p
;
}
}
}
}
// If remaining size less than expected size or cudaMalloc failed,
// If remaining size less than expected size or cudaMalloc failed,
// cudaMallocHost will be considered as a fallback allocator.
// cudaMallocHost will be considered as a fallback allocator.
//
// NOTE: here, we use GpuMaxAllocSize() as the maximum memory size
// of host fallback allocation. Allocates too much would reduce
// the amount of memory available to the underlying system for paging.
usable
=
paddle
::
platform
::
GpuMaxAllocSize
()
-
fallback_alloc_size_
;
if
(
size
>
usable
)
return
nullptr
;
cudaError_t
result
=
cudaMallocHost
(
&
p
,
size
);
cudaError_t
result
=
cudaMallocHost
(
&
p
,
size
);
if
(
result
==
cudaSuccess
)
{
if
(
result
==
cudaSuccess
)
{
index
=
1
;
index
=
1
;
total
_alloc_size_
+=
size
;
fallback
_alloc_size_
+=
size
;
return
p
;
return
p
;
}
}
...
@@ -100,16 +108,26 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
...
@@ -100,16 +108,26 @@ void* GPUAllocator::Alloc(size_t& index, size_t size) {
}
}
void
GPUAllocator
::
Free
(
void
*
p
,
size_t
size
,
size_t
index
)
{
void
GPUAllocator
::
Free
(
void
*
p
,
size_t
size
,
size_t
index
)
{
cudaError_t
err
;
if
(
index
==
0
)
{
PADDLE_ASSERT
(
gpu_alloc_size_
>=
size
);
gpu_alloc_size_
-=
size
;
err
=
cudaFree
(
p
);
}
else
{
PADDLE_ASSERT
(
fallback_alloc_size_
>=
size
);
fallback_alloc_size_
-=
size
;
err
=
cudaFreeHost
(
p
);
}
// Purposefully allow cudaErrorCudartUnloading, because
// Purposefully allow cudaErrorCudartUnloading, because
// that is returned if you ever call cudaFree after the
// that is returned if you ever call cudaFree after the
// driver has already shutdown. This happens only if the
// driver has already shutdown. This happens only if the
// process is terminating, in which case we don't care if
// process is terminating, in which case we don't care if
// cudaFree succeeds.
// cudaFree succeeds.
PADDLE_ASSERT
(
total_alloc_size_
>=
size
);
total_alloc_size_
-=
size
;
cudaError_t
err
=
index
==
1
?
cudaFreeHost
(
p
)
:
cudaFree
(
p
);
if
(
err
!=
cudaErrorCudartUnloading
)
{
if
(
err
!=
cudaErrorCudartUnloading
)
{
platform
::
throw_on_error
(
err
,
"cudaFree{Host} failed"
);
platform
::
throw_on_error
(
err
,
"cudaFree{Host} failed in GPUAllocator::Free."
);
}
}
}
}
...
...
paddle/memory/detail/system_allocator.h
浏览文件 @
26cd0bb5
...
@@ -47,7 +47,8 @@ class GPUAllocator : public SystemAllocator {
...
@@ -47,7 +47,8 @@ class GPUAllocator : public SystemAllocator {
virtual
void
Free
(
void
*
p
,
size_t
size
,
size_t
index
);
virtual
void
Free
(
void
*
p
,
size_t
size
,
size_t
index
);
private:
private:
size_t
total_alloc_size_
=
0
;
size_t
gpu_alloc_size_
=
0
;
size_t
fallback_alloc_size_
=
0
;
};
};
#endif // PADDLE_ONLY_CPU
#endif // PADDLE_ONLY_CPU
...
...
python/paddle/trainer_config_helpers/networks.py
浏览文件 @
26cd0bb5
...
@@ -1381,7 +1381,7 @@ def inputs(layers, *args):
...
@@ -1381,7 +1381,7 @@ def inputs(layers, *args):
if
len
(
args
)
!=
0
:
if
len
(
args
)
!=
0
:
layers
.
extend
(
args
)
layers
.
extend
(
args
)
Inputs
(
*
[
l
.
name
for
l
in
layers
])
Inputs
(
*
[
l
.
name
for
l
in
layers
])
def
outputs
(
layers
,
*
args
):
def
outputs
(
layers
,
*
args
):
...
@@ -1424,7 +1424,7 @@ def outputs(layers, *args):
...
@@ -1424,7 +1424,7 @@ def outputs(layers, *args):
assert
len
(
layers
)
>
0
assert
len
(
layers
)
>
0
if
HasInputsSet
():
# input already set
if
HasInputsSet
():
# input already set
Outputs
(
*
[
l
.
name
for
l
in
layers
])
Outputs
(
*
[
l
.
name
for
l
in
layers
])
return
# just return outputs.
return
# just return outputs.
if
len
(
layers
)
!=
1
:
if
len
(
layers
)
!=
1
:
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录