plan.cc 2.7 KB
Newer Older
L
LiYuRio 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/framework/new_executor/interpreter/plan.h"
16

L
LiYuRio 已提交
17 18 19 20
#include "paddle/fluid/framework/program_desc.h"

namespace paddle {
namespace framework {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36
namespace interpreter {

Plan::Plan(const std::vector<std::shared_ptr<Job>>& job_list,
           const std::unordered_map<std::string, ProgramDesc*>& type_to_program)
    : job_list_(job_list),
      type_to_program_(type_to_program),
      micro_batch_num_(1) {
  for (size_t i = 0; i < job_list_.size(); ++i) {
    const auto& job = job_list_[i];
    PADDLE_ENFORCE(type_to_program_.find(job->Type()) != type_to_program_.end(),
                   phi::errors::InvalidArgument(
                       "The %d-th job (type:%s, micro_batch_id:%d) has no "
                       "corresponding Program.",
                       i,
                       job->Type(),
                       job->MicroBatchId()));
L
LiYuRio 已提交
37

38 39 40 41
    micro_batch_num_ = std::max(micro_batch_num_, job->MicroBatchId() + 1);
  }
}

42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63
Plan::Plan(
    const std::vector<std::shared_ptr<Job>>& job_list,
    const std::unordered_map<std::string, std::shared_ptr<::ir::Program>>&
        type_to_ir_program)
    : job_list_(job_list),
      type_to_ir_program_(type_to_ir_program),
      micro_batch_num_(1) {
  for (size_t i = 0; i < job_list_.size(); ++i) {
    const auto& job = job_list_[i];
    PADDLE_ENFORCE(
        type_to_ir_program_.find(job->Type()) != type_to_ir_program_.end(),
        phi::errors::InvalidArgument(
            "The %d-th job (type:%s, micro_batch_id:%d) has no "
            "corresponding Program.",
            i,
            job->Type(),
            job->MicroBatchId()));

    micro_batch_num_ = std::max(micro_batch_num_, job->MicroBatchId() + 1);
  }
}

64
const std::vector<std::shared_ptr<Job>>& Plan::JobList() const {
L
LiYuRio 已提交
65 66
  return job_list_;
}
L
LiYuRio 已提交
67

68 69
const ProgramDesc* Plan::Program(const std::string& job_type) const {
  return type_to_program_.at(job_type);
L
LiYuRio 已提交
70 71
}

72 73 74 75 76
std::shared_ptr<::ir::Program> Plan::IrProgram(
    const std::string& job_type) const {
  return type_to_ir_program_.at(job_type);
}

77 78 79
int64_t Plan::MicroBatchNum() const { return micro_batch_num_; }

}  // namespace interpreter
L
LiYuRio 已提交
80 81
}  // namespace framework
}  // namespace paddle