Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
efdb4aa2
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
efdb4aa2
编写于
8月 16, 2017
作者:
G
gangliao
提交者:
GitHub
8月 16, 2017
浏览文件
操作
浏览文件
下载
差异文件
Merge pull request #3497 from QiJune/implement_EigenCudaStreamDevice
Implement EigenCudaStreamDevice
上级
c94a324c
13c20ad3
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
81 addition
and
44 deletion
+81
-44
Dockerfile
Dockerfile
+0
-14
cmake/flags.cmake
cmake/flags.cmake
+1
-8
paddle/memory/CMakeLists.txt
paddle/memory/CMakeLists.txt
+1
-1
paddle/memory/memcpy.cc
paddle/memory/memcpy.cc
+0
-2
paddle/platform/CMakeLists.txt
paddle/platform/CMakeLists.txt
+4
-1
paddle/platform/device_context.cc
paddle/platform/device_context.cc
+65
-14
paddle/platform/device_context.h
paddle/platform/device_context.h
+9
-4
paddle/platform/device_context_test.cc
paddle/platform/device_context_test.cc
+1
-0
未找到文件。
Dockerfile
浏览文件 @
efdb4aa2
...
...
@@ -71,20 +71,6 @@ RUN pip install -r /root/requirements.txt
RUN
apt-get
install
-y
libssl-dev libffi-dev
RUN
pip
install
certifi urllib3[secure]
# TODO(qijun) The template library Eigen doesn't work well with GCC 5
# coming with the default Docker image, so we switch to use GCC 4.8
# by default. And I will check Eigen library later.
RUN
ln
-sf
gcc-4.8 /usr/bin/gcc
&&
\
ln
-sf
gcc-ar-4.8 /usr/bin/gcc-ar
&&
\
ln
-sf
gcc-nm-4.8 /usr/bin/gcc-nm
&&
\
ln
-sf
gcc-ranlib-4.8 /usr/bin/gcc-ranlib
&&
\
ln
-sf
gcc-4.8 /usr/bin/x86_64-linux-gnu-gcc
&&
\
ln
-sf
gcc-ar-4.8 /usr/bin/x86_64-linux-gnu-gcc-ar
&&
\
ln
-sf
gcc-nm-4.8 /usr/bin/x86_64-linux-gnu-gcc-nm
&&
\
ln
-sf
gcc-ranlib-4.8 /usr/bin/x86_64-linux-gnu-gcc-ranlib
&&
\
ln
-sf
g++-4.8 /usr/bin/g++
&&
\
ln
-sf
g++-4.8 /usr/bin/x86_64-linux-gnu-g++
# Install woboq_codebrowser to /woboq
RUN
git clone https://github.com/woboq/woboq_codebrowser /woboq
&&
\
...
...
cmake/flags.cmake
浏览文件 @
efdb4aa2
...
...
@@ -9,13 +9,6 @@ function(CheckCompilerCXX11Flag)
if
(
${
CMAKE_CXX_COMPILER_VERSION
}
VERSION_LESS 4.8
)
message
(
FATAL_ERROR
"Unsupported GCC version. GCC >= 4.8 required."
)
endif
()
if
(
NOT ANDROID
)
# TODO(qijun) gcc 4.9 or later versions raise SEGV due to the optimization problem.
# Use Debug mode instead for now.
if
(
CMAKE_CXX_COMPILER_VERSION VERSION_GREATER 4.9 OR CMAKE_CXX_COMPILER_VERSION VERSION_EQUAL 4.9
)
set
(
CMAKE_BUILD_TYPE
"Debug"
CACHE STRING
""
FORCE
)
endif
()
endif
()
elseif
(
CMAKE_CXX_COMPILER_ID STREQUAL
"AppleClang"
OR CMAKE_CXX_COMPILER_ID STREQUAL
"Clang"
)
# cmake >= 3.0 compiler id "AppleClang" on Mac OS X, otherwise "Clang"
# Apple Clang is a different compiler than upstream Clang which havs different version numbers.
...
...
@@ -160,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
--default-stream per-thread
)
LIST
(
APPEND CUDA_NVCC_FLAGS -std=c++11
)
LIST
(
APPEND CUDA_NVCC_FLAGS --use_fast_math
)
if
(
CMAKE_BUILD_TYPE STREQUAL
"Debug"
)
...
...
paddle/memory/CMakeLists.txt
浏览文件 @
efdb4aa2
add_subdirectory
(
detail
)
cc_library
(
memory SRCS memory.cc
)
cc_library
(
memcpy SRCS memcpy.cc
DEPS device_context
)
cc_library
(
memcpy SRCS memcpy.cc
)
cc_library
(
paddle_memory
DEPS
...
...
paddle/memory/memcpy.cc
浏览文件 @
efdb4aa2
...
...
@@ -16,8 +16,6 @@ limitations under the License. */
#include <cstring> // for memcpy
#include "paddle/platform/device_context.h"
namespace
paddle
{
namespace
memory
{
...
...
paddle/platform/CMakeLists.txt
浏览文件 @
efdb4aa2
...
...
@@ -16,5 +16,8 @@ ELSE()
set
(
GPU_CTX_DEPS
)
ENDIF
()
cc_library
(
device_context SRCS device_context.cc DEPS place eigen3
${
GPU_CTX_DEPS
}
)
# memcpy deoends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library
(
device_context SRCS device_context.cc DEPS memory buddy_allocator
system_allocator memory_block meta_data meta_cache place eigen3
${
GPU_CTX_DEPS
}
)
nv_test
(
device_context_test SRCS device_context_test.cc DEPS device_context gpu_info
)
paddle/platform/device_context.cc
浏览文件 @
efdb4aa2
...
...
@@ -10,6 +10,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/platform/device_context.h"
#include "paddle/memory/memory.h"
namespace
paddle
{
namespace
platform
{
...
...
@@ -36,6 +37,59 @@ Place CPUDeviceContext::GetPlace() const { return CPUPlace(); }
#ifndef PADDLE_ONLY_CPU
class
EigenCudaStreamDevice
:
public
Eigen
::
StreamInterface
{
public:
EigenCudaStreamDevice
()
:
scratch_
(
nullptr
),
semaphore_
(
nullptr
)
{
Eigen
::
initializeDeviceProp
();
}
~
EigenCudaStreamDevice
()
override
{}
void
Reinitialize
(
const
cudaStream_t
*
cuda_stream
,
GPUPlace
place
)
{
stream_
=
cuda_stream
;
place_
=
place
;
device_prop_
=
&
Eigen
::
m_deviceProperties
[
place
.
device
];
}
const
cudaStream_t
&
stream
()
const
override
{
return
*
stream_
;
}
const
cudaDeviceProp
&
deviceProperties
()
const
override
{
return
*
device_prop_
;
}
void
*
allocate
(
size_t
num_bytes
)
const
override
{
return
paddle
::
memory
::
Alloc
(
place_
,
num_bytes
);
}
void
deallocate
(
void
*
buffer
)
const
override
{
paddle
::
memory
::
Free
(
place_
,
buffer
);
}
void
*
scratchpad
()
const
override
{
if
(
scratch_
==
NULL
)
{
scratch_
=
allocate
(
Eigen
::
kCudaScratchSize
+
sizeof
(
unsigned
int
));
}
return
scratch_
;
}
unsigned
int
*
semaphore
()
const
override
{
if
(
semaphore_
==
NULL
)
{
char
*
scratch
=
static_cast
<
char
*>
(
scratchpad
())
+
Eigen
::
kCudaScratchSize
;
semaphore_
=
reinterpret_cast
<
unsigned
int
*>
(
scratch
);
PADDLE_ENFORCE
(
cudaMemsetAsync
(
semaphore_
,
0
,
sizeof
(
unsigned
int
),
*
stream_
));
}
return
semaphore_
;
}
private:
GPUPlace
place_
;
const
cudaStream_t
*
stream_
;
// not owned;
const
cudaDeviceProp
*
device_prop_
;
// not owned;
mutable
void
*
scratch_
;
mutable
unsigned
int
*
semaphore_
;
};
template
<
>
Eigen
::
GpuDevice
*
DeviceContext
::
get_eigen_device
<
Eigen
::
GpuDevice
>
()
const
{
return
reinterpret_cast
<
const
CUDADeviceContext
*>
(
this
)
->
eigen_device
();
...
...
@@ -43,19 +97,9 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
CUDADeviceContext
::
CUDADeviceContext
(
GPUPlace
place
)
:
place_
(
place
)
{
SetDeviceId
(
place_
.
device
);
// TODO(qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add –default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_
.
reset
(
new
Eigen
::
CudaStreamDevice
());
PADDLE_ENFORCE
(
cudaStreamCreate
(
&
stream_
));
eigen_stream_
.
reset
(
new
EigenCudaStreamDevice
());
eigen_stream_
->
Reinitialize
(
&
stream_
,
place
);
eigen_device_
.
reset
(
new
Eigen
::
GpuDevice
(
eigen_stream_
.
get
()));
}
...
...
@@ -75,12 +119,13 @@ CUDADeviceContext::~CUDADeviceContext() {
}
eigen_stream_
.
reset
();
eigen_device_
.
reset
();
PADDLE_ENFORCE
(
cudaStreamDestroy
(
stream_
));
}
Place
CUDADeviceContext
::
GetPlace
()
const
{
return
place_
;
}
void
CUDADeviceContext
::
Wait
()
const
{
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
0
));
PADDLE_ENFORCE
(
cudaStreamSynchronize
(
stream_
));
}
Eigen
::
GpuDevice
*
CUDADeviceContext
::
eigen_device
()
const
{
...
...
@@ -91,6 +136,7 @@ cublasHandle_t CUDADeviceContext::cublas_handle() {
if
(
!
cublas_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cublasCreate
(
&
cublas_handle_
));
PADDLE_ENFORCE
(
dynload
::
cublasSetStream
(
cublas_handle_
,
stream_
));
}
return
cublas_handle_
;
}
...
...
@@ -99,10 +145,13 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if
(
!
cudnn_handle_
)
{
SetDeviceId
(
place_
.
device
);
PADDLE_ENFORCE
(
dynload
::
cudnnCreate
(
&
cudnn_handle_
));
PADDLE_ENFORCE
(
dynload
::
cudnnSetStream
(
cudnn_handle_
,
stream_
));
}
return
cudnn_handle_
;
}
cudaStream_t
CUDADeviceContext
::
stream
()
{
return
stream_
;
}
curandGenerator_t
CUDADeviceContext
::
curand_generator
()
{
if
(
!
curand_generator_
)
{
SetDeviceId
(
place_
.
device
);
...
...
@@ -110,6 +159,8 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
CURAND_RNG_PSEUDO_DEFAULT
));
PADDLE_ENFORCE
(
dynload
::
curandSetPseudoRandomGeneratorSeed
(
curand_generator_
,
seed_
));
PADDLE_ENFORCE
(
dynload
::
curandSetStream
(
curand_generator_
,
stream_
));
}
return
curand_generator_
;
}
...
...
paddle/platform/device_context.h
浏览文件 @
efdb4aa2
...
...
@@ -52,6 +52,7 @@ class CPUDeviceContext : public DeviceContext {
};
#ifndef PADDLE_ONLY_CPU
class
EigenCudaStreamDevice
;
class
CUDADeviceContext
:
public
DeviceContext
{
public:
...
...
@@ -76,6 +77,9 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Return curand handle in the device context. */
curandGenerator_t
curand_generator
();
/*! \brief Return cuda stream in the device context. */
cudaStream_t
stream
();
// clang-format on
private:
...
...
@@ -83,15 +87,16 @@ class CUDADeviceContext : public DeviceContext {
private:
std
::
unique_ptr
<
Eigen
::
GpuDevice
>
eigen_device_
;
std
::
unique_ptr
<
Eigen
::
CudaStreamDevice
>
eigen_stream_
;
std
::
unique_ptr
<
EigenCudaStreamDevice
>
eigen_stream_
;
private:
uint64_t
seed_
;
// clang-format off
cudnnHandle_t
cudnn_handle_
=
nullptr
;
cublasHandle_t
cublas_handle_
=
nullptr
;
curandGenerator_t
curand_generator_
=
nullptr
;
cudaStream_t
stream_
{
nullptr
};
cudnnHandle_t
cudnn_handle_
{
nullptr
};
cublasHandle_t
cublas_handle_
{
nullptr
};
curandGenerator_t
curand_generator_
{
nullptr
};
// clang-format on
};
...
...
paddle/platform/device_context_test.cc
浏览文件 @
efdb4aa2
...
...
@@ -45,6 +45,7 @@ TEST(Device, CUDADeviceContext) {
ASSERT_NE
(
nullptr
,
cublas_handle
);
curandGenerator_t
curand_handle
=
device_context
->
curand_generator
();
ASSERT_NE
(
nullptr
,
curand_handle
);
ASSERT_NE
(
nullptr
,
device_context
->
stream
());
delete
device_context
;
}
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录