#ifndef ONEFLOW_CORE_GRAPH_TASK_NODE_H_ #define ONEFLOW_CORE_GRAPH_TASK_NODE_H_ #include "oneflow/core/graph/exec_graph.h" #include "oneflow/core/job/id_manager.h" #include "oneflow/core/job/task.pb.h" #include "oneflow/core/operator/operator.h" namespace oneflow { bool IsForwardTaskType(TaskType); bool IsBackwardTaskType(TaskType); bool IsMdUpdtTaskType(TaskType); RegstDescProto* FindOrCreateProducedCtrlRegstDesc(TaskProto* task_proto, const std::string& regst_desc_name); RegstDescIdSet* FindOrCreateConsumedCtrlRegstDescIdSet(TaskProto* task_proto, const std::string& regst_desc_name); class TaskEdge; class TaskNode : public Node { public: OF_DISALLOW_COPY_AND_MOVE(TaskNode); TaskNode(); virtual ~TaskNode() = default; // Getters int64_t machine_id() const { return machine_id_; } int64_t thrd_id() const { return thrd_id_; } int64_t task_id() const { return task_id_; } int64_t area_id() const { return area_id_; } int64_t chain_id() const { return chain_id_; } int64_t order_in_graph() const { return order_in_graph_; } const ExecGraph& exec_gph() const { return exec_gph_; } std::shared_ptr GetProducedRegst(const std::string& name); const std::list>& GetConsumedRegst(const std::string& name); std::shared_ptr GetSoleConsumedRegst(const std::string& name); const HashMap>& produced_regsts() { return produced_regsts_; } const HashMap>>& consumed_regsts() { return consumed_regsts_; } const HashSet ancestors() const { return ancestors_; } HashSet& mut_ancestors() { return ancestors_; } DeviceType device_type() const; virtual const ParallelContext* parallel_ctx() const { return nullptr; } int64_t LocalWorkStreamId() const; int64_t GlobalWorkStreamId() const; int64_t GpuPhyId() const { return Global::Get()->GetGpuPhyIdFromThrdId(thrd_id_); } // Setters void set_machine_id(int64_t val); void set_thrd_id(int64_t val); void set_area_id(int64_t val); void set_chain_id(int64_t val); void set_order_in_graph(int64_t val); // Build virtual void ProduceAllRegstsAndBindEdges() = 0; virtual void ConsumeAllRegsts() = 0; void PinConsumedRegst(); void Build(); virtual bool IsReadyForBuild() { return IsAllConsumedRegstLocked(); } virtual void EraseEmptyProducedRegst(); void ClearOutOfDateConsumedRegst(); // Others virtual TaskType GetTaskType() const { return TaskType::kInvalid; } std::string VisualStr() const override; virtual bool IsMeaningLess(); virtual void ToProto(TaskProto*); virtual bool IsPersistence() const { return false; } void BindEdgeWithProducedRegst(TaskEdge*, const std::string& name); virtual int64_t MemZoneId121() const; // TODO: there is bug for reduce task node void BuildCtrlRegstDescIfNeed(TaskNode* dst_node); RegstDesc* BuildCtrlRegstDesc(TaskNode* dst_node); protected: std::shared_ptr ProduceRegst(const std::string& name, bool enable_mem_sharing); std::shared_ptr ProduceRegst(const std::string& name, bool enable_mem_sharing, int32_t min_register_num, int32_t max_register_num); std::shared_ptr ProduceRegst(const std::string& name, bool enable_mem_sharing, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto&); std::shared_ptr NewProducedRegst(bool enable_mem_sharing, int32_t min_register_num, int32_t max_register_num, const RegstDescTypeProto&); virtual void InitProducedRegstMemCase(RegstDesc* regst); virtual void InitProducedRegstMemCase(MemoryCase*); virtual void PinConsumedRegstMemCase(MemoryCase*); void ConsumeRegst(const std::string& name, std::shared_ptr); bool IsAllConsumedRegstLocked(); ExecGraph& mut_exec_gph() { return exec_gph_; } void TryLockConsumedRegst(const std::string& name); virtual void BuildExecGphAndRegst() = 0; virtual void LockRegsts(); virtual void FixRegisterNumRange(); virtual int64_t AllocateLocalWorkStreamId(); private: void UpdateTaskId(); int64_t machine_id_; int64_t thrd_id_; int64_t task_id_; int64_t area_id_; int64_t chain_id_; int64_t order_in_graph_; ExecGraph exec_gph_; HashMap> produced_regsts_; HashMap>> consumed_regsts_; HashSet ancestors_; }; class TaskEdge final : public Edge { public: OF_DISALLOW_COPY_AND_MOVE(TaskEdge); TaskEdge() = default; ~TaskEdge() = default; std::shared_ptr GetRegst(const std::string& name_in_producer) const; std::shared_ptr GetSoleRegst() const; void AddRegst(const std::string& name_in_producer, std::shared_ptr regst); private: HashMap> name_in_producer2regst_; }; extern std::map task_type2color; } // namespace oneflow #endif // ONEFLOW_CORE_GRAPH_TASK_NODE_H_