Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
63320f72
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
63320f72
编写于
2月 05, 2018
作者:
D
dzhwinter
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"add some interfaces"
上级
6f28084b
变更
3
隐藏空白更改
内联
并排
Showing
3 changed file
with
103 addition
and
39 deletion
+103
-39
paddle/framework/lod_tensor.h
paddle/framework/lod_tensor.h
+21
-1
paddle/framework/mixed_vector.h
paddle/framework/mixed_vector.h
+64
-38
paddle/memory/memory.h
paddle/memory/memory.h
+18
-0
未找到文件。
paddle/framework/lod_tensor.h
浏览文件 @
63320f72
...
@@ -48,12 +48,26 @@ namespace framework {
...
@@ -48,12 +48,26 @@ namespace framework {
*/
*/
struct
LoD
:
public
std
::
vector
<
Vector
<
size_t
>>
{
struct
LoD
:
public
std
::
vector
<
Vector
<
size_t
>>
{
using
std
::
vector
<
Vector
<
size_t
>>::
vector
;
using
std
::
vector
<
Vector
<
size_t
>>::
vector
;
platform
::
Place
place
()
const
{
if
(
this
->
size
()
==
0
)
{
// Not Initialze Yet.
return
platform
::
CPUPlace
();
}
else
{
return
this
->
front
().
place
();
}
}
void
CopyFromCUDA
()
{
void
CopyFromCUDA
()
{
for
(
auto
it
=
this
->
begin
();
it
!=
this
->
end
();
++
it
)
{
for
(
auto
it
=
this
->
begin
();
it
!=
this
->
end
();
++
it
)
{
it
->
CopyFromCUDA
();
it
->
CopyFromCUDA
();
}
}
}
}
void
CopyToPeer
(
platform
::
Place
place
)
{
for
(
auto
it
=
this
->
begin
();
it
!=
this
->
end
();
++
it
)
{
it
->
mutable_data
(
place
);
}
}
};
};
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
);
...
@@ -115,7 +129,13 @@ class LoDTensor : public Tensor {
...
@@ -115,7 +129,13 @@ class LoDTensor : public Tensor {
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
explicit
LoDTensor
(
const
LoD
&
lod
)
:
lod_
(
lod
)
{}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
}
void
set_lod
(
const
LoD
&
lod
)
{
lod_
=
lod
;
if
(
holder_
!=
nullptr
&&
platform
::
is_same_place
(
holder_
->
place
(),
lod
.
place
()))
{
lod_
.
CopyToPeer
(
holder_
->
place
());
}
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
const
LoD
&
lod
()
const
{
return
lod_
;
}
...
...
paddle/framework/mixed_vector.h
浏览文件 @
63320f72
...
@@ -40,14 +40,15 @@ class Vector : public std::vector<T> {
...
@@ -40,14 +40,15 @@ class Vector : public std::vector<T> {
Vector
()
{}
Vector
()
{}
Vector
(
const
std
::
vector
<
T
>
&
v
)
:
std
::
vector
<
T
>
(
v
)
{}
// NOLINT
Vector
(
const
std
::
vector
<
T
>
&
v
)
:
std
::
vector
<
T
>
(
v
)
{}
// NOLINT
virtual
~
Vector
()
{
inline
platform
::
Place
place
()
const
{
return
place_
;
}
#ifdef PADDLE_WITH_CUDA
if
(
cuda_ptr_
!=
nullptr
)
{
memory
::
Free
<
platform
::
CUDAPlace
>
(
place_
,
cuda_ptr_
);
}
#endif
}
/*! Return a pointer to constant memory block. */
inline
const
T
*
data
(
platform
::
Place
place
)
const
;
/*! Return a pointer to mutable memory block. */
inline
T
*
mutable_data
(
platform
::
Place
place
);
// TODO(dzhwinter): below interfaces should be removed
/* Get device vector */
/* Get device vector */
T
*
cuda_data
()
{
T
*
cuda_data
()
{
CopyToCUDA
();
CopyToCUDA
();
...
@@ -68,25 +69,71 @@ class Vector : public std::vector<T> {
...
@@ -68,25 +69,71 @@ class Vector : public std::vector<T> {
void
CopyToPeer
(
platform
::
Place
);
void
CopyToPeer
(
platform
::
Place
);
private:
private:
void
*
cuda_ptr_
=
nullptr
;
std
::
shared_ptr
<
void
>
cuda_ptr_
;
size_t
cuda_size_
=
0
;
// device vector numel
size_t
cuda_size_
=
0
;
// device vector numel
platform
::
CUDAPlace
place_
;
platform
::
CUDAPlace
place_
;
};
};
template
<
typename
T
>
template
<
typename
T
>
void
Vector
<
T
>::
CopyToCUDA
()
{
inline
const
T
*
Vector
<
T
>::
data
(
platform
::
Place
place
)
const
{
if
(
platform
::
is_cpu_place
(
place
))
{
return
std
::
vector
<
T
>::
data
();
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
cuda_ptr_
==
nullptr
)
{
return
nullptr
;
}
if
(
platform
::
is_same_place
(
place
,
place_
))
{
return
static_cast
<
const
T
*>
(
cuda_ptr_
.
get
());
}
else
{
PADDLE_THROW
(
"Unmatched place. Please use `mutable_data` copy lod to the target "
"Place first."
);
}
}
else
{
PADDLE_THROW
(
"Unsupport Place."
);
}
}
template
<
typename
T
>
inline
T
*
Vector
<
T
>::
mutable_data
(
platform
::
Place
place
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
return
std
::
vector
<
T
>::
data
();
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
!
platform
::
is_same_place
(
place
,
place_
))
{
place_
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
}
#ifdef PADDLE_WITH_CUDA
#ifdef PADDLE_WITH_CUDA
if
(
cuda_size_
<
this
->
size
())
{
if
(
cuda_size_
<
this
->
size
()
||
cuda_ptr_
==
nullptr
)
{
if
(
cuda_ptr_
!=
nullptr
)
{
cuda_ptr_
.
reset
(
memory
::
Free
<
platform
::
CUDAPlace
>
(
place_
,
cuda_ptr_
);
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
place_
,
this
->
size
()
*
sizeof
(
T
)),
memory
::
PlainDeleter
<
void
,
platform
::
CUDAPlace
>
(
place_
));
}
}
cuda_ptr_
=
cuda_size_
=
this
->
size
();
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
place_
,
this
->
size
()
*
sizeof
(
T
));
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
place_
,
cuda_ptr_
.
get
(),
platform
::
CPUPlace
(),
static_cast
<
const
void
*>
(
this
->
data
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
return
static_cast
<
T
*>
(
cuda_ptr_
.
get
());
#endif
}
else
{
PADDLE_THROW
(
"Unsupport Place."
);
}
}
template
<
typename
T
>
void
Vector
<
T
>::
CopyToCUDA
()
{
#ifdef PADDLE_WITH_CUDA
if
(
cuda_size_
<
this
->
size
()
||
cuda_ptr_
==
nullptr
)
{
cuda_ptr_
.
reset
(
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
this
->
size
()
*
sizeof
(
T
)),
memory
::
PlainDeleter
<
void
,
platform
::
CUDAPlace
>
(
place_
));
}
}
cuda_size_
=
this
->
size
();
cuda_size_
=
this
->
size
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
place_
,
cuda_ptr_
,
platform
::
CPUPlace
(),
memory
::
Copy
(
place_
,
cuda_ptr_
.
get
()
,
platform
::
CPUPlace
(),
static_cast
<
const
void
*>
(
this
->
data
()),
static_cast
<
const
void
*>
(
this
->
data
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
ctx
->
Wait
();
...
@@ -104,32 +151,11 @@ void Vector<T>::CopyFromCUDA() {
...
@@ -104,32 +151,11 @@ void Vector<T>::CopyFromCUDA() {
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
this
->
data
()),
place_
,
memory
::
Copy
(
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
this
->
data
()),
place_
,
static_cast
<
const
void
*>
(
cuda_ptr_
),
this
->
size
()
*
sizeof
(
T
),
static_cast
<
const
void
*>
(
cuda_ptr_
.
get
()),
ctx
->
stream
());
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
#endif
}
template
<
typename
T
>
void
Vector
<
T
>::
CopyToPeer
(
platform
::
Place
peer_place
)
{
#ifdef PADDLE_WITH_CUDA
auto
*
ctx
=
platform
::
DeviceContextPool
::
Instance
().
GetByPlace
(
place_
);
void
*
peer_cuda_ptr
=
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
boost
::
get
<
platform
::
CUDAPlace
>
(
peer_place
),
this
->
size
()
*
sizeof
(
T
));
memory
::
Copy
(
boost
::
get
<
platform
::
CUDAPlace
>
(
peer_place
),
peer_cuda_ptr
,
place_
,
cuda_ptr_
,
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
ctx
->
Wait
();
memory
::
Free
<
platform
::
CUDAPlace
>
(
place_
,
cuda_ptr_
);
place_
=
boost
::
get
<
platform
::
CUDAPlace
>
(
peer_place
);
cuda_ptr_
=
peer_cuda_ptr
;
#endif
#endif
}
}
template
class
Vector
<
int
>;
template
class
Vector
<
unsigned
>;
template
class
Vector
<
size_t
>;
template
class
Vector
<
int64_t
>;
}
// namespace framework
}
// namespace framework
}
// namespace paddle
}
// namespace paddle
paddle/memory/memory.h
浏览文件 @
63320f72
...
@@ -81,5 +81,23 @@ class PODDeleter {
...
@@ -81,5 +81,23 @@ class PODDeleter {
Place
place_
;
Place
place_
;
};
};
/**
* \brief Free memory block in one place does not meet POD
*
* \note In some cases, custom deleter is used to
* deallocate the memory automatically for
* std::unique_ptr<T> in tensor.h.
*
*/
template
<
typename
T
,
typename
Place
>
class
PlainDeleter
{
public:
explicit
PlainDeleter
(
Place
place
)
:
place_
(
place
)
{}
void
operator
()(
T
*
ptr
)
{
Free
(
place_
,
reinterpret_cast
<
void
*>
(
ptr
));
}
private:
Place
place_
;
};
}
// namespace memory
}
// namespace memory
}
// namespace paddle
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录