From c5fb5df0417d4b59bb3f2bf7c1a4aa8ae8f8a14c Mon Sep 17 00:00:00 2001 From: willzhang4a58 Date: Sat, 1 Jul 2017 16:35:19 +0800 Subject: [PATCH] CpuStream --- oneflow/core/device/cpu_device_context.h | 4 +-- oneflow/core/device/cpu_stream.h | 33 +++++++++++++++++++ oneflow/core/device/device_context.h | 11 +++---- oneflow/core/kernel/clone_kernel_test.cpp | 4 +-- .../core/kernel/innerproduct_kernel_test.cpp | 4 +-- oneflow/core/kernel/kernel_util.cpp | 18 +++++----- oneflow/core/kernel/model_save_kernel.cpp | 2 +- oneflow/core/thread/cpu_thread.cpp | 2 +- oneflow/core/thread/cpu_thread.h | 3 +- oneflow/core/thread/thread_context.h | 4 +-- 10 files changed, 58 insertions(+), 27 deletions(-) create mode 100644 oneflow/core/device/cpu_stream.h diff --git a/oneflow/core/device/cpu_device_context.h b/oneflow/core/device/cpu_device_context.h index 065bc91024..db56c329b4 100644 --- a/oneflow/core/device/cpu_device_context.h +++ b/oneflow/core/device/cpu_device_context.h @@ -11,10 +11,10 @@ class CpuDeviceCtx final : public DeviceCtx { CpuDeviceCtx() = delete; ~CpuDeviceCtx() = default; - CpuDeviceCtx(Channel>* chan) { set_cpu_stream(chan); } + CpuDeviceCtx(CpuStream* val) { set_cpu_stream(val); } void AddCallBack(std::function callback) const override { - cpu_stream()->Send(callback); + cpu_stream()->SendWork(callback); } private: diff --git a/oneflow/core/device/cpu_stream.h b/oneflow/core/device/cpu_stream.h new file mode 100644 index 0000000000..38862bbc51 --- /dev/null +++ b/oneflow/core/device/cpu_stream.h @@ -0,0 +1,33 @@ +#ifndef ONEFLOW_CORE_DEVICE_CPU_STREAM_H_ +#define ONEFLOW_CORE_DEVICE_CPU_STREAM_H_ + +#include "oneflow/core/common/channel.h" + +namespace oneflow { + +class CpuStream final { + public: + OF_DISALLOW_COPY_AND_MOVE(CpuStream); + CpuStream() = default; + ~CpuStream() = default; + + void SendWork(std::function work) { + CHECK_EQ(work_channel_.Send(work), 0); + } + + // 0: success + // -1: fail + int ReceiveWork(std::function* work) { + return work_channel_.Receive(work); + } + + void CloseSendEnd() { work_channel_.CloseSendEnd(); } + void CloseReceiveEnd() { work_channel_.CloseReceiveEnd(); } + + private: + Channel> work_channel_; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_DEVICE_CPU_STREAM_H_ diff --git a/oneflow/core/device/device_context.h b/oneflow/core/device/device_context.h index 8123cb9f8d..aa56bdadcf 100644 --- a/oneflow/core/device/device_context.h +++ b/oneflow/core/device/device_context.h @@ -1,8 +1,7 @@ #ifndef ONEFLOW_CORE_DEVICE_DEVICE_CONTEXT_H_ #define ONEFLOW_CORE_DEVICE_DEVICE_CONTEXT_H_ -#include "oneflow/core/common/channel.h" -#include "oneflow/core/common/util.h" +#include "oneflow/core/device/cpu_stream.h" namespace oneflow { @@ -11,7 +10,7 @@ class DeviceCtx { // OF_DISALLOW_COPY_AND_MOVE(DeviceCtx); virtual ~DeviceCtx() = default; - Channel>* cpu_stream() const { return cpu_stream_; } + CpuStream* cpu_stream() const { return cpu_stream_; } const cudaStream_t& cuda_stream() const { return *cuda_stream_; } const cublasHandle_t& cublas_handle() const { return *cublas_handle_; } const cudnnHandle_t& cudnn_handle() const { return *cudnn_handle_; } @@ -25,15 +24,13 @@ class DeviceCtx { cublas_handle_(nullptr), cudnn_handle_(nullptr) {} - void set_cpu_stream(Channel>* val) { - cpu_stream_ = val; - } + void set_cpu_stream(CpuStream* val) { cpu_stream_ = val; } void set_cuda_stream(const cudaStream_t* val) { cuda_stream_ = val; } void set_cublas_handle(const cublasHandle_t* val) { cublas_handle_ = val; } void set_cudnn_handle(const cudnnHandle_t* val) { cudnn_handle_ = val; } private: - Channel>* cpu_stream_; + CpuStream* cpu_stream_; const cudaStream_t* cuda_stream_; const cublasHandle_t* cublas_handle_; const cudnnHandle_t* cudnn_handle_; diff --git a/oneflow/core/kernel/clone_kernel_test.cpp b/oneflow/core/kernel/clone_kernel_test.cpp index 87a027e294..58eb3563fd 100644 --- a/oneflow/core/kernel/clone_kernel_test.cpp +++ b/oneflow/core/kernel/clone_kernel_test.cpp @@ -125,7 +125,7 @@ void InitBn2BlobPtr(HashMap& bn2blob_ptr, void CPUStreamExec(int out_num, std::function fp) { KernelCtx ctx; - ctx.device_ctx = new CpuDeviceCtx(new Channel>); + ctx.device_ctx = new CpuDeviceCtx(new CpuStream); auto clone_kernel = ConstructCloneKernel(out_num, "clone_kernel_test"); @@ -136,7 +136,7 @@ void CPUStreamExec(int out_num, std::function fp) { std::function work; // Both Forward and Backward receive out_num times for (int i = 0; i < out_num * 2; ++i) { - if (ctx.device_ctx->cpu_stream()->Receive(&work) == 0) { work(); } + if (ctx.device_ctx->cpu_stream()->ReceiveWork(&work) == 0) { work(); } } }); cpu_thread.join(); diff --git a/oneflow/core/kernel/innerproduct_kernel_test.cpp b/oneflow/core/kernel/innerproduct_kernel_test.cpp index bc13cb11a7..771a9337f7 100644 --- a/oneflow/core/kernel/innerproduct_kernel_test.cpp +++ b/oneflow/core/kernel/innerproduct_kernel_test.cpp @@ -107,7 +107,7 @@ void BuildKernelCtx(KernelCtx* ctx); template<> void BuildKernelCtx(KernelCtx* ctx) { - auto cpu_stream = new Channel>; + auto cpu_stream = new CpuStream; ctx->device_ctx = new CpuDeviceCtx(cpu_stream); } @@ -153,7 +153,7 @@ void SyncStream(KernelCtx* ctx) { auto cpu_thread = std::thread([&] { std::function work; - while (ctx->device_ctx->cpu_stream()->Receive(&work) == 0) { work(); } + while (ctx->device_ctx->cpu_stream()->ReceiveWork(&work) == 0) { work(); } }); cpu_thread.join(); } diff --git a/oneflow/core/kernel/kernel_util.cpp b/oneflow/core/kernel/kernel_util.cpp index 402aa3215e..fb89f1223a 100644 --- a/oneflow/core/kernel/kernel_util.cpp +++ b/oneflow/core/kernel/kernel_util.cpp @@ -11,13 +11,13 @@ class KernelUtil final { static void Memcpy( const KernelCtx& ctx, void* dst, const void* src, size_t sz, cudaMemcpyKind kind = cudaMemcpyKind::cudaMemcpyHostToHost) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [dst, src, sz]() { memcpy(dst, src, sz); }); } static void Memset(const KernelCtx& ctx, void* dst, const char value, size_t sz) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [dst, value, sz]() { memset(dst, value, sz); }); } @@ -25,7 +25,7 @@ class KernelUtil final { const FloatingPointType alpha, const FloatingPointType* x, const int incx, FloatingPointType* y, const int incy) { - ctx.device_ctx->cpu_stream()->Send([n, alpha, x, incx, y, incy]() { + ctx.device_ctx->cpu_stream()->SendWork([n, alpha, x, incx, y, incy]() { cblas_axpy(n, alpha, x, incx, y, incy); }); } @@ -33,7 +33,7 @@ class KernelUtil final { static void BlasScal(const KernelCtx& ctx, const int n, const FloatingPointType alpha, FloatingPointType* x, const int incx) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [n, alpha, x, incx]() { cblas_scal(n, alpha, x, incx); }); } @@ -43,7 +43,7 @@ class KernelUtil final { const FloatingPointType* x, const int incx, const FloatingPointType beta, FloatingPointType* y, const int incy) { - ctx.device_ctx->cpu_stream()->Send([=]() { + ctx.device_ctx->cpu_stream()->SendWork([=]() { // Set col major to keep it as the same with cublas cblas_gemv(CBLAS_ORDER::CblasColMajor, trans, m, n, alpha, a, lda, x, incx, beta, y, incy); @@ -58,7 +58,7 @@ class KernelUtil final { const FloatingPointType* b, const int ldb, const FloatingPointType beta, FloatingPointType* c, const int ldc) { - ctx.device_ctx->cpu_stream()->Send([=]() { + ctx.device_ctx->cpu_stream()->SendWork([=]() { cblas_gemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c, ldc); }); @@ -68,20 +68,20 @@ class KernelUtil final { const FloatingPointType* x, const int incx, const FloatingPointType* y, const int incy, FloatingPointType* result) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [=]() { *result = cblas_dot(n, x, incx, y, incy); }); } static void BlasSwap(const KernelCtx& ctx, const int n, FloatingPointType* x, const int incx, FloatingPointType* y, const int incy) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [=]() { cblas_swap(n, x, incx, y, incy); }); } static void BlasCopy(const KernelCtx& ctx, const int n, const FloatingPointType* x, const int incx, FloatingPointType* y, const int incy) { - ctx.device_ctx->cpu_stream()->Send( + ctx.device_ctx->cpu_stream()->SendWork( [=]() { cblas_copy(n, x, incx, y, incy); }); } }; diff --git a/oneflow/core/kernel/model_save_kernel.cpp b/oneflow/core/kernel/model_save_kernel.cpp index 2de8d02897..1ff07808d2 100644 --- a/oneflow/core/kernel/model_save_kernel.cpp +++ b/oneflow/core/kernel/model_save_kernel.cpp @@ -13,7 +13,7 @@ void ModelSaveKernel::Forward( for (const std::string& ibn : op()->input_bns()) { const std::string& lbn = op()->Lbn4BnInOp(ibn); Blob* blob_ptr = BnInOp2BlobPtr(ibn); - kernel_ctx.device_ctx->cpu_stream()->Send([=]() { + kernel_ctx.device_ctx->cpu_stream()->SendWork([=]() { std::unique_ptr out_stream = snapshot->GetOutStream(lbn, parallel_id); out_stream->Write( diff --git a/oneflow/core/thread/cpu_thread.cpp b/oneflow/core/thread/cpu_thread.cpp index 1c4a01524c..ee50c3381c 100644 --- a/oneflow/core/thread/cpu_thread.cpp +++ b/oneflow/core/thread/cpu_thread.cpp @@ -5,7 +5,7 @@ namespace oneflow { CpuThread::CpuThread() { cpu_device_ = std::thread([this]() { std::function work; - while (cpu_stream_.Receive(&work) == 0) { work(); } + while (cpu_stream_.ReceiveWork(&work) == 0) { work(); } }); mut_actor_thread() = std::thread([this]() { ThreadCtx ctx; diff --git a/oneflow/core/thread/cpu_thread.h b/oneflow/core/thread/cpu_thread.h index 8c7ad355d4..b3f58fb7cd 100644 --- a/oneflow/core/thread/cpu_thread.h +++ b/oneflow/core/thread/cpu_thread.h @@ -1,6 +1,7 @@ #ifndef ONEFLOW_CORE_THREAD_CPU_THREAD_H_ #define ONEFLOW_CORE_THREAD_CPU_THREAD_H_ +#include "oneflow/core/device/cpu_stream.h" #include "oneflow/core/thread/thread.h" namespace oneflow { @@ -13,7 +14,7 @@ class CpuThread final : public Thread { private: std::thread cpu_device_; - Channel> cpu_stream_; + CpuStream cpu_stream_; }; } // namespace oneflow diff --git a/oneflow/core/thread/thread_context.h b/oneflow/core/thread/thread_context.h index 5e79e2e65b..8e3fa68db3 100644 --- a/oneflow/core/thread/thread_context.h +++ b/oneflow/core/thread/thread_context.h @@ -1,14 +1,14 @@ #ifndef ONEFLOW_CORE_THREAD_THREAD_CONTEXT_H_ #define ONEFLOW_CORE_THREAD_THREAD_CONTEXT_H_ -#include "oneflow/core/common/channel.h" +#include "oneflow/core/device/cpu_stream.h" namespace oneflow { struct ThreadCtx { ThreadCtx() : cpu_stream(nullptr), copy_hd_cuda_stream(nullptr) {} - Channel>* cpu_stream; + CpuStream* cpu_stream; const cudaStream_t* copy_hd_cuda_stream; }; -- GitLab