提交 c927939b 编写于 作者: W willzhang4a58

add register_gpu_kernel, cpu_channel

上级 f187b057
......@@ -44,7 +44,7 @@ void Actor::Init(const TaskProto& task_proto) {
}
void Actor::WardKernel(
const KernelContext& kernel_ctx,
const KernelCtx& kernel_ctx,
std::function<std::shared_ptr<RegstWarpper>(uint64_t)> Regst4RegstDescId) {
for (const ExecKernel& ek : exec_kernel_vec_) {
(ek.kernel->*ward_func_)(kernel_ctx, [&](const std::string& bn_in_op) {
......
......@@ -34,7 +34,7 @@ class Actor {
Actor() = default;
void WardKernel(
const KernelContext& kernel_ctx,
const KernelCtx& kernel_ctx,
std::function<std::shared_ptr<RegstWarpper>(uint64_t)> Regst4RegstDescId);
void ForEachProducedRegst(std::function<void(Regst*)>);
uint64_t RegstDescId4Name(const std::string& name) const {
......
......@@ -11,7 +11,7 @@ void BoxingActor::Init(const TaskProto& task_proto) {
void BoxingActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
KernelContext kernel_ctx;
KernelCtx kernel_ctx;
if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) {
std::shared_ptr<RegstWarpper> regst_wp = msg.regst_warpper();
auto waiting_in_regst_it = waiting_in_regst_.find(regst_wp->piece_id());
......@@ -36,7 +36,7 @@ void BoxingActor::ProcessMsg(const ActorMsg& msg,
}
}
void BoxingActor::WardKernelAndSendMsg(const KernelContext& kernel_ctx) {
void BoxingActor::WardKernelAndSendMsg(const KernelCtx& kernel_ctx) {
uint64_t piece_id = ready_in_regst_.front().first;
WardKernel(kernel_ctx, [this](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
......
......@@ -18,7 +18,7 @@ class BoxingActor final : public Actor {
using RDescId2RwMap = HashMap<uint64_t, std::shared_ptr<RegstWarpper>>;
using RDescId2RwMapPtr = std::unique_ptr<RDescId2RwMap>;
void WardKernelAndSendMsg(const KernelContext&);
void WardKernelAndSendMsg(const KernelCtx&);
// <piece_id, map>
HashMap<uint64_t, RDescId2RwMapPtr> waiting_in_regst_;
......
......@@ -9,7 +9,7 @@ void CopyActor::Init(const TaskProto& task_proto) {
}
void CopyActor::ProcessMsgWithKernelCtx(const ActorMsg& msg,
const KernelContext& kernel_ctx) {
const KernelCtx& kernel_ctx) {
if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) {
waiting_in_regst_.push(std::move(msg.regst_warpper()));
}
......
......@@ -15,7 +15,7 @@ public:
protected:
CopyActor() = default;
void ProcessMsgWithKernelCtx(const ActorMsg& msg, const KernelContext& kernel_ctx);
void ProcessMsgWithKernelCtx(const ActorMsg& msg, const KernelCtx& kernel_ctx);
private:
std::queue<std::shared_ptr<RegstWarpper>> waiting_in_regst_;
......
......@@ -6,7 +6,7 @@ namespace oneflow {
void CopyCommNetActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext&) {
KernelContext kernel_ctx;
KernelCtx kernel_ctx;
ProcessMsgWithKernelCtx(msg, kernel_ctx);
}
......
......@@ -6,7 +6,7 @@ namespace oneflow {
void CopyHdActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
KernelContext kernel_ctx;
KernelCtx kernel_ctx;
kernel_ctx.cuda_stream = thread_ctx.copy_hd_cuda_stream;
ProcessMsgWithKernelCtx(msg, kernel_ctx);
}
......
......@@ -25,7 +25,7 @@ bool FwDataCompActor::IsReadReady() {
void FwDataCompActor::ProcessMsg(const ActorMsg& msg,
const ThreadContext& thread_ctx) {
KernelContext kernel_ctx;
KernelCtx kernel_ctx;
kernel_ctx.cuda_stream = thread_ctx.compute_cuda_stream;
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
TODO();
......@@ -55,7 +55,7 @@ void FwDataCompActor::ProcessMsg(const ActorMsg& msg,
}
}
void FwDataCompActor::WardKernelAndSendMsg(const KernelContext& kernel_ctx) {
void FwDataCompActor::WardKernelAndSendMsg(const KernelCtx& kernel_ctx) {
CHECK_EQ(in_.front()->piece_id(), expected_piece_id());
ready_in_regst_[in_.front()->regst_desc_id()] = in_.front();
uint64_t piece_id = in_.front()->piece_id();
......
......@@ -16,7 +16,7 @@ public:
private:
bool IsReadReady();
void WardKernelAndSendMsg(const KernelContext&);
void WardKernelAndSendMsg(const KernelCtx&);
uint64_t expected_model_version_id_;
uint64_t model_regst_desc_id_;
......
......@@ -12,14 +12,14 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto) {
void MdUpdtCompActor::ProcessMsg(const ActorMsg& actor_msg,
const ThreadContext& thread_ctx) {
KernelContext kernel_ctx;
KernelCtx kernel_ctx;
kernel_ctx.cuda_stream = thread_ctx.compute_cuda_stream;
(this->*cur_handle_)(actor_msg, kernel_ctx);
}
void MdUpdtCompActor::HandleBeforeInitializeModel(
const ActorMsg& actor_msg,
const KernelContext& kernel_ctx) {
const KernelCtx& kernel_ctx) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kInitializeModel);
Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_);
model_regst->set_model_version_id(0);
......@@ -46,7 +46,7 @@ void MdUpdtCompActor::HandleBeforeInitializeModel(
void MdUpdtCompActor::HandleBeforeSendInitialModel(
const ActorMsg& actor_msg,
const KernelContext& kernel_ctx) {
const KernelCtx& kernel_ctx) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kSendInitialModel);
CurWriteDone();
SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_);
......@@ -55,7 +55,7 @@ void MdUpdtCompActor::HandleBeforeSendInitialModel(
void MdUpdtCompActor::HandleForUpdateModel(
const ActorMsg& actor_msg,
const KernelContext& kernel_ctx) {
const KernelCtx& kernel_ctx) {
if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK(actor_msg.actor_cmd() == ActorCmd::kStop);
TODO();
......@@ -69,7 +69,7 @@ void MdUpdtCompActor::HandleForUpdateModel(
void MdUpdtCompActor::ProcessRegstFromMsg(
std::shared_ptr<RegstWarpper> regst_warpper,
const KernelContext& kernel_ctx) {
const KernelCtx& kernel_ctx) {
if (TryUpdtStateAsFromRegstReader(regst_warpper->regst_raw_ptr()) != 0) {
waiting_model_diff_acc_queue_.push(regst_warpper);
}
......
......@@ -15,13 +15,13 @@ class MdUpdtCompActor final : public CompActor {
void ProcessMsg(const ActorMsg&, const ThreadContext&) override;
private:
void HandleBeforeInitializeModel(const ActorMsg&, const KernelContext&);
void HandleBeforeSendInitialModel(const ActorMsg&, const KernelContext&);
void HandleForUpdateModel(const ActorMsg&, const KernelContext&);
void HandleBeforeInitializeModel(const ActorMsg&, const KernelCtx&);
void HandleBeforeSendInitialModel(const ActorMsg&, const KernelCtx&);
void HandleForUpdateModel(const ActorMsg&, const KernelCtx&);
void ProcessRegstFromMsg(std::shared_ptr<RegstWarpper>, const KernelContext&);
void ProcessRegstFromMsg(std::shared_ptr<RegstWarpper>, const KernelCtx&);
void (MdUpdtCompActor::*cur_handle_)(const ActorMsg&, const KernelContext&);
void (MdUpdtCompActor::*cur_handle_)(const ActorMsg&, const KernelCtx&);
uint64_t model_regst_desc_id_;
uint64_t model_tmp_regst_desc_id_;
std::queue<std::shared_ptr<RegstWarpper>> waiting_model_diff_acc_queue_;
......
......@@ -5,19 +5,19 @@ namespace oneflow {
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kCPU, floating_point_type>::Forward(
const KernelContext&,
const KernelCtx&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kCPU, floating_point_type>::Backward(
const KernelContext&,
const KernelCtx&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
INSTANTIATE_CPU_KERNEL_CLASS(ConvolutionKernel);
REGISTER_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel);
REGISTER_CPU_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel);
} // namespace oneflow
......@@ -5,18 +5,20 @@ namespace oneflow {
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kGPU, floating_point_type>::Forward(
const KernelContext&,
const KernelCtx&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
template<typename floating_point_type>
void ConvolutionKernel<DeviceType::kGPU, floating_point_type>::Backward(
const KernelContext&,
const KernelCtx&,
std::function<Blob*(const std::string&)> bn_in_op2blob_ptr) const {
TODO();
}
INSTANTIATE_GPU_KERNEL_CLASS(ConvolutionKernel);
REGISTER_GPU_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel);
} // namespace oneflow
......@@ -19,8 +19,8 @@ class ConvolutionKernel<DeviceType::kCPU, floating_point_type> final : public Ke
ConvolutionKernel() = default;
~ConvolutionKernel() = default;
void Forward(const KernelContext&, std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelContext&, std::function<Blob*(const std::string&)>) const override;
void Forward(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
};
template<typename floating_point_type>
......@@ -30,8 +30,8 @@ class ConvolutionKernel<DeviceType::kGPU, floating_point_type> final : public Ke
ConvolutionKernel() = default;
~ConvolutionKernel() = default;
void Forward(const KernelContext&, std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelContext&, std::function<Blob*(const std::string&)>) const override;
void Forward(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
void Backward(const KernelCtx&, std::function<Blob*(const std::string&)>) const override;
};
} // namespace oneflow
......
......@@ -9,7 +9,7 @@ void Kernel::InitFromOpProto(const OperatorProto& op_proto) {
}
void Kernel::InitModelAndModelTmpBlobs(
const KernelContext& ctx,
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> Blob4BnInOp) const {
TODO();
}
......
......@@ -9,13 +9,10 @@
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/operator/operator.pb.h"
#include "oneflow/core/kernel/kernel_context.h"
namespace oneflow {
struct KernelContext {
const cudaStream_t* cuda_stream;
};
class Kernel {
public:
OF_DISALLOW_COPY_AND_MOVE(Kernel);
......@@ -24,17 +21,17 @@ class Kernel {
void InitFromOpProto(const OperatorProto& op_proto);
void InitModelAndModelTmpBlobs(
const KernelContext& ctx,
const KernelCtx& ctx,
std::function<Blob*(const std::string&)> Blob4BnInOp) const;
// for Forward / Bp Calculation in FwExecGragh node and BpExecGragh node
// through bn_in_op2blob_ptr function get the input blob and output blob
// the Kernel will using the input blob calculate the result and fill output
virtual void Forward(
const KernelContext& ctx,
const KernelCtx& ctx,
std::function<Blob*(const std::string&)>) const = 0;
virtual void Backward(
const KernelContext& ctx,
const KernelCtx& ctx,
std::function<Blob*(const std::string&)>) const = 0;
//
......@@ -51,7 +48,7 @@ class Kernel {
};
using KernelWardFunc = void (Kernel::*)(
const KernelContext&, std::function<Blob*(const std::string&)>) const;
const KernelCtx&, std::function<Blob*(const std::string&)>) const;
#define INSTANTIATE_CPU_KERNEL_CLASS(classname) \
char gInstantiationGuardCPU##classname; \
......
#ifndef ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_
#define ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/channel.h"
namespace oneflow {
struct KernelCtx {
Channel<std::function<void()>>* cpu_channel;
const cudaStream_t* cuda_stream;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_
......@@ -67,9 +67,11 @@ struct GpuDoubleKernelRegister {
}
};
#define REGISTER_KERNEL(OpTypeCase, KernelType) \
#define REGISTER_CPU_KERNEL(OpTypeCase, KernelType) \
static CpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, float>> g_##KernelType##_cpu_float_regst_var; \
static CpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, double>> g_##KernelType##_cpu_double_regst_var; \
static CpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kCPU, double>> g_##KernelType##_cpu_double_regst_var;
#define REGISTER_GPU_KERNEL(OpTypeCase, KernelType) \
static GpuFloatKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, float>> g_##KernelType##_gpu_float_regst_var; \
static GpuDoubleKernelRegister<OpTypeCase, KernelType<DeviceType::kGPU, double>> g_##KernelType##_gpu_double_regst_var;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册