提交 03db51d7 编写于 作者: W willzhang4a58

first auto format

上级 3fbb2567
......@@ -28,8 +28,8 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
}
// name2regst_desc_id_
for (const auto& pair : task_proto.produced_regst_desc()) {
CHECK(name2regst_desc_id_.emplace(pair.first,
pair.second.regst_desc_id()).second);
CHECK(name2regst_desc_id_.emplace(pair.first, pair.second.regst_desc_id())
.second);
}
for (const auto& pair : task_proto.subscribed_regst_desc_id()) {
CHECK(name2regst_desc_id_.emplace(pair.first, pair.second).second);
......@@ -48,9 +48,7 @@ void Actor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
int64_t Actor::RegstDescId4Name(const std::string& name) const {
auto find_it = name2regst_desc_id_.find(name);
if (find_it != name2regst_desc_id_.end()) {
return find_it->second;
}
if (find_it != name2regst_desc_id_.end()) { return find_it->second; }
return -1;
}
......@@ -61,7 +59,8 @@ KernelCtx Actor::GenDefaultKernelCtx() const {
}
int Actor::HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
if (total_reading_cnt_ == 0) {
msg_handle_ = nullptr;
return 1;
......@@ -92,7 +91,8 @@ void Actor::AsyncSendReadableRegstMsg() {
ActorMsgBus::Singleton().SendMsg(std::move(msg));
}
});
produced_regst2reading_cnt_.at(regst) = regst->subscribers_actor_id().size();
produced_regst2reading_cnt_.at(regst) =
regst->subscribers_actor_id().size();
total_reading_cnt_ += regst->subscribers_actor_id().size();
if (!regst->subscribers_actor_id().empty()) { pair.second.pop(); }
if (pair.second.empty()) { writeable_produced_regst_desc_num_ -= 1; }
......@@ -121,13 +121,11 @@ void Actor::AsyncDo(std::function<void()> func) {
device_ctx_->AddCallBack(func);
}
void Actor::AsyncSendRegstMsgToProducer(const std::shared_ptr<RegstWarpper>& wp) {
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(
wp->producer_actor_id(),
wp->regst_raw_ptr());
AsyncDo([msg]() {
ActorMsgBus::Singleton().SendMsg(msg);
});
void Actor::AsyncSendRegstMsgToProducer(
const std::shared_ptr<RegstWarpper>& wp) {
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(wp->producer_actor_id(),
wp->regst_raw_ptr());
AsyncDo([msg]() { ActorMsgBus::Singleton().SendMsg(msg); });
}
int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
......@@ -139,9 +137,7 @@ int Actor::TryUpdtStateAsProducedRegst(Regst* regst) {
if (reading_cnt_it->second != 0) { return 0; }
auto writeable_it = writeable_produced_regst_.find(regst->regst_desc_id());
if (writeable_it == writeable_produced_regst_.end()) { return 0; }
if (writeable_it->second.empty()) {
writeable_produced_regst_desc_num_ += 1;
}
if (writeable_it->second.empty()) { writeable_produced_regst_desc_num_ += 1; }
writeable_it->second.push(regst);
return 0;
}
......@@ -166,4 +162,4 @@ bool Actor::IsWriteReady() {
return writeable_produced_regst_desc_num_ == writeable_produced_regst_.size();
}
} // namespace oneflow
} // namespace oneflow
#ifndef ONEFLOW_CORE_ACTOR_ACTOR_H_
#define ONEFLOW_CORE_ACTOR_ACTOR_H_
#include "oneflow/core/common/cuda_stream_handle.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/actor/cpu_device_context.h"
#include "oneflow/core/actor/cuda_device_context.h"
#include "oneflow/core/common/cuda_stream_handle.h"
#include "oneflow/core/job/task.pb.h"
#include "oneflow/core/actor/actor_message_bus.h"
#include "oneflow/core/kernel/kernel_context.h"
#include "oneflow/core/kernel/kernel_manager.h"
#include "oneflow/core/persistence/snapshot_manager.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/register/register_manager.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/thread/thread_context.h"
#include "oneflow/core/persistence/snapshot_manager.h"
namespace oneflow {
......@@ -24,12 +24,10 @@ class Actor {
virtual void Init(const TaskProto&, const ThreadCtx&) = 0;
// 1: success, and actor finish
// 0: success, and actor not finish
int ProcessMsg(const ActorMsg& msg) {
return (this->*msg_handle_)(msg);
}
int ProcessMsg(const ActorMsg& msg) { return (this->*msg_handle_)(msg); }
int64_t actor_id() const { return actor_id_; }
protected:
struct ExecKernel {
const Kernel* kernel;
......@@ -45,11 +43,11 @@ class Actor {
// Msg Handle
using MsgHandle = int (Actor::*)(const ActorMsg&);
void set_msg_handle(MsgHandle val) { msg_handle_ = val; }
#define OF_SET_MSG_HANDLE(val) \
do { \
LOG(INFO) << "Actor " << actor_id() << " switch to " << #val; \
set_msg_handle(static_cast<MsgHandle>(val)); \
} while(0)
#define OF_SET_MSG_HANDLE(val) \
do { \
LOG(INFO) << "Actor " << actor_id() << " switch to " << #val; \
set_msg_handle(static_cast<MsgHandle>(val)); \
} while (0)
// Common Handles
int HandleWaitUntilReadingCntEqualZero(const ActorMsg& msg);
......@@ -80,22 +78,23 @@ class Actor {
int64_t actor_id_;
KernelWardFunc ward_func_;
std::vector<ExecKernel> exec_kernel_vec_;
HashMap<int64_t, std::vector<std::unique_ptr<Regst>>> produced_regsts_; // <regst_desc_id, regst>
HashMap<int64_t, std::vector<std::unique_ptr<Regst>>>
produced_regsts_; // <regst_desc_id, regst>
HashMap<std::string, int64_t> name2regst_desc_id_;
std::unique_ptr<DeviceCtx> device_ctx_;
MsgHandle msg_handle_;
// Status of Produced Registers
int64_t expected_piece_id_;
HashMap<int64_t, std::queue<Regst*>> writeable_produced_regst_; // <regst_desc_id, regst>
HashMap<int64_t, std::queue<Regst*>>
writeable_produced_regst_; // <regst_desc_id, regst>
int64_t writeable_produced_regst_desc_num_;
HashMap<Regst*, int64_t> produced_regst2reading_cnt_;
int64_t total_reading_cnt_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_ACTOR_H_
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/register/remote_register_warpper.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/runtime_context.h"
#include "oneflow/core/register/local_register_warpper.h"
#include "oneflow/core/register/remote_register_warpper.h"
namespace oneflow {
OF_DEFINE_ENUM_TO_OSTREAM_FUNC(ActorCmd);
OF_DEFINE_ENUM_TO_OSTREAM_FUNC(ActorMsgType);
ActorMsg::ActorMsg() {
dst_actor_id_ = -1;
}
ActorMsg::ActorMsg() { dst_actor_id_ = -1; }
ActorMsg ActorMsg::BuildReadableRegstMsg(int64_t reader_actor_id,
Regst* regst_raw_ptr) {
......@@ -36,4 +34,4 @@ ActorMsg ActorMsg::BuildRegstMsgToProducer(int64_t writer_actor_id,
return msg;
}
} // namespace oneflow
} // namespace oneflow
......@@ -7,18 +7,15 @@
namespace oneflow {
enum class ActorCmd {
kInitializeModel = 0, // MdUpdt Actor
kSendInitialModel, // MdUpdt Actor
kEORD, // End Of Register Desc, All Actor except Source Actor
kStart // Source Actor
kInitializeModel = 0, // MdUpdt Actor
kSendInitialModel, // MdUpdt Actor
kEORD, // End Of Register Desc, All Actor except Source Actor
kStart // Source Actor
};
OF_DECLARE_ENUM_TO_OSTREAM_FUNC(ActorCmd);
enum class ActorMsgType {
kRegstMsg = 0,
kCmdMsg
};
enum class ActorMsgType { kRegstMsg = 0, kCmdMsg };
OF_DECLARE_ENUM_TO_OSTREAM_FUNC(ActorMsgType);
......@@ -43,9 +40,7 @@ class ActorMsg final {
return actor_cmd_;
}
// Setters
void set_dst_actor_id(int64_t val) {
dst_actor_id_ = val;
}
void set_dst_actor_id(int64_t val) { dst_actor_id_ = val; }
void set_regst_warpper(std::shared_ptr<RegstWarpper> val) {
msg_type_ = ActorMsgType::kRegstMsg;
regst_warpper_ = val;
......@@ -54,17 +49,15 @@ class ActorMsg final {
msg_type_ = ActorMsgType::kCmdMsg;
actor_cmd_ = val;
}
private:
private:
int64_t dst_actor_id_;
ActorMsgType msg_type_;
std::shared_ptr<RegstWarpper> regst_warpper_;
ActorCmd actor_cmd_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_H_
......@@ -7,10 +7,10 @@ namespace oneflow {
void ActorMsgBus::SendMsg(const ActorMsg& msg) {
int64_t dst_machine_id =
IDMgr::Singleton().MachineId4ActorId(msg.dst_actor_id());
IDMgr::Singleton().MachineId4ActorId(msg.dst_actor_id());
if (dst_machine_id == RuntimeCtx::Singleton().this_machine_id()) {
int64_t thrd_loc_id =
IDMgr::Singleton().ThrdLocId4ActorId(msg.dst_actor_id());
IDMgr::Singleton().ThrdLocId4ActorId(msg.dst_actor_id());
ThreadMgr::Singleton().GetThrd(thrd_loc_id)->GetMsgChannelPtr()->Send(msg);
} else {
TODO();
......
......@@ -2,8 +2,8 @@
#define ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
#include <stdint.h>
#include "oneflow/core/common/util.h"
#include "oneflow/core/actor/actor_message.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -22,4 +22,4 @@ class ActorMsgBus final {
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
#endif // ONEFLOW_CORE_ACTOR_ACTOR_MESSAGE_BUS_H_
......@@ -5,20 +5,21 @@ namespace oneflow {
namespace {
struct PairHash {
std::size_t operator () (const std::pair<int, bool> &p) const {
std::size_t operator()(const std::pair<int, bool>& p) const {
return std::hash<int>{}((p.first << 1) | (static_cast<int>(p.second)));
}
};
using ActorTypePair = std::pair<TaskType, bool>;
using ActorCreatorMap = HashMap<ActorTypePair, std::function<Actor*()>, PairHash>;
using ActorCreatorMap =
HashMap<ActorTypePair, std::function<Actor*()>, PairHash>;
ActorCreatorMap& ActorType2Creator() {
static ActorCreatorMap obj;
return obj;
}
}
} // namespace
void AddActorCreator(TaskType task_type, bool is_forward,
std::function<Actor*()> creator) {
......
......@@ -20,8 +20,9 @@ struct ActorRegister {
};
#define REGISTER_ACTOR(TaskType, IsForward, ActorType) \
static ActorRegister<TaskType, IsForward, ActorType> g_##ActorType##_##IsForward##_register_var;
static ActorRegister<TaskType, IsForward, ActorType> \
g_##ActorType##_##IsForward##_register_var;
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_ACTOR_REGISTRY_H_
#endif // ONEFLOW_CORE_ACTOR_ACTOR_REGISTRY_H_
......@@ -4,7 +4,8 @@
namespace oneflow {
void BoxingActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void BoxingActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
num_of_subscribed_regsts_ = task_proto.subscribed_regst_desc_id().size();
num_of_read_empty_ = num_of_subscribed_regsts_;
......@@ -22,7 +23,8 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
OF_SET_MSG_HANDLE(&BoxingActor::HandleBoxingWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
std::shared_ptr<RegstWarpper> regst_wp = msg.regst_warpper();
num_of_read_empty_ -= read_regst_[regst_wp->regst_desc_id()].empty();
read_regst_.at(regst_wp->regst_desc_id()).push(regst_wp);
......@@ -35,7 +37,8 @@ int BoxingActor::HandleBoxing(const ActorMsg& msg) {
}
int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (num_of_read_empty_ == num_of_subscribed_regsts_) {
AsyncSendEORDMsgForAllProducedRegstDesc();
......@@ -49,25 +52,25 @@ int BoxingActor::HandleBoxingWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return 0;
}
void BoxingActor::TryWardKernelAndSendMsg() {
if (!num_of_read_empty_ && IsWriteReady()) {
int64_t piece_id = expected_piece_id();
for (const auto& pair : read_regst_) {
CHECK_EQ(pair.second.front()->piece_id(), piece_id);
}
AsyncWardKernel(GenDefaultKernelCtx(),
AsyncWardKernel(
GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return read_regst_.at(regst_desc_id).front();
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
ForEachCurWriteableRegst([piece_id](Regst* regst) {
regst->set_piece_id(piece_id);
});
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return read_regst_.at(regst_desc_id).front();
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst(
[piece_id](Regst* regst) { regst->set_piece_id(piece_id); });
AsyncSendReadableRegstMsg();
for (auto& pair : read_regst_) {
AsyncSendRegstMsgToProducer(pair.second.front());
......
......@@ -24,9 +24,8 @@ class BoxingActor final : public Actor {
int num_of_eord_;
// <regst_desc_id, queue<regst_wp>>
HashMap<int64_t, std::queue<std::shared_ptr<RegstWarpper>>> read_regst_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_BOXING_ACTOR_H_
......@@ -4,7 +4,8 @@
namespace oneflow {
void BpDataCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void BpDataCompActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
model_regst_desc_id_ = RegstDescId4Name("model");
model_tmp_regst_desc_id_ = RegstDescId4Name("model_tmp");
......@@ -24,11 +25,11 @@ void BpDataCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_
}
bool BpDataCompActor::IsReadReady() {
if (num_of_read_empty_) {
return false;
}
if (read_regst_.at(model_regst_desc_id_).front()->model_version_id() !=
read_regst_.at(activation_regst_desc_id_).front()->model_version_id()) {
if (num_of_read_empty_) { return false; }
if (read_regst_.at(model_regst_desc_id_).front()->model_version_id()
!= read_regst_.at(activation_regst_desc_id_)
.front()
->model_version_id()) {
AsyncSendRegstMsgToProducer(read_regst_.at(model_regst_desc_id_).front());
read_regst_.at(model_regst_desc_id_).pop();
num_of_read_empty_ += read_regst_.at(model_regst_desc_id_).empty();
......@@ -44,7 +45,8 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
OF_SET_MSG_HANDLE(&BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
std::shared_ptr<RegstWarpper> regst_wp = msg.regst_warpper();
if (regst_wp->regst_desc_id() == model_tmp_regst_desc_id_) {
CHECK(read_regst_.find(model_tmp_regst_desc_id_) == read_regst_.end());
......@@ -62,14 +64,16 @@ int BpDataCompActor::HandleBpComp(const ActorMsg& msg) {
}
int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (read_regst_.at(activation_regst_desc_id_).empty()) {
while (!read_regst_.at(model_regst_desc_id_).empty()) {
AsyncSendRegstMsgToProducer(read_regst_.at(model_regst_desc_id_).front());
read_regst_.at(model_regst_desc_id_).pop();
}
AsyncSendRegstMsgToProducer(read_regst_.at(model_tmp_regst_desc_id_).front());
AsyncSendRegstMsgToProducer(
read_regst_.at(model_tmp_regst_desc_id_).front());
read_regst_.at(model_tmp_regst_desc_id_).pop();
AsyncSendEORDMsgForAllProducedRegstDesc();
num_of_read_empty_ = 6;
......@@ -83,33 +87,40 @@ int BpDataCompActor::HandleBpCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return 0;
}
void BpDataCompActor::TryWardKernelAndSendMsg() {
while (IsReadReady() && IsWriteReady()) {
int64_t cur_model = read_regst_.at(model_regst_desc_id_).front()->model_version_id();
int64_t cur_model =
read_regst_.at(model_regst_desc_id_).front()->model_version_id();
int64_t piece_id = expected_piece_id();
CHECK_EQ(cur_model, read_regst_.at(activation_regst_desc_id_).front()->model_version_id());
CHECK_EQ(cur_model, read_regst_.at(data_tmp_regst_desc_id_).front()->model_version_id());
CHECK_EQ(
cur_model,
read_regst_.at(activation_regst_desc_id_).front()->model_version_id());
CHECK_EQ(
cur_model,
read_regst_.at(data_tmp_regst_desc_id_).front()->model_version_id());
for (const auto& pair : read_regst_) {
if (pair.first != model_regst_desc_id_ && pair.first != model_tmp_regst_desc_id_) {
if (pair.first != model_regst_desc_id_
&& pair.first != model_tmp_regst_desc_id_) {
CHECK_EQ(pair.second.front()->piece_id(), piece_id);
}
}
AsyncWardKernel(GenDefaultKernelCtx(),
AsyncWardKernel(
GenDefaultKernelCtx(),
[this](int64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return read_regst_.at(regst_desc_id).front();
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
ForEachCurWriteableRegst([piece_id](Regst* regst) {
regst->set_piece_id(piece_id);
});
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return read_regst_.at(regst_desc_id).front();
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst(
[piece_id](Regst* regst) { regst->set_piece_id(piece_id); });
AsyncSendReadableRegstMsg();
for (auto& pair : read_regst_) {
if (pair.first != model_regst_desc_id_ && pair.first != model_tmp_regst_desc_id_) {
if (pair.first != model_regst_desc_id_
&& pair.first != model_tmp_regst_desc_id_) {
AsyncSendRegstMsgToProducer(pair.second.front());
pair.second.pop();
num_of_read_empty_ += pair.second.empty();
......
......@@ -6,14 +6,14 @@
namespace oneflow {
class BpDataCompActor final : public Actor {
public:
public:
OF_DISALLOW_COPY_AND_MOVE(BpDataCompActor);
BpDataCompActor() = default;
~BpDataCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
private:
private:
int HandleBpComp(const ActorMsg&);
int HandleBpCompWhenNoReadableRegstMsg(const ActorMsg&);
......@@ -34,4 +34,4 @@ private:
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_BP_DATA_COMP_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_BP_DATA_COMP_ACTOR_H_
......@@ -13,7 +13,8 @@ class CompActor : public Actor {
protected:
CompActor() = default;
virtual void Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) override {
virtual void Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) override {
Actor::Init(task_proto, thread_ctx);
parallel_id_ = task_proto.parallel_id();
}
......@@ -26,9 +27,8 @@ class CompActor : public Actor {
ParallelPolicy parallel_policy_;
int64_t parallel_id_;
int64_t parallel_num_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COMPUTE_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_COMPUTE_ACTOR_H_
......@@ -4,7 +4,8 @@
namespace oneflow {
void CopyCommNetActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void CopyCommNetActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
CHECK(thread_ctx.cpu_stream);
mut_device_ctx().reset(new CpuDeviceCtx(thread_ctx.cpu_stream));
......@@ -14,19 +15,23 @@ void CopyCommNetActor::Init(const TaskProto& task_proto, const ThreadCtx& thread
int CopyCommNetActor::HandleCopyCommNet(const ActorMsg& msg) {
if (msg.msg_type() == ActorMsgType::kCmdMsg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg);
OF_SET_MSG_HANDLE(
&CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
auto regst_wp = msg.regst_warpper();
if (TryUpdtStateAsProducedRegst(regst_wp->regst_raw_ptr()) != 0) {
CHECK(piece_id2waiting_in_regst_.emplace(regst_wp->piece_id(), regst_wp).second);
CHECK(piece_id2waiting_in_regst_.emplace(regst_wp->piece_id(), regst_wp)
.second);
}
}
TryWardKernelAndSendMsg();
return 0;
}
int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg(
const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (piece_id2waiting_in_regst_.empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
......@@ -40,23 +45,23 @@ int CopyCommNetActor::HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg& ms
}
return 0;
}
void CopyCommNetActor::TryWardKernelAndSendMsg() {
auto next_regst_it = piece_id2waiting_in_regst_.find(expected_piece_id());
if (next_regst_it == piece_id2waiting_in_regst_.end()) {
return;
}
if (next_regst_it == piece_id2waiting_in_regst_.end()) { return; }
if (IsWriteReady()) {
std::shared_ptr<RegstWarpper> regst_wp = next_regst_it->second;
AsyncWardKernel(GenDefaultKernelCtx(),
[this, &regst_wp](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return regst_wp;
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
AsyncWardKernel(
GenDefaultKernelCtx(),
[this,
&regst_wp](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return regst_wp;
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst([&regst_wp](Regst* regst) {
regst->set_piece_id(regst_wp->piece_id());
regst->set_model_version_id(regst_wp->model_version_id());
......
......@@ -6,22 +6,21 @@
namespace oneflow {
class CopyCommNetActor final : public Actor {
public:
public:
OF_DISALLOW_COPY_AND_MOVE(CopyCommNetActor);
CopyCommNetActor() = default;
~CopyCommNetActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
private:
private:
int HandleCopyCommNet(const ActorMsg&);
int HandleCopyCommNetWhenNoReadableRegstMsg(const ActorMsg&);
void TryWardKernelAndSendMsg();
HashMap<int64_t, std::shared_ptr<RegstWarpper>> piece_id2waiting_in_regst_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COPY_COMM_NET_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_COPY_COMM_NET_ACTOR_H_
......@@ -4,12 +4,12 @@
namespace oneflow {
void CopyHdActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void CopyHdActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
Actor::Init(task_proto, thread_ctx);
CHECK(thread_ctx.copy_hd_cuda_stream);
mut_device_ctx().reset(new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream,
nullptr,
nullptr));
mut_device_ctx().reset(
new CudaDeviceCtx(thread_ctx.copy_hd_cuda_stream, nullptr, nullptr));
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHd);
}
......@@ -18,7 +18,8 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
waiting_in_regst_.push(msg.regst_warpper());
}
}
......@@ -27,7 +28,8 @@ int CopyHdActor::HandleCopyHd(const ActorMsg& msg) {
}
int CopyHdActor::HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (waiting_in_regst_.empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
......@@ -46,16 +48,17 @@ void CopyHdActor::TryWardKernelAndSendMsg() {
if (!waiting_in_regst_.empty() && IsWriteReady()) {
std::shared_ptr<RegstWarpper> regst_wp = waiting_in_regst_.front();
CHECK_EQ(regst_wp->piece_id(), expected_piece_id());
AsyncWardKernel(GenDefaultKernelCtx(),
AsyncWardKernel(
GenDefaultKernelCtx(),
[this](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
CHECK_EQ(regst_desc_id, waiting_in_regst_.front()->regst_desc_id());
return waiting_in_regst_.front();
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
CHECK_EQ(regst_desc_id, waiting_in_regst_.front()->regst_desc_id());
return waiting_in_regst_.front();
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst([&regst_wp](Regst* regst) {
regst->set_piece_id(regst_wp->piece_id());
regst->set_model_version_id(regst_wp->model_version_id());
......
......@@ -6,22 +6,21 @@
namespace oneflow {
class CopyHdActor final : public Actor {
public:
public:
OF_DISALLOW_COPY_AND_MOVE(CopyHdActor);
CopyHdActor() = default;
~CopyHdActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
private:
private:
int HandleCopyHd(const ActorMsg&);
int HandleCopyHdWhenNoReadableRegstMsg(const ActorMsg&);
void TryWardKernelAndSendMsg();
std::queue<std::shared_ptr<RegstWarpper>> waiting_in_regst_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_COPY_HD_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_COPY_HD_ACTOR_H_
......@@ -10,10 +10,8 @@ class CpuDeviceCtx final : public DeviceCtx {
// OF_DISALLOW_COPY_AND_MOVE(CpuDeviceCtx);
CpuDeviceCtx() = delete;
~CpuDeviceCtx() = default;
CpuDeviceCtx(Channel<std::function<void()>>* chan) {
set_cpu_stream(chan);
}
CpuDeviceCtx(Channel<std::function<void()>>* chan) { set_cpu_stream(chan); }
void AddCallBack(std::function<void()> callback) const override {
cpu_stream()->Send(callback);
......@@ -22,6 +20,6 @@ class CpuDeviceCtx final : public DeviceCtx {
private:
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_CPU_DEVICE_CONTEXT_H_
#endif // ONEFLOW_CORE_ACTOR_CPU_DEVICE_CONTEXT_H_
......@@ -4,23 +4,21 @@ namespace oneflow {
namespace {
void CUDART_CB CudaCallBackHandle(cudaStream_t,
cudaError_t status,
void CUDART_CB CudaCallBackHandle(cudaStream_t, cudaError_t status,
void* void_ptr) {
CHECK_EQ(status, cudaSuccess);
auto callback_ptr = static_cast<std::function<void()>*> (void_ptr);
auto callback_ptr = static_cast<std::function<void()>*>(void_ptr);
(*callback_ptr)();
delete callback_ptr;
}
} // namespace
} // namespace
void CudaDeviceCtx::AddCallBack(std::function<void()> callback_stack) const {
auto callback_heap = new std::function<void()> (callback_stack);
CHECK_EQ(cudaStreamAddCallback(cuda_stream(),
&CudaCallBackHandle,
auto callback_heap = new std::function<void()>(callback_stack);
CHECK_EQ(cudaStreamAddCallback(cuda_stream(), &CudaCallBackHandle,
callback_heap, 0),
cudaSuccess);
}
} // namespace oneflow
} // namespace oneflow
......@@ -24,6 +24,6 @@ class CudaDeviceCtx final : public DeviceCtx {
private:
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_CUDA_DEVICE_CONTEXT_H_
#endif // ONEFLOW_CORE_ACTOR_CUDA_DEVICE_CONTEXT_H_
#ifndef ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#define ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/channel.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -19,32 +19,26 @@ class DeviceCtx {
virtual void AddCallBack(std::function<void()>) const = 0;
protected:
DeviceCtx() : cpu_stream_(nullptr),
cuda_stream_(nullptr),
cublas_handle_(nullptr),
cudnn_handle_(nullptr) {}
DeviceCtx()
: cpu_stream_(nullptr),
cuda_stream_(nullptr),
cublas_handle_(nullptr),
cudnn_handle_(nullptr) {}
void set_cpu_stream(Channel<std::function<void()>>* val) {
cpu_stream_ = val;
}
void set_cuda_stream(const cudaStream_t* val) {
cuda_stream_ = val;
}
void set_cublas_handle(const cublasHandle_t* val) {
cublas_handle_ = val;
}
void set_cudnn_handle(const cudnnHandle_t* val) {
cudnn_handle_ = val;
}
void set_cuda_stream(const cudaStream_t* val) { cuda_stream_ = val; }
void set_cublas_handle(const cublasHandle_t* val) { cublas_handle_ = val; }
void set_cudnn_handle(const cudnnHandle_t* val) { cudnn_handle_ = val; }
private:
Channel<std::function<void()>>* cpu_stream_;
const cudaStream_t* cuda_stream_;
const cublasHandle_t* cublas_handle_;
const cudnnHandle_t* cudnn_handle_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
#endif // ONEFLOW_CORE_ACTOR_DEVICE_CONTEXT_H_
......@@ -25,24 +25,24 @@ void FwDataCompActor::Init(const TaskProto& task_proto,
kernel_ctx_.other = reinterpret_cast<void*>(parallel_id());
OF_SET_MSG_HANDLE(&FwDataCompActor::WaitToStart);
} else {
num_of_not_eord_ = 1 + (model_regst_desc_id_ != -1)
+ (model_tmp_regst_desc_id_ != -1);
num_of_not_eord_ =
1 + (model_regst_desc_id_ != -1) + (model_tmp_regst_desc_id_ != -1);
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwComp);
}
}
bool FwDataCompActor::IsReadReady() {
if (in_desc_id_ == -1) {
return true;
}
if (in_desc_id_ == -1) { return true; }
if (in_.empty() || (model_regst_desc_id_ != -1 && !model_regst_)
|| (model_tmp_regst_desc_id_ != -1 && !model_tmp_regst_)) {
|| (model_tmp_regst_desc_id_ != -1 && !model_tmp_regst_)) {
return false;
}
if (model_regst_desc_id_ != -1) {
//Ho Q, Cipar J, Cui H, et al. More effective distributed ml via a stale synchronous parallel parameter server
// Ho Q, Cipar J, Cui H, et al. More effective distributed ml via a stale
// synchronous parallel parameter server
int32_t staleness = JobDesc::Singleton().staleness();
int32_t num_of_piece_in_batch = JobDesc::Singleton().num_of_piece_in_batch();
int32_t num_of_piece_in_batch =
JobDesc::Singleton().num_of_piece_in_batch();
int64_t cur_iteration = in_.front()->piece_id() / num_of_piece_in_batch;
int64_t stale_version = cur_iteration - staleness;
return model_regst_->model_version_id() >= stale_version;
......@@ -65,7 +65,8 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
OF_SET_MSG_HANDLE(&FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg);
}
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
std::shared_ptr<RegstWarpper> regst_wp = msg.regst_warpper();
if (regst_wp->regst_desc_id() == model_tmp_regst_desc_id_) {
CHECK(!model_tmp_regst_);
......@@ -73,9 +74,7 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
ready_in_regst_[model_tmp_regst_desc_id_] = regst_wp;
} else if (regst_wp->regst_desc_id() == model_regst_desc_id_) {
CHECK_EQ(regst_wp->model_version_id(), expected_model_version_id_);
if (model_regst_) {
AsyncSendRegstMsgToProducer(model_regst_);
}
if (model_regst_) { AsyncSendRegstMsgToProducer(model_regst_); }
model_regst_ = regst_wp;
ready_in_regst_[model_regst_desc_id_] = regst_wp;
expected_model_version_id_ += 1;
......@@ -89,10 +88,12 @@ int FwDataCompActor::HandleFwComp(const ActorMsg& msg) {
}
int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
int total_piece_num = JobDesc::Singleton().total_piece_num();
if ((in_desc_id_!=-1 && in_.empty()) || expected_piece_id() == total_piece_num) {
if ((in_desc_id_ != -1 && in_.empty())
|| expected_piece_id() == total_piece_num) {
if (model_regst_desc_id_ != -1) {
AsyncSendRegstMsgToProducer(model_regst_);
model_regst_ = nullptr;
......@@ -112,7 +113,7 @@ int FwDataCompActor::HandleFwCompWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
return 0;
}
void FwDataCompActor::TryWardKernelAndSendMsg() {
while (IsReadReady() && IsWriteReady()) {
int64_t piece_id = expected_piece_id();
......@@ -121,18 +122,17 @@ void FwDataCompActor::TryWardKernelAndSendMsg() {
ready_in_regst_[in_.front()->regst_desc_id()] = in_.front();
}
int64_t model_version_id = -1;
if (model_regst_) {
model_version_id = model_regst_->model_version_id();
}
AsyncWardKernel(kernel_ctx_,
if (model_regst_) { model_version_id = model_regst_->model_version_id(); }
AsyncWardKernel(
kernel_ctx_,
[this](int64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return ready_in_regst_.at(regst_desc_id);
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
return ready_in_regst_.at(regst_desc_id);
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst([piece_id, model_version_id](Regst* regst) {
regst->set_piece_id(piece_id);
regst->set_model_version_id(model_version_id);
......
......@@ -6,14 +6,14 @@
namespace oneflow {
class FwDataCompActor final : public CompActor {
public:
public:
OF_DISALLOW_COPY_AND_MOVE(FwDataCompActor);
FwDataCompActor() = default;
~FwDataCompActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
private:
private:
int WaitToStart(const ActorMsg&);
int HandleFwComp(const ActorMsg&);
int HandleFwCompWhenNoReadableRegstMsg(const ActorMsg&);
......@@ -36,4 +36,4 @@ private:
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_FW_DATA_COMP_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_FW_DATA_COMP_ACTOR_H_
......@@ -15,9 +15,8 @@ void MdDiffAccActor::Init(const TaskProto& task_proto,
cuda_handle_.cudnn_handle()));
}
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAcc);
ForEachCurWriteableRegst([this](Regst* regst) {
model_diff_acc_cnt_[regst] = 0;
});
ForEachCurWriteableRegst(
[this](Regst* regst) { model_diff_acc_cnt_[regst] = 0; });
}
int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
......@@ -25,7 +24,8 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
CHECK_EQ(msg.actor_cmd(), ActorCmd::kEORD);
OF_SET_MSG_HANDLE(&MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg);
} else if (msg.msg_type() == ActorMsgType::kRegstMsg) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()) != 0) {
if (TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr())
!= 0) {
waiting_in_regst_.push(msg.regst_warpper());
}
}
......@@ -34,7 +34,8 @@ int MdDiffAccActor::HandleMdDiffAcc(const ActorMsg& msg) {
}
int MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg& msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(TryUpdtStateAsProducedRegst(msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (waiting_in_regst_.empty()) {
AsyncSendEORDMsgForAllProducedRegstDesc();
......@@ -50,9 +51,7 @@ int MdDiffAccActor::HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg& msg) {
}
void MdDiffAccActor::TryWardKernelAndSendMsg() {
if (waiting_in_regst_.empty() || !IsWriteReady()) {
return;
}
if (waiting_in_regst_.empty() || !IsWriteReady()) { return; }
std::shared_ptr<RegstWarpper> regst_wp = waiting_in_regst_.front();
CHECK_EQ(regst_wp->piece_id(), expected_piece_id());
KernelCtx ctx = GenDefaultKernelCtx();
......@@ -67,15 +66,16 @@ void MdDiffAccActor::TryWardKernelAndSendMsg() {
});
diff_cnt->second = 0;
});
AsyncWardKernel(ctx, [this](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
CHECK_EQ(regst_desc_id, waiting_in_regst_.front()->regst_desc_id());
return waiting_in_regst_.front();
} else {
return std::make_shared<LocalRegstWarpper> (regst);
}
});
AsyncWardKernel(
ctx, [this](uint64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
Regst* regst = GetCurWriteableRegst(regst_desc_id);
if (regst == nullptr) {
CHECK_EQ(regst_desc_id, waiting_in_regst_.front()->regst_desc_id());
return waiting_in_regst_.front();
} else {
return std::make_shared<LocalRegstWarpper>(regst);
}
});
ForEachCurWriteableRegst([this, &regst_wp](Regst* regst) {
regst->set_piece_id(regst_wp->piece_id());
++model_diff_acc_cnt_.at(regst);
......
......@@ -6,14 +6,14 @@
namespace oneflow {
class MdDiffAccActor final : public CompActor {
public:
public:
OF_DISALLOW_COPY_AND_MOVE(MdDiffAccActor);
MdDiffAccActor() = default;
~MdDiffAccActor() = default;
void Init(const TaskProto&, const ThreadCtx&) override;
private:
private:
int HandleMdDiffAcc(const ActorMsg&);
int HandleMdDiffAccWhenNoReadableRegstMsg(const ActorMsg&);
......
......@@ -3,7 +3,8 @@
namespace oneflow {
void MdSaveCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void MdSaveCompActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
CompActor::Init(task_proto, thread_ctx);
model_regst_desc_id_ = RegstDescId4Name("model");
CHECK(thread_ctx.cpu_stream);
......@@ -18,29 +19,27 @@ int MdSaveCompActor::HandleSaveModel(const ActorMsg& actor_msg) {
} else if (actor_msg.msg_type() == ActorMsgType::kRegstMsg) {
std::shared_ptr<RegstWarpper> regst_warpper = actor_msg.regst_warpper();
int64_t model_version_id = regst_warpper->model_version_id();
int32_t num_of_batches_in_snapshot =
int32_t num_of_batches_in_snapshot =
JobDesc::Singleton().num_of_batches_in_snapshot();
CHECK_GT(num_of_batches_in_snapshot, 0);
if (model_version_id % num_of_batches_in_snapshot == 0) {
int64_t snapshot_id = model_version_id / num_of_batches_in_snapshot;
Snapshot* snapshot = SnapshotMgr::Singleton().GetWriteableSnapshot(snapshot_id);
Snapshot* snapshot =
SnapshotMgr::Singleton().GetWriteableSnapshot(snapshot_id);
KernelCtx kernel_ctx = GenDefaultKernelCtx();
std::tuple<Snapshot*, int64_t> save_ctx = std::make_tuple(snapshot,
parallel_id());
std::tuple<Snapshot*, int64_t> save_ctx =
std::make_tuple(snapshot, parallel_id());
kernel_ctx.other = &save_ctx;
AsyncWardKernel(
kernel_ctx,
[&](int64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
CHECK_EQ(regst_desc_id, model_regst_desc_id_);
return regst_warpper;
});
CHECK_EQ(regst_desc_id, model_regst_desc_id_);
return regst_warpper;
});
}
ActorMsg msg = ActorMsg::BuildRegstMsgToProducer(
regst_warpper->producer_actor_id(),
regst_warpper->regst_raw_ptr());
AsyncDo([msg]() {
ActorMsgBus::Singleton().SendMsg(msg);
});
regst_warpper->producer_actor_id(), regst_warpper->regst_raw_ptr());
AsyncDo([msg]() { ActorMsgBus::Singleton().SendMsg(msg); });
} else {
UNEXPECTED_RUN();
}
......
......@@ -21,4 +21,4 @@ class MdSaveCompActor final : public CompActor {
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_MODEL_SAVE_COMP_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_MODEL_SAVE_COMP_ACTOR_H_
......@@ -4,7 +4,8 @@
namespace oneflow {
void MdUpdtCompActor::Init(const TaskProto& task_proto, const ThreadCtx& thread_ctx) {
void MdUpdtCompActor::Init(const TaskProto& task_proto,
const ThreadCtx& thread_ctx) {
CompActor::Init(task_proto, thread_ctx);
model_regst_desc_id_ = RegstDescId4Name("model");
model_tmp_regst_desc_id_ = RegstDescId4Name("model_tmp");
......@@ -31,25 +32,20 @@ int MdUpdtCompActor::HandleBeforeInitializeModel(const ActorMsg& actor_msg) {
};
model_regst->ForEachLbn(CollectKernelsFromLbn);
model_tmp_regst->ForEachLbn(CollectKernelsFromLbn);
for (const Kernel* kernel : kernels) {
kernel->InitModelAndModelTmpBlobs(
GenDefaultKernelCtx(),
parallel_policy(),
parallel_id(),
parallel_num(),
GenDefaultKernelCtx(), parallel_policy(), parallel_id(), parallel_num(),
SnapshotMgr::Singleton().GetReadableSnapshot(),
[&](const std::string& bn_in_op) {
const std::string& lbn = kernel->Lbn4BnInOp(bn_in_op);
Blob* ret = model_regst->GetBlobPtrFromLbn(lbn);
if (ret == nullptr) { ret = model_tmp_regst->GetBlobPtrFromLbn(lbn); }
CHECK(ret != nullptr);
return ret;
});
const std::string& lbn = kernel->Lbn4BnInOp(bn_in_op);
Blob* ret = model_regst->GetBlobPtrFromLbn(lbn);
if (ret == nullptr) { ret = model_tmp_regst->GetBlobPtrFromLbn(lbn); }
CHECK(ret != nullptr);
return ret;
});
}
AsyncDo([]() {
RuntimeCtx::Singleton().OneModelInitDone();
});
AsyncDo([]() { RuntimeCtx::Singleton().OneModelInitDone(); });
OF_SET_MSG_HANDLE(&MdUpdtCompActor::HandleBeforeSendInitialModel);
return 0;
}
......@@ -86,8 +82,9 @@ int MdUpdtCompActor::HandleUpdateModel(const ActorMsg& actor_msg) {
int MdUpdtCompActor::HandleUpdtModelWhenNoReadableRegstMsg(
const ActorMsg& actor_msg) {
CHECK_EQ(TryUpdtStateAsProducedRegst(
actor_msg.regst_warpper()->regst_raw_ptr()), 0);
CHECK_EQ(
TryUpdtStateAsProducedRegst(actor_msg.regst_warpper()->regst_raw_ptr()),
0);
TryWardKernelAndSendMsg();
if (waiting_model_diff_acc_queue_.empty()) {
AsyncSendEORDMsgToSubscribers(model_regst_desc_id_);
......@@ -109,14 +106,15 @@ void MdUpdtCompActor::TryWardKernelAndSendMsg() {
Regst* model_regst = GetCurWriteableRegst(model_regst_desc_id_);
auto model_wpr = std::make_shared<LocalRegstWarpper>(model_regst);
model_regst->set_model_version_id(next_model_version_id_++);
AsyncWardKernel(GenDefaultKernelCtx(),
AsyncWardKernel(
GenDefaultKernelCtx(),
[&](int64_t regst_desc_id) -> std::shared_ptr<RegstWarpper> {
if (regst_desc_id == model_regst_desc_id_) {
return model_wpr;
} else {
return model_diff_acc_wpr;
}
});
if (regst_desc_id == model_regst_desc_id_) {
return model_wpr;
} else {
return model_diff_acc_wpr;
}
});
AsyncSendReadableRegstMsg();
AsyncSendRegstMsgToProducer(model_diff_acc_wpr);
}
......
......@@ -27,9 +27,8 @@ class MdUpdtCompActor final : public CompActor {
int64_t model_tmp_regst_desc_id_;
std::queue<std::shared_ptr<RegstWarpper>> waiting_model_diff_acc_queue_;
int64_t next_model_version_id_;
};
} // namespace oneflow
#endif // ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMP_ACTOR_H_
#endif // ONEFLOW_CORE_ACTOR_MODEL_UPDATE_COMP_ACTOR_H_
此差异已折叠。
......@@ -5,116 +5,112 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template<>
float cblas_dot<float>(
const int n, const float* x, const int incx,
const float* y, const int incy) {
float cblas_dot<float>(const int n, const float* x, const int incx,
const float* y, const int incy) {
return cblas_sdot(n, x, incx, y, incy);
}
template<>
double cblas_dot<double>(
const int n, const double* x, const int incx,
const double* y, const int incy) {
double cblas_dot<double>(const int n, const double* x, const int incx,
const double* y, const int incy) {
return cblas_ddot(n, x, incx, y, incy);
}
// swap x and y
template<>
void cblas_swap<float>(
const int n, float* x, const int incx, float* y, const int incy) {
void cblas_swap<float>(const int n, float* x, const int incx, float* y,
const int incy) {
cblas_sswap(n, x, incx, y, incy);
}
template<>
void cblas_swap<double>(
const int n, double* x, const int incx, double* y, const int incy) {
void cblas_swap<double>(const int n, double* x, const int incx, double* y,
const int incy) {
cblas_dswap(n, x, incx, y, incy);
}
// copy x into y
template<>
void cblas_copy<float>(
const int n, const float* x, const int incx, float* y, const int incy) {
void cblas_copy<float>(const int n, const float* x, const int incx, float* y,
const int incy) {
cblas_scopy(n, x, incx, y, incy);
}
template<>
void cblas_copy<double>(
const int n, const double* x, const int incx, double* y, const int incy) {
void cblas_copy<double>(const int n, const double* x, const int incx, double* y,
const int incy) {
cblas_dcopy(n, x, incx, y, incy);
}
// y = a*x + y
template<>
void cblas_axpy<float>(
const int n, const float alpha, const float* x, const int incx,
float* y, const int incy) {
void cblas_axpy<float>(const int n, const float alpha, const float* x,
const int incx, float* y, const int incy) {
cblas_saxpy(n, alpha, x, incx, y, incy);
}
template<>
void cblas_axpy<double>(
const int n, const double alpha, const double* x, const int incx,
double* y, const int incy) {
void cblas_axpy<double>(const int n, const double alpha, const double* x,
const int incx, double* y, const int incy) {
cblas_daxpy(n, alpha, x, incx, y, incy);
}
// x = a*x
template<>
void cblas_scal<float>(
const int n, const float alpha, float* x, const int incx) {
void cblas_scal<float>(const int n, const float alpha, float* x,
const int incx) {
cblas_sscal(n, alpha, x, incx);
}
template<>
void cblas_scal<double>(
const int n, const double alpha, double* x, const int incx) {
void cblas_scal<double>(const int n, const double alpha, double* x,
const int incx) {
cblas_dscal(n, alpha, x, incx);
}
// level 2 matrix and vector
// matrix vector multiply
template<>
void cblas_gemv<float>(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const int m, const int n, const float alpha,
const float* a, const int lda,
const float* x, const int incx, const float beta,
float* y, const int incy) {
void cblas_gemv<float>(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const int m,
const int n, const float alpha, const float* a,
const int lda, const float* x, const int incx,
const float beta, float* y, const int incy) {
cblas_sgemv(order, trans_a, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
template<>
void cblas_gemv<double>(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const int m, const int n, const double alpha,
const double* a, const int lda,
const double* x, const int incx, const double beta,
double* y, const int incy) {
void cblas_gemv<double>(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const int m,
const int n, const double alpha, const double* a,
const int lda, const double* x, const int incx,
const double beta, double* y, const int incy) {
cblas_dgemv(order, trans_a, m, n, alpha, a, lda, x, incx, beta, y, incy);
}
// matrix matrix multiply
template<>
void cblas_gemm<float>(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const float alpha, const float* a, const int lda,
const float* b, const int ldb, const float beta,
float* c, const int ldc) {
cblas_sgemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb,
beta, c, ldc);
void cblas_gemm<float>(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m,
const int n, const int k, const float alpha,
const float* a, const int lda, const float* b,
const int ldb, const float beta, float* c,
const int ldc) {
cblas_sgemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
}
template<>
void cblas_gemm<double>(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const double alpha, const double* a, const int lda,
const double* b, const int ldb, const double beta,
double* c, const int ldc) {
cblas_dgemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb,
beta, c, ldc);
void cblas_gemm<double>(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m,
const int n, const int k, const double alpha,
const double* a, const int lda, const double* b,
const int ldb, const double beta, double* c,
const int ldc) {
cblas_dgemm(order, trans_a, trans_b, m, n, k, alpha, a, lda, b, ldb, beta, c,
ldc);
}
} // namespace oneflow
......@@ -10,59 +10,52 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template<typename FloatingPointType>
FloatingPointType cblas_dot(
const int n,
const FloatingPointType* x, const int incx,
const FloatingPointType* y, const int incy);
FloatingPointType cblas_dot(const int n, const FloatingPointType* x,
const int incx, const FloatingPointType* y,
const int incy);
// swap x and y
template<typename FloatingPointType>
void cblas_swap(
const int n,
FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
void cblas_swap(const int n, FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
// copy x into y
template<typename FloatingPointType>
void cblas_copy(
const int n,
const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
void cblas_copy(const int n, const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
// y = a*x + y
template<typename FloatingPointType>
void cblas_axpy(
const int n,
const FloatingPointType alpha,
const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
void cblas_axpy(const int n, const FloatingPointType alpha,
const FloatingPointType* x, const int incx,
FloatingPointType* y, const int incy);
// x = a*x
template<typename FloatingPointType>
void cblas_scal(
const int n,
const FloatingPointType alpha,
FloatingPointType* x, const int incx);
void cblas_scal(const int n, const FloatingPointType alpha,
FloatingPointType* x, const int incx);
// level 2 matrix and vector
// matrix vector multiply
template<typename FloatingPointType>
void cblas_gemv(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const int m, const int n, const FloatingPointType alpha,
const FloatingPointType* a, const int lda,
const FloatingPointType* x, const int incx,
const FloatingPointType beta, FloatingPointType* y, const int incy);
void cblas_gemv(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a, const int m, const int n,
const FloatingPointType alpha, const FloatingPointType* a,
const int lda, const FloatingPointType* x, const int incx,
const FloatingPointType beta, FloatingPointType* y,
const int incy);
// level 3 matrix and matrix
// matrix matrix multiply
template<typename FloatingPointType>
void cblas_gemm(
const enum CBLAS_ORDER order, const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m, const int n, const int k,
const FloatingPointType alpha, const FloatingPointType* a,
const int lda, const FloatingPointType* b, const int ldb,
const FloatingPointType beta, FloatingPointType* c, const int ldc);
void cblas_gemm(const enum CBLAS_ORDER order,
const enum CBLAS_TRANSPOSE trans_a,
const enum CBLAS_TRANSPOSE trans_b, const int m, const int n,
const int k, const FloatingPointType alpha,
const FloatingPointType* a, const int lda,
const FloatingPointType* b, const int ldb,
const FloatingPointType beta, FloatingPointType* c,
const int ldc);
} // namespace oneflow
#endif // ONEFLOW_CORE_BLAS_CBLAS_TEMPLATE_H_
......@@ -5,125 +5,117 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template<>
void cublas_dot<float>(
cublasHandle_t handle, int n, const float* x, int incx, const float* y,
int incy, float* result) {
void cublas_dot<float>(cublasHandle_t handle, int n, const float* x, int incx,
const float* y, int incy, float* result) {
CHECK_EQ(cublasSdot(handle, n, x, incx, y, incy, result),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_dot<double>(
cublasHandle_t handle, int n, const double* x, int incx, const double* y,
int incy, double* result) {
void cublas_dot<double>(cublasHandle_t handle, int n, const double* x, int incx,
const double* y, int incy, double* result) {
CHECK_EQ(cublasDdot(handle, n, x, incx, y, incy, result),
CUBLAS_STATUS_SUCCESS);
}
// swap x and y
template<>
void cublas_swap<float>(
cublasHandle_t handle, int n, float* x, int incx, float* y, int incy) {
void cublas_swap<float>(cublasHandle_t handle, int n, float* x, int incx,
float* y, int incy) {
CHECK_EQ(cublasSswap(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_swap<double>(
cublasHandle_t handle, int n, double* x, int incx, double* y, int incy) {
void cublas_swap<double>(cublasHandle_t handle, int n, double* x, int incx,
double* y, int incy) {
CHECK_EQ(cublasDswap(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
// copy x into y
template<>
void cublas_copy<float>(
cublasHandle_t handle, int n, const float* x, int incx,
float* y, int incy) {
void cublas_copy<float>(cublasHandle_t handle, int n, const float* x, int incx,
float* y, int incy) {
CHECK_EQ(cublasScopy(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_copy<double>(
cublasHandle_t handle, int n, const double* x, int incx,
double* y, int incy) {
void cublas_copy<double>(cublasHandle_t handle, int n, const double* x,
int incx, double* y, int incy) {
CHECK_EQ(cublasDcopy(handle, n, x, incx, y, incy), CUBLAS_STATUS_SUCCESS);
}
// y = a*x + y
template<>
void cublas_axpy<float>(
cublasHandle_t handle, int n, const float* alpha, const float* x,
const int incx, float* y, const int incy) {
void cublas_axpy<float>(cublasHandle_t handle, int n, const float* alpha,
const float* x, const int incx, float* y,
const int incy) {
CHECK_EQ(cublasSaxpy(handle, n, alpha, x, incx, y, incy),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_axpy<double>(
cublasHandle_t handle, int n, const double* alpha,
const double* x, const int incx,
double* y, int incy) {
void cublas_axpy<double>(cublasHandle_t handle, int n, const double* alpha,
const double* x, const int incx, double* y, int incy) {
CHECK_EQ(cublasDaxpy(handle, n, alpha, x, incx, y, incy),
CUBLAS_STATUS_SUCCESS);
}
// x = a*x
template<>
void cublas_scal<float>(
cublasHandle_t handle, int n, const float* alpha, float* x, int incx) {
void cublas_scal<float>(cublasHandle_t handle, int n, const float* alpha,
float* x, int incx) {
CHECK_EQ(cublasSscal(handle, n, alpha, x, incx), CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_scal<double>(
cublasHandle_t handle, int n, const double* alpha, double* x, int incx) {
void cublas_scal<double>(cublasHandle_t handle, int n, const double* alpha,
double* x, int incx) {
CHECK_EQ(cublasDscal(handle, n, alpha, x, incx), CUBLAS_STATUS_SUCCESS);
}
// level 2 matrix and vector
// matrix vector multiply
template<>
void cublas_gemv<float>(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const float* alpha, const float* a, int lda, const float* x, int incx,
const float* beta, float* y, int incy) {
CHECK_EQ(cublasSgemv(
handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
void cublas_gemv<float>(cublasHandle_t handle, cublasOperation_t trans, int m,
int n, const float* alpha, const float* a, int lda,
const float* x, int incx, const float* beta, float* y,
int incy) {
CHECK_EQ(
cublasSgemv(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_gemv<double>(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const double* alpha, const double* a, int lda, const double* x, int incx,
const double* beta, double* y, int incy) {
CHECK_EQ(cublasDgemv(
handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
void cublas_gemv<double>(cublasHandle_t handle, cublasOperation_t trans, int m,
int n, const double* alpha, const double* a, int lda,
const double* x, int incx, const double* beta,
double* y, int incy) {
CHECK_EQ(
cublasDgemv(handle, trans, m, n, alpha, a, lda, x, incx, beta, y, incy),
CUBLAS_STATUS_SUCCESS);
}
// level 3 matrix and matrix
// matrix matrix multiply
template<>
void cublas_gemm<float>(
cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const float* alpha, const float* a, int lda,
const float* b, int ldb, const float* beta, float* c, int ldc) {
CHECK_EQ(cublasSgemm(
handle, cutrans_a, cutrans_b, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc),
void cublas_gemm<float>(cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const float* alpha, const float* a, int lda,
const float* b, int ldb, const float* beta, float* c,
int ldc) {
CHECK_EQ(cublasSgemm(handle, cutrans_a, cutrans_b, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc),
CUBLAS_STATUS_SUCCESS);
}
template<>
void cublas_gemm<double>(
cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const double* alpha, const double* a, int lda,
const double* b, int ldb, const double* beta, double* c, int ldc) {
CHECK_EQ(cublasDgemm(
handle, cutrans_a, cutrans_b, m, n, k, alpha, a, lda, b, ldb, beta,
c, ldc),
void cublas_gemm<double>(cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const double* alpha, const double* a, int lda,
const double* b, int ldb, const double* beta,
double* c, int ldc) {
CHECK_EQ(cublasDgemm(handle, cutrans_a, cutrans_b, m, n, k, alpha, a, lda, b,
ldb, beta, c, ldc),
CUBLAS_STATUS_SUCCESS);
}
......
......@@ -8,57 +8,48 @@ namespace oneflow {
// level 1 vector and vector
// dot product
template<typename FloatingPointType>
void cublas_dot(
cublasHandle_t handle, int n,
const FloatingPointType* x, int incx,
const FloatingPointType* y, int incy, FloatingPointType* result);
void cublas_dot(cublasHandle_t handle, int n, const FloatingPointType* x,
int incx, const FloatingPointType* y, int incy,
FloatingPointType* result);
// swap x and y
template<typename FloatingPointType>
void cublas_swap(
cublasHandle_t handle, int n,
FloatingPointType* x, int incx, FloatingPointType* y, int incy);
void cublas_swap(cublasHandle_t handle, int n, FloatingPointType* x, int incx,
FloatingPointType* y, int incy);
// copy x into y
template<typename FloatingPointType>
void cublas_copy(
cublasHandle_t handle, int n,
const FloatingPointType* x, int incx,
FloatingPointType* y, int incy);
void cublas_copy(cublasHandle_t handle, int n, const FloatingPointType* x,
int incx, FloatingPointType* y, int incy);
// y = a*x + y
template<typename FloatingPointType>
void cublas_axpy(
cublasHandle_t handle, int n,
const FloatingPointType* alpha,
const FloatingPointType* x, int incx,
FloatingPointType* y, int incy);
void cublas_axpy(cublasHandle_t handle, int n, const FloatingPointType* alpha,
const FloatingPointType* x, int incx, FloatingPointType* y,
int incy);
// x = a*x
template<typename FloatingPointType>
void cublas_scal(
cublasHandle_t handle, int n,
const FloatingPointType* alpha, FloatingPointType* x, int incx);
void cublas_scal(cublasHandle_t handle, int n, const FloatingPointType* alpha,
FloatingPointType* x, int incx);
// level 2 matrix and vector
// matrix vector multiply
template<typename FloatingPointType>
void cublas_gemv(
cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const FloatingPointType* alpha, const FloatingPointType* a, int lda,
const FloatingPointType* x, int incx, const FloatingPointType* beta,
FloatingPointType* y, int incy);
void cublas_gemv(cublasHandle_t handle, cublasOperation_t trans, int m, int n,
const FloatingPointType* alpha, const FloatingPointType* a,
int lda, const FloatingPointType* x, int incx,
const FloatingPointType* beta, FloatingPointType* y, int incy);
// level 3 matrix and matrix
// matrix matrix multiply
template<typename FloatingPointType>
void cublas_gemm(
cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const FloatingPointType* alpha, const FloatingPointType* a, int lda,
const FloatingPointType* b, int ldb,
const FloatingPointType* beta, FloatingPointType* c, int ldc);
void cublas_gemm(cublasHandle_t handle, cublasOperation_t cutrans_a,
cublasOperation_t cutrans_b, int m, int n, int k,
const FloatingPointType* alpha, const FloatingPointType* a,
int lda, const FloatingPointType* b, int ldb,
const FloatingPointType* beta, FloatingPointType* c, int ldc);
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_BLAS_CUBLAS_TEMPLATE_H_
#endif // ONEFLOW_CORE_BLAS_CUBLAS_TEMPLATE_H_
......@@ -17,10 +17,10 @@ Range BalancedSplitter::At(int64_t idx) const {
upper_pound_num = lower_pound_num + (size_per_range_ + 1);
} else {
lower_pound_num = (size_per_range_ + 1) * change_pos_
+ size_per_range_ * (idx - change_pos_);
+ size_per_range_ * (idx - change_pos_);
upper_pound_num = lower_pound_num + size_per_range_;
}
return Range(lower_pound_num, upper_pound_num);
}
} // namespace oneflow
} // namespace oneflow
......@@ -2,8 +2,8 @@
#define ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#include <stdint.h>
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/range.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -27,11 +27,11 @@ class BalancedSplitter final {
Range At(int64_t idx) const;
private:
int64_t size_per_range_;
int64_t change_pos_;
int64_t split_num_;
int64_t size_per_range_;
int64_t change_pos_;
int64_t split_num_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
#endif // ONEFLOW_CORE_COMMON_BALANCED_SPLITTER_H_
......@@ -19,4 +19,4 @@ TEST(BalancedSplitter, split_2_to_3_part) {
ASSERT_TRUE(splitter.At(2) == Range(2, 2));
}
} // namespace oneflow
} // namespace oneflow
......@@ -26,7 +26,8 @@ class Channel final {
// close the channel's send end, the thread can't send item to the channel
void CloseSendEnd();
// close the channel's receive end , the thread can't receive item from channel
// close the channel's receive end , the thread can't receive item from
// channel
void CloseReceiveEnd();
private:
......@@ -40,9 +41,7 @@ class Channel final {
template<typename T>
int Channel<T>::Send(const T& item) {
std::unique_lock<std::mutex> lock(mutex_);
if (is_send_closed_) {
return -1;
}
if (is_send_closed_) { return -1; }
val_.push(item);
cond_.notify_one();
return 0;
......@@ -51,10 +50,10 @@ int Channel<T>::Send(const T& item) {
template<typename T>
int Channel<T>::Receive(T* item) {
std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this]() { return !val_.empty() || is_receive_closed_ || is_send_closed_; });
if (val_.empty() || is_receive_closed_) {
return -1;
}
cond_.wait(lock, [this]() {
return !val_.empty() || is_receive_closed_ || is_send_closed_;
});
if (val_.empty() || is_receive_closed_) { return -1; }
*item = val_.front();
val_.pop();
return 0;
......@@ -76,4 +75,4 @@ void Channel<T>::CloseReceiveEnd() {
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CHANNEL_H_
#endif // ONEFLOW_CORE_COMMON_CHANNEL_H_
......@@ -5,22 +5,16 @@ namespace oneflow {
void CallFromSenderThread(Channel<int>* channel, Range range) {
for (int i = range.begin(); i < range.end(); ++i) {
if (channel->Send(i) == -1) {
break;
}
if (channel->Send(i) == -1) { break; }
}
}
void CallFromReceiverThread(std::vector<int>* visit,
Channel<int>* channel) {
void CallFromReceiverThread(std::vector<int>* visit, Channel<int>* channel) {
int num = -1;
int* num_ptr = &num;
while (channel->Receive(num_ptr) == 0) {
++visit->at(*num_ptr);
}
while (channel->Receive(num_ptr) == 0) { ++visit->at(*num_ptr); }
}
TEST(Channel, 30sender40receiver) {
Channel<int> channel;
std::vector<std::thread> senders;
......@@ -31,34 +25,24 @@ TEST(Channel, 30sender40receiver) {
std::vector<std::vector<int>> visits;
for (int i = 0; i < receiver_num; ++i) {
std::vector<int> visit_i;
for (int j = 0; j < range_num; j++) {
visit_i.push_back(0);
}
for (int j = 0; j < range_num; j++) { visit_i.push_back(0); }
visits.push_back(visit_i);
}
for (int i = 0; i < sender_num; ++i) {
senders.push_back(std::thread(CallFromSenderThread,
&channel,
Range(0, range_num)));
senders.push_back(
std::thread(CallFromSenderThread, &channel, Range(0, range_num)));
}
for (int i = 0; i < receiver_num; ++i) {
receivers.push_back(std::thread(CallFromReceiverThread,
&visits[i],
&channel));
}
for (std::thread& this_thread : senders) {
this_thread.join();
receivers.push_back(
std::thread(CallFromReceiverThread, &visits[i], &channel));
}
for (std::thread& this_thread : senders) { this_thread.join(); }
channel.CloseSendEnd();
for (std::thread& this_thread : receivers) {
this_thread.join();
}
for (std::thread& this_thread : receivers) { this_thread.join(); }
channel.CloseReceiveEnd();
for (int i = 0; i < range_num; ++i) {
int visit_count = 0;
for (int j = 0; j < receiver_num; j++) {
visit_count += visits[j][i];
}
for (int j = 0; j < receiver_num; j++) { visit_count += visits[j][i]; }
ASSERT_EQ(visit_count, sender_num);
}
}
......
......@@ -29,15 +29,9 @@ const cudnnHandle_t* CudaStreamHandle::cudnn_handle() {
}
CudaStreamHandle::~CudaStreamHandle() {
if (cudnn_handle_) {
CHECK_EQ(cudnnDestroy(*cudnn_handle_), 0);
}
if (cublas_handle_) {
CHECK_EQ(cublasDestroy(*cublas_handle_), 0);
}
if (cuda_stream_) {
CHECK_EQ(cudaStreamDestroy(*cuda_stream_), 0);
}
if (cudnn_handle_) { CHECK_EQ(cudnnDestroy(*cudnn_handle_), 0); }
if (cublas_handle_) { CHECK_EQ(cublasDestroy(*cublas_handle_), 0); }
if (cuda_stream_) { CHECK_EQ(cudaStreamDestroy(*cuda_stream_), 0); }
}
} // namespace oneflow
} // namespace oneflow
......@@ -20,9 +20,8 @@ class CudaStreamHandle final {
std::unique_ptr<cudaStream_t> cuda_stream_;
std::unique_ptr<cublasHandle_t> cublas_handle_;
std::unique_ptr<cudnnHandle_t> cudnn_handle_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_CUDA_STREAM_HANDLE_H_
#endif // ONEFLOW_CORE_COMMON_CUDA_STREAM_HANDLE_H_
......@@ -10,21 +10,18 @@ inline void CudaCheck(cudaError_t error) {
}
// CUDA: grid stride looping
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; \
i < (n); \
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int32_t i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
// CUDA: check for error after kernel execution and exit loudly if there is one.
inline void CudaPostKernelCheck() {
CudaCheck(cudaPeekAtLastError());
}
inline void CudaPostKernelCheck() { CudaCheck(cudaPeekAtLastError()); }
const int32_t kCudaThreadsNumPerBlock = 512;
const int32_t kCudaMaxBlocksNum = 4096;
inline int32_t BlocksNum4ThreadsNum(const int32_t N) {
return std::min((N + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,
return std::min((N + kCudaThreadsNumPerBlock - 1) / kCudaThreadsNumPerBlock,
kCudaMaxBlocksNum);
}
......
......@@ -2,9 +2,9 @@
#define ONEFLOW_CORE_COMMON_PROCESS_STATE_H_
#if defined(_MSC_VER)
#include <WinSock2.h>
#include <direct.h>
#include <stdlib.h>
#include <WinSock2.h>
#pragma comment(lib, "Ws2_32.lib")
#else
#include <unistd.h>
......@@ -32,4 +32,3 @@ std::string GetCwd() {
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_PROCESS_STATE_H_
......@@ -35,8 +35,8 @@ void ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto) {
}
void PrintProtoToTextFile(const PbMessage& proto,
const std::string& file_path) {
std::ofstream out_stream(
file_path.c_str(), std::ofstream::out | std::ofstream::trunc);
std::ofstream out_stream(file_path.c_str(),
std::ofstream::out | std::ofstream::trunc);
// make sure out_stream lives longer than output
{
OstreamOutputStream output(&out_stream);
......@@ -45,15 +45,15 @@ void PrintProtoToTextFile(const PbMessage& proto,
out_stream.close();
}
#define DEFINE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name) { \
const Descriptor* d = msg.GetDescriptor(); \
const FieldDescriptor* fd = d->FindFieldByName(field_name); \
CHECK_NOTNULL(fd); \
const Reflection* r = msg.GetReflection(); \
return r->Get##func_name (msg, fd); \
}
#define DEFINE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name) { \
const Descriptor* d = msg.GetDescriptor(); \
const FieldDescriptor* fd = d->FindFieldByName(field_name); \
CHECK_NOTNULL(fd); \
const Reflection* r = msg.GetReflection(); \
return r->Get##func_name(msg, fd); \
}
DEFINE_GET_VAL_FROM_PBMESSAGE(std::string, String);
DEFINE_GET_VAL_FROM_PBMESSAGE(int32_t, Int32);
......@@ -61,4 +61,4 @@ DEFINE_GET_VAL_FROM_PBMESSAGE(uint32_t, UInt32);
DEFINE_GET_VAL_FROM_PBMESSAGE(int64_t, Int64);
DEFINE_GET_VAL_FROM_PBMESSAGE(uint64_t, UInt64);
} // namespace oneflow
} // namespace oneflow
......@@ -4,10 +4,10 @@
#ifdef _MSC_VER
#include <io.h>
#endif
#include "oneflow/core/common/util.h"
#include "google/protobuf/descriptor.h"
#include "google/protobuf/map.h"
#include "google/protobuf/message.h"
#include "google/protobuf/descriptor.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -24,16 +24,14 @@ void ParseProtoFromString(const std::string& str, PbMessage* proto);
void PrintProtoToString(const PbMessage& proto, std::string* str);
// Prototxt <-> File
void ParseProtoFromTextFile(const std::string& file_path,
PbMessage* proto);
void PrintProtoToTextFile(const PbMessage& proto,
const std::string& file_path);
void ParseProtoFromTextFile(const std::string& file_path, PbMessage* proto);
void PrintProtoToTextFile(const PbMessage& proto, const std::string& file_path);
// Get From PbMessage
#define DECLARE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name);
#define DECLARE_GET_VAL_FROM_PBMESSAGE(ret_type, func_name) \
ret_type Get##func_name##FromPbMessage(const PbMessage& msg, \
const std::string& field_name);
DECLARE_GET_VAL_FROM_PBMESSAGE(std::string, String);
DECLARE_GET_VAL_FROM_PBMESSAGE(int32_t, Int32);
......@@ -45,8 +43,7 @@ DECLARE_GET_VAL_FROM_PBMESSAGE(uint64_t, UInt64);
// Alias PbType
#define ALIAS_PB_TYPE(type, name) \
using Pb##name = google::protobuf::type; \
#define ALIAS_PB_TYPE(type, name) using Pb##name = google::protobuf::type;
ALIAS_PB_TYPE(int32, Int32);
ALIAS_PB_TYPE(int64, Int64);
......@@ -55,13 +52,11 @@ ALIAS_PB_TYPE(uint64, UInt64);
#undef ALIAS_PB_TYPE
// PbRpf <-> std::vector
inline std::vector<std::string> PbVec2StdVec(
const PbRpf<std::string>& rpf) {
return std::vector<std::string> (rpf.begin(), rpf.end());
// PbRpf <-> std::vector
inline std::vector<std::string> PbVec2StdVec(const PbRpf<std::string>& rpf) {
return std::vector<std::string>(rpf.begin(), rpf.end());
}
inline PbRpf<std::string> StdVec2PbVec (
const std::vector<std::string>& vec) {
inline PbRpf<std::string> StdVec2PbVec(const std::vector<std::string>& vec) {
using RetType = PbRpf<std::string>;
return RetType(vec.begin(), vec.end());
}
......@@ -69,7 +64,7 @@ inline PbRpf<std::string> StdVec2PbVec (
// ProtoMap <-> HashMap
template<typename K, typename V>
HashMap<K, V> PbMap2HashMap(const google::protobuf::Map<K, V>& pb_map) {
return HashMap<K, V> (pb_map.begin(), pb_map.end());
return HashMap<K, V>(pb_map.begin(), pb_map.end());
}
template<typename K, typename V>
......@@ -79,16 +74,16 @@ google::protobuf::Map<K, V> HashMap2PbMap(const HashMap<K, V>& hash_map) {
}
// operator
inline bool operator == (const google::protobuf::MessageLite& lhs,
const google::protobuf::MessageLite& rhs) {
inline bool operator==(const google::protobuf::MessageLite& lhs,
const google::protobuf::MessageLite& rhs) {
return lhs.SerializeAsString() == rhs.SerializeAsString();
}
inline bool operator != (const google::protobuf::MessageLite& lhs,
const google::protobuf::MessageLite& rhs) {
inline bool operator!=(const google::protobuf::MessageLite& lhs,
const google::protobuf::MessageLite& rhs) {
return !(lhs == rhs);
}
} // namespace caffe
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_PROTOBUF_H_
#endif // ONEFLOW_CORE_COMMON_PROTOBUF_H_
......@@ -13,13 +13,13 @@ class Range final {
Range(int64_t begin, int64_t end) : begin_(begin), end_(end) {}
bool operator == (const Range& rhs) const {
bool operator==(const Range& rhs) const {
return begin_ == rhs.begin_ && end_ == rhs.end_;
}
int64_t begin() const { return begin_; }
int64_t end() const { return end_; }
int64_t& mut_begin() { return begin_; }
int64_t& mut_end() { return end_; }
......@@ -30,6 +30,6 @@ class Range final {
int64_t end_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_RANGE_H_
#endif // ONEFLOW_CORE_COMMON_RANGE_H_
......@@ -9,28 +9,23 @@ Shape::Shape(const ShapeProto& shape_proto) {
}
void Shape::ToProto(ShapeProto* ret) const {
*(ret->mutable_dim()) = PbRf<PbInt64> (dim_vec_.begin(), dim_vec_.end());
*(ret->mutable_dim()) = PbRf<PbInt64>(dim_vec_.begin(), dim_vec_.end());
}
std::string Shape::DebugStr() const {
std::stringstream ss;
ss << "{";
for (int64_t dim : dim_vec_) {
ss << dim << ",";
}
for (int64_t dim : dim_vec_) { ss << dim << ","; }
ss << "(" << elem_cnt_ << ")}";
return ss.str();
}
int64_t Shape::Count(int64_t begin_axis, int64_t end_axis) const {
CHECK(0 <= begin_axis && begin_axis <= end_axis && end_axis <= NumAxes())
<< "[begin_axis:" << begin_axis
<< "][end_axis:" << end_axis
<< "[begin_axis:" << begin_axis << "][end_axis:" << end_axis
<< "][num_axes:" << NumAxes() << "]";
int64_t cnt = 1;
for (int64_t i = begin_axis; i < end_axis; ++i) {
cnt *= At(i);
}
for (int64_t i = begin_axis; i < end_axis; ++i) { cnt *= At(i); }
return cnt;
}
......@@ -42,17 +37,13 @@ int64_t Shape::CanonicalAxisIndex(int64_t axis_index) const {
void Shape::UpdateElemCnt() {
elem_cnt_ = 1;
for (int64_t s : dim_vec_) {
elem_cnt_ *= s;
}
if (dim_vec_.size() == 0) {
elem_cnt_ = 0;
}
for (int64_t s : dim_vec_) { elem_cnt_ *= s; }
if (dim_vec_.size() == 0) { elem_cnt_ = 0; }
}
std::ostream& operator<< (std::ostream& out, const Shape& shape) {
std::ostream& operator<<(std::ostream& out, const Shape& shape) {
out << shape.DebugStr();
return out;
}
} // namespace oneflow
} // namespace oneflow
#ifndef ONEFLOW_CORE_COMMON_SHAPE_H_
#define ONEFLOW_CORE_COMMON_SHAPE_H_
#include "oneflow/core/common/util.h"
#include "oneflow/core/common/shape.pb.h"
#include "oneflow/core/common/util.h"
namespace oneflow {
......@@ -13,8 +13,8 @@ class Shape final {
explicit Shape(const std::vector<int64_t>& dim_vec);
Shape(const ShapeProto& shape_proto);
~Shape() = default;
bool operator == (const Shape& rhs) const;
bool operator==(const Shape& rhs) const;
std::string DebugStr() const;
void ToProto(ShapeProto*) const;
......@@ -34,17 +34,15 @@ class Shape final {
std::vector<int64_t> dim_vec_;
int64_t elem_cnt_;
};
std::ostream& operator<< (std::ostream& out, const Shape& shape);
std::ostream& operator<<(std::ostream& out, const Shape& shape);
inline Shape::Shape(const std::vector<int64_t>& dim_vec) :
dim_vec_(dim_vec) {
inline Shape::Shape(const std::vector<int64_t>& dim_vec) : dim_vec_(dim_vec) {
UpdateElemCnt();
}
inline bool Shape::operator == (const Shape& rhs) const {
inline bool Shape::operator==(const Shape& rhs) const {
return dim_vec_ == rhs.dim_vec_;
}
......@@ -61,6 +59,6 @@ inline int64_t Shape::Count(int64_t begin_axis) const {
return Count(begin_axis, NumAxes());
}
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_SHAPE_H_
#endif // ONEFLOW_CORE_COMMON_SHAPE_H_
......@@ -31,8 +31,7 @@ uint64_t oneflow_cast(const std::string& s) {
return ret;
}
void Split(const std::string& text,
const std::string& delims,
void Split(const std::string& text, const std::string& delims,
std::function<void(std::string&&)> Func) {
size_t token_start = 0;
if (text.empty()) { return; }
......@@ -44,4 +43,4 @@ void Split(const std::string& text,
}
}
} // namespace oneflow
} // namespace oneflow
#ifndef ONEFLOW_CORE_COMMON_UTIL_H_
#define ONEFLOW_CORE_COMMON_UTIL_H_
#include <unordered_set>
#include <unordered_map>
#include <functional>
#include <algorithm>
#include <mutex>
#include <utility>
#include <memory>
#include <thread>
#include <list>
#include <condition_variable>
#include <atomic>
#include <queue>
#include <condition_variable>
#include <fstream>
#include <functional>
#include <iostream>
#include "glog/logging.h"
#include "gtest/gtest.h"
#include <list>
#include <memory>
#include <mutex>
#include <queue>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include "cublas_v2.h"
#include "cuda.h"
#include "cuda_runtime.h"
#include "cublas_v2.h"
#include "cudnn.h"
#include "glog/logging.h"
#include "gtest/gtest.h"
namespace oneflow {
#define OF_DISALLOW_COPY(ClassName) \
#define OF_DISALLOW_COPY(ClassName) \
ClassName(const ClassName&) = delete; \
ClassName& operator = (const ClassName&) = delete;
ClassName& operator=(const ClassName&) = delete;
#define OF_DISALLOW_MOVE(ClassName) \
ClassName(ClassName&&) = delete; \
ClassName& operator = (ClassName&&) = delete;
ClassName(ClassName&&) = delete; \
ClassName& operator=(ClassName&&) = delete;
#define OF_DISALLOW_COPY_AND_MOVE(ClassName) \
OF_DISALLOW_COPY(ClassName) \
OF_DISALLOW_COPY(ClassName) \
OF_DISALLOW_MOVE(ClassName)
#define UNEXPECTED_RUN() \
LOG(FATAL) << "Unexpected Run";
#define UNEXPECTED_RUN() LOG(FATAL) << "Unexpected Run";
#define TODO() \
LOG(FATAL) << "TODO";
#define TODO() LOG(FATAL) << "TODO";
#define OF_SINGLETON(ClassName) \
#define OF_SINGLETON(ClassName) \
static ClassName& Singleton() { \
static ClassName obj; \
return obj; \
static ClassName obj; \
return obj; \
}
template<typename T>
bool operator == (const std::weak_ptr<T>& lhs, const std::weak_ptr<T>& rhs) {
bool operator==(const std::weak_ptr<T>& lhs, const std::weak_ptr<T>& rhs) {
return lhs.lock().get() == rhs.lock().get();
}
......@@ -83,9 +81,7 @@ inline std::string LogDir() {
inline void str_replace(std::string* str, char old_ch, char new_ch) {
for (size_t i = 0; i < str->size(); ++i) {
if (str->at(i) == old_ch) {
str->at(i) = new_ch;
}
if (str->at(i) == old_ch) { str->at(i) = new_ch; }
}
}
......@@ -102,30 +98,26 @@ void EraseIf(HashMap<K, V>* hash_map,
}
#define OF_DECLARE_ENUM_TO_OSTREAM_FUNC(EnumType) \
std::ostream& operator << (std::ostream& out_stream, const EnumType&)
std::ostream& operator<<(std::ostream& out_stream, const EnumType&)
#define OF_DEFINE_ENUM_TO_OSTREAM_FUNC(EnumType) \
std::ostream& operator << (std::ostream& out_stream, const EnumType& x) { \
out_stream << static_cast<int> (x); \
return out_stream; \
}
#define OF_DEFINE_ENUM_TO_OSTREAM_FUNC(EnumType) \
std::ostream& operator<<(std::ostream& out_stream, const EnumType& x) { \
out_stream << static_cast<int>(x); \
return out_stream; \
}
template<typename OutType, typename InType>
OutType oneflow_cast(const InType&);
void Split(const std::string& text,
const std::string& delims,
void Split(const std::string& text, const std::string& delims,
std::function<void(std::string&&)> Func);
template<typename T>
void SplitAndParseAs(const std::string& text,
const std::string& delims,
void SplitAndParseAs(const std::string& text, const std::string& delims,
std::function<void(T&&)> Func) {
Split(text, delims, [&Func](std::string&& s) {
Func(oneflow_cast<T>(s));
});
Split(text, delims, [&Func](std::string&& s) { Func(oneflow_cast<T>(s)); });
}
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
#endif // ONEFLOW_CORE_COMMON_UTIL_H_
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/operator/boxing_op.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
......@@ -31,7 +31,7 @@ void FwCompleteBoxOpConfFakerMdUpdt(BoxingOpConf* conf) {
conf->mutable_clone_box();
}
} // namespace
} // namespace
void BoxingTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
EnrollAllRegstAndBindRelatedEdge();
......@@ -68,14 +68,14 @@ void BoxingTaskNode::FwInitChain2SortedEdgesMaps(
}
for (auto& pair : *chain2sorted_edges) {
std::vector<const TaskEdge*>& edges = pair.second;
std::sort(edges.begin(), edges.end(), [&edge2stage](const TaskEdge* lhs,
const TaskEdge* rhs) {
const StageNode* lhs_stage = edge2stage.at(lhs);
const StageNode* rhs_stage = edge2stage.at(rhs);
CHECK(lhs_stage->chain_node() == rhs_stage->chain_node());
return lhs_stage->parallel_range().begin() <
rhs_stage->parallel_range().begin();
});
std::sort(edges.begin(), edges.end(),
[&edge2stage](const TaskEdge* lhs, const TaskEdge* rhs) {
const StageNode* lhs_stage = edge2stage.at(lhs);
const StageNode* rhs_stage = edge2stage.at(rhs);
CHECK(lhs_stage->chain_node() == rhs_stage->chain_node());
return lhs_stage->parallel_range().begin()
< rhs_stage->parallel_range().begin();
});
}
}
......@@ -91,12 +91,12 @@ void BoxingTaskNode::FwSortEdgesInnerStage(
}
return ret;
};
std::sort(edges_to_be_sorted->begin(), edges_to_be_sorted->end(), [&]
(const TaskEdge* lhs, const TaskEdge* rhs) {
const CompTaskNode* lhs_node = GetPredSuccCompTaskNode(lhs);
const CompTaskNode* rhs_node = GetPredSuccCompTaskNode(rhs);
return lhs_node->parallel_id() < rhs_node->parallel_id();
});
std::sort(edges_to_be_sorted->begin(), edges_to_be_sorted->end(),
[&](const TaskEdge* lhs, const TaskEdge* rhs) {
const CompTaskNode* lhs_node = GetPredSuccCompTaskNode(lhs);
const CompTaskNode* rhs_node = GetPredSuccCompTaskNode(rhs);
return lhs_node->parallel_id() < rhs_node->parallel_id();
});
}
void BoxingTaskNode::FwBuildChainSortedEdgesPair(
......@@ -141,9 +141,8 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
CHECK_EQ(lbns.size(), 1);
lbns.clear();
auto in_regst_0 = GetRelatedRegst(sorted_in_edges.at(0));
in_regst_0->ForEachLbn([&](const std::string& lbn) {
lbns.push_back(lbn);
});
in_regst_0->ForEachLbn(
[&](const std::string& lbn) { lbns.push_back(lbn); });
}
// Enroll Lbn
auto middle_regst = GetProducedRegstDesc("middle");
......@@ -173,11 +172,9 @@ void BoxingTaskNode::FwBuildChainSortedEdgesPair(
void BoxingTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
exec_gph().ConstForEachNode([this](const ExecNode* exec_node) {
exec_node->op()->InferShape4FwBlobs(
exec_node->GetMutShapePtr4BnInOpFunc(),
chain_node()->parallel_desc()->policy(),
0,
0);
exec_node->op()->InferShape4FwBlobs(exec_node->GetMutShapePtr4BnInOpFunc(),
chain_node()->parallel_desc()->policy(),
0, 0);
});
}
......@@ -191,7 +188,7 @@ std::shared_ptr<RegstDesc> GetBpRegstFromFwRegst(
return GetRelatedRegst(bp_edge);
}
}
} // namespace
void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
EnrollAllRegstAndBindRelatedEdge();
......@@ -231,7 +228,7 @@ void BoxingTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
});
mut_exec_gph().UpdateSourceAndSink();
}
void BoxingTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
for (TaskEdge* fw_in_edge : GetFwNode()->in_edges()) {
auto in_regst = GetRelatedRegst(fw_in_edge);
......@@ -240,8 +237,8 @@ void BoxingTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
}
}
auto fw_middle_regst = GetFwNode()->GetProducedRegstDesc("middle");
auto bp_middle_regst = GetProducedRegstDesc("middle");
auto bp_middle_regst = GetProducedRegstDesc("middle");
bp_middle_regst->CopyShapeFrom(fw_middle_regst.get());
}
} // namespace oneflow
} // namespace oneflow
......@@ -10,20 +10,18 @@ class BoxingTaskNode : public TaskNode {
OF_DISALLOW_COPY_AND_MOVE(BoxingTaskNode);
BoxingTaskNode() = default;
virtual ~BoxingTaskNode() = default;
std::string VisualStr() const override {
return TaskNode::VisualStr() + "Boxing";
}
void ToProto(TaskProto* ret) const override {
TaskNode::ToProto(ret);
};
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
protected:
virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
}
using ChainEdgesPair =
std::pair<const ChainNode*, std::vector<const TaskEdge*>>;
using Chain2EdgesMap =
......@@ -33,10 +31,9 @@ class BoxingTaskNode : public TaskNode {
const std::unordered_set<TaskEdge*>& (TaskNode::*in_out_edges)() const,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const);
void FwSortEdgesInnerStage(
std::vector<const TaskEdge*>* edges_to_be_sorted,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const);
void FwSortEdgesInnerStage(std::vector<const TaskEdge*>* edges_to_be_sorted,
TaskNode* (TaskEdge::*src_dst_node)() const,
TaskEdge* (TaskNode::*SoleEdge)() const);
void FwBuildChainSortedEdgesPair(
const ChainEdgesPair& chain_sorted_in_edges,
const ChainEdgesPair& chain_sorted_out_edges);
......@@ -50,12 +47,11 @@ class BoxingTaskNode : public TaskNode {
void FwInferShapeOfBlobsInProducedRegsts(TaskGraph*);
void BpBuildExecAndEnrollLbn2Regsts(TaskGraph*);
void BpInferShapeOfBlobsInProducedRegsts(TaskGraph*);
void EnrollAllRegstAndBindRelatedEdge();
TaskType task_type() const override { return kBoxingTask; }
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
......@@ -27,10 +27,8 @@ void SetChainNodeWithChainIt(ChainNode* chain_node, ChainIt chain_it) {
}
}
void InitChains(
const LogicalGraph& logi_gph,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
void InitChains(const LogicalGraph& logi_gph, std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
chain_list->clear();
logical2chain_it->clear();
logi_gph.ConstForEachNode([&](const LogicalNode* node) {
......@@ -82,9 +80,8 @@ void InitChains(
});
}
void ModelMergeChains(
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
void ModelMergeChains(std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
for (auto& pair : *logical2chain_it) {
// Get cur_node, pred_node
const LogicalNode* cur_node = pair.first;
......@@ -101,8 +98,7 @@ void ModelMergeChains(
ChainIt pred_chain = logical2chain_it->at(pred_node);
ChainIt cur_chain = pair.second;
// Merge
pred_chain->nodes.insert(pred_chain->nodes.end(),
cur_chain->nodes.begin(),
pred_chain->nodes.insert(pred_chain->nodes.end(), cur_chain->nodes.begin(),
cur_chain->nodes.end());
for (const LogicalNode* node : cur_chain->nodes) {
pred_chain->descendants.erase(node);
......@@ -112,11 +108,10 @@ void ModelMergeChains(
}
}
bool TryMergeWithConnect(
const LogicalNode* up_node,
const LogicalNode* bottom_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
bool TryMergeWithConnect(const LogicalNode* up_node,
const LogicalNode* bottom_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
// Get chain
ChainIt up_chain = logical2chain_it->at(up_node);
ChainIt bottom_chain = logical2chain_it->at(bottom_node);
......@@ -146,11 +141,10 @@ bool TryMergeWithConnect(
return true;
}
bool TryMergeWithoutConnect(
const LogicalNode* lhs_node,
const LogicalNode* rhs_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
bool TryMergeWithoutConnect(const LogicalNode* lhs_node,
const LogicalNode* rhs_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
// Get chain
ChainIt lhs_chain = logical2chain_it->at(lhs_node);
ChainIt rhs_chain = logical2chain_it->at(rhs_node);
......@@ -170,11 +164,9 @@ bool TryMergeWithoutConnect(
return true;
}
bool TryDataMerge(
const LogicalNode* first,
const LogicalNode* second,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
bool TryDataMerge(const LogicalNode* first, const LogicalNode* second,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
if (first->parallel_desc()->Equal(second->parallel_desc().get()) == false) {
return false;
}
......@@ -186,10 +178,9 @@ bool TryDataMerge(
return false;
}
bool DoOneDataMerge(
const std::vector<const LogicalNode*>& data_parallel_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
bool DoOneDataMerge(const std::vector<const LogicalNode*>& data_parallel_node,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
for (const LogicalNode* first : data_parallel_node) {
for (const LogicalNode* second : data_parallel_node) {
if (first == second) { continue; }
......@@ -204,10 +195,9 @@ bool DoOneDataMerge(
return false;
}
void DataMergeChains(
const LogicalGraph& logical_gph,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
void DataMergeChains(const LogicalGraph& logical_gph,
std::list<Chain>* chain_list,
Logical2ChainItMap* logical2chain_it) {
std::vector<const LogicalNode*> data_parallel_node;
for (const auto& pair : *logical2chain_it) {
const LogicalNode* cur_logi_node = pair.first;
......@@ -215,17 +205,14 @@ void DataMergeChains(
if (cur_logi_node->IsLossNode()) { continue; }
data_parallel_node.push_back(cur_logi_node);
}
while (DoOneDataMerge(data_parallel_node, chain_list, logical2chain_it)) {
}
while (DoOneDataMerge(data_parallel_node, chain_list, logical2chain_it)) {}
}
} // namespace
} // namespace
std::string ChainNode::ConcatedOpsName() const {
std::stringstream ss;
for (auto op : op_vec_) {
ss << "\\n" << op->op_name();
}
for (auto op : op_vec_) { ss << "\\n" << op->op_name(); }
if (!op_vec_.empty()) {
return ss.str().substr(2);
} else {
......@@ -252,19 +239,21 @@ ChainGraph::ChainGraph(const LogicalGraph* logical_gph) {
DataMergeChains(*logical_gph, &chain_list, &logical2chain_it);
// Init chain_nodes
auto HashChainIt = [](const ChainIt& chain_it) {
return std::hash<Chain*> ()(&(*chain_it));
return std::hash<Chain*>()(&(*chain_it));
};
HashMap<ChainIt, ChainNode*, decltype(HashChainIt)>
chain_it2chain_node(11, HashChainIt);
HashMap<ChainIt, ChainNode*, decltype(HashChainIt)> chain_it2chain_node(
11, HashChainIt);
HashMap<ChainNode*, std::unordered_set<ChainNode*>> chain_node2pred;
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end(); ++chain_it) {
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end();
++chain_it) {
ChainNode* chain_node = NewNode();
chain_it2chain_node[chain_it] = chain_node;
chain_node2pred[chain_node] = {};
SetChainNodeWithChainIt(chain_node, chain_it);
}
// Record the predecessor
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end(); ++chain_it) {
for (auto chain_it = chain_list.begin(); chain_it != chain_list.end();
++chain_it) {
ChainNode* chain_node = chain_it2chain_node.at(chain_it);
for (const LogicalNode* logi_node : chain_it->nodes) {
for (auto logi_in_edge : logi_node->in_edges()) {
......@@ -324,14 +313,12 @@ void ChainGraph::SetInOutLbn4AllChainNodeInDataTaskGraph() {
});
}
std::vector<std::string> FindLbnsBetween(const ChainNode* src_node,
std::vector<std::string> FindLbnsBetween(const ChainNode* src_node,
const ChainNode* dst_node) {
std::vector<std::string> matching_lbns;
for (const std::string& src_node_output_lbn : src_node->output_lbns()) {
for (const std::string& dst_node_input_lbn : dst_node->input_lbns()) {
if (src_node_output_lbn != dst_node_input_lbn) {
continue;
}
for (const std::string& dst_node_input_lbn : dst_node->input_lbns()) {
if (src_node_output_lbn != dst_node_input_lbn) { continue; }
matching_lbns.push_back(src_node_output_lbn);
break;
}
......@@ -343,10 +330,8 @@ std::vector<std::string> FindLbnsBetween(const ChainNode* src_node,
std::string ChainEdge::VisualStr() const {
std::vector<std::string> lbns = FindLbnsBetween(src_node(), dst_node());
std::stringstream ss;
for (const std::string& lbn : lbns) {
ss << "\\n" << lbn;
}
for (const std::string& lbn : lbns) { ss << "\\n" << lbn; }
return ss.str().substr(2);
}
} // namespace oneflow
} // namespace oneflow
......@@ -22,9 +22,7 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
const std::vector<std::shared_ptr<const Operator>>& op_vec() const {
return op_vec_;
}
std::vector<std::shared_ptr<const Operator>>& mut_op_vec() {
return op_vec_;
}
std::vector<std::shared_ptr<const Operator>>& mut_op_vec() { return op_vec_; }
std::shared_ptr<const ParallelDesc> parallel_desc() const {
return parallel_desc_;
......@@ -33,26 +31,18 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
return parallel_desc_;
}
const std::vector<std::string>& input_lbns() const {
return input_lbns_;
}
std::vector<std::string>& mut_input_lbns() {
return input_lbns_;
}
const std::vector<std::string>& output_lbns() const {
return output_lbns_;
}
std::vector<std::string>& mut_output_lbns() {
return output_lbns_;
}
const std::vector<std::string>& input_lbns() const { return input_lbns_; }
std::vector<std::string>& mut_input_lbns() { return input_lbns_; }
const std::vector<std::string>& output_lbns() const { return output_lbns_; }
std::vector<std::string>& mut_output_lbns() { return output_lbns_; }
bool IsLossNode() const {
return op_vec_.size() == 1 && op_vec_.front()->IsLossOp();
}
std::string VisualStr() const { return ConcatedOpsName(); }
bool HasOpWithModelOrModelTmpBlob() const;
private:
......@@ -60,10 +50,8 @@ class ChainNode final : public Node<ChainNode, ChainEdge> {
std::shared_ptr<const ParallelDesc> parallel_desc_;
std::vector<std::string> input_lbns_;
std::vector<std::string> output_lbns_;
};
class ChainEdge final : public Edge<ChainNode, ChainEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(ChainEdge);
......@@ -87,11 +75,10 @@ class ChainGraph final : public Graph<ChainNode, ChainEdge> {
private:
void SetInOutLbn4AllChainNodeInDataTaskGraph();
};
std::vector<std::string> FindLbnsBetween(const ChainNode*, const ChainNode*);
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_CHAIN_GRAPH_H_
#include "oneflow/core/graph/comp_task_node.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/graph/model_update_task_graph.h"
#include "oneflow/core/operator/clone_op.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
std::string CompTaskNode::VisualStr() const {
std::stringstream ss;
ss << TaskNode::VisualStr()
<< "Compute" << ":"
<< stage_node()->machine_id_str() << ":"
<< thrd_loc_id_str() << "\\n"
ss << TaskNode::VisualStr() << "Compute"
<< ":" << stage_node()->machine_id_str() << ":" << thrd_loc_id_str()
<< "\\n"
<< chain_node()->VisualStr();
return ss.str();
}
std::string CompTaskNode::device_name() const {
return IDMgr::Singleton().MachineName4MachineId(stage_node()->machine_id())
+ ":"
+ std::to_string(IDMgr::Singleton().DevPhyId4ThrdLocId(thrd_loc_id()));
+ ":"
+ std::to_string(IDMgr::Singleton().DevPhyId4ThrdLocId(thrd_loc_id()));
}
void SortByParallelId(std::vector<CompTaskNode*>* comp_node_vec) {
std::sort(comp_node_vec->begin(), comp_node_vec->end(), []
(const CompTaskNode* lhs, const CompTaskNode* rhs) {
return lhs->parallel_id() < rhs->parallel_id();
});
std::sort(comp_node_vec->begin(), comp_node_vec->end(),
[](const CompTaskNode* lhs, const CompTaskNode* rhs) {
return lhs->parallel_id() < rhs->parallel_id();
});
}
} // namespace oneflow
} // namespace oneflow
......@@ -21,17 +21,16 @@ class CompTaskNode : public TaskNode {
protected:
virtual void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
auto fw_comp_code = static_cast<CompTaskNode*> (fw_node);
auto fw_comp_code = static_cast<CompTaskNode*>(fw_node);
parallel_id_ = fw_comp_code->parallel_id_;
}
private:
int64_t parallel_id_;
};
void SortByParallelId(std::vector<CompTaskNode*>* comp_node_vec);
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_COMP_TASK_NODE_H_
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/copy_hd_op.h"
#include "oneflow/core/operator/copy_comm_net_op.h"
#include "oneflow/core/operator/copy_hd_op.h"
namespace oneflow {
void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*){
void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*) {
auto out_regst = NewProducedRegstDesc("copy_out");
BindProducedRegstAndOutEdge(out_regst, SoleOutEdge());
std::shared_ptr<RegstDesc> in_regst = GetRelatedRegst(SoleInEdge());
SubscribeRegstDesc("copy_in", in_regst);
out_regst->CopyLbnFrom(in_regst.get());
ExecNode* node = mut_exec_gph().NewNode();
node->mut_op() = ConstructOp();
if (IsFwNode()) {
node->BindBnInOpAndRegst(node->op()->SoleIbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleObn(), out_regst);
......@@ -21,7 +21,7 @@ void CopyTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph*){
node->BindBnInOpAndRegst(node->op()->SoleOdbn(), in_regst);
node->BindBnInOpAndRegst(node->op()->SoleIdbn(), out_regst);
}
mut_exec_gph().UpdateSourceAndSink();
}
......@@ -56,4 +56,4 @@ std::shared_ptr<const Operator> CopyCommNetTaskNode::ConstructOp() const {
return OpMgr::Singleton().ConstructOp(op_conf);
}
} // namespace oneflow
} // namespace oneflow
......@@ -17,7 +17,6 @@ class CopyTaskNode : public TaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph*) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph*) override;
};
class CopyHDTaskNode final : public CopyTaskNode {
......@@ -25,26 +24,22 @@ class CopyHDTaskNode final : public CopyTaskNode {
OF_DISALLOW_COPY_AND_MOVE(CopyHDTaskNode);
CopyHDTaskNode() = default;
~CopyHDTaskNode() = default;
bool IsH2D() const {
return ((IsFwInCopy() && IsFwNode()) || (IsFwOutCopy() && IsBpNode()));
}
bool IsD2H() const {
return !IsH2D();
}
bool IsD2H() const { return !IsH2D(); }
bool IsFwInCopy() const { return is_fw_in_copy_; }
bool IsFwOutCopy() const { return !is_fw_in_copy_; }
void SetFwInCopy();
void SetFwOutCopy();
std::string VisualStr() const override {
return TaskNode::VisualStr() + "CopyHD";
}
void ToProto(TaskProto* ret) const override {
TaskNode::ToProto(ret);
};
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
private:
std::shared_ptr<const Operator> ConstructOp() const override;
......@@ -54,12 +49,11 @@ class CopyHDTaskNode final : public CopyTaskNode {
is_fw_in_copy_ = static_cast<CopyHDTaskNode*>(fw_node)->is_fw_in_copy_;
}
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyHDTaskNode> ();
return of_make_unique<CopyHDTaskNode>();
}
TaskType task_type() const override { return kCopyHdTask; }
bool is_fw_in_copy_;
};
class CopyCommNetTaskNode final : public CopyTaskNode {
......@@ -72,14 +66,12 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
return TaskNode::VisualStr() + "CommNet";
}
void ToProto(TaskProto* ret) const override {
TaskNode::ToProto(ret);
};
void ToProto(TaskProto* ret) const override { TaskNode::ToProto(ret); };
private:
std::shared_ptr<const Operator> ConstructOp() const override;
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<CopyCommNetTaskNode> ();
return of_make_unique<CopyCommNetTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
TaskNode::InitWithFwNode(fw_node);
......@@ -87,9 +79,8 @@ class CopyCommNetTaskNode final : public CopyTaskNode {
set_task_id();
}
TaskType task_type() const override { return kCopyCommNetTask; }
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_COPY_TASK_NODE_H_
......@@ -21,22 +21,20 @@ void DataCompTaskNode::FwBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
FwSetExecNodeFromInRegst(extern_in_lbn2consumer);
FwEnrollLbn2OutRegst(lbn2producer);
FwEnrollLbn2ActivationRegst();
FwEnrollLbn2ModelAndTmpRegsts(); // model model_tmp data_tmp
FwEnrollLbn2ModelAndTmpRegsts(); // model model_tmp data_tmp
}
void DataCompTaskNode::FwInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
exec_gph().ConstTopoForEachNode([this](const ExecNode* node) {
node->op()->InferShape4FwBlobs(
node->GetMutShapePtr4BnInOpFunc(),
chain_node()->parallel_desc()->policy(),
parallel_id(),
chain_node()->parallel_desc()->policy(), parallel_id(),
chain_node()->parallel_desc()->parallel_num());
});
}
void DataCompTaskNode::FwBuildFromUserOps(
Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer) {
Lbn2NodeBnMap* lbn2producer, Lbn2NodeBnMap* extern_in_lbn2consumer) {
for (std::shared_ptr<const Operator> op : chain_node()->op_vec()) {
ExecNode* cur_node = mut_exec_gph().NewNode();
cur_node->mut_op() = op;
......@@ -56,8 +54,7 @@ void DataCompTaskNode::FwBuildFromUserOps(
edge->mut_dst_bn() = ibn;
Connect(producer_it->second.first, edge, cur_node);
} else {
CHECK(extern_in_lbn2consumer->insert({lbn,
{cur_node, ibn}}).second);
CHECK(extern_in_lbn2consumer->insert({lbn, {cur_node, ibn}}).second);
}
}
});
......@@ -161,8 +158,7 @@ void DataCompTaskNode::BpBuildExecAndEnrollLbn2Regsts(TaskGraph*) {
// Subscribe
SubscribeRegstDesc("activation",
GetFwNode()->GetProducedRegstDesc("activation"));
SubscribeRegstDesc("data_tmp",
GetFwNode()->GetProducedRegstDesc("data_tmp"));
SubscribeRegstDesc("data_tmp", GetFwNode()->GetProducedRegstDesc("data_tmp"));
SubscribeRegstDesc("model", GetFwNode()->GetSubscribedRegstDesc("model"));
SubscribeRegstDesc("model_tmp",
GetFwNode()->GetSubscribedRegstDesc("model_tmp"));
......@@ -179,7 +175,8 @@ void DataCompTaskNode::BpInferShapeOfBlobsInProducedRegsts(TaskGraph*) {
in_diff_regst->CopyShapeFrom(in_regst.get());
// model_diff_regst
if (auto md_diff_regst = GetProducedRegstDesc("model_diff")) {
md_diff_regst->CopyShapeFrom(GetFwNode()->GetSubscribedRegstDesc("model").get());
md_diff_regst->CopyShapeFrom(
GetFwNode()->GetSubscribedRegstDesc("model").get());
}
// activation_diff_regst
if (auto acti_diff_regst = GetProducedRegstDesc("activation_diff")) {
......@@ -201,8 +198,7 @@ void DataCompTaskNode::BpBuildExecGraph() {
bp_edge->set_lbn(fw_edge->lbn());
bp_edge->mut_src_bn() = GenDiffBn(fw_edge->dst_bn());
bp_edge->mut_dst_bn() = GenDiffBn(fw_edge->src_bn());
Connect(fw_node2bp_node.at(fw_edge->dst_node()),
bp_edge,
Connect(fw_node2bp_node.at(fw_edge->dst_node()), bp_edge,
fw_node2bp_node.at(fw_edge->src_node()));
});
mut_exec_gph().UpdateSourceAndSink();
......@@ -222,7 +218,7 @@ void DataCompTaskNode::BpEnrollLbn2ActivationDiffRegst() {
exec_gph().ConstForEachEdge([&](const ExecEdge* edge) {
edge->src_node()->BindBnInOpAndRegst(edge->src_bn(), activation_diff_regst);
edge->dst_node()->BindBnInOpAndRegst(edge->dst_bn(), activation_diff_regst);
edge->src_node()->BindBnInOpAndRegst(GenUnDiffBn(edge->src_bn()),
edge->src_node()->BindBnInOpAndRegst(GenUnDiffBn(edge->src_bn()),
activation_regst);
edge->dst_node()->BindBnInOpAndRegst(GenUnDiffBn(edge->dst_bn()),
activation_regst);
......@@ -280,4 +276,4 @@ void DataCompTaskNode::BpEnrollLbn2ModelDiffRegst() {
});
}
} // namespace oneflow
} // namespace oneflow
......@@ -21,17 +21,14 @@ class DataCompTaskNode final : public CompTaskNode {
private:
OVERRIDE_IF_FW_BP_FOR_FUNC(BuildExecAndEnrollLbn2Regsts);
OVERRIDE_IF_FW_BP_FOR_FUNC(InferShapeOfBlobsInProducedRegsts);
using Lbn2NodeBnMap =
HashMap<std::string, std::pair<ExecNode*, std::string>>;
using Lbn2NodeBnMap = HashMap<std::string, std::pair<ExecNode*, std::string>>;
void FwBuildExecAndEnrollLbn2Regsts(TaskGraph* gph);
void FwInferShapeOfBlobsInProducedRegsts(TaskGraph* gph);
void FwBuildFromUserOps(
Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetExecNodeFromInRegst(
const Lbn2NodeBnMap& extern_in_lbn2consumer);
void FwBuildFromUserOps(Lbn2NodeBnMap* lbn2producer,
Lbn2NodeBnMap* extern_in_lbn2consumer);
void FwSetExecNodeFromInRegst(const Lbn2NodeBnMap& extern_in_lbn2consumer);
void FwEnrollLbn2OutRegst(const Lbn2NodeBnMap& lbn2producer);
void FwEnrollLbn2OutRegstWhenLoss();
void FwEnrollLbn2OutRegstWhenNotLoss(const Lbn2NodeBnMap& lbn2producer);
......@@ -45,16 +42,13 @@ class DataCompTaskNode final : public CompTaskNode {
void BpSetExecNodeFromOutDiffRegst();
void BpEnrollLbn2InDiffRegst();
void BpEnrollLbn2ModelDiffRegst();
TaskType task_type() const override {
return kDataCompTask;
}
TaskType task_type() const override { return kDataCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<DataCompTaskNode> ();
return of_make_unique<DataCompTaskNode>();
}
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_DATA_COMP_TASK_NODE_H_
......@@ -4,11 +4,9 @@ namespace oneflow {
class DataCompTaskNode;
DataTaskGraph::DataTaskGraph(
const std::string& name,
const DLNetConf& dl_net_conf,
const Strategy& strategy_conf,
bool need_bp) {
DataTaskGraph::DataTaskGraph(const std::string& name,
const DLNetConf& dl_net_conf,
const Strategy& strategy_conf, bool need_bp) {
mut_name() = name;
LogicalGraph logical_gph(dl_net_conf, strategy_conf);
auto chain_gph = of_make_unique<ChainGraph>(&logical_gph);
......@@ -16,4 +14,4 @@ DataTaskGraph::DataTaskGraph(
BuildExecAndEnrollLbn2Regsts();
}
}
} // namespace oneflow
......@@ -10,17 +10,15 @@ class DataTaskGraph final : public TaskGraph {
OF_DISALLOW_COPY_AND_MOVE(DataTaskGraph);
DataTaskGraph() = delete;
~DataTaskGraph() = default;
DataTaskGraph(const std::string& name,
const DLNetConf& dl_net_conf,
const Strategy& strategy_conf,
bool need_bp);
DataTaskGraph(const std::string& name, const DLNetConf& dl_net_conf,
const Strategy& strategy_conf, bool need_bp);
const char* TypeName() const override { return "DataTaskGraph"; }
private:
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_DATA_TASK_GRAPH_H_
......@@ -2,12 +2,10 @@
namespace oneflow {
void ExecEdge::set_lbn(const std::string& lbn) {
lbn_ = lbn;
}
void ExecEdge::set_lbn(const std::string& lbn) { lbn_ = lbn; }
std::function<Shape*(const std::string&)>
ExecNode::GetMutShapePtr4BnInOpFunc() const {
std::function<Shape*(const std::string&)> ExecNode::GetMutShapePtr4BnInOpFunc()
const {
return [this](const std::string& bn_in_op) -> Shape* {
auto it = this->bn_in_op2regst_.find(bn_in_op);
if (it == this->bn_in_op2regst_.end()) { return nullptr; }
......@@ -19,11 +17,11 @@ ExecNode::GetMutShapePtr4BnInOpFunc() const {
void ExecNode::ToProto(ExecNodeProto* ret) const {
ret->set_op_name(op_->op_name());
for (const auto& bn_regst: bn_in_op2regst_) {
for (const auto& bn_regst : bn_in_op2regst_) {
auto regst = bn_regst.second.lock();
if (regst) {
ret->mutable_bn_in_op2regst_desc_id()->insert({
bn_regst.first, regst->regst_desc_id()});
ret->mutable_bn_in_op2regst_desc_id()->insert(
{bn_regst.first, regst->regst_desc_id()});
}
}
}
......@@ -36,4 +34,4 @@ void ExecGraph::ToExecSequence(ExecSequence* ret) const {
});
}
} // namespace oneflow
} // namespace oneflow
#ifndef ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#include "oneflow/core/common/protobuf.h"
#include "oneflow/core/graph/exec_sequence.pb.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/register/register_desc.h"
#include "oneflow/core/common/protobuf.h"
namespace oneflow {
......@@ -32,7 +32,6 @@ class ExecEdge final : public Edge<ExecNode, ExecEdge> {
std::string lbn_;
std::string src_bn_;
std::string dst_bn_;
};
class ExecNode final : public Node<ExecNode, ExecEdge> {
......@@ -44,10 +43,12 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
std::shared_ptr<const Operator> op() const { return op_; }
std::shared_ptr<const Operator>& mut_op() { return op_; }
void BindBnInOpAndRegst(const std::string& bn_in_op, std::weak_ptr<RegstDesc> regst) {
void BindBnInOpAndRegst(const std::string& bn_in_op,
std::weak_ptr<RegstDesc> regst) {
CHECK(bn_in_op2regst_.emplace(bn_in_op, regst).second);
}
std::shared_ptr<RegstDesc> GetRegstFromBnInOp(const std::string& bn_in_op) const {
std::shared_ptr<RegstDesc> GetRegstFromBnInOp(
const std::string& bn_in_op) const {
return bn_in_op2regst_.at(bn_in_op).lock();
}
const HashMap<std::string, std::weak_ptr<RegstDesc>>& bn_in_op2regst() const {
......@@ -55,15 +56,14 @@ class ExecNode final : public Node<ExecNode, ExecEdge> {
}
std::function<Shape*(const std::string&)> GetMutShapePtr4BnInOpFunc() const;
std::string VisualStr() const { return op_->op_name(); }
void ToProto(ExecNodeProto* ret) const;
private:
std::shared_ptr<const Operator> op_;
HashMap<std::string, std::weak_ptr<RegstDesc>> bn_in_op2regst_;
};
class ExecGraph final : public Graph<ExecNode, ExecEdge> {
......@@ -71,14 +71,13 @@ class ExecGraph final : public Graph<ExecNode, ExecEdge> {
OF_DISALLOW_COPY_AND_MOVE(ExecGraph);
ExecGraph() = default;
~ExecGraph() = default;
void ToExecSequence(ExecSequence* ret) const;
const char* TypeName() const override { return "ExecGraph"; }
private:
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_EXEC_GRAPH_H_
#ifndef ONEFLOW_CORE_GRAPH_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_GRAPH_H_
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/lib/io/path.h"
#include "gflags/gflags.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "oneflow/core/graph/node.h"
#include "oneflow/core/persistence/persistent_out_stream.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
namespace oneflow {
......@@ -23,11 +23,11 @@ class Graph {
void ConstForEachNode(std::function<void(const NodeType*)>) const;
void ConstTopoForEachNode(std::function<void(const NodeType*)>) const;
void ConstReverseTopoForEachNode(std::function<void(const NodeType*)>) const;
// For Each Edge
void ForEachEdge(std::function<void(EdgeType*)>);
void ConstForEachEdge(std::function<void(const EdgeType*)>) const;
// Getters
const std::unordered_set<NodeType*>& source_nodes() const;
const std::unordered_set<NodeType*>& sink_nodes() const;
......@@ -37,7 +37,7 @@ class Graph {
size_t node_num() const { return nodes_.size(); }
size_t edge_num() const { return edges_.size(); }
virtual const char* TypeName() const { return "Not Defined"; }
// Setters
NodeType* NewNode();
EdgeType* NewEdge();
......@@ -57,9 +57,9 @@ class Graph {
class TopoIterator;
class ReverseTopoIterator;
TopoIterator begin() { return source_nodes_; }
TopoIterator end() { return std::unordered_set<NodeType*> (); }
TopoIterator end() { return std::unordered_set<NodeType*>(); }
ReverseTopoIterator rbegin() { return sink_nodes_; }
ReverseTopoIterator rend() { return std::unordered_set<NodeType*> (); }
ReverseTopoIterator rend() { return std::unordered_set<NodeType*>(); }
//
std::unordered_set<NodeType*> source_nodes_;
......@@ -68,25 +68,22 @@ class Graph {
std::vector<std::unique_ptr<EdgeType>> edges_;
};
template<typename NodeType, typename EdgeType>
class Graph<NodeType, EdgeType>::TopoIterator final {
public:
// OF_DISALLOW_COPY_AND_MOVE(TopoIterator);
TopoIterator() = default;
~TopoIterator() = default;
TopoIterator(const std::unordered_set<NodeType*>& source_nodes) {
for (NodeType* node : source_nodes) {
bfs_queue_.push(node);
}
for (NodeType* node : source_nodes) { bfs_queue_.push(node); }
}
NodeType& operator * () { return *(bfs_queue_.front()); }
NodeType* operator -> () { return &(*(*this)); }
TopoIterator& operator ++ ();
bool operator != (const TopoIterator&) const;
NodeType& operator*() { return *(bfs_queue_.front()); }
NodeType* operator->() { return &(*(*this)); }
TopoIterator& operator++();
bool operator!=(const TopoIterator&) const;
private:
std::queue<NodeType*> bfs_queue_;
......@@ -99,57 +96,48 @@ class Graph<NodeType, EdgeType>::ReverseTopoIterator final {
// OF_DISALLOW_COPY_AND_MOVE(ReverseTopoIterator);
ReverseTopoIterator() = default;
~ReverseTopoIterator() = default;
ReverseTopoIterator(const std::unordered_set<NodeType*>& sink_nodes) {
for (NodeType* node : sink_nodes) {
bfs_queue_.push(node);
}
for (NodeType* node : sink_nodes) { bfs_queue_.push(node); }
}
NodeType& operator * () { return *(bfs_queue_.front()); }
NodeType* operator -> () { return &(*(*this)); }
ReverseTopoIterator& operator ++ ();
bool operator != (const ReverseTopoIterator&) const;
NodeType& operator*() { return *(bfs_queue_.front()); }
NodeType* operator->() { return &(*(*this)); }
ReverseTopoIterator& operator++();
bool operator!=(const ReverseTopoIterator&) const;
private:
std::queue<NodeType*> bfs_queue_;
HashMap<NodeType*, int32_t > visited_cnt_;
HashMap<NodeType*, int32_t> visited_cnt_;
};
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachNode(
std::function<void(NodeType*)> func) {
for (auto& x : nodes_) {
func(x.get());
}
for (auto& x : nodes_) { func(x.get()); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::TopoForEachNode(
std::function<void(NodeType*)> func) {
for (TopoIterator it = begin(); it != end(); ++it) {
func(&(*it));
}
for (TopoIterator it = begin(); it != end(); ++it) { func(&(*it)); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ReverseTopoForEachNode(
std::function<void(NodeType*)> func) {
for (ReverseTopoIterator it = rbegin(); it != rend(); ++it) {
func(&(*it));
}
for (ReverseTopoIterator it = rbegin(); it != rend(); ++it) { func(&(*it)); }
}
#define OF_DEFINE_CONST_FOR_EACH_NODE(FuncName) \
template<typename NodeType, typename EdgeType> \
void Graph<NodeType, EdgeType>::Const##FuncName( \
std::function<void(const NodeType*)> func) const { \
auto cast_this = const_cast<Graph<NodeType, EdgeType>*> (this); \
cast_this->FuncName(std::bind(func, std::placeholders::_1)); \
}
#define OF_DEFINE_CONST_FOR_EACH_NODE(FuncName) \
template<typename NodeType, typename EdgeType> \
void Graph<NodeType, EdgeType>::Const##FuncName( \
std::function<void(const NodeType*)> func) const { \
auto cast_this = const_cast<Graph<NodeType, EdgeType>*>(this); \
cast_this->FuncName(std::bind(func, std::placeholders::_1)); \
}
OF_DEFINE_CONST_FOR_EACH_NODE(ForEachNode);
OF_DEFINE_CONST_FOR_EACH_NODE(TopoForEachNode);
......@@ -160,27 +148,25 @@ OF_DEFINE_CONST_FOR_EACH_NODE(ReverseTopoForEachNode);
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ForEachEdge(
std::function<void(EdgeType*)> func) {
for (auto& x : edges_) {
func(x.get());
}
for (auto& x : edges_) { func(x.get()); }
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ConstForEachEdge(
std::function<void(const EdgeType*)> func) const {
auto cast_this = const_cast<Graph<NodeType, EdgeType>*> (this);
auto cast_this = const_cast<Graph<NodeType, EdgeType>*>(this);
cast_this->ForEachEdge(std::bind(func, std::placeholders::_1));
}
template<typename NodeType, typename EdgeType>
const std::unordered_set<NodeType*>&
Graph<NodeType, EdgeType>::source_nodes() const {
const std::unordered_set<NodeType*>& Graph<NodeType, EdgeType>::source_nodes()
const {
return source_nodes_;
}
template<typename NodeType, typename EdgeType>
const std::unordered_set<NodeType*>&
Graph<NodeType, EdgeType>::sink_nodes() const {
const std::unordered_set<NodeType*>& Graph<NodeType, EdgeType>::sink_nodes()
const {
return sink_nodes_;
}
......@@ -241,12 +227,8 @@ void Graph<NodeType, EdgeType>::UpdateSourceAndSink() {
source_nodes_.clear();
sink_nodes_.clear();
for (const std::unique_ptr<NodeType>& node : nodes_) {
if (node->in_edges().empty()) {
source_nodes_.insert(node.get());
}
if (node->out_edges().empty()) {
sink_nodes_.insert(node.get());
}
if (node->in_edges().empty()) { source_nodes_.insert(node.get()); }
if (node->out_edges().empty()) { sink_nodes_.insert(node.get()); }
}
}
......@@ -259,14 +241,15 @@ void Graph<NodeType, EdgeType>::ToDotWithStream(StreamT& out_stream) const {
});
this->ConstForEachEdge([&](const EdgeType* edge) {
out_stream << "\"" << edge->src_node()->VisualStr() << "\" -> "
<< "\"" << edge->dst_node()->VisualStr() << "\""
<< "[label=\"" << edge->VisualStr() << "\"];\n";
<< "\"" << edge->dst_node()->VisualStr() << "\""
<< "[label=\"" << edge->VisualStr() << "\"];\n";
});
out_stream << "}\n";
}
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path) const {
void Graph<NodeType, EdgeType>::ToDotWithFilePath(
const std::string& file_path) const {
std::string dir_name = tensorflow::io::Dirname(file_path).ToString();
tensorflow::Env* env = tensorflow::Env::Default();
if (env->IsDirectory(dir_name).code() != tensorflow::error::OK) {
......@@ -278,17 +261,15 @@ void Graph<NodeType, EdgeType>::ToDotWithFilePath(const std::string& file_path)
template<typename NodeType, typename EdgeType>
void Graph<NodeType, EdgeType>::ToDotWithAutoFilePath() const {
std::string file_path = LogDir() + "/dot/" + TypeName()
+ "/" + NewUniqueId() + ".dot";
std::string file_path =
LogDir() + "/dot/" + TypeName() + "/" + NewUniqueId() + ".dot";
ToDotWithFilePath(file_path);
}
template<typename NodeType>
bool IsNotEqual4BfsQueue(const std::queue<NodeType*>& lhs,
const std::queue<NodeType*>& rhs) {
if (lhs.empty() != rhs.empty()) {
return true;
}
if (lhs.empty() != rhs.empty()) { return true; }
if (lhs.empty() == false && rhs.empty() == false) {
return lhs.front() != rhs.front();
}
......@@ -296,27 +277,28 @@ bool IsNotEqual4BfsQueue(const std::queue<NodeType*>& lhs,
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::TopoIterator::operator ++ () -> TopoIterator& {
auto Graph<NodeType, EdgeType>::TopoIterator::operator++() -> TopoIterator& {
NodeType* cur_node = bfs_queue_.front();
bfs_queue_.pop();
for (EdgeType* out_edge : cur_node->out_edges()) {
NodeType* dst_node = out_edge->dst_node();
visited_cnt_[dst_node] += 1;
if (visited_cnt_.at(dst_node) == dst_node->in_edges().size()) {
bfs_queue_.push(dst_node);
bfs_queue_.push(dst_node);
}
}
return *this;
}
template<typename NodeType, typename EdgeType>
bool Graph<NodeType, EdgeType>::TopoIterator::operator != (
bool Graph<NodeType, EdgeType>::TopoIterator::operator!=(
const TopoIterator& rhs) const {
return IsNotEqual4BfsQueue(bfs_queue_, rhs.bfs_queue_);
}
template<typename NodeType, typename EdgeType>
auto Graph<NodeType, EdgeType>::ReverseTopoIterator::operator ++ () -> ReverseTopoIterator& {
auto Graph<NodeType, EdgeType>::ReverseTopoIterator::operator++()
-> ReverseTopoIterator& {
NodeType* cur_node = bfs_queue_.front();
bfs_queue_.pop();
for (EdgeType* in_edge : cur_node->in_edges()) {
......@@ -330,11 +312,11 @@ auto Graph<NodeType, EdgeType>::ReverseTopoIterator::operator ++ () -> ReverseTo
}
template<typename NodeType, typename EdgeType>
bool Graph<NodeType, EdgeType>::ReverseTopoIterator::operator != (
bool Graph<NodeType, EdgeType>::ReverseTopoIterator::operator!=(
const ReverseTopoIterator& rhs) const {
return IsNotEqual4BfsQueue(bfs_queue_, rhs.bfs_queue_);
}
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_GRAPH_H_
......@@ -7,12 +7,11 @@ class TestEdge;
class TestNode final : public Node<TestNode, TestEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TestNode);
TestNode(int64_t node_id_) {
test_node_id_ = node_id_;
}
TestNode(int64_t node_id_) { test_node_id_ = node_id_; }
~TestNode() = default;
int64_t test_node_id() const { return test_node_id_; }
private:
int64_t test_node_id_;
};
......@@ -59,9 +58,11 @@ void DoOneTestGraph(const TestGraph& test_graph,
// 1. Determines whether the traversal result satisfies the topological order
HashMap<int64_t, int64_t> node_id2order, node_id2rorder;
auto NodePairHash = [](const NodeIdPair& val) { return val.first ^ val.second; };
std::unordered_set<NodeIdPair,
decltype(NodePairHash)> edges_node_pair(11, NodePairHash);
auto NodePairHash = [](const NodeIdPair& val) {
return val.first ^ val.second;
};
std::unordered_set<NodeIdPair, decltype(NodePairHash)> edges_node_pair(
11, NodePairHash);
int64_t order = 0;
test_graph.ConstTopoForEachNode([&](const TestNode* node) {
node_id2order.emplace(node->test_node_id(), order);
......@@ -76,7 +77,7 @@ void DoOneTestGraph(const TestGraph& test_graph,
});
ASSERT_EQ(node_id2rorder.size(), node_num);
// method :
// method :
// judge every directed edge <u,v>
// the node u's order is smaller than v
int64_t edge_num = 0;
......@@ -90,7 +91,7 @@ void DoOneTestGraph(const TestGraph& test_graph,
src_ord = node_id2rorder.at(src_node_id);
dst_ord = node_id2rorder.at(dst_node_id);
ASSERT_GE(src_ord, dst_ord);
//
//
++edge_num;
edges_node_pair.insert(std::make_pair(src_node_id, dst_node_id));
}
......@@ -109,8 +110,8 @@ void DoOneTestGraph(const TestGraph& test_graph,
test_graph.ConstForEachEdge([&](const TestEdge* cur_edge) {
int64_t src_node_id = cur_edge->src_node()->test_node_id();
int64_t dst_node_id = cur_edge->dst_node()->test_node_id();
ASSERT_TRUE(
edges_node_pair.count(std::make_pair(src_node_id, dst_node_id)) > 0);
ASSERT_TRUE(edges_node_pair.count(std::make_pair(src_node_id, dst_node_id))
> 0);
});
}
......@@ -129,4 +130,4 @@ TEST(TestGraph, test_graph_node_num_7) {
DoOneTestGraph(test_graph, graph_conf);
}
}// namespace oneflow
} // namespace oneflow
......@@ -4,15 +4,12 @@ namespace oneflow {
void InBoxingTaskNode::FwVirtualBuild() {
Chain2EdgesMap chain2sorted_in_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_in_edges,
&TaskNode::in_edges,
&TaskEdge::src_node,
&TaskNode::SoleInEdge);
FwInitChain2SortedEdgesMaps(&chain2sorted_in_edges, &TaskNode::in_edges,
&TaskEdge::src_node, &TaskNode::SoleInEdge);
ChainEdgesPair chain_sorted_out_edges;
chain_sorted_out_edges.first = chain_node();
chain_sorted_out_edges.second.assign(out_edges().begin(), out_edges().end());
FwSortEdgesInnerStage(&chain_sorted_out_edges.second,
&TaskEdge::dst_node,
FwSortEdgesInnerStage(&chain_sorted_out_edges.second, &TaskEdge::dst_node,
&TaskNode::SoleOutEdge);
for (const ChainEdgesPair& chain_sorted_in_edges : chain2sorted_in_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
......@@ -20,4 +17,4 @@ void InBoxingTaskNode::FwVirtualBuild() {
mut_exec_gph().UpdateSourceAndSink();
}
} // namespace oneflow
} // namespace oneflow
......@@ -13,15 +13,14 @@ class InBoxingTaskNode final : public BoxingTaskNode {
private:
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<InBoxingTaskNode> ();
return of_make_unique<InBoxingTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node);
}
void FwVirtualBuild() override;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_IN_BOXING_TASK_NODE_H_
......@@ -16,8 +16,7 @@ LogicalGraph::LogicalGraph(const DLNetConf& dl_net_conf,
}
void LogicalGraph::NaiveBuildGraphStruct(
const DLNetConf& dl_net_conf,
HashMap<LogicalEdge*, std::string>* edge2lbn,
const DLNetConf& dl_net_conf, HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2ibn) {
HashMap<std::string, LogicalNode*> lbn2producer;
// Process Op
......@@ -123,4 +122,4 @@ void LogicalGraph::AddOneCloneNode(
}
}
} // namespace oneflow
} // namespace oneflow
......@@ -2,10 +2,10 @@
#define ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
#include "oneflow/core/graph/graph.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/job/dlnet_conf.pb.h"
#include "oneflow/core/job/strategy.pb.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/job/strategy.pb.h"
#include "oneflow/core/operator/operator.h"
namespace oneflow {
......@@ -17,12 +17,8 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
LogicalNode() = default;
~LogicalNode() = default;
std::shared_ptr<Operator> op() const {
return op_;
}
std::shared_ptr<Operator>& mut_op() {
return op_;
}
std::shared_ptr<Operator> op() const { return op_; }
std::shared_ptr<Operator>& mut_op() { return op_; }
std::shared_ptr<const ParallelDesc> parallel_desc() const {
return parallel_desc_;
......@@ -38,7 +34,6 @@ class LogicalNode final : public Node<LogicalNode, LogicalEdge> {
private:
std::shared_ptr<Operator> op_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
};
class LogicalEdge final : public Edge<LogicalNode, LogicalEdge> {
......@@ -56,16 +51,14 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
LogicalGraph() = delete;
~LogicalGraph() = default;
LogicalGraph(const DLNetConf& dl_net_conf,
const Strategy& strategy_conf);
LogicalGraph(const DLNetConf& dl_net_conf, const Strategy& strategy_conf);
const char* TypeName() const override { return "LogicalGraph"; }
private:
void NaiveBuildGraphStruct(
const DLNetConf& dl_net_conf,
HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2ibn);
void NaiveBuildGraphStruct(const DLNetConf& dl_net_conf,
HashMap<LogicalEdge*, std::string>* edge2lbn,
HashMap<LogicalEdge*, std::string>* edge2ibn);
void FillNodeWithParallelDesc(const Strategy& strategy_conf);
struct CloneInfo {
......@@ -73,18 +66,14 @@ class LogicalGraph final : public Graph<LogicalNode, LogicalEdge> {
LogicalNode* pred_node;
std::vector<LogicalEdge*> edges;
};
void AddCloneNodes(
const HashMap<LogicalEdge*, std::string>& edge2lbn,
const HashMap<LogicalEdge*, std::string>& edge2ibn);
void CollectCloneInfos(
std::vector<CloneInfo>* clone_infos,
const HashMap<LogicalEdge*, std::string>& edge2lbn);
void AddOneCloneNode(
const CloneInfo& clone_info,
const HashMap<LogicalEdge*, std::string>& edge2ibn);
void AddCloneNodes(const HashMap<LogicalEdge*, std::string>& edge2lbn,
const HashMap<LogicalEdge*, std::string>& edge2ibn);
void CollectCloneInfos(std::vector<CloneInfo>* clone_infos,
const HashMap<LogicalEdge*, std::string>& edge2lbn);
void AddOneCloneNode(const CloneInfo& clone_info,
const HashMap<LogicalEdge*, std::string>& edge2ibn);
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_LOGICAL_GRAPH_H_
......@@ -5,8 +5,8 @@ namespace oneflow {
void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_diff_acc_gph = static_cast<MdDiffAccTaskGraph*> (gph);
CompTaskNode* fw_task_ =
auto md_diff_acc_gph = static_cast<MdDiffAccTaskGraph*>(gph);
CompTaskNode* fw_task_ =
md_diff_acc_gph->GetFwTaskFromParallelId(parallel_id());
TaskNode* bp_task = fw_task_->GetBpNode();
std::shared_ptr<RegstDesc> model_diff_regst =
......@@ -30,7 +30,7 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
exec_node->BindBnInOpAndRegst(ibn, GetRelatedRegst(SoleInEdge()));
SubscribeRegstDesc(ibn, GetRelatedRegst(SoleInEdge()));
}
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(),
exec_node->BindBnInOpAndRegst(exec_node->op()->SoleObn(),
model_diff_acc_regst);
mut_exec_gph().UpdateSourceAndSink();
}
......@@ -38,10 +38,11 @@ void MdDiffAccCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
void MdDiffAccCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
if (!chain_node()->op_vec().empty()) {
std::shared_ptr<RegstDesc> in_regst =GetSubscribedRegstDesc("model_diff");
std::shared_ptr<RegstDesc> out_regst = GetProducedRegstDesc("model_diff_acc");
std::shared_ptr<RegstDesc> in_regst = GetSubscribedRegstDesc("model_diff");
std::shared_ptr<RegstDesc> out_regst =
GetProducedRegstDesc("model_diff_acc");
out_regst->CopyShapeFrom(in_regst.get());
}
}
} // namespace oneflow
} // namespace oneflow
......@@ -13,23 +13,23 @@ class MdDiffAccCompTaskNode final : public CompTaskNode {
void ToProto(TaskProto* proto) const override {
TaskNode::ToProto(proto);
proto->set_parallel_policy(fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_policy(
fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_id(fw_task_->parallel_id());
proto->set_parallel_num(fw_task_->chain_node()->parallel_desc()->parallel_num());
proto->set_parallel_num(
fw_task_->chain_node()->parallel_desc()->parallel_num());
}
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override {
return kMdDiffAccCompTask;
}
TaskType task_type() const override { return kMdDiffAccCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdDiffAccCompTaskNode> ();
return of_make_unique<MdDiffAccCompTaskNode>();
}
CompTaskNode* fw_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_COMP_TASK_NODE_H_
......@@ -4,8 +4,7 @@
namespace oneflow {
MdDiffAccTaskGraph::MdDiffAccTaskGraph(
const std::string& name,
const ChainNode* data_chain,
const std::string& name, const ChainNode* data_chain,
const std::vector<CompTaskNode*>& sorted_fw_comptasks4data_chain) {
mut_name() = name;
BuildTaskGraph(data_chain);
......@@ -22,7 +21,7 @@ void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
op_conf.mutable_model_diff_acc_conf();
auto model_diff_acc_op = OpMgr::Singleton().ConstructOp(op_conf);
// ModelDiffAccChain
auto chain_gph = of_make_unique<ChainGraph> ();
auto chain_gph = of_make_unique<ChainGraph>();
ChainNode* diff_acc_chain = chain_gph->NewNode();
diff_acc_chain->mut_op_vec() = {model_diff_acc_op};
auto parallel_desc4diff_acc =
......@@ -46,4 +45,4 @@ void MdDiffAccTaskGraph::BuildTaskGraph(const ChainNode* data_chain) {
BuildFromChainGph<MdDiffAccCompTaskNode>(std::move(chain_gph), false);
}
} // namespace oneflow
} // namespace oneflow
......@@ -12,14 +12,13 @@ class MdDiffAccTaskGraph final : public TaskGraph {
~MdDiffAccTaskGraph() = default;
MdDiffAccTaskGraph(
const std::string& name,
const ChainNode* data_chain,
const std::string& name, const ChainNode* data_chain,
const std::vector<CompTaskNode*>& sorted_fw_comptasks4data_chain);
CompTaskNode* GetFwTaskFromParallelId(int64_t parallel_id) const {
return parallel_id2fw_task_.at(parallel_id);
}
const char* TypeName() const override { return "MdDiffAccTaskGraph"; }
private:
......@@ -28,6 +27,6 @@ class MdDiffAccTaskGraph final : public TaskGraph {
HashMap<int64_t, CompTaskNode*> parallel_id2fw_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_DIFF_ACCUMULATE_TASK_GRAPH_H_
#include "oneflow/core/graph/model_save_comp_task_node.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
#include "oneflow/core/graph/model_save_task_graph.h"
#include "oneflow/core/graph/model_update_comp_task_node.h"
namespace oneflow {
void MdSaveCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_save_gph = static_cast<MdSaveTaskGraph*> (gph);
auto md_save_gph = static_cast<MdSaveTaskGraph*>(gph);
CompTaskNode* updt_task = md_save_gph->update_task();
if (in_edges().empty()) {
BindProducedRegstAndOutEdge(updt_task->GetProducedRegstDesc("model"),
......@@ -36,4 +36,4 @@ void MdSaveCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
}
} // namespace oneflow
} // namespace oneflow
......@@ -10,12 +10,14 @@ class MdSaveCompTaskNode final : public CompTaskNode {
OF_DISALLOW_COPY_AND_MOVE(MdSaveCompTaskNode);
MdSaveCompTaskNode() = default;
~MdSaveCompTaskNode() = default;
void ToProto(TaskProto* proto) const override {
TaskNode::ToProto(proto);
proto->set_parallel_policy(fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_policy(
fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_id(fw_task_->parallel_id());
proto->set_parallel_num(fw_task_->chain_node()->parallel_desc()->parallel_num());
proto->set_parallel_num(
fw_task_->chain_node()->parallel_desc()->parallel_num());
}
void set_fw_task(CompTaskNode* fw_task) { fw_task_ = fw_task; }
......@@ -28,15 +30,13 @@ class MdSaveCompTaskNode final : public CompTaskNode {
return !GetSubscribedRegstDesc("model");
}
TaskType task_type() const override {
return kMdSaveCompTask;
}
TaskType task_type() const override { return kMdSaveCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdSaveCompTaskNode> ();
return of_make_unique<MdSaveCompTaskNode>();
}
CompTaskNode* fw_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_COMP_TASK_NODE_H_
......@@ -13,12 +13,13 @@ MdSaveTaskGraph::MdSaveTaskGraph(const std::string& name,
}
void MdSaveTaskGraph::BuildTaskGraph() {
auto chain_gph = of_make_unique<ChainGraph> ();
auto chain_gph = of_make_unique<ChainGraph>();
// faker
ChainNode* faker_chain = chain_gph->NewNode();
ParallelConf faker_pr_conf;
faker_pr_conf.set_policy(kDataParallel);
faker_pr_conf.mutable_device_set()->add_device_name(update_task_->device_name());
faker_pr_conf.mutable_device_set()->add_device_name(
update_task_->device_name());
faker_chain->mut_parallel_desc().reset(new ParallelDesc(faker_pr_conf));
faker_chain->mut_output_lbns() = {kBaledBlobName};
// save
......@@ -27,7 +28,8 @@ void MdSaveTaskGraph::BuildTaskGraph() {
GetMachineNameFromDeviceName(update_task_->device_name());
ParallelConf save_pr_conf;
save_pr_conf.set_policy(kDataParallel);
save_pr_conf.mutable_device_set()->add_device_name(machine_name + ":persistence");
save_pr_conf.mutable_device_set()->add_device_name(machine_name
+ ":persistence");
save_chain->mut_parallel_desc().reset(new ParallelDesc(save_pr_conf));
save_chain->mut_input_lbns() = {kBaledBlobName};
//
......@@ -40,9 +42,10 @@ void MdSaveTaskGraph::BuildTaskGraph() {
if (model_save_comp_task_node != nullptr) {
auto model_update_comp_task_node =
static_cast<MdUpdtCompTaskNode*>(update_task_);
model_save_comp_task_node->set_fw_task(model_update_comp_task_node->fw_task());
model_save_comp_task_node->set_fw_task(
model_update_comp_task_node->fw_task());
}
});
}
} // namespace oneflow
} // namespace oneflow
......@@ -11,8 +11,7 @@ class MdSaveTaskGraph final : public TaskGraph {
MdSaveTaskGraph() = delete;
~MdSaveTaskGraph() = default;
MdSaveTaskGraph(const std::string& name,
CompTaskNode* update_task);
MdSaveTaskGraph(const std::string& name, CompTaskNode* update_task);
CompTaskNode* update_task() const { return update_task_; }
const char* TypeName() const override { return "MdSaveTaskGraph"; }
......@@ -23,6 +22,6 @@ class MdSaveTaskGraph final : public TaskGraph {
CompTaskNode* update_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_SAVE_TASK_GRAPH_H_
......@@ -6,7 +6,7 @@ namespace oneflow {
void MdUpdtCompTaskNode::BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) {
CHECK(IsFwNode());
auto md_updt_gph = static_cast<MdUpdtTaskGraph*> (gph);
auto md_updt_gph = static_cast<MdUpdtTaskGraph*>(gph);
CompTaskNode* fw_task = md_updt_gph->fw_task();
CompTaskNode* diff_acc_task = md_updt_gph->diff_acc_task();
std::shared_ptr<RegstDesc> model_diff_acc_regst;
......@@ -33,4 +33,4 @@ void MdUpdtCompTaskNode::InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) {
CHECK(IsFwNode());
}
} // namespace oneflow
} // namespace oneflow
......@@ -13,9 +13,11 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
void ToProto(TaskProto* proto) const override {
TaskNode::ToProto(proto);
proto->set_parallel_policy(fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_policy(
fw_task_->chain_node()->parallel_desc()->policy());
proto->set_parallel_id(fw_task_->parallel_id());
proto->set_parallel_num(fw_task_->chain_node()->parallel_desc()->parallel_num());
proto->set_parallel_num(
fw_task_->chain_node()->parallel_desc()->parallel_num());
}
void set_fw_task(CompTaskNode* fw_task) { fw_task_ = fw_task; }
......@@ -24,15 +26,13 @@ class MdUpdtCompTaskNode final : public CompTaskNode {
private:
void BuildExecAndEnrollLbn2Regsts(TaskGraph* gph) override;
void InferShapeOfBlobsInProducedRegsts(TaskGraph* gph) override;
TaskType task_type() const override {
return kMdUpdtCompTask;
}
TaskType task_type() const override { return kMdUpdtCompTask; }
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<MdUpdtCompTaskNode> ();
return of_make_unique<MdUpdtCompTaskNode>();
}
CompTaskNode* fw_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_COMP_TASK_NODE_H_
......@@ -3,8 +3,7 @@
namespace oneflow {
MdUpdtTaskGraph::MdUpdtTaskGraph(const std::string& name,
CompTaskNode* fw_task,
MdUpdtTaskGraph::MdUpdtTaskGraph(const std::string& name, CompTaskNode* fw_task,
CompTaskNode* diff_acc_task) {
mut_name() = name;
fw_task_ = fw_task;
......@@ -14,9 +13,9 @@ MdUpdtTaskGraph::MdUpdtTaskGraph(const std::string& name,
}
void MdUpdtTaskGraph::BuildTaskGraph() {
auto chain_gph = of_make_unique<ChainGraph> ();
auto chain_gph = of_make_unique<ChainGraph>();
OperatorConf op_conf;
op_conf.set_name("model_update_"+ NewUniqueId());
op_conf.set_name("model_update_" + NewUniqueId());
op_conf.mutable_model_update_conf();
auto model_updt_op = OpMgr::Singleton().ConstructOp(op_conf);
......@@ -39,4 +38,4 @@ void MdUpdtTaskGraph::BuildTaskGraph() {
});
}
} // namespace oneflow
} // namespace oneflow
......@@ -11,8 +11,7 @@ class MdUpdtTaskGraph final : public TaskGraph {
MdUpdtTaskGraph() = delete;
~MdUpdtTaskGraph() = default;
MdUpdtTaskGraph(const std::string& name,
CompTaskNode* fw_task,
MdUpdtTaskGraph(const std::string& name, CompTaskNode* fw_task,
CompTaskNode* diff_acc_task);
CompTaskNode* fw_task() const { return fw_task_; }
......@@ -26,6 +25,6 @@ class MdUpdtTaskGraph final : public TaskGraph {
CompTaskNode* diff_acc_task_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_MODEL_UPDATE_TASK_GRAPH_H_
......@@ -12,4 +12,4 @@ int64_t NewEdgeId() {
return edge_id++;
}
} // namespace oneflow
} // namespace oneflow
......@@ -8,9 +8,7 @@
namespace oneflow {
template<typename NodeType, typename EdgeType>
void Connect(NodeType* src_node,
EdgeType* edge,
NodeType* dst_node) {
void Connect(NodeType* src_node, EdgeType* edge, NodeType* dst_node) {
CHECK(src_node->out_edges_.insert(edge).second);
CHECK(dst_node->in_edges_.insert(edge).second);
CHECK(edge->src_node_ == nullptr);
......@@ -50,25 +48,21 @@ class Edge {
virtual std::string VisualStr() const { return ""; }
private:
friend void Connect<NodeType, EdgeType>(NodeType* src_node,
EdgeType* edge,
friend void Connect<NodeType, EdgeType>(NodeType* src_node, EdgeType* edge,
NodeType* dst_node);
friend void DisConnect<EdgeType>(EdgeType* edge);
int64_t edge_id_;
NodeType* src_node_;
NodeType* dst_node_;
};
template<typename NodeType, typename EdgeType>
class Node {
public:
OF_DISALLOW_COPY_AND_MOVE(Node);
Node() {
node_id_ = NewNodeId();
}
Node() { node_id_ = NewNodeId(); }
virtual ~Node() = default;
int64_t node_id() const { return node_id_; }
......@@ -82,36 +76,26 @@ class Node {
return *(out_edges_.begin());
}
const std::unordered_set<EdgeType*>& in_edges() const {
return in_edges_;
}
const std::unordered_set<EdgeType*>& out_edges() const {
return out_edges_;
}
const std::unordered_set<EdgeType*>& in_edges() const { return in_edges_; }
const std::unordered_set<EdgeType*>& out_edges() const { return out_edges_; }
void DisconnectAllEdges() {
for (EdgeType* edge : in_edges_) {
DisConnect(edge);
}
for (EdgeType* edge : out_edges_) {
DisConnect(edge);
}
for (EdgeType* edge : in_edges_) { DisConnect(edge); }
for (EdgeType* edge : out_edges_) { DisConnect(edge); }
}
virtual std::string VisualStr() const { return ""; }
private:
friend void Connect<NodeType, EdgeType>(NodeType* src_node,
EdgeType* edge,
friend void Connect<NodeType, EdgeType>(NodeType* src_node, EdgeType* edge,
NodeType* dst_node);
friend void DisConnect<EdgeType>(EdgeType* edge);
int64_t node_id_;
std::unordered_set<EdgeType*> in_edges_;
std::unordered_set<EdgeType*> out_edges_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_NODE_H_
......@@ -4,15 +4,12 @@ namespace oneflow {
void OutBoxingTaskNode::FwVirtualBuild() {
Chain2EdgesMap chain2sorted_out_edges;
FwInitChain2SortedEdgesMaps(&chain2sorted_out_edges,
&TaskNode::out_edges,
&TaskEdge::dst_node,
&TaskNode::SoleOutEdge);
FwInitChain2SortedEdgesMaps(&chain2sorted_out_edges, &TaskNode::out_edges,
&TaskEdge::dst_node, &TaskNode::SoleOutEdge);
ChainEdgesPair chain_sorted_in_edges;
chain_sorted_in_edges.first = chain_node();
chain_sorted_in_edges.second.assign(in_edges().begin(), in_edges().end());
FwSortEdgesInnerStage(&chain_sorted_in_edges.second,
&TaskEdge::src_node,
FwSortEdgesInnerStage(&chain_sorted_in_edges.second, &TaskEdge::src_node,
&TaskNode::SoleInEdge);
for (const ChainEdgesPair& chain_sorted_out_edges : chain2sorted_out_edges) {
FwBuildChainSortedEdgesPair(chain_sorted_in_edges, chain_sorted_out_edges);
......@@ -20,4 +17,4 @@ void OutBoxingTaskNode::FwVirtualBuild() {
mut_exec_gph().UpdateSourceAndSink();
}
} // namespace oneflow
} // namespace oneflow
......@@ -13,15 +13,14 @@ class OutBoxingTaskNode final : public BoxingTaskNode {
private:
std::unique_ptr<TaskNode> CreateSameTypeNode() const override {
return of_make_unique<OutBoxingTaskNode> ();
return of_make_unique<OutBoxingTaskNode>();
}
void InitWithFwNode(TaskNode* fw_node) override {
BoxingTaskNode::InitWithFwNode(fw_node);
}
void FwVirtualBuild() override;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
#endif // ONEFLOW_CORE_GRAPH_OUT_BOXING_TASK_NODE_H_
......@@ -19,7 +19,7 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
size_t device_num =
parallel_desc->sorted_device_phy_ids(machine_id).size();
if (device_num == 0) {
device_num = 1; // persistence
device_num = 1; // persistence
}
range_idx += device_num;
stage_node->mut_parallel_range().mut_end() = range_idx;
......@@ -43,4 +43,4 @@ StageGraph::StageGraph(std::unique_ptr<const ChainGraph>&& chain_gph) {
ToDotWithAutoFilePath();
}
} // namespace oneflow
} // namespace oneflow
......@@ -14,29 +14,17 @@ class StageNode final : public Node<StageNode, StageEdge> {
StageNode() = default;
~StageNode() = default;
std::string machine_id_str() const {
return std::to_string(machine_id_);
}
const int64_t& machine_id() const {
return machine_id_;
}
int64_t& mut_machine_id() {
return machine_id_;
}
std::string machine_id_str() const { return std::to_string(machine_id_); }
const int64_t& machine_id() const { return machine_id_; }
int64_t& mut_machine_id() { return machine_id_; }
const ChainNode* chain_node() const {
return chain_node_;
}
const ChainNode* chain_node() const { return chain_node_; }
void set_chain_node(const ChainNode* new_chain_node) {
chain_node_ = new_chain_node;
}
const Range& parallel_range() const {
return parallel_range_;
}
Range& mut_parallel_range() {
return parallel_range_;
}
const Range& parallel_range() const { return parallel_range_; }
Range& mut_parallel_range() { return parallel_range_; }
const std::vector<int64_t>& SortedDevicePhyIds() const {
return chain_node_->parallel_desc()->sorted_device_phy_ids(machine_id_);
......@@ -50,7 +38,6 @@ class StageNode final : public Node<StageNode, StageEdge> {
const ChainNode* chain_node_;
int64_t machine_id_;
Range parallel_range_;
};
class StageEdge final : public Edge<StageNode, StageEdge> {
......@@ -58,7 +45,7 @@ class StageEdge final : public Edge<StageNode, StageEdge> {
OF_DISALLOW_COPY_AND_MOVE(StageEdge);
StageEdge() = default;
~StageEdge() = default;
private:
};
......@@ -75,9 +62,8 @@ class StageGraph final : public Graph<StageNode, StageEdge> {
private:
std::unique_ptr<const ChainGraph> chain_gph_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_STAGE_GRAPH_H_
此差异已折叠。
#ifndef ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#define ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#include "oneflow/core/graph/stage_graph.h"
#include "oneflow/core/graph/boxing_task_node.h"
#include "oneflow/core/graph/copy_task_node.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/graph/stage_graph.h"
#include "oneflow/core/job/id_manager.h"
#include "oneflow/core/job/job_desc.h"
#include "oneflow/core/job/parallel_desc.h"
#include "oneflow/core/operator/operator.h"
#include "oneflow/core/operator/operator_manager.h"
namespace oneflow {
......@@ -16,22 +16,21 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
public:
OF_DISALLOW_COPY_AND_MOVE(TaskGraph);
virtual ~TaskGraph() = default;
// Getters
const StageGraph* stage_gph() const { return stage_gph_.get(); }
const ChainGraph* chain_gph() const { return stage_gph_->chain_gph(); }
std::vector<CompTaskNode*> CompTasksInChain(const ChainNode*);
void InferShapeOfBlobsInProducedRegsts();
const std::string& name() const { return name_; }
protected:
TaskGraph() = default;
template<typename CompTaskNodeType>
void BuildFromChainGph(std::unique_ptr<ChainGraph>&& chain_gph,
bool need_bp);
void BuildFromChainGph(std::unique_ptr<ChainGraph>&& chain_gph, bool need_bp);
void BuildExecAndEnrollLbn2Regsts();
std::string& mut_name() { return name_; }
......@@ -55,9 +54,8 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
BoxingTaskNode* in_boxing_task_node;
BoxingTaskNode* out_boxing_task_node;
};
using Stage2TaskNodesMap =
HashMap<const StageNode*, TaskNodesInStage>;
using Stage2TaskNodesMap = HashMap<const StageNode*, TaskNodesInStage>;
template<typename TaskNodeType>
void InitCompTaskNodes(Stage2TaskNodesMap* stage2task_nodes);
......@@ -73,15 +71,14 @@ class TaskGraph : public Graph<TaskNode, TaskEdge> {
void InitOutBoxingTaskNode(const StageNode* stage,
TaskNodesInStage* task_nodes_in_stage);
void ConnectBoxingTaskNodes(const Stage2TaskNodesMap* stage2task_nodes);
void GenerateRelatedBpNodes(std::vector<TaskNode*> *turning_node_vec);
void GenerateRelatedBpNodes(std::vector<TaskNode*>* turning_node_vec);
void BackwardConnect(const std::vector<TaskNode*>& turning_node_vec);
void BuildBpStruct();
std::unique_ptr<const StageGraph> stage_gph_;
std::string name_;
std::string name_;
};
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
#endif // ONEFLOW_CORE_GRAPH_TASK_GRAPH_H_
......@@ -2,7 +2,10 @@
namespace oneflow {
TaskNode::TaskNode() : produced_regst2out_edge_(11, [](const std::weak_ptr<RegstDesc>& v) { return std::hash<void*>() (v.lock().get()); }) {
TaskNode::TaskNode()
: produced_regst2out_edge_(11, [](const std::weak_ptr<RegstDesc>& v) {
return std::hash<void*>()(v.lock().get());
}) {
stage_node_ = nullptr;
related_fw_or_bp_node_ = nullptr;
}
......@@ -75,10 +78,11 @@ void TaskNode::TakeOverRegstDesc(TaskNode* rhs,
}
void TaskNode::EraseProducedEmptyRegsts() {
EraseIf<std::string, std::shared_ptr<RegstDesc>> (&produced_regst_descs_, []
(HashMap<std::string, std::shared_ptr<RegstDesc>>::iterator it) {
return it->second->NumOfLbn() == 0;
});
EraseIf<std::string, std::shared_ptr<RegstDesc>>(
&produced_regst_descs_,
[](HashMap<std::string, std::shared_ptr<RegstDesc>>::iterator it) {
return it->second->NumOfLbn() == 0;
});
}
void TaskNode::EraseZeroSizeBlobInProducedRegsts() {
......@@ -113,7 +117,7 @@ void TaskNode::BindProducedRegstAndOutEdge(std::weak_ptr<RegstDesc> regst,
std::shared_ptr<RegstDesc> TaskNode::NewProducedRegstDesc(
const std::string& regst_desc_name) {
auto regst_desc = std::make_shared<RegstDesc> ();
auto regst_desc = std::make_shared<RegstDesc>();
regst_desc->SetProducer(this);
regst_desc->set_regst_desc_id(IDMgr::Singleton().NewRegstDescId());
CHECK(produced_regst_descs_.emplace(regst_desc_name, regst_desc).second);
......@@ -136,14 +140,16 @@ void TaskNode::ToProto(TaskProto* ret) const {
for (const auto& pair : produced_regst_descs_) {
RegstDescProto regst_desc_proto;
pair.second->ToProto(&regst_desc_proto);
CHECK(ret->mutable_produced_regst_desc()->insert(
{pair.first, regst_desc_proto}).second);
CHECK(ret->mutable_produced_regst_desc()
->insert({pair.first, regst_desc_proto})
.second);
}
for (const auto& pair : subscribed_regst_descs_) {
auto regst_desc = pair.second.lock();
if (regst_desc) {
CHECK(ret->mutable_subscribed_regst_desc_id()->insert(
{pair.first, regst_desc->regst_desc_id()}).second);
CHECK(ret->mutable_subscribed_regst_desc_id()
->insert({pair.first, regst_desc->regst_desc_id()})
.second);
}
}
}
......@@ -159,10 +165,10 @@ std::string TaskNode::DebugStr() const {
std::stringstream ss;
ss << "{" << node_id_str() << "\t";
for (const auto& pair : produced_regst_descs_) {
ss << "{" << pair.first << ":" << pair.second->DebugStr() << "}";
ss << "{" << pair.first << ":" << pair.second->DebugStr() << "}";
}
ss << "}";
return ss.str();
}
} // namespace oneflow
} // namespace oneflow
此差异已折叠。
此差异已折叠。
......@@ -19,8 +19,10 @@ class IDMgr final {
machine_num_ = resource.machine_size();
CHECK_LT(machine_num_, static_cast<int64_t>(1) << machine_id_bit_num_);
device_num_per_machine_ = resource.device_num_per_machine();
// reserve 3 number of device_id for persistence_, boxing_ and commnet_ ThrdLocId
CHECK_LT(device_num_per_machine_, (static_cast<int64_t>(1) << device_id_bit_num_) - 3);
// reserve 3 number of device_id for persistence_, boxing_ and commnet_
// ThrdLocId
CHECK_LT(device_num_per_machine_,
(static_cast<int64_t>(1) << device_id_bit_num_) - 3);
for (int64_t i = 0; i < machine_num_; ++i) {
const std::string& machine_name = resource.machine(i).name();
CHECK(machine_name2machine_id_.emplace(machine_name, i).second);
......@@ -51,20 +53,15 @@ class IDMgr final {
int64_t machine_id64bit = machine_id << (63 - machine_id_bit_num_);
int64_t device_id64bit = thrd_local_id << task_id_bit_num_;
int64_t thrd_id = machine_id64bit | device_id64bit;
CHECK_LT(thread_id2num_of_tasks_[thrd_id], (static_cast<int64_t>(1) << task_id_bit_num_) - 1);
CHECK_LT(thread_id2num_of_tasks_[thrd_id],
(static_cast<int64_t>(1) << task_id_bit_num_) - 1);
return thrd_id | (thread_id2num_of_tasks_[thrd_id]++);
}
int64_t NewRegstDescId() {
return regst_desc_id_count_++;
}
int64_t NewRegstDescId() { return regst_desc_id_count_++; }
// Runtime
int64_t ActorId4TaskId(int64_t task_id) {
return task_id;
}
int64_t TaskId4ActorId(int64_t actor_id) {
return actor_id;
}
int64_t ActorId4TaskId(int64_t task_id) { return task_id; }
int64_t TaskId4ActorId(int64_t actor_id) { return actor_id; }
int64_t MachineId4ActorId(int64_t actor_id) {
return actor_id >> (63 - machine_id_bit_num_);
}
......@@ -99,4 +96,4 @@ class IDMgr final {
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_ID_MANAGER_H_
#endif // ONEFLOW_CORE_JOB_ID_MANAGER_H_
......@@ -50,4 +50,4 @@ void JobDesc::ToProto(JobDescProto* proto) const {
proto->set_total_batch_num(total_batch_num_);
}
} // namespace oneflow
} // namespace oneflow
此差异已折叠。
......@@ -4,4 +4,4 @@ namespace oneflow {
const char* kBaledBlobName = "_oneflow_BaledBlobName";
} // namespace oneflow
} // namespace oneflow
......@@ -5,6 +5,6 @@ namespace oneflow {
extern const char* kBaledBlobName;
} // namespace oneflow
} // namespace oneflow
#endif // ONEFLOW_CORE_JOB_KEYWORD_H_
#endif // ONEFLOW_CORE_JOB_KEYWORD_H_
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册