提交 84b368d2 编写于 作者: Q QI JUN 提交者: GitHub

Merge pull request #3101 from QiJune/fix_bug_in_CUDADeviceContext

Fix bug in cuda device context
......@@ -7,17 +7,8 @@ INCLUDE_DIRECTORIES(${EIGEN_SOURCE_DIR}/src/extern_eigen3)
ExternalProject_Add(
extern_eigen3
${EXTERNAL_PROJECT_LOG_ARGS}
# for latest version, please get from official website
# URL "https://bitbucket.org/eigen/eigen/get/3.3.4.tar.gz"
# URL_MD5 "1a47e78efe365a97de0c022d127607c3"
# for no-ssl http support, please get from bazel's mirror
# URL "http://mirror.bazel.build/bitbucket.org/eigen/eigen/get/f3a22f35b044.tar.gz"
# URL_MD5 "4645c66075982da6fa0bcf6b20f3e8f7"
# get from github mirror
GIT_REPOSITORY "https://github.com/RLovelett/eigen.git"
GIT_TAG "a46d2e7337c4656f00abe54a8115f6d76153a048"
GIT_TAG "master"
PREFIX ${EIGEN_SOURCE_DIR}
UPDATE_COMMAND ""
CONFIGURE_COMMAND ""
......
......@@ -153,7 +153,7 @@ set(CUDA_PROPAGATE_HOST_FLAGS OFF)
# Release/Debug flags set by cmake. Such as -O3 -g -DNDEBUG etc.
# So, don't set these flags here.
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11)
LIST(APPEND CUDA_NVCC_FLAGS -std=c++11 --default-stream per-thread)
LIST(APPEND CUDA_NVCC_FLAGS --use_fast_math)
if(CMAKE_BUILD_TYPE STREQUAL "Debug")
......
......@@ -83,56 +83,38 @@ inline void Tensor::ShareDataWith(const Tensor& src) {
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::CPUDeviceContext& ctx) {
const platform::Place& dst_place) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_place = ctx.GetPlace();
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size);
}
#ifndef PADDLE_ONLY_CPU
else if (platform::is_gpu_place(src_place)) {
else if (platform::is_gpu_place(src_place) &&
platform::is_cpu_place(dst_place)) {
memory::Copy(boost::get<platform::CPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
#endif
}
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void Tensor::CopyFrom(const Tensor& src,
const platform::CUDADeviceContext& ctx) {
src.check_memory_size<T>();
Resize(src.dims());
auto src_place = src.holder_->place();
auto src_ptr = static_cast<const void*>(src.data<T>());
auto dst_place = ctx.GetPlace();
auto dst_ptr = static_cast<void*>(mutable_data<T>(dst_place));
auto size = product(src.dims_) * sizeof(T);
if (platform::is_cpu_place(src_place)) {
} else if (platform::is_cpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::CPUPlace>(src_place), src_ptr, size,
ctx.stream());
} else if (platform::is_gpu_place(src_place)) {
boost::get<platform::CPUPlace>(src_place), src_ptr, size, 0);
} else if (platform::is_gpu_place(src_place) &&
platform::is_gpu_place(dst_place)) {
memory::Copy(boost::get<platform::GPUPlace>(dst_place), dst_ptr,
boost::get<platform::GPUPlace>(src_place), src_ptr, size,
ctx.stream());
boost::get<platform::GPUPlace>(src_place), src_ptr, size, 0);
}
}
#endif
}
template <typename T>
inline Tensor Tensor::Slice(const int& begin_idx, const int& end_idx) const {
......
......@@ -94,14 +94,7 @@ class Tensor {
* @note CopyFrom supports CPU <-> GPU, GPU <-> GPU.
*/
template <typename T>
inline void CopyFrom(const Tensor& src,
const platform::CPUDeviceContext& ctx);
#ifndef PADDLE_ONLY_CPU
template <typename T>
inline void CopyFrom(const Tensor& src,
const platform::CUDADeviceContext& ctx);
#endif
inline void CopyFrom(const Tensor& src, const platform::Place& dst_place);
/**
* @brief Return the slice of the tensor.
......
......@@ -198,8 +198,8 @@ TEST(Tensor, CopyFrom) {
int arr[9] = {1, 2, 3, 4, 5, 6, 7, 8, 9};
memcpy(src_ptr, arr, 9 * sizeof(int));
auto* cpu_ctx = new paddle::platform::CPUDeviceContext();
dst_tensor.CopyFrom<int>(src_tensor, *cpu_ctx);
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(src_tensor, *cpu_place);
const int* dst_ptr = dst_tensor.data<int>();
ASSERT_NE(src_ptr, dst_ptr);
......@@ -208,7 +208,7 @@ TEST(Tensor, CopyFrom) {
}
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_ctx);
dst_tensor.CopyFrom<int>(slice_tensor, *cpu_place);
const int* slice_ptr = slice_tensor.data<int>();
dst_ptr = dst_tensor.data<int>();
ASSERT_NE(dst_ptr, slice_ptr);
......@@ -228,12 +228,12 @@ TEST(Tensor, CopyFrom) {
memcpy(src_ptr, arr, 9 * sizeof(int));
// CPU Tensor to GPU Tensor
auto gpu_ctx = new paddle::platform::CUDADeviceContext(0);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_ctx);
auto gpu_place = new paddle::platform::GPUPlace(0);
gpu_tensor.CopyFrom<int>(src_tensor, *gpu_place);
// GPU Tensor to CPU Tensor
auto cpu_ctx = new paddle::platform::CPUDeviceContext();
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_ctx);
auto cpu_place = new paddle::platform::CPUPlace();
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
// Compare Tensors
const int* dst_ptr = dst_tensor.data<int>();
......@@ -245,10 +245,10 @@ TEST(Tensor, CopyFrom) {
Tensor slice_tensor = src_tensor.Slice<int>(1, 2);
// CPU Slice Tensor to GPU Tensor
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_ctx);
gpu_tensor.CopyFrom<int>(slice_tensor, *gpu_place);
// GPU Tensor to CPU Tensor
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_ctx);
dst_tensor.CopyFrom<int>(gpu_tensor, *cpu_place);
// Compare Slice Tensors
const int* slice_ptr = slice_tensor.data<int>();
......
......@@ -43,8 +43,19 @@ Eigen::GpuDevice* DeviceContext::get_eigen_device<Eigen::GpuDevice>() const {
CUDADeviceContext::CUDADeviceContext(GPUPlace place) : place_(place) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
eigen_stream_.reset(new Eigen::CudaStreamDevice(&stream_));
// TODO (qijun) Pass a created cuda stream to Eigen::CudaStreamDevice directly
// here will cause segment fault. We must implement a class derived from
// Eigen::StreamInterface, and reinitialize it with a cuda stream and a gpu id
// later. Please refer to the implementation of class EigenCudaStreamDevice
// in TensorFlow.
//
// We find that CUDA 7 introduces a new option, the per-thread default stream,
// that has two effects. Please refer to https://devblogs.nvidia.com/
// parallelforall/gpu-pro-tip-cuda-7-streams-simplify-concurrency/
//
// So, we decide to use default stream and add –default-stream per-thread nvcc
// flag. Than, two threads with two CUDADeviceContexts will run parallelly.
eigen_stream_.reset(new Eigen::CudaStreamDevice());
eigen_device_.reset(new Eigen::GpuDevice(eigen_stream_.get()));
}
......@@ -64,15 +75,12 @@ CUDADeviceContext::~CUDADeviceContext() {
}
eigen_stream_.reset();
eigen_device_.reset();
PADDLE_ENFORCE(cudaStreamDestroy(stream_));
}
Place CUDADeviceContext::GetPlace() const { return place_; }
cudaStream_t CUDADeviceContext::stream() const { return stream_; }
void CUDADeviceContext::Wait() const {
PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaStreamSynchronize(0));
}
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
......@@ -83,7 +91,6 @@ cublasHandle_t CUDADeviceContext::cublas_handle() {
if (!cublas_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cublasCreate(&cublas_handle_));
PADDLE_ENFORCE(dynload::cublasSetStream(cublas_handle_, stream_));
}
return cublas_handle_;
}
......@@ -92,7 +99,6 @@ cudnnHandle_t CUDADeviceContext::cudnn_handle() {
if (!cudnn_handle_) {
SetDeviceId(place_.device);
PADDLE_ENFORCE(dynload::cudnnCreate(&cudnn_handle_));
PADDLE_ENFORCE(dynload::cudnnSetStream(cudnn_handle_, stream_));
}
return cudnn_handle_;
}
......@@ -104,7 +110,6 @@ curandGenerator_t CUDADeviceContext::curand_generator() {
CURAND_RNG_PSEUDO_DEFAULT));
PADDLE_ENFORCE(
dynload::curandSetPseudoRandomGeneratorSeed(curand_generator_, seed_));
PADDLE_ENFORCE(dynload::curandSetStream(curand_generator_, stream_));
}
return curand_generator_;
}
......
......@@ -61,9 +61,6 @@ class CUDADeviceContext : public DeviceContext {
/*! \brief Wait for all operations completion in the stream. */
void Wait() const;
/*! \brief Return CUDA stream in the device context. */
cudaStream_t stream() const;
/*! \brief Return place in the device context. */
Place GetPlace() const override;
......@@ -91,8 +88,6 @@ class CUDADeviceContext : public DeviceContext {
private:
uint64_t seed_;
cudaStream_t stream_;
// clang-format off
cudnnHandle_t cudnn_handle_ = nullptr;
cublasHandle_t cublas_handle_ = nullptr;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册