提交 c5fb5df0 编写于 作者: W willzhang4a58

CpuStream

上级 330ef814
......@@ -11,10 +11,10 @@ class CpuDeviceCtx final : public DeviceCtx {
CpuDeviceCtx() = delete;
~CpuDeviceCtx() = default;
CpuDeviceCtx(Channel<std::function<void()>>* chan) { set_cpu_stream(chan); }
CpuDeviceCtx(CpuStream* val) { set_cpu_stream(val); }
void AddCallBack(std::function<void()> callback) const override {
cpu_stream()->Send(callback);
cpu_stream()->SendWork(callback);
}
private:
......
#ifndef ONEFLOW_CORE_DEVICE_CPU_STREAM_H_
#define ONEFLOW_CORE_DEVICE_CPU_STREAM_H_
#include "oneflow/core/common/channel.h"
namespace oneflow {
class CpuStream final {
public:
OF_DISALLOW_COPY_AND_MOVE(CpuStream);
CpuStream() = default;
~CpuStream() = default;
void SendWork(std::function<void()> work) {
CHECK_EQ(work_channel_.Send(work), 0);
}
// 0: success
// -1: fail
int ReceiveWork(std::function<void()>* work) {
return work_channel_.Receive(work);
}
void CloseSendEnd() { work_channel_.CloseSendEnd(); }
void CloseReceiveEnd() { work_channel_.CloseReceiveEnd(); }
private:
Channel<std::function<void()>> work_channel_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_DEVICE_CPU_STREAM_H_
#ifndef ONEFLOW_CORE_DEVICE_DEVICE_CONTEXT_H_
#define ONEFLOW_CORE_DEVICE_DEVICE_CONTEXT_H_
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/util.h"
#include "oneflow/core/device/cpu_stream.h"
namespace oneflow {
......@@ -11,7 +10,7 @@ class DeviceCtx {
// OF_DISALLOW_COPY_AND_MOVE(DeviceCtx);
virtual ~DeviceCtx() = default;
Channel<std::function<void()>>* cpu_stream() const { return cpu_stream_; }
CpuStream* cpu_stream() const { return cpu_stream_; }
const cudaStream_t& cuda_stream() const { return *cuda_stream_; }
const cublasHandle_t& cublas_handle() const { return *cublas_handle_; }
const cudnnHandle_t& cudnn_handle() const { return *cudnn_handle_; }
......@@ -25,15 +24,13 @@ class DeviceCtx {
cublas_handle_(nullptr),
cudnn_handle_(nullptr) {}
void set_cpu_stream(Channel<std::function<void()>>* val) {
cpu_stream_ = val;
}
void set_cpu_stream(CpuStream* val) { cpu_stream_ = val; }
void set_cuda_stream(const cudaStream_t* val) { cuda_stream_ = val; }
void set_cublas_handle(const cublasHandle_t* val) { cublas_handle_ = val; }
void set_cudnn_handle(const cudnnHandle_t* val) { cudnn_handle_ = val; }
private:
Channel<std::function<void()>>* cpu_stream_;
CpuStream* cpu_stream_;
const cudaStream_t* cuda_stream_;
const cublasHandle_t* cublas_handle_;
const cudnnHandle_t* cudnn_handle_;
......
......@@ -125,7 +125,7 @@ void InitBn2BlobPtr(HashMap<std::string, Blob*>& bn2blob_ptr,
void CPUStreamExec(int out_num, std::function<Blob*(const std::string&)> fp) {
KernelCtx ctx;
ctx.device_ctx = new CpuDeviceCtx(new Channel<std::function<void()>>);
ctx.device_ctx = new CpuDeviceCtx(new CpuStream);
auto clone_kernel =
ConstructCloneKernel<DeviceType::kCPU>(out_num, "clone_kernel_test");
......@@ -136,7 +136,7 @@ void CPUStreamExec(int out_num, std::function<Blob*(const std::string&)> fp) {
std::function<void()> work;
// Both Forward and Backward receive out_num times
for (int i = 0; i < out_num * 2; ++i) {
if (ctx.device_ctx->cpu_stream()->Receive(&work) == 0) { work(); }
if (ctx.device_ctx->cpu_stream()->ReceiveWork(&work) == 0) { work(); }
}
});
cpu_thread.join();
......
......@@ -107,7 +107,7 @@ void BuildKernelCtx(KernelCtx* ctx);
template<>
void BuildKernelCtx<DeviceType::kCPU>(KernelCtx* ctx) {
auto cpu_stream = new Channel<std::function<void()>>;
auto cpu_stream = new CpuStream;
ctx->device_ctx = new CpuDeviceCtx(cpu_stream);
}
......@@ -153,7 +153,7 @@ void SyncStream<DeviceType::kCPU>(KernelCtx* ctx) {
auto cpu_thread = std::thread([&] {
std::function<void()> work;
while (ctx->device_ctx->cpu_stream()->Receive(&work) == 0) { work(); }
while (ctx->device_ctx->cpu_stream()->ReceiveWork(&work) == 0) { work(); }
});
cpu_thread.join();
}
......
......@@ -11,13 +11,13 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
static void Memcpy(
const KernelCtx& ctx, void* dst, const void* src, size_t sz,
cudaMemcpyKind kind = cudaMemcpyKind::cudaMemcpyHostToHost) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[dst, src, sz]() { memcpy(dst, src, sz); });
}
static void Memset(const KernelCtx& ctx, void* dst, const char value,
size_t sz) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[dst, value, sz]() { memset(dst, value, sz); });
}
......@@ -25,7 +25,7 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
const FloatingPointType alpha,
const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy) {
ctx.device_ctx->cpu_stream()->Send([n, alpha, x, incx, y, incy]() {
ctx.device_ctx->cpu_stream()->SendWork([n, alpha, x, incx, y, incy]() {
cblas_axpy(n, alpha, x, incx, y, incy);
});
}
......@@ -33,7 +33,7 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
static void BlasScal(const KernelCtx& ctx, const int n,
const FloatingPointType alpha, FloatingPointType* x,
const int incx) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[n, alpha, x, incx]() { cblas_scal(n, alpha, x, incx); });
}
......@@ -43,7 +43,7 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
const FloatingPointType* x, const int incx,
const FloatingPointType beta, FloatingPointType* y,
const int incy) {
ctx.device_ctx->cpu_stream()->Send([=]() {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
// Set col major to keep it as the same with cublas
cblas_gemv(CBLAS_ORDER::CblasColMajor, trans, m, n, alpha, a, lda, x,
incx, beta, y, incy);
......@@ -58,7 +58,7 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
const FloatingPointType* b, const int ldb,
const FloatingPointType beta, FloatingPointType* c,
const int ldc) {
ctx.device_ctx->cpu_stream()->Send([=]() {
ctx.device_ctx->cpu_stream()->SendWork([=]() {
cblas_gemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc);
});
......@@ -68,20 +68,20 @@ class KernelUtil<DeviceType::kCPU, FloatingPointType> final {
const FloatingPointType* x, const int incx,
const FloatingPointType* y, const int incy,
FloatingPointType* result) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[=]() { *result = cblas_dot(n, x, incx, y, incy); });
}
static void BlasSwap(const KernelCtx& ctx, const int n, FloatingPointType* x,
const int incx, FloatingPointType* y, const int incy) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[=]() { cblas_swap(n, x, incx, y, incy); });
}
static void BlasCopy(const KernelCtx& ctx, const int n,
const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy) {
ctx.device_ctx->cpu_stream()->Send(
ctx.device_ctx->cpu_stream()->SendWork(
[=]() { cblas_copy(n, x, incx, y, incy); });
}
};
......
......@@ -13,7 +13,7 @@ void ModelSaveKernel<DeviceType::kCPU, FloatingPointType>::Forward(
for (const std::string& ibn : op()->input_bns()) {
const std::string& lbn = op()->Lbn4BnInOp(ibn);
Blob* blob_ptr = BnInOp2BlobPtr(ibn);
kernel_ctx.device_ctx->cpu_stream()->Send([=]() {
kernel_ctx.device_ctx->cpu_stream()->SendWork([=]() {
std::unique_ptr<PersistentOutStream> out_stream =
snapshot->GetOutStream(lbn, parallel_id);
out_stream->Write(
......
......@@ -5,7 +5,7 @@ namespace oneflow {
CpuThread::CpuThread() {
cpu_device_ = std::thread([this]() {
std::function<void()> work;
while (cpu_stream_.Receive(&work) == 0) { work(); }
while (cpu_stream_.ReceiveWork(&work) == 0) { work(); }
});
mut_actor_thread() = std::thread([this]() {
ThreadCtx ctx;
......
#ifndef ONEFLOW_CORE_THREAD_CPU_THREAD_H_
#define ONEFLOW_CORE_THREAD_CPU_THREAD_H_
#include "oneflow/core/device/cpu_stream.h"
#include "oneflow/core/thread/thread.h"
namespace oneflow {
......@@ -13,7 +14,7 @@ class CpuThread final : public Thread {
private:
std::thread cpu_device_;
Channel<std::function<void()>> cpu_stream_;
CpuStream cpu_stream_;
};
} // namespace oneflow
......
#ifndef ONEFLOW_CORE_THREAD_THREAD_CONTEXT_H_
#define ONEFLOW_CORE_THREAD_THREAD_CONTEXT_H_
#include "oneflow/core/common/channel.h"
#include "oneflow/core/device/cpu_stream.h"
namespace oneflow {
struct ThreadCtx {
ThreadCtx() : cpu_stream(nullptr), copy_hd_cuda_stream(nullptr) {}
Channel<std::function<void()>>* cpu_stream;
CpuStream* cpu_stream;
const cudaStream_t* copy_hd_cuda_stream;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册