提交 4ecca347 编写于 作者: W willzhang4a58

implement register_desc

上级 42402a1b
...@@ -72,7 +72,7 @@ void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask( ...@@ -72,7 +72,7 @@ void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask(
RegstDesc* model_diff_regst = bp_task->GetProducedRegstDesc("model_diff"); RegstDesc* model_diff_regst = bp_task->GetProducedRegstDesc("model_diff");
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model"); RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
// complete update task // complete update task
model_regst->CopyLbnAndShape(model_diff_regst); model_regst->CopyLbn2ShapeMap(model_diff_regst);
ExecNode* update_exec = update_task->exec_gph().SoleNode(); ExecNode* update_exec = update_task->exec_gph().SoleNode();
const std::string& ibn = update_exec->op()->SoleIbn(); const std::string& ibn = update_exec->op()->SoleIbn();
if (update_task->in_edges().empty()) { if (update_task->in_edges().empty()) {
......
...@@ -8,6 +8,30 @@ RegstDesc::RegstDesc() { ...@@ -8,6 +8,30 @@ RegstDesc::RegstDesc() {
producer_ = nullptr; producer_ = nullptr;
} }
const char* ContigRegstDesc::kAllLbn = "OfReservedAllLbn"; void RegstDesc::CopyLbn2ShapeMap(const RegstDesc* rhs) {
for (const auto& pair : rhs->lbn2shape_) {
const std::string& lbn = pair.first;
std::unique_ptr<Shape> shape(new Shape);
*shape = *(pair.second);
CHECK(lbn2shape_.insert(std::make_pair(lbn, std::move(shape))).second);
}
}
Shape* RegstDesc::EnrollLbn(const std::string& lbn) {
Shape* raw_ptr = new Shape;
std::unique_ptr<Shape> uptr(raw_ptr);
CHECK(lbn2shape_.insert(std::make_pair(lbn, std::move(uptr))).second);
return raw_ptr;
}
const Shape& RegstDesc::GetShape(const std::string& lbn) {
return *(lbn2shape_.at(lbn));
}
Shape* RegstDesc::GetMutShapePtr(const std::string& lbn) {
return lbn2shape_.at(lbn).get();
}
const char* RegstDesc::kAllLbn = "OfReservedAllLbn";
} // namespace oneflow } // namespace oneflow
...@@ -6,9 +6,10 @@ ...@@ -6,9 +6,10 @@
namespace oneflow { namespace oneflow {
class TaskNode; // Regst : Register
// Contig : Contiguous
// Regst : Register class TaskNode;
class RegstDesc { class RegstDesc {
public: public:
...@@ -16,32 +17,28 @@ class RegstDesc { ...@@ -16,32 +17,28 @@ class RegstDesc {
RegstDesc(); RegstDesc();
virtual ~RegstDesc() = default; virtual ~RegstDesc() = default;
// // Producer
const TaskNode* GetProducer() const { return producer_; } const TaskNode* GetProducer() const { return producer_; }
void SetProducer(const TaskNode* task_node) { producer_ = task_node; } void SetProducer(const TaskNode* task_node) { producer_ = task_node; }
void AddSubscriber(const TaskNode* task_node) {
CHECK(subscribers_.insert(task_node).second);
}
void CopyLbnAndShape(const RegstDesc*) { TODO(); }
Shape* EnrollLbn(const std::string& lbn) { TODO(); } // Lbn and Shape
const Shape& GetShape(const std::string& lbn) { TODO(); } void CopyLbn2ShapeMap(const RegstDesc*);
Shape* GetMutShapePtr(const std::string& lbn) { TODO(); } Shape* EnrollLbn(const std::string& lbn);
const Shape& GetShape(const std::string& lbn);
Shape* GetMutShapePtr(const std::string& lbn);
static const char* kAllLbn;
private: private:
int32_t regst_desc_id_; int32_t regst_desc_id_;
const TaskNode* producer_; const TaskNode* producer_;
std::unordered_set<const TaskNode*> subscribers_;
HashMap<std::string, std::unique_ptr<Shape>> lbn2shape_; HashMap<std::string, std::unique_ptr<Shape>> lbn2shape_;
}; };
// Contiguous
class ContigRegstDesc final : public RegstDesc { class ContigRegstDesc final : public RegstDesc {
public: public:
static const char* kAllLbn;
OF_DISALLOW_COPY_AND_MOVE(ContigRegstDesc); OF_DISALLOW_COPY_AND_MOVE(ContigRegstDesc);
ContigRegstDesc() = default; ContigRegstDesc() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册