未验证 提交 29b25969 编写于 作者: J Juncheng 提交者: GitHub

Reuse operator instance (#4251)

Co-authored-by: Noneflow-ci-bot <69100618+oneflow-ci-bot@users.noreply.github.com>
上级 066426d4
......@@ -30,7 +30,7 @@ void DecodeH2DCompTaskNode::ProduceAllRegstsAndBindEdges() {
void DecodeH2DCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
std::shared_ptr<const Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
......
......@@ -65,7 +65,8 @@ void LogicalGraph::NaiveBuildFwStruct(
CHECK(parallel_desc_ptr_it != name2parallel_desc.end());
const std::shared_ptr<ParallelDesc>& parallel_desc_ptr = parallel_desc_ptr_it->second;
cur_op_conf.set_device_tag(CHECK_JUST(DeviceTag4DeviceType(parallel_desc_ptr->device_type())));
std::shared_ptr<Operator> cur_op = ConstructOp(cur_op_conf, &GlobalJobDesc());
std::shared_ptr<const Operator> cur_op =
Global<OpGraph>::Get()->OpNode4OpName(cur_op_conf.name())->shared_op();
LogicalNode* cur_node = cur_op->NewProperLogicalNode();
AddAllocatedNode(cur_node);
cur_node->mut_op_vec() = {cur_op};
......
......@@ -43,9 +43,9 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
virtual ~LogicalNode() = default;
// op_vec_
std::shared_ptr<Operator> SoleOp() const;
const std::vector<std::shared_ptr<Operator>>& op_vec() const { return op_vec_; }
std::vector<std::shared_ptr<Operator>>& mut_op_vec() { return op_vec_; }
std::shared_ptr<const Operator> SoleOp() const;
const std::vector<std::shared_ptr<const Operator>>& op_vec() const { return op_vec_; }
std::vector<std::shared_ptr<const Operator>>& mut_op_vec() { return op_vec_; }
// parallel_desc_
std::shared_ptr<const ParallelDesc> parallel_desc() const { return parallel_desc_; }
......@@ -83,7 +83,7 @@ class LogicalNode : public Node<LogicalNode, LogicalEdge> {
private:
bool HasOpWithCondition(std::function<bool(const Operator*)>) const;
std::vector<std::shared_ptr<Operator>> op_vec_;
std::vector<std::shared_ptr<const Operator>> op_vec_;
std::shared_ptr<const ParallelDesc> parallel_desc_;
HashMap<const LogicalNode*, std::vector<LogicalBlobId>> dst2data_lbis_;
......
......@@ -43,6 +43,7 @@ class OpNode final : public Node<OpNode, OpEdge> {
const Shape* out_blob_time_shape() const;
bool IsTimeShapeIdentity() const;
const Operator& op() const { return *op_; }
std::shared_ptr<const Operator> shared_op() const { return op_; }
const ParallelDesc& parallel_desc() const { return parallel_desc_; }
const SbpSignature& sbp_signature() const { return *CHECK_JUST(op().sbp_signature()); }
const SbpParallel& SbpParallel4Lbi(const LogicalBlobId& lbi) const;
......
......@@ -22,7 +22,7 @@ namespace oneflow {
void CaseCompTaskNode::ConsumeAllRegsts() { ConsumeRegst("in", SoleInDataEdge()->GetSoleRegst()); }
void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() {
const std::shared_ptr<Operator> op = logical_node()->SoleOp();
const std::shared_ptr<const Operator> op = logical_node()->SoleOp();
HashMap<LogicalBlobId, int64_t> lbi2obn_id;
FOR_RANGE(int64_t, obn_id, 0, op->output_bns().size()) {
CHECK(lbi2obn_id.emplace(op->BnInOp2Lbi(GenRepeatedBn("out", obn_id)), obn_id).second);
......@@ -46,7 +46,7 @@ void CaseCompTaskNode::ProduceAllRegstsAndBindEdges() {
void CaseCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
std::shared_ptr<const Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
node->BindBnWithRegst("in", GetSoleConsumedRegst("in"));
FOR_RANGE(int64_t, obn_id, 0, sole_op->output_bns().size()) {
......
......@@ -20,7 +20,7 @@ limitations under the License.
namespace oneflow {
void EsacCompTaskNode::ConsumeAllRegsts() {
const std::shared_ptr<Operator> op = logical_node()->SoleOp();
const std::shared_ptr<const Operator> op = logical_node()->SoleOp();
HashMap<LogicalBlobId, int64_t> lbi2ibn_id;
FOR_RANGE(int64_t, ibn_id, 0, op->input_bns().size()) {
CHECK(lbi2ibn_id.emplace(op->BnInOp2Lbi(GenRepeatedBn("in", ibn_id)), ibn_id).second);
......@@ -47,7 +47,7 @@ void EsacCompTaskNode::ProduceAllRegstsAndBindEdges() {
void EsacCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
std::shared_ptr<const Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
FOR_RANGE(int64_t, ibn_id, 0, sole_op->input_bns().size()) {
node->BindBnWithRegst(GenRepeatedBn("in", ibn_id),
......
......@@ -83,7 +83,7 @@ void AddFuncForFindBldSubTskGphMthd(const std::string& k, BldSubTskGphMthd v) {
} // namespace
std::shared_ptr<Operator> LogicalNode::SoleOp() const {
std::shared_ptr<const Operator> LogicalNode::SoleOp() const {
CHECK_EQ(op_vec_.size(), 1);
return op_vec_.front();
}
......@@ -110,7 +110,7 @@ bool LogicalNode::IsDataLbiOnOutEdge(const LogicalBlobId& lbi) const {
std::string LogicalNode::VisualStr() const {
std::stringstream ss;
ss << TypeName();
for (std::shared_ptr<Operator> op : op_vec_) { ss << "\\n" << op->op_name(); }
for (std::shared_ptr<const Operator> op : op_vec_) { ss << "\\n" << op->op_name(); }
return ss.str();
}
......
......@@ -53,7 +53,7 @@ void RepeatCompTaskNode::ProduceAllRegstsAndBindEdges() {
void RepeatCompTaskNode::BuildExecGphAndRegst() {
ExecNode* node = mut_exec_gph().NewNode();
std::shared_ptr<Operator> sole_op = this->logical_node()->SoleOp();
std::shared_ptr<const Operator> sole_op = this->logical_node()->SoleOp();
node->mut_op() = sole_op;
node->BindBnWithRegst(sole_op->SoleIbn(), GetSoleConsumedRegst("in"));
std::shared_ptr<RegstDesc> out_regst = GetProducedRegst("out");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册