function_schema.cc 3.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/jit/function_schema.h"

17
#include "paddle/fluid/framework/program_desc.h"
18 19 20
#include "paddle/phi/core/enforce.h"

#include "paddle/fluid/jit/function_utils.h"
21 22 23 24 25 26 27 28
namespace paddle {
namespace jit {

Argument::Argument(const std::string& name, bool is_out)
    : name_(name), is_output_(is_out) {}

const std::string& Argument::Name() const { return name_; }

29
const std::vector<std::string> FunctionSchema::InputArgNames() const {
30 31 32 33 34 35 36
  std::vector<std::string> input_arg_names;
  for (auto& arg : input_args) {
    input_arg_names.emplace_back(arg.Name());
  }
  return input_arg_names;
}

37
const std::vector<std::string> FunctionSchema::OutputArgNames() const {
38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55
  std::vector<std::string> output_arg_names;
  for (auto& arg : output_args) {
    output_arg_names.emplace_back(arg.Name());
  }
  return output_arg_names;
}

void FunctionSchema::AddInputArg(const std::string& name) {
  input_args.emplace_back(name, false);
}

void FunctionSchema::AddOutputArg(const std::string& name) {
  output_args.emplace_back(name, true);
}

FunctionInfo::FunctionInfo(const std::string& func_name,
                           const std::vector<std::string>& param_names,
                           const framework::ProgramDesc& program_desc)
56 57
    : func_name_(func_name), param_names_(param_names) {
  program_desc_.reset(new framework::ProgramDesc(program_desc));
58
  // Parse FunctionSchema
59
  for (auto& in_name : program_desc_->GetFeedTargetNames()) {
60 61
    schema_.AddInputArg(in_name);
  }
62
  for (auto& out_name : program_desc_->GetFetchTargetNames()) {
63 64 65 66
    schema_.AddOutputArg(out_name);
  }
}

67
const std::string& FunctionInfo::FunctionName() const { return func_name_; }
68

69
const framework::ProgramDesc& FunctionInfo::ProgramDesc() const {
70
  return *program_desc_.get();
71 72
}

73
const std::vector<std::string>& FunctionInfo::ParamNames() const {
74 75 76
  return param_names_;
}

77 78
const std::vector<std::string> FunctionInfo::InputArgNames() const {
  return schema_.InputArgNames();
79 80
}

81 82
const std::vector<std::string> FunctionInfo::OutputArgNames() const {
  return schema_.OutputArgNames();
83 84
}

85 86 87 88 89 90 91 92
const std::string& FunctionInfo::ProgramFilePath() const {
  return prog_file_path_;
}

void FunctionInfo::SetProgramFilePath(const std::string& path) {
  prog_file_path_ = path;
}

93
void FunctionInfo::RemoveDescFeedFetch() {
94
  utils::RemoveFeedFetch(program_desc_.get());
95 96
}

97 98
}  // namespace jit
}  // namespace paddle