提交 4ecca347 编写于 作者: W willzhang4a58

implement register_desc

上级 42402a1b
......@@ -72,7 +72,7 @@ void MdUpdtTaskGraph::CompleteUpdateTaskAndFwTask(
RegstDesc* model_diff_regst = bp_task->GetProducedRegstDesc("model_diff");
RegstDesc* model_regst = update_task->GetProducedRegstDesc("model");
// complete update task
model_regst->CopyLbnAndShape(model_diff_regst);
model_regst->CopyLbn2ShapeMap(model_diff_regst);
ExecNode* update_exec = update_task->exec_gph().SoleNode();
const std::string& ibn = update_exec->op()->SoleIbn();
if (update_task->in_edges().empty()) {
......
......@@ -8,6 +8,30 @@ RegstDesc::RegstDesc() {
producer_ = nullptr;
}
const char* ContigRegstDesc::kAllLbn = "OfReservedAllLbn";
void RegstDesc::CopyLbn2ShapeMap(const RegstDesc* rhs) {
for (const auto& pair : rhs->lbn2shape_) {
const std::string& lbn = pair.first;
std::unique_ptr<Shape> shape(new Shape);
*shape = *(pair.second);
CHECK(lbn2shape_.insert(std::make_pair(lbn, std::move(shape))).second);
}
}
Shape* RegstDesc::EnrollLbn(const std::string& lbn) {
Shape* raw_ptr = new Shape;
std::unique_ptr<Shape> uptr(raw_ptr);
CHECK(lbn2shape_.insert(std::make_pair(lbn, std::move(uptr))).second);
return raw_ptr;
}
const Shape& RegstDesc::GetShape(const std::string& lbn) {
return *(lbn2shape_.at(lbn));
}
Shape* RegstDesc::GetMutShapePtr(const std::string& lbn) {
return lbn2shape_.at(lbn).get();
}
const char* RegstDesc::kAllLbn = "OfReservedAllLbn";
} // namespace oneflow
......@@ -6,9 +6,10 @@
namespace oneflow {
class TaskNode;
// Regst : Register
// Contig : Contiguous
// Regst : Register
class TaskNode;
class RegstDesc {
public:
......@@ -16,32 +17,28 @@ class RegstDesc {
RegstDesc();
virtual ~RegstDesc() = default;
//
// Producer
const TaskNode* GetProducer() const { return producer_; }
void SetProducer(const TaskNode* task_node) { producer_ = task_node; }
void AddSubscriber(const TaskNode* task_node) {
CHECK(subscribers_.insert(task_node).second);
}
void CopyLbnAndShape(const RegstDesc*) { TODO(); }
Shape* EnrollLbn(const std::string& lbn) { TODO(); }
const Shape& GetShape(const std::string& lbn) { TODO(); }
Shape* GetMutShapePtr(const std::string& lbn) { TODO(); }
// Lbn and Shape
void CopyLbn2ShapeMap(const RegstDesc*);
Shape* EnrollLbn(const std::string& lbn);
const Shape& GetShape(const std::string& lbn);
Shape* GetMutShapePtr(const std::string& lbn);
static const char* kAllLbn;
private:
int32_t regst_desc_id_;
const TaskNode* producer_;
std::unordered_set<const TaskNode*> subscribers_;
HashMap<std::string, std::unique_ptr<Shape>> lbn2shape_;
};
// Contiguous
class ContigRegstDesc final : public RegstDesc {
public:
static const char* kAllLbn;
OF_DISALLOW_COPY_AND_MOVE(ContigRegstDesc);
ContigRegstDesc() = default;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册