compiler.h 4.0 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <absl/strings/string_view.h>

#include <fstream>
#include <memory>
#include <mutex>
#include <string>

#include "paddle/cinn/backends/llvm/codegen_llvm.h"
#include "paddle/cinn/backends/llvm/execution_engine.h"
#include "paddle/cinn/backends/llvm/simple_jit.h"
27
#include "paddle/cinn/hlir/framework/parallel_compiler.h"
28 29 30 31 32 33 34 35
#include "paddle/cinn/lang/packed_func.h"
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#endif

namespace cinn {
namespace backends {

36 37 38 39 40 41 42 43 44 45
/**
 * A class for dumping the code after compilation.
 * Use FLAGS_cinn_dump_group_lowered_func to specify the directory to dump
 * lowered function. Use FLAGS_cinn_dump_group_source_code to specify the
 * directory to dump the source code. Use FLAGS_cinn_dump_group_ptx to specify
 * the directory to dump ptx. Use FLAGS_cinn_dump_group_instruction to specify
 * the directory to dump instruction.
 */
class CompilationInfoDumper {
 public:
46
  explicit CompilationInfoDumper(const hlir::framework::CompilationResult& info)
47 48 49 50 51 52 53
      : info_(info) {
    DumpLoweredFunc();
    DumpSourceCode();
    DumpPtxCode();
    DumpInstruction();
  }

54 55 56 57 58 59 60 61 62 63
  static void DumpLoweredFuncByGroupIndex(const ir::LoweredFunc& lowered_func,
                                          const int gidx);
  static void DumpSourceCodeByGroupIndex(const std::string& source_code,
                                         const int gidx);
  static void DumpPtxCodeByGroupIndex(const std::string& source_ptx,
                                      const int gidx);
  static void DumpInstructionByGroupIndex(
      const std::unique_ptr<cinn::hlir::framework::Instruction>& instr,
      const int gidx);

64 65 66 67 68
 private:
  void DumpLoweredFunc();
  void DumpSourceCode();
  void DumpPtxCode();
  void DumpInstruction();
69 70 71 72
  static void Dump(const std::string& base_path,
                   const int idx,
                   const std::string& file_name,
                   const std::string& content);
73

74
  const hlir::framework::CompilationResult& info_;
75 76
};

77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117
class SourceCodePrint {
 public:
  static SourceCodePrint* GetInstance() {
    static SourceCodePrint print;
    return &print;
  }

  void write(const std::string& source_code);

 private:
  SourceCodePrint();
  ~SourceCodePrint();

  std::ofstream of;
  std::mutex mtx_;
};

class Compiler final {
 public:
  static std::unique_ptr<Compiler> Create(const Target& target) {
    return std::unique_ptr<Compiler>(new Compiler(target));
  }

  /**
   * Compile and link to a CINN module.
   */
  void Build(const ir::Module& module, const std::string& code = "");

  void ExportObject(const std::string& path);

  std::string GetSourceCode(const ir::Module& module);

  void BuildDefault(const ir::Module& module);

  /**
   * Retrieve a function by \p fn_name.
   * @return function address or null if not exists.
   */
  void* Lookup(absl::string_view fn_name);

 private:
118 119
  void CompileCudaModule(const ir::Module& module,
                         const std::string& code = "");
120 121 122

  void CompileX86Module(const ir::Module& module);

123 124
  explicit Compiler(const Target& target)
      : target_(target), engine_(ExecutionEngine::Create(ExecutionOptions())) {}
125 126 127 128 129 130 131 132 133 134 135 136 137 138

  CINN_DISALLOW_COPY_AND_ASSIGN(Compiler);

 private:
  Target target_;
  std::unique_ptr<ExecutionEngine> engine_;

#ifdef CINN_WITH_CUDA
  std::unique_ptr<runtime::cuda::CUDAModule> cuda_module_;
#endif
};

}  // namespace backends
}  // namespace cinn