diff --git a/oneflow/core/actor/actor.cpp b/oneflow/core/actor/actor.cpp index 7a209da2e3be91071f626b735f63348d4505f1f8..80469c09484cbe3af16683298d113d4a9f8c84ea 100644 --- a/oneflow/core/actor/actor.cpp +++ b/oneflow/core/actor/actor.cpp @@ -44,7 +44,7 @@ void Actor::Init(const TaskProto& task_proto) { } void Actor::WardKernel( - const KernelContext& kernel_ctx, + const KernelCtx& kernel_ctx, std::function(uint64_t)> Regst4RegstDescId) { for (const ExecKernel& ek : exec_kernel_vec_) { (ek.kernel->*ward_func_)(kernel_ctx, [&](const std::string& bn_in_op) { diff --git a/oneflow/core/actor/actor.h b/oneflow/core/actor/actor.h index 63985009776aa7d1614958f59d67d8704991820b..9ef7c4f3267797a17cbdd6ac4b54be6a2d660fd5 100644 --- a/oneflow/core/actor/actor.h +++ b/oneflow/core/actor/actor.h @@ -34,7 +34,7 @@ class Actor { Actor() = default; void WardKernel( - const KernelContext& kernel_ctx, + const KernelCtx& kernel_ctx, std::function(uint64_t)> Regst4RegstDescId); void ForEachProducedRegst(std::function); uint64_t RegstDescId4Name(const std::string& name) const { diff --git a/oneflow/core/actor/boxing_actor.cpp b/oneflow/core/actor/boxing_actor.cpp index a24f3379705fda7a62eb1c3dd9b0910404f4a5a7..cd1f9f16d14c1e0e74ac0b0c75bfa17b2479b852 100644 --- a/oneflow/core/actor/boxing_actor.cpp +++ b/oneflow/core/actor/boxing_actor.cpp @@ -11,7 +11,7 @@ void BoxingActor::Init(const TaskProto& task_proto) { void BoxingActor::ProcessMsg(const ActorMsg& msg, const ThreadContext& thread_ctx) { - KernelContext kernel_ctx; + KernelCtx kernel_ctx; if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) { std::shared_ptr regst_wp = msg.regst_warpper(); auto waiting_in_regst_it = waiting_in_regst_.find(regst_wp->piece_id()); @@ -36,7 +36,7 @@ void BoxingActor::ProcessMsg(const ActorMsg& msg, } } -void BoxingActor::WardKernelAndSendMsg(const KernelContext& kernel_ctx) { +void BoxingActor::WardKernelAndSendMsg(const KernelCtx& kernel_ctx) { uint64_t piece_id = ready_in_regst_.front().first; WardKernel(kernel_ctx, [this](uint64_t regst_desc_id) -> std::shared_ptr { Regst* regst = GetCurWriteableRegst(regst_desc_id); diff --git a/oneflow/core/actor/boxing_actor.h b/oneflow/core/actor/boxing_actor.h index 7135710f524891c561551a89d84deace3ce6d1f5..af8d52e027af591b12bab76bbbaa093d9f438ca0 100644 --- a/oneflow/core/actor/boxing_actor.h +++ b/oneflow/core/actor/boxing_actor.h @@ -18,7 +18,7 @@ class BoxingActor final : public Actor { using RDescId2RwMap = HashMap>; using RDescId2RwMapPtr = std::unique_ptr; - void WardKernelAndSendMsg(const KernelContext&); + void WardKernelAndSendMsg(const KernelCtx&); // HashMap waiting_in_regst_; diff --git a/oneflow/core/actor/copy_actor.cpp b/oneflow/core/actor/copy_actor.cpp index 56a0ca9f77427b2b1e66072ce3c758d244976d15..8e23f1e364a4b752714948896844f40b37b63b2d 100644 --- a/oneflow/core/actor/copy_actor.cpp +++ b/oneflow/core/actor/copy_actor.cpp @@ -9,7 +9,7 @@ void CopyActor::Init(const TaskProto& task_proto) { } void CopyActor::ProcessMsgWithKernelCtx(const ActorMsg& msg, - const KernelContext& kernel_ctx) { + const KernelCtx& kernel_ctx) { if (TryUpdtStateAsFromRegstReader(msg.regst_warpper()->regst_raw_ptr()) != 0) { waiting_in_regst_.push(std::move(msg.regst_warpper())); } diff --git a/oneflow/core/actor/copy_actor.h b/oneflow/core/actor/copy_actor.h index d12a4848a39d9fc3023d0c7bfc256428f37503f3..b26be0260a859ca2851bff778fa507675c6037c3 100644 --- a/oneflow/core/actor/copy_actor.h +++ b/oneflow/core/actor/copy_actor.h @@ -15,7 +15,7 @@ public: protected: CopyActor() = default; - void ProcessMsgWithKernelCtx(const ActorMsg& msg, const KernelContext& kernel_ctx); + void ProcessMsgWithKernelCtx(const ActorMsg& msg, const KernelCtx& kernel_ctx); private: std::queue> waiting_in_regst_; diff --git a/oneflow/core/actor/copy_comm_net_actor.cpp b/oneflow/core/actor/copy_comm_net_actor.cpp index ee901ec028334d472dd14c27a77577dd383577e6..2e72be96d19afe8c1810941e0e7c9add83826d1c 100644 --- a/oneflow/core/actor/copy_comm_net_actor.cpp +++ b/oneflow/core/actor/copy_comm_net_actor.cpp @@ -6,7 +6,7 @@ namespace oneflow { void CopyCommNetActor::ProcessMsg(const ActorMsg& msg, const ThreadContext&) { - KernelContext kernel_ctx; + KernelCtx kernel_ctx; ProcessMsgWithKernelCtx(msg, kernel_ctx); } diff --git a/oneflow/core/actor/copy_hd_actor.cpp b/oneflow/core/actor/copy_hd_actor.cpp index c09200a8131fac0ade6193357b97ae0870728133..87398e71f1cdc7abd08c0609c251604008fdf008 100644 --- a/oneflow/core/actor/copy_hd_actor.cpp +++ b/oneflow/core/actor/copy_hd_actor.cpp @@ -6,7 +6,7 @@ namespace oneflow { void CopyHdActor::ProcessMsg(const ActorMsg& msg, const ThreadContext& thread_ctx) { - KernelContext kernel_ctx; + KernelCtx kernel_ctx; kernel_ctx.cuda_stream = thread_ctx.copy_hd_cuda_stream; ProcessMsgWithKernelCtx(msg, kernel_ctx); } diff --git a/oneflow/core/actor/fw_data_comp_actor.cpp b/oneflow/core/actor/fw_data_comp_actor.cpp index a8c9aea800cbee9e83cc34fc54f843f4ebca332e..126b8d549e6fbf758a65ea7748a2362c76a3406d 100644 --- a/oneflow/core/actor/fw_data_comp_actor.cpp +++ b/oneflow/core/actor/fw_data_comp_actor.cpp @@ -25,7 +25,7 @@ bool FwDataCompActor::IsReadReady() { void FwDataCompActor::ProcessMsg(const ActorMsg& msg, const ThreadContext& thread_ctx) { - KernelContext kernel_ctx; + KernelCtx kernel_ctx; kernel_ctx.cuda_stream = thread_ctx.compute_cuda_stream; if (msg.msg_type() == ActorMsgType::kCmdMsg) { TODO(); @@ -55,7 +55,7 @@ void FwDataCompActor::ProcessMsg(const ActorMsg& msg, } } -void FwDataCompActor::WardKernelAndSendMsg(const KernelContext& kernel_ctx) { +void FwDataCompActor::WardKernelAndSendMsg(const KernelCtx& kernel_ctx) { CHECK_EQ(in_.front()->piece_id(), expected_piece_id()); ready_in_regst_[in_.front()->regst_desc_id()] = in_.front(); uint64_t piece_id = in_.front()->piece_id(); diff --git a/oneflow/core/actor/fw_data_comp_actor.h b/oneflow/core/actor/fw_data_comp_actor.h index e088a5ba7585e300e1d42f473236487236421d7c..17362eeacec288d36803cd8700d8e397f1793c15 100644 --- a/oneflow/core/actor/fw_data_comp_actor.h +++ b/oneflow/core/actor/fw_data_comp_actor.h @@ -16,7 +16,7 @@ public: private: bool IsReadReady(); - void WardKernelAndSendMsg(const KernelContext&); + void WardKernelAndSendMsg(const KernelCtx&); uint64_t expected_model_version_id_; uint64_t model_regst_desc_id_; diff --git a/oneflow/core/actor/model_update_comp_actor.cpp b/oneflow/core/actor/model_update_comp_actor.cpp index d417b2f488e073c0121296ce5f40acc33075aea8..b78190f59245d573437ed2940d9fe27e6f963f9b 100644 --- a/oneflow/core/actor/model_update_comp_actor.cpp +++ b/oneflow/core/actor/model_update_comp_actor.cpp @@ -12,14 +12,14 @@ void MdUpdtCompActor::Init(const TaskProto& task_proto) { void MdUpdtCompActor::ProcessMsg(const ActorMsg& actor_msg, const ThreadContext& thread_ctx) { - KernelContext kernel_ctx; + KernelCtx kernel_ctx; kernel_ctx.cuda_stream = thread_ctx.compute_cuda_stream; (this->*cur_handle_)(actor_msg, kernel_ctx); } void MdUpdtCompActor::HandleBeforeInitializeModel( const ActorMsg& actor_msg, - const KernelContext& kernel_ctx) { + const KernelCtx& kernel_ctx) { CHECK(actor_msg.actor_cmd() == ActorCmd::kInitializeModel); Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_); model_regst->set_model_version_id(0); @@ -46,7 +46,7 @@ void MdUpdtCompActor::HandleBeforeInitializeModel( void MdUpdtCompActor::HandleBeforeSendInitialModel( const ActorMsg& actor_msg, - const KernelContext& kernel_ctx) { + const KernelCtx& kernel_ctx) { CHECK(actor_msg.actor_cmd() == ActorCmd::kSendInitialModel); CurWriteDone(); SetReadOnlyForRegstDescId(model_tmp_regst_desc_id_); @@ -55,7 +55,7 @@ void MdUpdtCompActor::HandleBeforeSendInitialModel( void MdUpdtCompActor::HandleForUpdateModel( const ActorMsg& actor_msg, - const KernelContext& kernel_ctx) { + const KernelCtx& kernel_ctx) { if (actor_msg.msg_type() == ActorMsgType::kCmdMsg) { CHECK(actor_msg.actor_cmd() == ActorCmd::kStop); TODO(); @@ -69,7 +69,7 @@ void MdUpdtCompActor::HandleForUpdateModel( void MdUpdtCompActor::ProcessRegstFromMsg( std::shared_ptr regst_warpper, - const KernelContext& kernel_ctx) { + const KernelCtx& kernel_ctx) { if (TryUpdtStateAsFromRegstReader(regst_warpper->regst_raw_ptr()) != 0) { waiting_model_diff_acc_queue_.push(regst_warpper); } diff --git a/oneflow/core/actor/model_update_comp_actor.h b/oneflow/core/actor/model_update_comp_actor.h index 704810637b56e9b9e219bb1d4f7c0d059673ddab..579765b961a86560b6b5a7e9843139f28cab4b97 100644 --- a/oneflow/core/actor/model_update_comp_actor.h +++ b/oneflow/core/actor/model_update_comp_actor.h @@ -15,13 +15,13 @@ class MdUpdtCompActor final : public CompActor { void ProcessMsg(const ActorMsg&, const ThreadContext&) override; private: - void HandleBeforeInitializeModel(const ActorMsg&, const KernelContext&); - void HandleBeforeSendInitialModel(const ActorMsg&, const KernelContext&); - void HandleForUpdateModel(const ActorMsg&, const KernelContext&); + void HandleBeforeInitializeModel(const ActorMsg&, const KernelCtx&); + void HandleBeforeSendInitialModel(const ActorMsg&, const KernelCtx&); + void HandleForUpdateModel(const ActorMsg&, const KernelCtx&); - void ProcessRegstFromMsg(std::shared_ptr, const KernelContext&); + void ProcessRegstFromMsg(std::shared_ptr, const KernelCtx&); - void (MdUpdtCompActor::*cur_handle_)(const ActorMsg&, const KernelContext&); + void (MdUpdtCompActor::*cur_handle_)(const ActorMsg&, const KernelCtx&); uint64_t model_regst_desc_id_; uint64_t model_tmp_regst_desc_id_; std::queue> waiting_model_diff_acc_queue_; diff --git a/oneflow/core/kernel/convolution_kernel.cpp b/oneflow/core/kernel/convolution_kernel.cpp index 463231bc4bb0862885bb8c563e50e342fdb3e21d..017ba0bd93adba281af59c4e1760e31f3e7e7006 100644 --- a/oneflow/core/kernel/convolution_kernel.cpp +++ b/oneflow/core/kernel/convolution_kernel.cpp @@ -5,19 +5,19 @@ namespace oneflow { template void ConvolutionKernel::Forward( - const KernelContext&, + const KernelCtx&, std::function bn_in_op2blob_ptr) const { TODO(); } template void ConvolutionKernel::Backward( - const KernelContext&, + const KernelCtx&, std::function bn_in_op2blob_ptr) const { TODO(); } INSTANTIATE_CPU_KERNEL_CLASS(ConvolutionKernel); -REGISTER_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel); +REGISTER_CPU_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel); } // namespace oneflow diff --git a/oneflow/core/kernel/convolution_kernel.cu b/oneflow/core/kernel/convolution_kernel.cu index 3f6d0a837ed9ada82cab4f529cf0a205ef179126..7febef1dfde4eb0ba62ede76094dd5076bbcbe7c 100644 --- a/oneflow/core/kernel/convolution_kernel.cu +++ b/oneflow/core/kernel/convolution_kernel.cu @@ -5,18 +5,20 @@ namespace oneflow { template void ConvolutionKernel::Forward( - const KernelContext&, + const KernelCtx&, std::function bn_in_op2blob_ptr) const { TODO(); } template void ConvolutionKernel::Backward( - const KernelContext&, + const KernelCtx&, std::function bn_in_op2blob_ptr) const { TODO(); } INSTANTIATE_GPU_KERNEL_CLASS(ConvolutionKernel); +REGISTER_GPU_KERNEL(OperatorConf::kConvolutionConf, ConvolutionKernel); + } // namespace oneflow diff --git a/oneflow/core/kernel/convolution_kernel.h b/oneflow/core/kernel/convolution_kernel.h index d4936b3b8f608158113f6b5ddb4ce272974bf68e..9e2f6c55bf195fd8d829808718c36de8bc806c5b 100644 --- a/oneflow/core/kernel/convolution_kernel.h +++ b/oneflow/core/kernel/convolution_kernel.h @@ -19,8 +19,8 @@ class ConvolutionKernel final : public Ke ConvolutionKernel() = default; ~ConvolutionKernel() = default; - void Forward(const KernelContext&, std::function) const override; - void Backward(const KernelContext&, std::function) const override; + void Forward(const KernelCtx&, std::function) const override; + void Backward(const KernelCtx&, std::function) const override; }; template @@ -30,8 +30,8 @@ class ConvolutionKernel final : public Ke ConvolutionKernel() = default; ~ConvolutionKernel() = default; - void Forward(const KernelContext&, std::function) const override; - void Backward(const KernelContext&, std::function) const override; + void Forward(const KernelCtx&, std::function) const override; + void Backward(const KernelCtx&, std::function) const override; }; } // namespace oneflow diff --git a/oneflow/core/kernel/kernel.cpp b/oneflow/core/kernel/kernel.cpp index 507e5b15238cc56845b42827dadbc628845523eb..c1e20defc78a165fd88bcbc295aed7495e51174e 100644 --- a/oneflow/core/kernel/kernel.cpp +++ b/oneflow/core/kernel/kernel.cpp @@ -9,7 +9,7 @@ void Kernel::InitFromOpProto(const OperatorProto& op_proto) { } void Kernel::InitModelAndModelTmpBlobs( - const KernelContext& ctx, + const KernelCtx& ctx, std::function Blob4BnInOp) const { TODO(); } diff --git a/oneflow/core/kernel/kernel.h b/oneflow/core/kernel/kernel.h index 5548ce2e27144d2dce28732d6535c7f269836c42..12d888e71b3194d9f075cbc9ad8c4b38d7e0d1c6 100644 --- a/oneflow/core/kernel/kernel.h +++ b/oneflow/core/kernel/kernel.h @@ -9,13 +9,10 @@ #include "oneflow/core/operator/operator.h" #include "oneflow/core/operator/operator_manager.h" #include "oneflow/core/operator/operator.pb.h" +#include "oneflow/core/kernel/kernel_context.h" namespace oneflow { -struct KernelContext { - const cudaStream_t* cuda_stream; -}; - class Kernel { public: OF_DISALLOW_COPY_AND_MOVE(Kernel); @@ -24,17 +21,17 @@ class Kernel { void InitFromOpProto(const OperatorProto& op_proto); void InitModelAndModelTmpBlobs( - const KernelContext& ctx, + const KernelCtx& ctx, std::function Blob4BnInOp) const; // for Forward / Bp Calculation in FwExecGragh node and BpExecGragh node // through bn_in_op2blob_ptr function get the input blob and output blob // the Kernel will using the input blob calculate the result and fill output virtual void Forward( - const KernelContext& ctx, + const KernelCtx& ctx, std::function) const = 0; virtual void Backward( - const KernelContext& ctx, + const KernelCtx& ctx, std::function) const = 0; // @@ -51,7 +48,7 @@ class Kernel { }; using KernelWardFunc = void (Kernel::*)( - const KernelContext&, std::function) const; + const KernelCtx&, std::function) const; #define INSTANTIATE_CPU_KERNEL_CLASS(classname) \ char gInstantiationGuardCPU##classname; \ diff --git a/oneflow/core/kernel/kernel_context.h b/oneflow/core/kernel/kernel_context.h new file mode 100644 index 0000000000000000000000000000000000000000..f567fa434680584bfc854733a6017fc94d3242d3 --- /dev/null +++ b/oneflow/core/kernel/kernel_context.h @@ -0,0 +1,16 @@ +#ifndef ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ +#define ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ + +#include "oneflow/core/common/util.h" +#include "oneflow/core/common/channel.h" + +namespace oneflow { + +struct KernelCtx { + Channel>* cpu_channel; + const cudaStream_t* cuda_stream; +}; + +} // namespace oneflow + +#endif // ONEFLOW_CORE_KERNEL_KERNEL_CONTEXT_H_ diff --git a/oneflow/core/kernel/kernel_manager.h b/oneflow/core/kernel/kernel_manager.h index 70ec21719adf21e3e5e7918bfe1639b154e2bae7..7449813121de8078cc5ac7cad1e00aa11ed1a081 100644 --- a/oneflow/core/kernel/kernel_manager.h +++ b/oneflow/core/kernel/kernel_manager.h @@ -67,9 +67,11 @@ struct GpuDoubleKernelRegister { } }; -#define REGISTER_KERNEL(OpTypeCase, KernelType) \ +#define REGISTER_CPU_KERNEL(OpTypeCase, KernelType) \ static CpuFloatKernelRegister> g_##KernelType##_cpu_float_regst_var; \ - static CpuDoubleKernelRegister> g_##KernelType##_cpu_double_regst_var; \ + static CpuDoubleKernelRegister> g_##KernelType##_cpu_double_regst_var; + +#define REGISTER_GPU_KERNEL(OpTypeCase, KernelType) \ static GpuFloatKernelRegister> g_##KernelType##_gpu_float_regst_var; \ static GpuDoubleKernelRegister> g_##KernelType##_gpu_double_regst_var;