function_schema.cc 2.7 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/jit/function_schema.h"

17 18 19
#include "paddle/phi/core/enforce.h"

#include "paddle/fluid/jit/function_utils.h"
20 21 22 23 24 25 26 27
namespace paddle {
namespace jit {

Argument::Argument(const std::string& name, bool is_out)
    : name_(name), is_output_(is_out) {}

const std::string& Argument::Name() const { return name_; }

28
const std::vector<std::string> FunctionSchema::InputArgNames() const {
29 30 31 32 33 34 35
  std::vector<std::string> input_arg_names;
  for (auto& arg : input_args) {
    input_arg_names.emplace_back(arg.Name());
  }
  return input_arg_names;
}

36
const std::vector<std::string> FunctionSchema::OutputArgNames() const {
37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66
  std::vector<std::string> output_arg_names;
  for (auto& arg : output_args) {
    output_arg_names.emplace_back(arg.Name());
  }
  return output_arg_names;
}

void FunctionSchema::AddInputArg(const std::string& name) {
  input_args.emplace_back(name, false);
}

void FunctionSchema::AddOutputArg(const std::string& name) {
  output_args.emplace_back(name, true);
}

FunctionInfo::FunctionInfo(const std::string& func_name,
                           const std::vector<std::string>& param_names,
                           const framework::ProgramDesc& program_desc)
    : func_name_(func_name),
      param_names_(param_names),
      program_desc_(program_desc) {
  // Parse FunctionSchema
  for (auto& in_name : program_desc_.GetFeedTargetNames()) {
    schema_.AddInputArg(in_name);
  }
  for (auto& out_name : program_desc_.GetFetchTargetNames()) {
    schema_.AddOutputArg(out_name);
  }
}

67
const std::string& FunctionInfo::FunctionName() const { return func_name_; }
68

69
const framework::ProgramDesc& FunctionInfo::ProgramDesc() const {
70 71 72
  return program_desc_;
}

73
const std::vector<std::string>& FunctionInfo::ParamNames() const {
74 75 76
  return param_names_;
}

77 78
const std::vector<std::string> FunctionInfo::InputArgNames() const {
  return schema_.InputArgNames();
79 80
}

81 82
const std::vector<std::string> FunctionInfo::OutputArgNames() const {
  return schema_.OutputArgNames();
83 84
}

85 86 87 88
void FunctionInfo::RemoveDescFeedFetch() {
  utils::RemoveFeedFetch(&program_desc_);
}

89 90
}  // namespace jit
}  // namespace paddle