compiler.cc 8.1 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2021 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/backends/compiler.h"

#include <fstream>

#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/common/context.h"
21
#include "paddle/cinn/hlir/framework/visualize_helper.h"
22 23 24 25 26 27 28 29 30 31 32
#ifdef CINN_WITH_CUDA
#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/runtime/cuda/cuda_module.h"
#include "paddle/cinn/runtime/cuda/cuda_util.h"
#include "paddle/cinn/runtime/flags.h"
#endif

DECLARE_string(cinn_source_code_save_path);
33 34 35 36
DECLARE_string(cinn_dump_group_lowered_func);
DECLARE_string(cinn_dump_group_source_code);
DECLARE_string(cinn_dump_group_ptx);
DECLARE_string(cinn_dump_group_instruction);
37 38 39 40 41 42 43

namespace cinn {
namespace backends {
using ir::Module;

static constexpr int DebugLogMaxLen = 30000;

44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
void CompilationInfoDumper::DumpLoweredFunc() {
  if (FLAGS_cinn_dump_group_lowered_func.empty()) {
    return;
  }
  for (int idx = 0; idx < info_.lowered_funcs.size(); ++idx) {
    std::stringstream content;
    content << info_.lowered_funcs[idx].front();
    Dump(FLAGS_cinn_dump_group_lowered_func,
         idx,
         "lowered_function.txt",
         content.str());
  }
}

void CompilationInfoDumper::DumpSourceCode() {
  if (FLAGS_cinn_dump_group_source_code.empty()) {
    return;
  }
  for (int idx = 0; idx < info_.source_codes.size(); ++idx) {
    Dump(FLAGS_cinn_dump_group_source_code,
         idx,
         "source_code.cu",
         info_.source_codes[idx]);
  }
}

void CompilationInfoDumper::DumpPtxCode() {
  if (FLAGS_cinn_dump_group_ptx.empty()) {
    return;
  }
  for (int idx = 0; idx < info_.source_ptxs.size(); ++idx) {
    Dump(FLAGS_cinn_dump_group_ptx,
         idx,
         "source_ptx.ptx",
         info_.source_ptxs[idx]);
  }
}

void CompilationInfoDumper::DumpInstruction() {
  if (FLAGS_cinn_dump_group_instruction.empty()) {
    return;
  }
  for (int idx = 0; idx < info_.instructions.size(); ++idx) {
    Dump(FLAGS_cinn_dump_group_instruction,
         idx,
         "instruction.txt",
         info_.instructions[idx]->DumpInstruction());
  }
}

void CompilationInfoDumper::Dump(const std::string& base_path,
                                 const int idx,
                                 const std::string& file_name,
                                 const std::string& content) {
  auto dump_path =
      utils::StringFormat("%s/fusion_group_%d", base_path.c_str(), idx);
  if (!hlir::framework::MakeDirectory(
          dump_path, S_IRWXU | S_IRGRP | S_IXGRP | S_IROTH | S_IXOTH)) {
    LOG(WARNING) << "Failed to make directory: \"" << dump_path
                 << "\", the instruction for this group will not dump.";
  } else {
    auto dump_file =
        utils::StringFormat("%s/%s", dump_path.c_str(), file_name.c_str());
    VLOG(7) << "Dump instruction to: " << dump_file;
    std::ofstream of(dump_file, std::ios_base::out);
    if (of.is_open()) {
      of << content;
      of.close();
    } else {
      LOG(WARNING) << "Failed to open file: " << dump_file
                   << ", please check your path.";
    }
  }
}

119 120
SourceCodePrint::SourceCodePrint() {
  if (!FLAGS_cinn_source_code_save_path.empty()) {
121 122 123
    LOG(INFO)
        << "The CINN auto generated source code will writing into file: \""
        << FLAGS_cinn_source_code_save_path << "\"";
124 125 126 127 128 129 130 131 132 133 134 135 136 137 138
    of.open(FLAGS_cinn_source_code_save_path, std::ios_base::out);
  }
}

SourceCodePrint::~SourceCodePrint() {
  if (of.is_open()) {
    of.close();
  }
}

void SourceCodePrint::write(const std::string& source_code) {
  std::lock_guard<std::mutex> guard(mtx_);
  if (of.is_open()) {
    of << source_code << std::endl;
  } else if (!FLAGS_cinn_source_code_save_path.empty()) {
139 140
    LOG(WARNING) << "Failed to open \"" << FLAGS_cinn_source_code_save_path
                 << "\", source code will print.";
141
    if (source_code.size() > DebugLogMaxLen) {
142 143
      LOG(INFO) << "[CUDA] source code-0:\n"
                << source_code.substr(0, DebugLogMaxLen);
144
      for (int i = 1; i * DebugLogMaxLen < source_code.size(); ++i) {
145 146
        LOG(INFO) << "[CUDA] source code-" << i << ":\n"
                  << source_code.substr(DebugLogMaxLen * i, DebugLogMaxLen);
147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166
      }
    } else {
      LOG(INFO) << "[CUDA] source code:\n" << source_code;
    }
  }
}

void Compiler::Build(const Module& module, const std::string& code) {
  if (target_.arch == Target::Arch::NVGPU) {
    CompileCudaModule(module, code);
  } else if (target_.arch == Target::Arch::X86) {
    CompileX86Module(module);
  } else {
    CINN_NOT_IMPLEMENTED
  }
}

std::string Compiler::GetSourceCode(const ir::Module& module) {
  if (target_.arch == Target::Arch::NVGPU) {
#ifdef CINN_WITH_CUDA
167 168 169 170
    auto _host_module_device_module_ =
        SplitCudaAndHostModule(module);  // NOLINT
    auto& host_module = std::get<0>(_host_module_device_module_);
    auto& device_module = std::get<1>(_host_module_device_module_);
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191
    CodeGenCUDA_Dev codegen(target_);
    auto source_code = codegen.Compile(device_module);
    return source_code;
#else
    CINN_NOT_IMPLEMENTED
#endif
  } else {
    CINN_NOT_IMPLEMENTED
  }
}

void Compiler::BuildDefault(const Module& module) {
  if (target_.arch == Target::Arch::NVGPU) {
    CompileCudaModule(module);
  } else if (target_.arch == Target::Arch::X86) {
    CompileX86Module(module);
  } else {
    CINN_NOT_IMPLEMENTED
  }
}

192 193
void Compiler::CompileCudaModule(const Module& module,
                                 const std::string& code) {
194 195
#ifdef CINN_WITH_CUDA
  auto _host_module_device_module_ = SplitCudaAndHostModule(module);  // NOLINT
196 197
  auto& host_module = std::get<0>(_host_module_device_module_);
  auto& device_module = std::get<1>(_host_module_device_module_);
198 199 200 201 202 203 204 205 206 207
  VLOG(3) << "[CUDA] host module:\n" << host_module;

  VLOG(3) << "[CUDA] device module:\n" << device_module;
  std::string source_code;
  if (code.empty()) {
    CodeGenCUDA_Dev codegen(target_);
    source_code = codegen.Compile(device_module);
  } else {
    source_code = code;
  }
208 209 210
  CHECK(!source_code.empty())
      << "Compile CUDA C code failed from device module:\n"
      << device_module;
211 212 213 214 215 216
  VLOG(3) << "[CUDA] C:\n" << source_code;
  SourceCodePrint::GetInstance()->write(source_code);
  using runtime::cuda::CUDAModule;

  nvrtc::Compiler compiler;
  auto ptx = compiler(source_code);
217 218 219 220 221 222
  CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n"
                      << source_code;
  cuda_module_.reset(new CUDAModule(ptx,
                                    compiler.compile_to_cubin()
                                        ? CUDAModule::Kind::CUBIN
                                        : CUDAModule::Kind::PTX));
223 224 225 226

  RuntimeSymbols symbols;
  for (auto& fn : device_module.functions()) {
    std::string kernel_fn_name = fn->name;
227
    auto fn_kernel = cuda_module_->GetFunction(0, kernel_fn_name);
228 229
    CHECK(fn_kernel);

230 231
    symbols.RegisterVar(kernel_fn_name + "_ptr_",
                        reinterpret_cast<void*>(fn_kernel));
232 233 234 235 236 237 238 239 240 241
  }

  engine_ = ExecutionEngine::Create(ExecutionOptions(), std::move(symbols));
  engine_->Link<CodeGenCUDA_Host>(host_module);

#else
  CINN_NOT_IMPLEMENTED
#endif
}

242 243 244
void Compiler::CompileX86Module(const Module& module) {
  engine_->Link<CodeGenX86>(module);
}
245

246 247 248
void Compiler::ExportObject(const std::string& path) {
  engine_->ExportObject(path);
}
249 250 251 252 253 254 255 256 257 258 259

void* Compiler::Lookup(absl::string_view fn_name) {
  CHECK(engine_);
  if (engine_->Lookup(fn_name) != nullptr) {
    return engine_->Lookup(fn_name);
  }
  return nullptr;
}

}  // namespace backends
}  // namespace cinn