提交 9948ddfe 编写于 作者: W willzhang4a58

add cudnn

上级 d2649862
......@@ -11,6 +11,7 @@ include(grpc)
include(tensorflow)
find_package(CUDA REQUIRED)
find_package(CuDNN REQUIRED)
set(oneflow_third_party_libs
${tensorflow_STATIC_LIBRARIES}
......@@ -28,6 +29,7 @@ set(oneflow_third_party_libs
${PNG_STATIC_LIBRARIES}
${JSONCPP_STATIC_LIBRARIES}
${CUDA_CUBLAS_LIBRARIES}
${CUDNN_LIBRARIES}
)
if(WIN32)
......@@ -81,4 +83,5 @@ include_directories(
${PNG_INCLUDE_DIR}
${JSONCPP_INCLUDE_DIR}
${EIGEN_INCLUDE_DIRS}
${CUDNN_INCLUDE_DIRS}
)
......@@ -7,7 +7,7 @@ namespace oneflow {
// need review
void CopyHdActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
CudaKernelCtx kernel_ctx(thread_ctx.copy_hd_cuda_stream, nullptr);
CudaKernelCtx kernel_ctx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr);
ProcessMsgWithKernelCtx(msg, kernel_ctx);
}
......
......@@ -26,7 +26,7 @@ bool FwDataCompActor::IsReadReady() {
void FwDataCompActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
CudaKernelCtx kernel_ctx(thread_ctx.compute_cuda_stream, nullptr);
CudaKernelCtx kernel_ctx(thread_ctx.compute_cuda_stream, nullptr, nullptr);
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
TODO();
}
......
......@@ -13,7 +13,7 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto) {
void MdUpdtCompActor::ProcessMsg(const ActorMsg& actor_msg,
const ThreadContext& thread_ctx) {
CudaKernelCtx kernel_ctx(thread_ctx.compute_cuda_stream, nullptr);
CudaKernelCtx kernel_ctx(thread_ctx.compute_cuda_stream, nullptr, nullptr);
(this->*cur_handle_)(actor_msg, kernel_ctx);
}
......
#ifndef ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
#define ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
#include "oneflow/core/common/util.h"
namespace oneflow {
class UniqueCudnnHandle final {
public:
OF_DISALLOW_COPY_AND_MOVE(UniqueCudnnHandle);
UniqueCudnnHandle() = delete;
UniqueCudnnHandle(const cudaStream_t* cuda_stream) {
CHECK_EQ(cudnnCreate(&handle_), CUDNN_STATUS_SUCCESS);
CHECK_EQ(cudnnSetStream(handle_, *cuda_stream), CUDNN_STATUS_SUCCESS);
}
~UniqueCudnnHandle() {
cudnnDestroy(handle_);
}
const cudnnHandle_t* get() const { return &handle_; }
private:
cudnnHandle_t handle_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_UNIQUE_CUDNN_HANDLE_H_
......@@ -12,6 +12,7 @@
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include "cudnn.h"
namespace oneflow {
......
......@@ -12,9 +12,11 @@ class CudaKernelCtx final : public KernelCtx {
~CudaKernelCtx() = default;
CudaKernelCtx(const cudaStream_t* cuda_stream,
const cublasHandle_t* cublas_handle) {
const cublasHandle_t* cublas_handle,
const cudnnHandle_t* cudnn_handle) {
set_cuda_stream(cuda_stream);
set_cublas_handle(cublas_handle);
set_cudnn_handle(cudnn_handle);
}
void AddCallBack(std::function<void()> callback) const override;
......
......@@ -14,13 +14,15 @@ class KernelCtx {
Channel<std::function<void()>>* cpu_stream() const { return cpu_stream_; }
const cudaStream_t& cuda_stream() const { return *cuda_stream_; }
const cublasHandle_t& cublas_handle() const { return *cublas_handle_; }
const cudnnHandle_t& cudnn_handle() const { return *cudnn_handle_; }
virtual void AddCallBack(std::function<void()>) const = 0;
protected:
KernelCtx() : cpu_stream_(nullptr),
cuda_stream_(nullptr),
cublas_handle_(nullptr) {}
cublas_handle_(nullptr),
cudnn_handle_(nullptr) {}
void set_cpu_stream(Channel<std::function<void()>>* val) {
cpu_stream_ = val;
......@@ -31,11 +33,15 @@ class KernelCtx {
void set_cublas_handle(const cublasHandle_t* val) {
cublas_handle_ = val;
}
void set_cudnn_handle(const cudnnHandle_t* val) {
cudnn_handle_ = val;
}
private:
Channel<std::function<void()>>* cpu_stream_;
const cudaStream_t* cuda_stream_;
const cublasHandle_t* cublas_handle_;
const cudnnHandle_t* cudnn_handle_;
};
......
......@@ -2,6 +2,7 @@
#include "cuda_runtime.h"
#include "oneflow/core/common/unique_cuda_stream.h"
#include "oneflow/core/common/unique_cublas_handle.h"
#include "oneflow/core/common/unique_cudnn_handle.h"
namespace oneflow {
......@@ -12,10 +13,12 @@ GpuThread::GpuThread(int device_phy_id) {
UniqueCudaStream compute_cuda_stream;
{
UniqueCublasHandle cublas_handle(compute_cuda_stream.get());
UniqueCudnnHandle cudnn_handle(compute_cuda_stream.get());
ThreadContext ctx;
ctx.copy_hd_cuda_stream = copy_hd_cuda_stream.get();
ctx.compute_cuda_stream = compute_cuda_stream.get();
ctx.cublas_handle = cublas_handle.get();
ctx.cudnn_handle = cudnn_handle.get();
PollMsgChannel(ctx);
}
});
......
......@@ -9,12 +9,14 @@ struct ThreadContext {
ThreadContext() : cpu_stream(nullptr),
copy_hd_cuda_stream(nullptr),
compute_cuda_stream(nullptr),
cublas_handle(nullptr) {}
cublas_handle(nullptr),
cudnn_handle(nullptr) {}
Channel<std::function<void()>>* cpu_stream;
const cudaStream_t* copy_hd_cuda_stream;
const cudaStream_t* compute_cuda_stream;
const cublasHandle_t* cublas_handle;
const cudnnHandle_t* cudnn_handle;
};
} // namespace oneflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册