program_builder.h 3.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30
// Copyright (c) 2023 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once
#include <string>

#include "paddle/cinn/common/type.h"
#include "paddle/cinn/frontend/net_builder.h"
#include "paddle/cinn/frontend/syntax.h"
#include "paddle/cinn/utils/type_defs.h"

namespace cinn {
namespace tests {

// This struct packing data fields for variable constructor
struct VariableInfo {
  std::string id;
  std::vector<int> shape;
  common::Type type;
31 32 33
  VariableInfo(std::string name,
               std::vector<int> shape,
               common::Type dtype = common::Float(32))
34 35 36
      : id(name), shape(shape), type(dtype) {}
};

37 38 39 40
// This class define a general interface to build a frontend::Program easily for
// test usage, developer can implement derived classes to build customized
// programs which may be reused by others through specifying the detail of input
// variables and attributes
41 42
class ProgramBuilder {
 public:
43
  explicit ProgramBuilder(const std::string& name) : builder_(name) {}
44 45

  /*
46 47 48 49
   * \brief Build a frontend::Program with the input variables info and
   * attributes
   * @param input_varinfo The detail data fields of each input variable, input
   * order should match their usage in override implement
50 51 52 53
   * @param attrs The detail value of each input attributes, input order should
   *                      match their usage in override implementdefinition
   * @return The built program
   */
54 55 56
  virtual frontend::Program Build(
      const std::vector<VariableInfo>& inputs_varinfo,
      const utils::AttributeMap& attrs) = 0;
57 58 59 60 61 62 63 64 65 66 67 68 69 70 71

  // return the output variables
  const std::vector<frontend::Variable>& GetOutputs() const { return outputs_; }

 protected:
  void AppendOutput(frontend::Variable var) { outputs_.emplace_back(var); }

  frontend::NetBuilder builder_;
  std::vector<frontend::Variable> outputs_;
};

// Build a frontend::Program which has only one operator
class OpBuilder final : public ProgramBuilder {
 public:
  /*
72 73
   * @param op_name name of the built operator, you should lookup the name from
   * its registry exactly
74 75 76
   */
  OpBuilder(const std::string& op_name);

77 78
  // the item order in `inputs_varinfo` and `attrs` should match their usage in
  // the underlying operator
79 80 81 82 83 84 85 86 87 88 89 90 91 92
  frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo,
                          const utils::AttributeMap& attrs = {}) override;

 private:
  std::string op_name_;
};

// Build a frontend::Program by loading paddle model from local files
class PaddleModelBuilder final : public ProgramBuilder {
 public:
  /*
   * @param model_path the path to local model files
   * @param target device type
   */
93 94
  PaddleModelBuilder(const std::string& model_path,
                     const common::Target& target);
95 96 97 98 99 100 101 102 103 104 105

  frontend::Program Build(const std::vector<VariableInfo>& inputs_varinfo,
                          const utils::AttributeMap& attrs = {}) override;

 private:
  std::string model_path_;
  common::Target target_;
};

}  // namespace tests
}  // namespace cinn