cinn_launch_context.h 7.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
22
#include <vector>
23

24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/new_executor/interpretercore.h"
26
#include "paddle/fluid/framework/parallel_executor.h"
27
#include "paddle/fluid/platform/place.h"
28
#include "paddle/phi/core/ddim.h"
29 30 31 32 33 34 35 36 37

// type declaration forward
struct cinn_buffer_t;
struct cinn_pod_value_t;
namespace cinn::hlir::framework {
class Tensor;
class Scope;
class Program;
}  // namespace cinn::hlir::framework
38 39

namespace paddle {
40 41 42 43 44 45 46 47 48 49 50 51 52 53
namespace framework {
class ProgramDesc;
class Scope;
class VarDesc;

namespace ir {
class Graph;
}  // namespace ir

namespace paddle2cinn {
class CinnCompiledObject;
}  // namespace paddle2cinn
}  // namespace framework

54
namespace operators::details {
55 56 57

using CinnTensor = ::cinn::hlir::framework::Tensor;
using CinnScope = ::cinn::hlir::framework::Scope;
58
using CinnCompiledObject = framework::paddle2cinn::CinnCompiledObject;
59

60 61 62 63 64 65 66
// This class is used to cache some reusable data among repeated
// executions for efficiency and it also provides easy interfaces
// to get details of the compilation result.
// A object of this class is constructed and saved in the
// compilation cache once a graph compiled by CINN.
// Generally speaking, here, a variable is refer to a Paddle
// Variable while a CINN variable is called an Argument.
67 68
class CinnLaunchContext {
 public:
69 70 71 72 73 74 75 76
  explicit CinnLaunchContext(const framework::ir::Graph& graph,
                             const CinnCompiledObject& compiled_obj);

  // Initialize a ParallelExecutor to execute the runtime graph,
  // it will be constructed in the first call, and just update
  // the execution scope in the following usage.
  framework::ParallelExecutor* InitializePE(const platform::Place& place,
                                            framework::Scope* scope);
77

78 79 80
  framework::InterpreterCore* InitializeInterpreterCore(
      const platform::Place& place, framework::Scope* scope);

81 82 83 84 85
  // explicitly update several environment variables captured
  // by callback of execution arguments
  void UpdateCapturedEnv(const framework::Scope& scope,
                         const platform::Place& place);

86 87
  // Return whether a Paddle variable used in cinn execution
  bool IsVariableUsed(const std::string& var_name) const;
88

89 90 91
  // Check the equiality in type and dimension between the tensor
  // in Paddle and the compiled tensor returned by CINN of a same variable
  void CheckTensorEquivalent(const std::string& var_name,
92
                             const phi::DenseTensor& paddle_tensor);
93

94 95 96 97 98
  // Return the name list of variables skipped eager deletion
  const std::vector<std::string>& GetSkipEagerVars() const {
    return skip_eager_vars_;
  }

99 100 101 102
  // Return internal variable names list
  const std::unordered_set<std::string>& GetInternalVarNames() const {
    return internal_var_names_;
  }
103

104
  // Finalize all execution arguments and return the name->argument map
105 106 107
  const std::map<std::string, cinn_pod_value_t>& FinalizeArguments() const {
    return name2argument_;
  }
108

109 110
  // Return the cinn_buffer_t* of a specific variable
  cinn_buffer_t* GetCinnBufferOfVar(const std::string& var_name);
111

112
 private:
113 114 115
  // Get corresponding compiled tensor of a Paddle variable name
  CinnTensor GetCinnTensorOfVar(const std::string& var_name);

116 117 118 119 120 121
  // Build the name maps of paddle->cinn and cinn->paddle
  // in reverse for all variables used in cinn execution
  void BuildVarNameMap(
      const std::unordered_map<std::string, std::string>& compiled_varmap,
      const std::unordered_set<std::string>& argument_names);

122 123 124 125 126 127 128 129
  // Extract internal variable names from all applied variables
  // in execution by excluding the input and output variables
  std::unordered_set<std::string> ExtractInternalVarNames(
      const std::vector<std::string>& input_var_names,
      const std::vector<std::string>& output_var_names);

  // Initialize each execution argument with a cinn_buffer_t
  void InitializeArguments();
130

131 132 133 134 135 136 137 138
  // Assign tensor buffer to input or output variables
  void AssignExternalVariable(const std::string& var_name);

  // Assign tensor buffer to internal variables
  void AssignInternalVariable(const std::string& var_name);

  // Construct a Paddle ProgramDesc with the CINN runtime
  // instructions included in the compiled CINN Program
139
  std::unique_ptr<framework::ProgramDesc> BuildCompiledProgram(
140 141
      const framework::ir::Graph& graph,
      const CinnCompiledObject& compiled_obj);
142 143

 private:
144 145 146 147
  const framework::Scope* cached_scope_ = nullptr;
  const platform::Place* cached_place_ = nullptr;
  std::unique_ptr<framework::Scope> cached_temp_scope_ = nullptr;

148
  // a name map from paddle variables to cinn execution arguments
149
  std::unordered_map<std::string, std::string> paddle2cinn_varmap_;
150
  // a name map from cinn execution arguments to paddle variables
151
  std::unordered_map<std::string, std::string> cinn2paddle_varmap_;
152 153
  // a list of internal variable names in Paddle
  std::unordered_set<std::string> internal_var_names_;
154 155
  // the names of the cinn arguments used in compiled executable program
  std::unordered_set<std::string> cinn_argument_names_;
156 157 158
  // TODO(CtfGo): remove this list after fixing batch_norm bug
  // due to duplicate association in the same variable.
  std::vector<std::string> initialized_beforehand_vars_;
159
  // the variable scope compiled from cinn
160 161
  const std::shared_ptr<CinnScope> cinn_scope_;

162 163 164 165
  std::unique_ptr<framework::ProgramDesc> runtime_program_desc_;
  std::unique_ptr<framework::InterpreterCore> interpreter_core_;
  std::set<std::string> skip_gc_vars_;

166 167 168 169
  // the ir::Graph object converted from the program compiled by CINN
  std::unique_ptr<framework::ir::Graph> runtime_graph_;
  // a ParallelExecutor to execute the runtime graph
  std::unique_ptr<framework::ParallelExecutor> parallel_executor_;
170 171
  // the name list of skip_eager_vars in runtime
  std::vector<std::string> skip_eager_vars_;
172

173 174
  // because a cinn_pod_value_t does not own a cinn_buffer_t object,
  // an extra stroage is necessary to keep those objects and they can
175
  // not be released until the runtime program finish execution.
176
  std::vector<std::unique_ptr<cinn_buffer_t>> hold_buffers_;
177 178
  // this map saves all execution arguments with their cinn names as key,
  // and it is passed to the Execute interface of a cinn runtime program.
179 180 181
  std::map<std::string, cinn_pod_value_t> name2argument_;
};

182
}  // namespace operators::details
183
}  // namespace paddle