boxing_task_node.h 3.2 KB
Newer Older
J
jiyuan 已提交
1 2
#ifndef ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
#define ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_
3

W
Will Zhang 已提交
4
#include "oneflow/core/graph/compute_task_node.h"
5 6 7

namespace oneflow {

W
Will Zhang 已提交
8 9
class ChainNode;

W
willzhang4a58 已提交
10
class BoxingTaskNode : public TaskNode {
11
 public:
W
Will Zhang 已提交
12 13 14 15 16 17
  struct EdgeInfo {
    const TaskEdge* edge;
    int64_t parallel_id_min;
    int64_t parallel_id_max;
  };

18 19
  OF_DISALLOW_COPY_AND_MOVE(BoxingTaskNode);
  BoxingTaskNode() = default;
W
willzhang4a58 已提交
20
  virtual ~BoxingTaskNode() = default;
W
willzhang4a58 已提交
21

W
Will Zhang 已提交
22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76
  void Init(int64_t machine_id);
  TodoTaskType GetTaskType() const override { return TodoTaskType::kBoxing; }

  void ProduceAllRegstsAndBindEdges() override;
  void ConsumeAllRegsts() override;
  void Build() override;

#define DECLARE_BLD_BOXING_OP_CONF_METHOD(x)                                   \
  void BldBoxingOpConfWith##x(                                                 \
      const std::string& lbn, const std::vector<EdgeInfo>& sorted_in_edges,    \
      int64_t in_parallel_num, int64_t in_edge_first, int64_t in_edge_last,    \
      const std::vector<EdgeInfo>& sorted_out_edges, int64_t out_parallel_num, \
      int64_t* used_out_edge_begin, BoxingOpConf*)

#define DECLARE_VIRTUAL_BLD_BOXING_OP_CONF_METHOD(x) \
  virtual DECLARE_BLD_BOXING_OP_CONF_METHOD(x) = 0

  DECLARE_BLD_BOXING_OP_CONF_METHOD();

  DECLARE_VIRTUAL_BLD_BOXING_OP_CONF_METHOD(DataConcatAndDataSplit);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(DataConcatAndClone);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(DataConcatAndModelSplit);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(ModelConcatAndDataSplit);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(ModelConcatAndClone);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndDataSplit);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndModelSplit);
  DECLARE_BLD_BOXING_OP_CONF_METHOD(AddAndClone);

 private:
  void InitChain2SortedEdgeInfo(
      const std::unordered_set<TaskEdge*>& (TaskNode::*GetEdges)() const,
      TaskEdge* (TaskNode::*SoleEdge)() const,
      TaskNode* (TaskEdge::*SoleNode)() const,
      HashMap<const ChainNode*, std::vector<EdgeInfo>>*);
  void BuildWithChainPair(const ChainNode* in_chain,
                          const std::vector<EdgeInfo>& sorted_in_edges,
                          const ChainNode* out_chain,
                          const std::vector<EdgeInfo>& sorted_out_edges);
  std::shared_ptr<Operator> NewBoxingOp(
      const std::string& lbn, const ChainNode* in_chain,
      const ChainNode* out_chain, const std::vector<EdgeInfo>& sorted_in_edges,
      const std::vector<EdgeInfo>& sorted_out_edges,
      int64_t* used_in_edge_begin, int64_t* used_out_edge_begin);
};

#define OVERRIDE_BLD_BOXING_OP_METHOD(x) \
  DECLARE_BLD_BOXING_OP_CONF_METHOD(x) override

class InBoxingTaskNode final : public BoxingTaskNode {
 public:
  OF_DISALLOW_COPY_AND_MOVE(InBoxingTaskNode);
  InBoxingTaskNode() = default;
  ~InBoxingTaskNode() = default;

  OVERRIDE_BLD_BOXING_OP_METHOD(DataConcatAndDataSplit);
W
willzhang4a58 已提交
77

W
willzhang4a58 已提交
78
 private:
W
Will Zhang 已提交
79 80 81 82 83 84 85
};

class OutBoxingTaskNode final : public BoxingTaskNode {
 public:
  OF_DISALLOW_COPY_AND_MOVE(OutBoxingTaskNode);
  OutBoxingTaskNode() = default;
  ~OutBoxingTaskNode() = default;
W
willzhang4a58 已提交
86

W
Will Zhang 已提交
87
  OVERRIDE_BLD_BOXING_OP_METHOD(DataConcatAndDataSplit);
W
willzhang4a58 已提交
88

W
Will Zhang 已提交
89
 private:
90 91
};

W
willzhang4a58 已提交
92
}  // namespace oneflow
93

W
willzhang4a58 已提交
94
#endif  // ONEFLOW_CORE_GRAPH_BOXING_TASK_NODE_H_