Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
机器未来
Paddle
提交
ef1aba39
P
Paddle
项目概览
机器未来
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
ef1aba39
编写于
2月 08, 2018
作者:
Y
Yu Yang
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Rewrite mixed_vector.h
上级
b1869f16
变更
24
隐藏空白更改
内联
并排
Showing
24 changed file
with
316 addition
and
268 deletion
+316
-268
.gitignore
.gitignore
+6
-0
cmake/cuda.cmake
cmake/cuda.cmake
+2
-1
paddle/framework/lod_tensor.h
paddle/framework/lod_tensor.h
+1
-23
paddle/framework/lod_tensor_test.cu
paddle/framework/lod_tensor_test.cu
+4
-5
paddle/framework/mixed_vector.h
paddle/framework/mixed_vector.h
+260
-139
paddle/framework/mixed_vector_test.cu
paddle/framework/mixed_vector_test.cu
+0
-59
paddle/framework/tensor.h
paddle/framework/tensor.h
+4
-0
paddle/framework/tensor_impl.h
paddle/framework/tensor_impl.h
+1
-1
paddle/operators/adagrad_op.cu
paddle/operators/adagrad_op.cu
+3
-3
paddle/operators/adam_op.h
paddle/operators/adam_op.h
+1
-1
paddle/operators/ctc_align_op.cu
paddle/operators/ctc_align_op.cu
+3
-2
paddle/operators/lookup_table_op.cu
paddle/operators/lookup_table_op.cu
+3
-1
paddle/operators/math/selected_rows_functor.cc
paddle/operators/math/selected_rows_functor.cc
+1
-1
paddle/operators/math/selected_rows_functor.cu
paddle/operators/math/selected_rows_functor.cu
+9
-6
paddle/operators/math/sequence2batch.cu
paddle/operators/math/sequence2batch.cu
+2
-2
paddle/operators/math/sequence_padding.cu
paddle/operators/math/sequence_padding.cu
+4
-4
paddle/operators/math/sequence_pooling.cu
paddle/operators/math/sequence_pooling.cu
+2
-1
paddle/operators/math/sequence_scale.cu
paddle/operators/math/sequence_scale.cu
+2
-1
paddle/operators/parallel_do_op.cc
paddle/operators/parallel_do_op.cc
+0
-9
paddle/operators/row_conv_op.cu
paddle/operators/row_conv_op.cu
+2
-2
paddle/operators/sequence_erase_op.cu
paddle/operators/sequence_erase_op.cu
+1
-2
paddle/operators/sgd_op.cu
paddle/operators/sgd_op.cu
+2
-2
paddle/operators/target_assign_op.h
paddle/operators/target_assign_op.h
+2
-2
paddle/testing/paddle_gtest_main.cc
paddle/testing/paddle_gtest_main.cc
+1
-1
未找到文件。
.gitignore
浏览文件 @
ef1aba39
paddle/operators/check_t.save
paddle/operators/check_tensor.ls
paddle/operators/tensor.save
python/paddle/v2/fluid/tests/book/image_classification_resnet.inference.model/
python/paddle/v2/fluid/tests/book/image_classification_vgg.inference.model/
python/paddle/v2/fluid/tests/book/label_semantic_roles.inference.model/
*.DS_Store
build/
build_doc/
...
...
cmake/cuda.cmake
浏览文件 @
ef1aba39
...
...
@@ -181,7 +181,8 @@ elseif(CMAKE_BUILD_TYPE STREQUAL "Release")
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"RelWithDebInfo"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELWITHDEBINFO
}
)
elseif
(
CMAKE_BUILD_TYPE STREQUAL
"MinSizeRel"
)
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_MINSIZEREL
}
)
# nvcc 9 does not support -Os. Use Release flags instead
list
(
APPEND CUDA_NVCC_FLAGS
${
CMAKE_CXX_FLAGS_RELEASE
}
)
endif
()
mark_as_advanced
(
CUDA_BUILD_CUBIN CUDA_BUILD_EMULATION CUDA_VERBOSE_BUILD
)
...
...
paddle/framework/lod_tensor.h
浏览文件 @
ef1aba39
...
...
@@ -46,29 +46,7 @@ namespace framework {
* 0 2 4 7
* 0 2 5 7 10 12 15 20
*/
struct
LoD
:
public
std
::
vector
<
Vector
<
size_t
>>
{
using
std
::
vector
<
Vector
<
size_t
>>::
vector
;
platform
::
Place
place
()
const
{
if
(
this
->
size
()
==
0
)
{
// Not Initialze Yet.
return
platform
::
CPUPlace
();
}
else
{
return
this
->
front
().
place
();
}
}
void
CopyFromCUDA
()
{
for
(
auto
it
=
this
->
begin
();
it
!=
this
->
end
();
++
it
)
{
it
->
CopyFromCUDA
();
}
}
void
CopyToPeer
(
platform
::
Place
place
)
{
for
(
auto
it
=
this
->
begin
();
it
!=
this
->
end
();
++
it
)
{
it
->
CopyToPeer
(
place
);
}
}
};
using
LoD
=
std
::
vector
<
Vector
<
size_t
>>
;
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoD
&
lod
);
std
::
ostream
&
operator
<<
(
std
::
ostream
&
os
,
const
LoDTensor
&
t
);
...
...
paddle/framework/lod_tensor_test.cu
浏览文件 @
ef1aba39
...
...
@@ -20,6 +20,7 @@
#include "paddle/platform/assert.h"
#include <gtest/gtest.h>
#include <paddle/platform/place.h>
__global__
void
test
(
size_t
*
a
,
int
size
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
...
...
@@ -36,10 +37,9 @@ TEST(LoD, data) {
lod
.
push_back
(
std
::
vector
<
size_t
>
({
0
,
1
,
6
,
8
,
10
,
11
}));
auto
&
v
=
lod
[
0
];
test
<<<
1
,
1
>>>
(
v
.
cuda_data
(),
v
.
size
());
paddle
::
platform
::
CUDAPlace
gpu
(
0
);
test
<<<
1
,
1
>>>
(
v
.
CUDAMutableData
(
gpu
),
v
.
size
());
cudaDeviceSynchronize
();
v
.
CopyFromCUDA
();
for
(
size_t
i
=
0
;
i
<
v
.
size
();
++
i
)
{
EXPECT_EQ
(
v
[
i
],
i
*
2
);
}
...
...
@@ -63,9 +63,8 @@ TEST(LoDTensor, LoDInGPU) {
auto
lod
=
lod_tensor
.
lod
();
test
<<<
1
,
8
>>>
(
lod
[
0
].
cuda_data
(
),
lod
[
0
].
size
());
test
<<<
1
,
8
>>>
(
lod
[
0
].
CUDAMutableData
(
place
),
lod
[
0
].
size
());
cudaDeviceSynchronize
();
lod
.
CopyFromCUDA
();
for
(
size_t
i
=
0
;
i
<
src_lod
[
0
].
size
();
++
i
)
{
EXPECT_EQ
(
lod
[
0
].
data
()[
i
],
src_lod
[
0
].
data
()[
i
]
*
2
);
...
...
paddle/framework/mixed_vector.h
浏览文件 @
ef1aba39
...
...
@@ -17,176 +17,297 @@
#include <initializer_list>
#include <vector>
#include "paddle/memory/memcpy.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/enforce.h"
#include "paddle/platform/place.h"
#include "paddle/framework/tensor.h"
#include "paddle/framework/tensor_util.h"
#include "glog/logging.h"
namespace
paddle
{
namespace
framework
{
/**
* @brief Vector support both cpu and gpu.
* host vector lifetime is same with Vector
* device vector is lazily malloc and modified.
*/
template
<
typename
T
>
class
Vector
:
public
std
::
vector
<
T
>
{
class
Vector
{
public:
using
std
::
vector
<
T
>::
vector
;
using
value_type
=
T
;
Vector
()
{
size_
=
0
;
flag_
=
kDataInCPU
;
}
explicit
Vector
(
size_t
count
,
const
T
&
value
=
T
())
{
resize
(
count
);
T
*
ptr
=
begin
();
for
(
size_t
i
=
0
;
i
<
count
;
++
i
)
{
ptr
[
i
]
=
value
;
}
}
Vector
(
std
::
initializer_list
<
T
>
init
)
{
InitByIter
(
init
.
size
(),
init
.
begin
(),
init
.
end
());
}
template
<
typename
U
>
Vector
(
const
std
::
vector
<
U
>&
dat
)
{
// NOLINT
InitByIter
(
dat
.
size
(),
dat
.
begin
(),
dat
.
end
());
}
Vector
(
const
Vector
<
T
>&
other
)
{
this
->
operator
=
(
other
);
}
Vector
<
T
>&
operator
=
(
const
Vector
<
T
>&
other
)
{
if
(
other
.
size
()
!=
0
)
{
this
->
InitByIter
(
other
.
size
(),
other
.
begin
(),
other
.
end
());
}
else
{
size_
=
0
;
flag_
=
kDataInCPU
;
}
return
*
this
;
}
Vector
(
Vector
<
T
>&&
other
)
{
this
->
size_
=
other
.
size_
;
this
->
flag_
=
other
.
flag_
;
if
(
other
.
cuda_vec_
.
capacity
())
{
this
->
cuda_vec_
.
ShareDataWith
(
other
.
cuda_vec_
);
}
if
(
other
.
cpu_vec_
.
capacity
())
{
this
->
cpu_vec_
.
ShareDataWith
(
other
.
cpu_vec_
);
}
}
Vector
()
{}
Vector
(
const
std
::
vector
<
T
>
&
v
)
:
std
::
vector
<
T
>
(
v
)
{}
// NOLINT
T
&
operator
[](
size_t
i
)
{
MutableCPU
();
return
const_cast
<
T
*>
(
cpu_vec_
.
data
<
T
>
())[
i
];
}
const
T
&
operator
[](
size_t
i
)
const
{
ImmutableCPU
();
return
cpu_vec_
.
data
<
T
>
()[
i
];
}
size_t
size
()
const
{
return
size_
;
}
T
*
begin
()
{
return
&
this
->
operator
[](
0
);
}
T
*
end
()
{
return
&
this
->
operator
[](
size
());
}
T
&
front
()
{
return
*
begin
();
}
T
&
back
()
{
auto
it
=
end
();
--
it
;
return
*
it
;
}
const
T
*
begin
()
const
{
return
&
this
->
operator
[](
0
);
}
const
T
*
end
()
const
{
return
&
this
->
operator
[](
size
());
}
inline
platform
::
Place
place
()
const
{
return
place_
;
}
const
T
&
back
()
const
{
auto
it
=
end
();
--
it
;
return
*
it
;
}
const
T
&
front
()
const
{
return
*
begin
();
}
template
<
typename
Iter
>
void
assign
(
Iter
begin
,
Iter
end
)
{
InitByIter
(
end
-
begin
,
begin
,
end
);
}
T
*
data
()
{
return
begin
();
}
/*! Return a pointer to constant memory block. */
inline
const
T
*
data
(
platform
::
Place
place
)
const
;
const
T
*
data
()
const
{
return
begin
();
}
/*! Return a pointer to mutable memory block. */
inline
T
*
mutable_data
(
platform
::
Place
place
);
void
push_back
(
T
elem
)
{
if
(
size_
+
1
>
capacity
())
{
reserve
((
size_
+
1
)
<<
1
);
}
*
end
()
=
elem
;
++
size_
;
}
// TODO(dzhwinter): below interfaces should be removed
/* Get device vector */
T
*
cuda_data
()
{
CopyToCUDA
();
PADDLE_ENFORCE_NOT_NULL
(
cuda_ptr_
,
"No data or Insufficient CUDA memory to allocation"
);
return
static_cast
<
T
*>
(
cuda_ptr_
.
get
());
void
resize
(
size_t
size
)
{
if
(
size
+
1
<
capacity
())
{
size_
=
size
;
}
else
{
MutableCPU
();
Tensor
cpu_tensor
;
platform
::
Place
cpu
=
platform
::
CPUPlace
();
T
*
ptr
=
cpu_tensor
.
mutable_data
<
T
>
(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
size
)}),
cpu
);
const
T
*
old_ptr
=
cpu_vec_
.
capacity
()
==
0
?
nullptr
:
cpu_vec_
.
data
<
T
>
();
if
(
old_ptr
!=
nullptr
)
{
std
::
copy
(
old_ptr
,
old_ptr
+
size_
,
ptr
);
}
size_
=
size
;
cpu_vec_
.
ShareDataWith
(
cpu_tensor
);
}
}
/* Get host vector */
T
*
data
()
{
return
std
::
vector
<
T
>::
data
();
}
const
T
*
data
()
const
{
return
std
::
vector
<
T
>::
data
();
}
const
T
*
CUDAData
(
platform
::
Place
place
)
const
{
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
place
),
"CUDA Data must on CUDA place"
);
ImmutableCUDA
(
place
);
return
cuda_vec_
.
data
<
T
>
();
}
T
*
data
(
const
platform
::
Place
&
place
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
T
*
CUDAMutableData
(
platform
::
Place
place
)
{
const
T
*
ptr
=
CUDAData
(
place
);
flag_
=
kDirty
|
kDataInCUDA
;
return
const_cast
<
T
*>
(
ptr
);
}
template
<
typename
It
>
void
Extend
(
It
begin
,
It
end
)
{
size_t
pre_size
=
size_
;
resize
(
pre_size
+
(
end
-
begin
));
T
*
ptr
=
this
->
begin
()
+
pre_size
;
for
(;
begin
<
end
;
++
begin
,
++
ptr
)
{
*
ptr
=
*
begin
;
}
}
void
clear
()
{
size_
=
0
;
flag_
=
kDirty
|
kDataInCPU
;
}
size_t
capacity
()
const
{
return
cpu_vec_
.
capacity
()
/
SizeOfType
(
typeid
(
T
));
}
void
reserve
(
size_t
size
)
{
size_t
pre_size
=
size_
;
resize
(
size
);
resize
(
pre_size
);
}
const
T
*
Data
(
platform
::
Place
place
)
const
{
if
(
platform
::
is_gpu_place
(
place
))
{
return
CUDAData
(
place
);
}
else
{
return
data
();
}
}
T
*
MutableData
(
platform
::
Place
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
return
CUDAMutableData
(
place
);
}
else
{
return
cuda_
data
();
return
data
();
}
}
/* Synchronize host vector to device vector */
void
CopyToCUDA
();
/* Synchronize device vector to host vector */
void
CopyFromCUDA
();
/* Switch device vector location */
void
CopyToPeer
(
platform
::
Place
);
operator
std
::
vector
<
T
>
()
const
{
std
::
vector
<
T
>
result
;
result
.
resize
(
size
());
std
::
copy
(
begin
(),
end
(),
result
.
begin
());
return
result
;
}
bool
operator
==
(
const
Vector
<
T
>&
other
)
const
{
if
(
size
()
!=
other
.
size
())
return
false
;
for
(
auto
it1
=
begin
(),
it2
=
other
.
begin
();
it1
<
end
();
++
it1
,
++
it2
)
{
if
(
*
it1
!=
*
it2
)
{
return
false
;
}
}
return
true
;
}
private:
std
::
shared_ptr
<
void
>
cuda_ptr_
;
size_t
cuda_size_
=
0
;
// device vector numel
platform
::
CUDAPlace
place_
;
};
template
<
typename
Iter
>
void
InitByIter
(
size_t
size
,
Iter
begin
,
Iter
end
)
{
platform
::
Place
cpu
=
platform
::
CPUPlace
();
T
*
ptr
=
this
->
cpu_vec_
.
template
mutable_data
<
T
>(
framework
::
make_ddim
({
static_cast
<
int64_t
>
(
size
)}),
cpu
);
for
(
size_t
i
=
0
;
i
<
size
;
++
i
)
{
*
ptr
++
=
*
begin
++
;
}
flag_
=
kDataInCPU
|
kDirty
;
size_
=
size
;
}
template
<
typename
T
>
inline
const
T
*
Vector
<
T
>::
data
(
platform
::
Place
place
)
const
{
if
(
platform
::
is_cpu_place
(
place
)
)
{
return
std
::
vector
<
T
>::
data
();
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
cuda_ptr_
==
nullptr
)
{
return
nullptr
;
enum
DataFlag
{
kDataInCPU
=
0x01
,
kDataInCUDA
=
0x02
,
kDirty
=
0x10
};
void
MutableCPU
(
)
{
if
(
IsInCUDA
()
&&
IsDirty
())
{
// COPY GPU Data To CPU
Copy
(
cuda_vec_
,
platform
::
CPUPlace
(),
&
cpu_vec_
);
WaitPlace
(
cuda_vec_
.
place
())
;
}
if
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
)
==
place_
)
{
return
static_cast
<
const
T
*>
(
cuda_ptr_
.
get
());
flag_
=
kDirty
|
kDataInCPU
;
}
void
ImmutableCUDA
(
platform
::
Place
place
)
const
{
if
(
IsDirty
())
{
if
(
IsInCPU
())
{
Copy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
WaitPlace
(
place
);
UnsetFlag
(
kDirty
);
SetFlag
(
kDataInCUDA
);
}
else
if
(
IsInCUDA
()
&&
!
(
place
==
cuda_vec_
.
place
()))
{
framework
::
Tensor
tmp
;
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
WaitPlace
(
cuda_vec_
.
place
());
cuda_vec_
.
ShareDataWith
(
tmp
);
// Still dirty
}
else
{
// Dirty && DataInCUDA && Device is same
// Do nothing
}
}
else
{
PADDLE_THROW
(
"Unmatched place. Please use `mutable_data` copy lod to the target "
"Place first."
);
if
(
!
IsInCUDA
())
{
// Even data is not dirty. However, data is not in CUDA. Copy data.
Copy
(
cpu_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
cuda_vec_
);
WaitPlace
(
place
);
SetFlag
(
kDataInCUDA
);
}
else
if
(
!
(
place
==
cuda_vec_
.
place
()))
{
framework
::
Tensor
tmp
;
Copy
(
cuda_vec_
,
boost
::
get
<
platform
::
CUDAPlace
>
(
place
),
&
tmp
);
WaitPlace
(
cuda_vec_
.
place
());
cuda_vec_
.
ShareDataWith
(
tmp
);
}
else
{
// Not Dirty && DataInCUDA && Device is same
// Do nothing.
}
}
}
else
{
PADDLE_THROW
(
"Unsupport Place."
);
}
}
template
<
typename
T
>
inline
T
*
Vector
<
T
>::
mutable_data
(
platform
::
Place
place
)
{
if
(
platform
::
is_cpu_place
(
place
))
{
return
std
::
vector
<
T
>::
data
();
}
else
if
(
platform
::
is_gpu_place
(
place
))
{
if
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
)
!=
place_
)
{
place_
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
void
ImmutableCPU
()
const
{
if
(
IsDirty
()
&&
!
IsInCPU
())
{
// If data has been changed in CUDA, or CPU has no data.
Copy
(
cuda_vec_
,
platform
::
CPUPlace
(),
&
cpu_vec_
);
WaitPlace
(
cuda_vec_
.
place
());
UnsetFlag
(
kDirty
);
}
#ifdef PADDLE_WITH_CUDA
if
(
cuda_size_
<
this
->
size
()
||
cuda_ptr_
==
nullptr
)
{
cuda_ptr_
.
reset
(
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
place_
,
this
->
size
()
*
sizeof
(
T
)),
memory
::
PlainDeleter
<
void
,
platform
::
CUDAPlace
>
(
place_
));
}
cuda_size_
=
this
->
size
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
place_
,
cuda_ptr_
.
get
(),
platform
::
CPUPlace
(),
static_cast
<
const
void
*>
(
this
->
data
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
return
static_cast
<
T
*>
(
cuda_ptr_
.
get
());
#else
return
nullptr
;
#endif
}
else
{
PADDLE_THROW
(
"Unsupport Place."
);
}
}
SetFlag
(
kDataInCPU
);
}
template
<
typename
T
>
void
Vector
<
T
>::
CopyToCUDA
()
{
#ifdef PADDLE_WITH_CUDA
if
(
cuda_size_
<
this
->
size
()
||
cuda_ptr_
==
nullptr
)
{
cuda_ptr_
.
reset
(
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
place_
,
this
->
size
()
*
sizeof
(
T
)),
memory
::
PlainDeleter
<
void
,
platform
::
CUDAPlace
>
(
place_
));
}
cuda_size_
=
this
->
size
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
place_
,
cuda_ptr_
.
get
(),
platform
::
CPUPlace
(),
static_cast
<
const
void
*>
(
this
->
data
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
#endif
}
void
UnsetFlag
(
int
flag
)
const
{
flag_
&=
~
flag
;
}
void
SetFlag
(
int
flag
)
const
{
flag_
|=
flag
;
}
template
<
typename
T
>
void
Vector
<
T
>::
CopyFromCUDA
()
{
#ifdef PADDLE_WITH_CUDA
if
(
cuda_ptr_
==
nullptr
)
{
LOG
(
WARNING
)
<<
"No uncommitted cuda data."
;
return
;
}
this
->
resize
(
cuda_size_
);
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
platform
::
CPUPlace
(),
static_cast
<
void
*>
(
this
->
data
()),
place_
,
static_cast
<
const
void
*>
(
cuda_ptr_
.
get
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
#endif
}
bool
IsDirty
()
const
{
return
flag_
&
kDirty
;
}
template
<
typename
T
>
void
Vector
<
T
>::
CopyToPeer
(
platform
::
Place
place
)
{
#ifdef PADDLE_WITH_CUDA
if
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
)
!=
place_
)
{
place_
=
boost
::
get
<
platform
::
CUDAPlace
>
(
place
);
}
if
(
cuda_size_
<
this
->
size
()
||
cuda_ptr_
==
nullptr
)
{
cuda_ptr_
.
reset
(
memory
::
Alloc
<
platform
::
CUDAPlace
>
(
place_
,
this
->
size
()
*
sizeof
(
T
)),
memory
::
PlainDeleter
<
void
,
platform
::
CUDAPlace
>
(
place_
));
}
cuda_size_
=
this
->
size
();
platform
::
DeviceContextPool
&
pool
=
platform
::
DeviceContextPool
::
Instance
();
auto
*
ctx
=
pool
.
GetByPlace
(
place_
);
memory
::
Copy
(
place_
,
cuda_ptr_
.
get
(),
platform
::
CPUPlace
(),
static_cast
<
const
void
*>
(
this
->
data
()),
this
->
size
()
*
sizeof
(
T
),
ctx
->
stream
());
ctx
->
Wait
();
#endif
}
bool
IsInCUDA
()
const
{
return
flag_
&
kDataInCUDA
;
}
bool
IsInCPU
()
const
{
return
flag_
&
kDataInCPU
;
}
static
void
WaitPlace
(
const
platform
::
Place
place
)
{
if
(
platform
::
is_gpu_place
(
place
))
{
platform
::
DeviceContextPool
::
Instance
()
.
Get
(
boost
::
get
<
platform
::
CUDAPlace
>
(
place
))
->
Wait
();
}
}
mutable
int
flag_
;
mutable
Tensor
cpu_vec_
;
mutable
Tensor
cuda_vec_
;
size_t
size_
;
};
}
// namespace framework
}
// namespace paddle
paddle/framework/mixed_vector_test.cu
浏览文件 @
ef1aba39
...
...
@@ -11,62 +11,3 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cuda.h>
#include <cuda_runtime.h>
#include "gtest/gtest.h"
#include "paddle/framework/init.h"
#include "paddle/framework/mixed_vector.h"
using
namespace
paddle
::
framework
;
using
namespace
paddle
::
platform
;
using
namespace
paddle
::
memory
;
template
<
typename
T
>
__global__
void
test
(
T
*
data
,
int
size
)
{
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<
size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
data
[
i
]
*=
2
;
}
}
TEST
(
Vector
,
Normal
)
{
// fill the device context pool.
InitDevices
();
Vector
<
size_t
>
vec
({
1
,
2
,
3
});
size_t
*
ptr
=
vec
.
data
();
for
(
size_t
i
=
0
;
i
<
vec
.
size
();
++
i
)
{
EXPECT_EQ
(
vec
[
i
],
*
(
ptr
+
i
));
}
vec
.
clear
();
vec
.
CopyFromCUDA
();
std
::
vector
<
size_t
>
v
=
{
1
,
2
,
3
};
for
(
size_t
i
=
0
;
i
<
v
.
size
();
++
i
)
{
EXPECT_EQ
(
v
[
i
],
vec
[
i
]);
}
}
TEST
(
Vector
,
MultipleCopy
)
{
InitDevices
();
Vector
<
size_t
>
vec
({
1
,
2
,
3
});
CUDAPlace
place
(
0
);
vec
.
mutable_data
(
place
);
auto
vec2
=
Vector
<
size_t
>
(
vec
);
{
const
size_t
*
ptr
=
vec2
.
data
(
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
vec2
.
size
();
++
i
)
{
EXPECT_EQ
(
*
(
ptr
+
i
),
vec
[
i
]);
}
}
test
<
size_t
><<<
3
,
3
>>>
(
vec2
.
mutable_data
(
place
),
vec2
.
size
());
vec2
.
CopyFromCUDA
();
{
const
size_t
*
ptr
=
vec2
.
data
(
CPUPlace
());
for
(
size_t
i
=
0
;
i
<
vec2
.
size
();
++
i
)
{
EXPECT_EQ
(
*
(
ptr
+
i
),
vec
[
i
]
*
2
);
}
}
}
paddle/framework/tensor.h
浏览文件 @
ef1aba39
...
...
@@ -128,6 +128,10 @@ class Tensor {
inline
void
set_layout
(
const
DataLayout
layout
)
{
layout_
=
layout
;
}
size_t
capacity
()
const
{
return
holder_
==
nullptr
?
0UL
:
holder_
->
size
()
-
offset_
;
}
private:
friend
class
LoDTensor
;
...
...
paddle/framework/tensor_impl.h
浏览文件 @
ef1aba39
...
...
@@ -52,7 +52,7 @@ struct SizeOfTypeFunctor<HEAD, TAIL...> {
};
static
inline
size_t
SizeOfType
(
std
::
type_index
type
)
{
SizeOfTypeFunctor
<
int
,
float
,
double
,
int16_t
,
int64_t
,
bool
>
functor
;
SizeOfTypeFunctor
<
int
,
float
,
double
,
int16_t
,
int64_t
,
bool
,
size_t
>
functor
;
size_t
size
=
functor
(
type
);
PADDLE_ENFORCE
(
size
!=
0UL
,
"Cannot get size of type %s"
,
type
.
name
());
return
size
;
...
...
paddle/operators/adagrad_op.cu
浏览文件 @
ef1aba39
...
...
@@ -101,9 +101,9 @@ struct SparseAdagradFunctor<platform::CUDADeviceContext, T> {
SparseAdagradFunctorKernel
<
T
,
256
><<<
grid2
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
grad_merge_data
,
merge_rows
.
cuda_data
(),
lr
,
param_data
,
moment_data
,
grad_width
,
epsilon
);
.
stream
()
>>>
(
grad_merge_data
,
merge_rows
.
CUDAMutableData
(
context
.
GetPlace
()),
lr
,
param_data
,
moment_data
,
grad_width
,
epsilon
);
}
};
...
...
paddle/operators/adam_op.h
浏览文件 @
ef1aba39
...
...
@@ -201,7 +201,7 @@ class AdamOpKernel : public framework::OpKernel<T> {
const
T
*
grad_data
=
grad_tensor
.
template
data
<
T
>();
int64_t
*
rows
=
nullptr
;
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
rows
=
grad_merge
.
mutable_rows
()
->
cuda_data
(
);
rows
=
grad_merge
.
mutable_rows
()
->
CUDAMutableData
(
ctx
.
GetPlace
()
);
}
else
{
rows
=
grad_merge
.
mutable_rows
()
->
data
();
}
...
...
paddle/operators/ctc_align_op.cu
浏览文件 @
ef1aba39
...
...
@@ -69,8 +69,9 @@ class CTCAlignOpCUDAKernel : public framework::OpKernel<T> {
auto
stream
=
ctx
.
cuda_device_context
().
stream
();
MergeAndDelCudaKernel
<
T
><<<
1
,
1
,
0
,
stream
>>>
(
num_tokens
,
tokens
,
num_seq
,
input_lod
[
level
].
cuda_data
(),
blank
,
merge_repeated
,
dev_out_lod0_ptr
,
output_data
);
num_tokens
,
tokens
,
num_seq
,
input_lod
[
level
].
CUDAMutableData
(
ctx
.
GetPlace
()),
blank
,
merge_repeated
,
dev_out_lod0_ptr
,
output_data
);
// set output lod
std
::
vector
<
size_t
>
host_out_lod0
(
dev_out_lod0
.
begin
(),
dev_out_lod0
.
end
());
...
...
paddle/operators/lookup_table_op.cu
浏览文件 @
ef1aba39
...
...
@@ -125,7 +125,9 @@ class LookupTableGradCUDAKernel : public framework::OpKernel<T> {
new_rows
.
resize
(
ids_dim
[
0
]);
auto
gpu_place
=
boost
::
get
<
platform
::
CUDAPlace
>
(
context
.
GetPlace
());
memory
::
Copy
(
platform
::
CPUPlace
(),
new_rows
.
cuda_data
(),
gpu_place
,
// TODO(yuyang18): Strange code here.
memory
::
Copy
(
platform
::
CPUPlace
(),
new_rows
.
CUDAMutableData
(
context
.
GetPlace
()),
gpu_place
,
ids_data
,
ids_dim
[
0
]
*
sizeof
(
int64_t
),
stream
);
d_table
->
set_rows
(
new_rows
);
...
...
paddle/operators/math/selected_rows_functor.cc
浏览文件 @
ef1aba39
...
...
@@ -128,7 +128,7 @@ struct SelectedRowsAddTo<platform::CPUDeviceContext, T> {
auto
*
in2_value
=
input2
->
mutable_value
();
// concat rows
in2_rows
.
insert
(
in2_rows
.
end
(),
in1_rows
.
begin
(),
in1_rows
.
end
());
in2_rows
.
Extend
(
in1_rows
.
begin
(),
in1_rows
.
end
());
auto
in1_place
=
input1
.
place
();
PADDLE_ENFORCE
(
platform
::
is_cpu_place
(
in1_place
));
...
...
paddle/operators/math/selected_rows_functor.cu
浏览文件 @
ef1aba39
...
...
@@ -126,7 +126,8 @@ struct SelectedRowsAddTensor<platform::CUDADeviceContext, T> {
dim3
grid
(
1
,
in1_rows
.
size
());
SelectedRowsAddTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
in1_rows
.
cuda_data
(),
out_data
,
in1_row_numel
);
in1_data
,
in1_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
in1_row_numel
);
auto
out_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
*
output
);
auto
in2_eigen
=
framework
::
EigenVector
<
T
>::
Flatten
(
input2
);
...
...
@@ -153,7 +154,7 @@ struct SelectedRowsAddTo<platform::CUDADeviceContext, T> {
auto
*
in2_value
=
input2
->
mutable_value
();
// concat rows
in2_rows
.
insert
(
in2_rows
.
end
(),
in1_rows
.
begin
(),
in1_rows
.
end
());
in2_rows
.
Extend
(
in1_rows
.
begin
(),
in1_rows
.
end
());
auto
in1_place
=
input1
.
place
();
PADDLE_ENFORCE
(
platform
::
is_gpu_place
(
in1_place
));
...
...
@@ -216,7 +217,8 @@ struct SelectedRowsAddToTensor<platform::CUDADeviceContext, T> {
dim3
grid
(
1
,
in1_rows
.
size
());
SelectedRowsAddToTensorKernel
<
T
,
block_size
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
in1_data
,
in1_rows
.
cuda_data
(),
in2_data
,
in1_row_numel
);
in1_data
,
in1_rows
.
CUDAData
(
context
.
GetPlace
()),
in2_data
,
in1_row_numel
);
}
};
...
...
@@ -283,9 +285,10 @@ struct MergeAdd<platform::CUDADeviceContext, T> {
MergeAddKernel
<
T
,
256
><<<
grid1
,
threads
,
0
,
reinterpret_cast
<
const
platform
::
CUDADeviceContext
&>
(
context
)
.
stream
()
>>>
(
input_data
,
input_rows
.
cuda_data
(),
out_data
,
out
.
mutable_rows
()
->
cuda_data
(),
out
.
rows
().
size
(),
input_width
);
.
stream
()
>>>
(
input_data
,
input_rows
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
out
.
mutable_rows
()
->
CUDAMutableData
(
context
.
GetPlace
()),
out
.
rows
().
size
(),
input_width
);
return
out
;
}
};
...
...
paddle/operators/math/sequence2batch.cu
浏览文件 @
ef1aba39
...
...
@@ -45,7 +45,6 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
const
framework
::
Tensor
&
src
,
framework
::
Vector
<
size_t
>
index_lod
,
framework
::
Tensor
&
dst
,
bool
is_src_index
)
{
size_t
*
index
=
index_lod
.
cuda_data
();
auto
src_dims
=
src
.
dims
();
auto
dst_dims
=
dst
.
dims
();
PADDLE_ENFORCE_EQ
(
src_dims
.
size
(),
2
,
...
...
@@ -63,7 +62,8 @@ class CopyMatrixRowsFunctor<platform::CUDADeviceContext, T> {
dim3
grid
(
8
,
1
);
auto
stream
=
context
.
stream
();
CopyMatrixRowsKernel
<
T
,
128
,
8
,
8
><<<
grid
,
threads
,
0
,
stream
>>>
(
src_data
,
dst_data
,
index
,
height
,
width
,
is_src_index
);
src_data
,
dst_data
,
index_lod
.
CUDAData
(
context
.
GetPlace
()),
height
,
width
,
is_src_index
);
}
};
...
...
paddle/operators/math/sequence_padding.cu
浏览文件 @
ef1aba39
...
...
@@ -121,12 +121,12 @@ class PaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
padding_data
,
const_cast
<
T
*>
(
seq_data
),
abs_offset_lod
[
level
].
cuda_data
(
),
sequence_width
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()
),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
else
{
SequencePaddingKernel
<
T
,
0
,
1
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
padding_data
,
const_cast
<
T
*>
(
seq_data
),
abs_offset_lod
[
level
].
cuda_data
(
),
sequence_width
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()
),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
}
...
...
@@ -196,12 +196,12 @@ class UnpaddingLoDTensorFunctor<platform::CUDADeviceContext, T> {
if
(
norm_by_times
)
{
SequencePaddingKernel
<
T
,
1
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
padding_data
),
seq_data
,
abs_offset_lod
[
level
].
cuda_data
(
),
sequence_width
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()
),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
else
{
SequencePaddingKernel
<
T
,
0
,
0
><<<
grid
,
threads
,
0
,
context
.
stream
()
>>>
(
const_cast
<
T
*>
(
padding_data
),
seq_data
,
abs_offset_lod
[
level
].
cuda_data
(
),
sequence_width
,
abs_offset_lod
[
level
].
CUDAData
(
context
.
GetPlace
()
),
sequence_width
,
max_sequence_length
,
num_sequences
);
}
}
...
...
paddle/operators/math/sequence_pooling.cu
浏览文件 @
ef1aba39
...
...
@@ -73,7 +73,8 @@ class MaxSeqPoolFunctor<platform::CUDADeviceContext, T> {
dim3
grid
(
num_seq
,
1
);
auto
stream
=
context
.
stream
();
KeMaxSequencePool
<
T
><<<
grid
,
threads
,
0
,
stream
>>>
(
in_data
,
starts
.
cuda_data
(),
out_data
,
max_index
,
num_seq
,
dim
);
in_data
,
starts
.
CUDAData
(
context
.
GetPlace
()),
out_data
,
max_index
,
num_seq
,
dim
);
}
};
...
...
paddle/operators/math/sequence_scale.cu
浏览文件 @
ef1aba39
...
...
@@ -46,7 +46,8 @@ class ScaleLoDTensorFunctor<platform::CUDADeviceContext, T> {
SequenceScaleKernel
<
T
,
PADDLE_CUDA_NUM_THREADS
><<<
num_seq
,
PADDLE_CUDA_NUM_THREADS
,
0
,
context
.
stream
()
>>>
(
seq_data
,
abs_offset_lod
[
level
].
cuda_data
(),
scales
,
seq_width
);
seq_data
,
abs_offset_lod
[
level
].
CUDAMutableData
(
context
.
GetPlace
()),
scales
,
seq_width
);
}
};
...
...
paddle/operators/parallel_do_op.cc
浏览文件 @
ef1aba39
...
...
@@ -79,9 +79,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst
->
GetMutable
<
LoDTensor
>
()
->
set_lod
(
src
.
Get
<
LoDTensor
>
().
lod
());
}
else
{
Copy
(
src
.
Get
<
LoDTensor
>
(),
dst_place
,
dst
->
GetMutable
<
LoDTensor
>
());
framework
::
LoD
lod
(
src
.
Get
<
LoDTensor
>
().
lod
());
lod
.
CopyToPeer
(
dst_place
);
dst
->
GetMutable
<
LoDTensor
>
()
->
set_lod
(
lod
);
}
}
else
if
(
src
.
IsType
<
SelectedRows
>
())
{
auto
&
src_sr
=
src
.
Get
<
SelectedRows
>
();
...
...
@@ -92,9 +89,6 @@ inline void CopyOrShare(const framework::Variable &src,
dst_sr
->
set_rows
(
src_sr
.
rows
());
}
else
{
Copy
(
src_sr
.
value
(),
dst_place
,
dst_sr
->
mutable_value
());
framework
::
Vector
<
int64_t
>
lod
(
src_sr
.
rows
());
lod
.
CopyToPeer
(
dst_place
);
dst_sr
->
set_rows
(
lod
);
}
}
else
{
PADDLE_THROW
(
"Expect LoDTensor/SelectedRows, get %s"
,
src
.
Type
().
name
());
...
...
@@ -152,9 +146,6 @@ class ParallelDoOp : public framework::OperatorBase {
auto
*
sub_scope
=
sub_scopes
[
i
];
auto
*
dst
=
sub_scope
->
Var
(
param
)
->
GetMutable
<
LoDTensor
>
();
framework
::
Copy
(
src
,
place
,
dst
);
framework
::
LoD
lod
(
src
.
lod
());
lod
.
CopyToPeer
(
place
);
dst
->
set_lod
(
lod
);
}
}
WaitOnPlaces
(
places
);
...
...
paddle/operators/row_conv_op.cu
浏览文件 @
ef1aba39
...
...
@@ -307,7 +307,7 @@ class RowConvKernel<platform::CUDADeviceContext, T>
int
input_dim
=
X
->
dims
()[
1
];
int
num_sequence
=
batch_indices
.
size
()
-
1
;
int
future_context
=
Filter
->
dims
()[
0
];
size_t
*
idx
=
batch_indices
.
cuda_data
(
);
size_t
*
idx
=
batch_indices
.
CUDAMutableData
(
context
.
GetPlace
()
);
auto
stream
=
context
.
cuda_device_context
().
stream
();
if
(
future_context
<=
32
)
{
...
...
@@ -345,7 +345,7 @@ class RowConvGradKernel<platform::CUDADeviceContext, T>
int
input_dim
=
X
->
dims
()[
1
];
int
num_sequence
=
batch_indices
.
size
()
-
1
;
int
future_context
=
Filter
->
dims
()[
0
];
size_t
*
idx
=
batch_indices
.
cuda_data
(
);
size_t
*
idx
=
batch_indices
.
CUDAMutableData
(
context
.
GetPlace
()
);
auto
&
device_ctx
=
context
.
cuda_device_context
();
math
::
SetConstant
<
platform
::
CUDADeviceContext
,
T
>
zero
;
...
...
paddle/operators/sequence_erase_op.cu
浏览文件 @
ef1aba39
...
...
@@ -87,8 +87,7 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> {
// Copy LoD to GPU
auto
lod0
=
lod
[
0
];
auto
lod_len
=
lod0
.
size
();
thrust
::
device_vector
<
size_t
>
dev_in_lod
=
lod0
;
size_t
*
dev_in_lod_ptr
=
thrust
::
raw_pointer_cast
(
dev_in_lod
.
data
());
const
size_t
*
dev_in_lod_ptr
=
lod0
.
CUDAData
(
ctx
.
GetPlace
());
// Calc output LoD
thrust
::
device_vector
<
size_t
>
dev_out_lod
(
lod_len
);
...
...
paddle/operators/sgd_op.cu
浏览文件 @
ef1aba39
...
...
@@ -102,8 +102,8 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
dim3
grid
(
1
,
in_rows
.
size
());
SparseSGDFunctorKernel
<
T
,
256
><<<
grid
,
threads
,
0
,
ctx
.
cuda_device_context
().
stream
()
>>>
(
in_data
,
in_rows
.
cuda_data
(),
learning_rate
->
data
<
T
>
(),
out_data
,
in_row_numel
);
in_data
,
in_rows
.
CUDAData
(
ctx
.
GetPlace
()),
learning_rate
->
data
<
T
>
()
,
out_data
,
in_row_numel
);
}
else
{
PADDLE_THROW
(
"Unsupported Variable Type of Grad"
);
...
...
paddle/operators/target_assign_op.h
浏览文件 @
ef1aba39
...
...
@@ -137,8 +137,8 @@ class TargetAssignKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_EQ
(
gt_lod
.
data
()[
i
],
gt_label_lod
.
data
()[
i
]);
}
size_t
*
gt_lod_data
=
gt_lod
.
d
ata
(
ctx
.
GetPlace
());
size_t
*
neg_lod_data
=
neg_lod
.
d
ata
(
ctx
.
GetPlace
());
size_t
*
gt_lod_data
=
gt_lod
.
MutableD
ata
(
ctx
.
GetPlace
());
size_t
*
neg_lod_data
=
neg_lod
.
MutableD
ata
(
ctx
.
GetPlace
());
TargetAssignFunctor
<
T
>
functor
(
box_data
,
label_data
,
match_idx_data
,
gt_lod_data
,
background_label
,
num
,
...
...
paddle/testing/paddle_gtest_main.cc
浏览文件 @
ef1aba39
...
...
@@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/memory/memory.h"
int
main
(
int
argc
,
char
**
argv
)
{
testing
::
InitGoogleTest
(
&
argc
,
argv
);
std
::
vector
<
char
*>
new_argv
;
std
::
string
gflags_env
;
for
(
int
i
=
0
;
i
<
argc
;
++
i
)
{
...
...
@@ -35,7 +36,6 @@ int main(int argc, char** argv) {
int
new_argc
=
static_cast
<
int
>
(
new_argv
.
size
());
char
**
new_argv_address
=
new_argv
.
data
();
google
::
ParseCommandLineFlags
(
&
new_argc
,
&
new_argv_address
,
false
);
testing
::
InitGoogleTest
(
&
argc
,
argv
);
paddle
::
memory
::
Used
(
paddle
::
platform
::
CPUPlace
());
#ifdef PADDLE_WITH_CUDA
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录