instruction.h 5.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <string>
#include <utility>
#include <vector>

#include "paddle/cinn/backends/cuda_util.h"
#include "paddle/cinn/hlir/framework/scope.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#endif
#include "paddle/cinn/utils/string.h"
#include "paddle/cinn/utils/timer.h"

namespace cinn {
namespace hlir {
namespace framework {

/**
35 36 37 38
 * Instruction is the basic executable element in runtime, it holds a pointer to
 * the JIT-compiled LoweredFunc, and collect the cinn_buffer of the inputs and
 * outputs from the scope, prepare the arguments and finally pass them into the
 * LoweredFunc and execute it.
39 40 41
 */
class Instruction {
 public:
42 43
  using infershape_t =
      std::function<void(Scope*, const std::vector<std::string>&)>;
44 45 46 47

  /**
   * Constructor.
   * @param target The \p target the instruction runs on.
48 49
   * @param scope The scope containing all the runtime variables(Tensors and
   * PODs).
50 51
   * @param in_args The names of the inputs.
   * @param out_args The names of the outputs.
52 53
   * @param infershape The handler of this Instruction to perform shape
   * inference.
54 55 56 57 58 59
   */
  Instruction(const Target& target,
              Scope* scope,
              const std::vector<std::string>& in_args,
              const std::vector<std::string>& out_args,
              const std::string& function_name = "")
60 61 62 63 64
      : target_(target),
        scope_(scope),
        in_args_({in_args}),
        out_args_({out_args}),
        function_name_(function_name) {}
65 66 67 68 69 70 71 72 73 74

  /**
   * Set compiled function address.
   * @param fn The JIT compiled function address.
   */
  void SetLoweredFunc(void* fn_ptr, const std::string& name = "") {
    fn_ptrs_.push_back(fn_ptr);
    fn_names_.push_back(name);
  }

75 76
  // explicitly finalize the instruction, and can't append function again after
  // call it
77 78
  void Finalize();

79 80
  void UpdateArgsCache(
      const std::map<std::string, cinn_pod_value_t>* name2podargs);
81 82 83
  /**
   * Run the Instruction.
   */
84 85 86 87 88 89 90 91
  void Run(
      const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr,
      bool dryrun = false,
      void* stream = nullptr,
      bool use_cache = true);

  void PreRun(
      const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr) {
92 93 94 95 96 97 98 99 100 101 102
    CHECK_EQ(fn_ptrs_.size(), 4);
    if (fn_ptrs_.size() > 1 && fn_ptrs_.size() != in_args_.size()) {
      out_args_.back()[0] = out_args_.front()[0];
      out_args_.erase(out_args_.begin());
      in_args_.erase(in_args_.begin());
    }
    UpdateArgsCache(name2podargs);

    CHECK_EQ(fn_ptrs_.size(), in_args_.size());
    CHECK_EQ(fn_ptrs_.size(), out_args_.size());

103
    int flag = -1;
104 105 106 107
    void* stream = nullptr;
    for (int idx = 0; idx < 4; idx++) {
      if (utils::Startswith(out_args_[idx][0], "kernel_pack")) {
        VLOG(3) << "PreRun " << idx << "-th function of fn_:" << fn_names_[idx];
108
        flag = idx;
109
        auto& pod_args = args_cached_[idx];
110 111
        CHECK(fn_ptrs_[idx]) << "The LoweredFunc address should be set first "
                                "by calling SetLoweredFunc method";
112
        if (target_ == common::DefaultNVGPUTarget()) {
113 114
          ((lower_func_ptr_g)fn_ptrs_[idx])(
              static_cast<void*>(pod_args.data()), pod_args.size(), stream);
115
        } else {
116 117
          ((lower_func_ptr_t)fn_ptrs_[idx])(static_cast<void*>(pod_args.data()),
                                            pod_args.size());
118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
        }
#ifdef CINN_WITH_CUDA
        CUDA_CALL(cudaDeviceSynchronize());
#endif
      }
    }
    if (flag >= 0) {
      args_cached_.erase(args_cached_.begin() + flag);
      in_args_.erase(in_args_.begin() + flag);
      out_args_.erase(out_args_.begin() + flag);
      fn_ptrs_.erase(fn_ptrs_.begin() + flag);
      fn_names_.erase(fn_names_.begin() + flag);
    }
  }

  int size() { return fn_ptrs_.size(); }

  std::vector<std::vector<std::string>> GetInArgs() { return in_args_; }
  std::vector<std::vector<std::string>> GetOutArgs() { return out_args_; }
  void ClearInArgs() { in_args_.clear(); }
  void ClearOutArgs() { out_args_.clear(); }
  std::vector<std::string> GetFnNames() { return fn_names_; }
140 141 142 143 144 145
  void AddInArgs(const std::vector<std::string>& in_args) {
    in_args_.push_back(in_args);
  }
  void AddOutArgs(const std::vector<std::string>& out_args) {
    out_args_.push_back(out_args);
  }
146 147 148 149 150 151
  std::vector<int> attrs;
  std::vector<std::string> str_attrs;
  bool pre_run = false;
  Target target_;

 protected:
152 153 154
  void CheckResults(
      const std::map<std::string, cinn_pod_value_t>* name2podargs = nullptr,
      void* stream = nullptr);
155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171

 private:
  bool finalized_flag_ = false;
  Scope* scope_{};
  std::string function_name_;
  std::vector<std::vector<std::string>> in_args_;
  std::vector<std::vector<std::string>> out_args_;

  std::vector<std::vector<cinn_pod_value_t>> args_cached_;

  std::vector<void*> fn_ptrs_{};
  std::vector<std::string> fn_names_;
};

}  // namespace framework
}  // namespace hlir
}  // namespace cinn