Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
PaddleDetection
提交
5364b394
P
PaddleDetection
项目概览
PaddlePaddle
/
PaddleDetection
大约 1 年 前同步成功
通知
694
Star
11112
Fork
2696
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
184
列表
看板
标记
里程碑
合并请求
40
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
PaddleDetection
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
184
Issue
184
列表
看板
标记
里程碑
合并请求
40
合并请求
40
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
5364b394
编写于
7月 28, 2017
作者:
Q
qijun
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
use cuda default stream
上级
d962c2a9
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
19 addition
and
44 deletion
+19
-44
cmake/external/eigen.cmake
cmake/external/eigen.cmake
+1
-10
paddle/framework/detail/tensor-inl.h
paddle/framework/detail/tensor-inl.h
+4
-8
paddle/framework/tensor_test.cc
paddle/framework/tensor_test.cc
+9
-9
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+3
-3
paddle/memory/memcpy.h
paddle/memory/memcpy.h
+1
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+1
-8
paddle/platform/device_context.h
paddle/platform/device_context.h
+0
-5
未找到文件。
cmake/external/eigen.cmake
浏览文件 @
5364b394
...
...
@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add
(
extern_eigen3
${
EXTERNAL_PROJECT_LOG_ARGS
}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY
"https://github.com/RLovelett/eigen.git"
GIT_TAG
"
a46d2e7337c4656f00abe54a8115f6d76153a048
"
GIT_TAG
"
master
"
PREFIX
${
EIGEN_SOURCE_DIR
}
UPDATE_COMMAND
""
CONFIGURE_COMMAND
""
...
...
paddle/framework/detail/tensor-inl.h
浏览文件 @
5364b394
...
...
@@ -83,14 +83,13 @@ inline void Tensor::ShareDataWith(const Tensor& src) {
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CPU
DeviceContext
&
ctx
)
{
const
platform
::
CPU
Place
&
dst_place
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
...
...
@@ -110,26 +109,23 @@ inline void Tensor::CopyFrom(const Tensor& src,
#ifndef PADDLE_ONLY_CPU
template
<
typename
T
>
inline
void
Tensor
::
CopyFrom
(
const
Tensor
&
src
,
const
platform
::
CUDADeviceContext
&
ctx
)
{
const
platform
::
GPUPlace
&
dst_place
)
{
src
.
check_memory_size
<
T
>
();
Resize
(
src
.
dims
());
auto
src_place
=
src
.
holder_
->
place
();
auto
src_ptr
=
static_cast
<
const
void
*>
(
src
.
data
<
T
>
());
auto
dst_place
=
ctx
.
GetPlace
();
auto
dst_ptr
=
static_cast
<
void
*>
(
mutable_data
<
T
>
(
dst_place
));
auto
size
=
product
(
src
.
dims_
)
*
sizeof
(
T
);
if
(
platform
::
is_cpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
boost
::
get
<
platform
::
CPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
else
if
(
platform
::
is_gpu_place
(
src_place
))
{
memory
::
Copy
(
boost
::
get
<
platform
::
GPUPlace
>
(
dst_place
),
dst_ptr
,
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
ctx
.
stream
());
boost
::
get
<
platform
::
GPUPlace
>
(
src_place
),
src_ptr
,
size
,
0
);
}
}
#endif
...
...
paddle/framework/tensor_test.cc
浏览文件 @
5364b394
...
...
@@ -198,8 +198,8 @@ TEST(Tensor, CopyFrom) {
int
arr
[
9
]
=
{
1
,
2
,
3
,
4
,
5
,
6
,
7
,
8
,
9
};
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
auto
*
cpu_ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_
ctx
);
auto
cpu_place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
cpu_
place
);
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
src_ptr
,
dst_ptr
);
...
...
@@ -208,7 +208,7 @@ TEST(Tensor, CopyFrom) {
}
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_
ctx
);
dst_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
cpu_
place
);
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
dst_ptr
=
dst_tensor
.
data
<
int
>
();
ASSERT_NE
(
dst_ptr
,
slice_ptr
);
...
...
@@ -228,12 +228,12 @@ TEST(Tensor, CopyFrom) {
memcpy
(
src_ptr
,
arr
,
9
*
sizeof
(
int
));
// CPU Tensor to GPU Tensor
auto
gpu_
ctx
=
new
paddle
::
platform
::
CUDADeviceContext
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_
ctx
);
auto
gpu_
place
=
new
paddle
::
platform
::
GPUPlace
(
0
);
gpu_tensor
.
CopyFrom
<
int
>
(
src_tensor
,
*
gpu_
place
);
// GPU Tensor to CPU Tensor
auto
cpu_
ctx
=
new
paddle
::
platform
::
CPUDeviceContext
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
ctx
);
auto
cpu_
place
=
new
paddle
::
platform
::
CPUPlace
();
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
place
);
// Compare Tensors
const
int
*
dst_ptr
=
dst_tensor
.
data
<
int
>
();
...
...
@@ -245,10 +245,10 @@ TEST(Tensor, CopyFrom) {
Tensor
slice_tensor
=
src_tensor
.
Slice
<
int
>
(
1
,
2
);
// CPU Slice Tensor to GPU Tensor
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_
ctx
);
gpu_tensor
.
CopyFrom
<
int
>
(
slice_tensor
,
*
gpu_
place
);
// GPU Tensor to CPU Tensor
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
ctx
);
dst_tensor
.
CopyFrom
<
int
>
(
gpu_tensor
,
*
cpu_
place
);
// Compare Slice Tensors
const
int
*
slice_ptr
=
slice_tensor
.
data
<
int
>
();
...
...
paddle/memory/memcpy.cc
浏览文件 @
5364b394
...
...
@@ -34,7 +34,7 @@ void Copy<platform::CPUPlace, platform::GPUPlace>(platform::CPUPlace dst_place,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
=
0
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToHost
,
stream
);
}
...
...
@@ -44,7 +44,7 @@ void Copy<platform::GPUPlace, platform::CPUPlace>(platform::GPUPlace dst_place,
void
*
dst
,
platform
::
CPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
=
0
)
{
platform
::
SetDeviceId
(
dst_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyHostToDevice
,
stream
);
}
...
...
@@ -54,7 +54,7 @@ void Copy<platform::GPUPlace, platform::GPUPlace>(platform::GPUPlace dst_place,
void
*
dst
,
platform
::
GPUPlace
src_place
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
)
{
cudaStream_t
stream
=
0
)
{
if
(
dst_place
==
src_place
)
{
platform
::
SetDeviceId
(
src_place
.
device
);
platform
::
GpuMemcpyAsync
(
dst
,
src
,
num
,
cudaMemcpyDeviceToDevice
,
stream
);
...
...
paddle/memory/memcpy.h
浏览文件 @
5364b394
...
...
@@ -51,7 +51,7 @@ void Copy(DstPlace, void* dst, SrcPlace, const void* src, size_t num);
*/
template
<
typename
DstPlace
,
typename
SrcPlace
>
void
Copy
(
DstPlace
,
void
*
dst
,
SrcPlace
,
const
void
*
src
,
size_t
num
,
cudaStream_t
stream
);
cudaStream_t
stream
=
0
);
#endif // PADDLE_ONLY_CPU
...
...
paddle/platform/device_context.cc
浏览文件 @
5364b394
...
...
@@ -43,7 +43,6 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
// TODO (qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
...
...
@@ -76,15 +75,12 @@ CUDADeviceContext::~CUDADeviceContext() {
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
const
{
return
stream_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
));
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
...
...
@@ -95,7 +91,6 @@ cublasHandle_t CUDADeviceContext::cublas_handle() {
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
}
return
cublas_handle_
;
}
...
...
@@ -104,7 +99,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
return
cudnn_handle_
;
}
...
...
@@ -116,7 +110,6 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
...
...
paddle/platform/device_context.h
浏览文件 @
5364b394
...
...
@@ -61,9 +61,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */
void
Wait
()
const
;
/*! \brief Return CUDA stream in the device context. */
cudaStream_t
stream
()
const
;
/*! \brief Return place in the device context. */
Place
GetPlace
()
const
override
;
...
...
@@ -91,8 +88,6 @@ class CUDADeviceContext : public DeviceContext {
private:
uint64_t
seed_
;
cudaStream_t
stream_
;
// clang-format off
cudnnHandle_t
cudnn_handle_
=
nullptr
;
cublasHandle_t
cublas_handle_
=
nullptr
;
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录