parallel_compiler.cc 8.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/framework/parallel_compiler.h"

#include <algorithm>
#include <fstream>
#include <thread>

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/common/context.h"
29
#include "paddle/cinn/hlir/framework/graph_compiler_util.h"
30 31 32 33 34 35 36 37 38 39
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/flags.h"

DECLARE_int32(cinn_parallel_compile_thread);

namespace cinn {
namespace hlir {
namespace framework {

40 41 42 43
CompilationResult ParallelCompiler::operator()() {
  if (context_->graph->fusion_groups.empty()) {
    hlir::framework::ApplyPasses(context_->graph.get(),
                                 {"BuildNonFusedGroupsPass"});
44
  }
45 46 47
  // init compilation result
  result_.InitCompilationResult(context_->graph->fusion_groups.size());
  // task spilt
48 49 50
  SplitTask();
  // launch task
  LaunchTask();
51 52
  // return compilation result
  return std::move(result_);
53 54 55
}

void ParallelCompiler::SplitTask() {
56 57 58 59 60
  CHECK(!context_->graph->fusion_groups.empty());
  CHECK(context_->lowered_funcs.empty() ||
        context_->graph->fusion_groups.size() ==
            context_->lowered_funcs.size());
  for (int i = 0; i < context_->graph->fusion_groups.size(); ++i) {
61
    tasks_.emplace_back(i, this, context_);
62 63 64
  }
}

65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
void ParallelCompiler::RunTask() {
  while (true) {
    int idx = GetTaskIdx();
    if (idx < 0) {
      return;
    }
    VLOG(4) << "Start run task " << idx
            << " on thread: " << std::this_thread::get_id();
    VLOG(4) << "Start lowering on task " << idx;
    tasks_[idx].Lowering();
    if (context_->stage == CompilationStage::LOWERING) {
      VLOG(4) << "Just lowering, finish task " << idx
              << " on thread: " << std::this_thread::get_id();
      return;
    }
    VLOG(4) << "Start CodegenAndJit";
    tasks_[idx].CodegenAndJit();
    if (context_->stage == CompilationStage::CODEGEN_AND_JIT) {
      VLOG(4) << "Just codegen and jit, finish task " << idx
              << " on thread: " << std::this_thread::get_id();
      return;
    }
    VLOG(4) << "Start BuildInstruction";
    tasks_[idx].BuildInstruction();
    if (context_->stage == CompilationStage::BUILD_INSTRUCTION) {
      VLOG(4) << "Just build instruction, finish task " << idx
              << " on thread: " << std::this_thread::get_id();
      return;
    }
    VLOG(4) << "Finish task " << idx
            << " on thread: " << std::this_thread::get_id();
  }
97 98 99
}

void ParallelCompiler::LaunchTask() {
100
  // multi thread compilation
101
  std::vector<std::thread> threads;
102 103 104 105
  VLOG(4) << "Compile with " << FLAGS_cinn_parallel_compile_thread
          << " threads";
  for (int idx = 1; idx < FLAGS_cinn_parallel_compile_thread; ++idx) {
    threads.emplace_back(&ParallelCompiler::RunTask, this);
106 107
  }

108
  RunTask();
109
  // syncthreads.
110
  for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join));
111 112 113
}

void ParallelCompiler::Task::Lowering() {
114 115 116 117 118 119
  if (!context->lowered_funcs.empty()) {
    CHECK_EQ(context->lowered_funcs.size(),
             context->graph->fusion_groups.size());
    pcompiler->result_.lowered_funcs[group_id] =
        context->lowered_funcs[group_id];
  } else {
120 121 122 123 124 125 126 127
    auto& dtype_dict =
        context->graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
            "inferdtype");
    auto& shape_dict =
        context->graph
            ->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
                "infershape");
    OpLowerer op_lowerer(dtype_dict, shape_dict, context->target);
128 129
    auto& group = context->graph->fusion_groups[group_id];
    VLOG(4) << "Start Lowering Group " << group_id << " at "
130
            << std::this_thread::get_id() << " :\n"
131 132 133
            << "Group " << group_id << " {\n"
            << context->graph->DebugGroupedGraph(group->CollectNodes())
            << "}\n";
134 135
    auto lowered_group = op_lowerer.Lower(group);
    CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
136
    pcompiler->result_.lowered_funcs[group_id] = std::move(lowered_group);
137
  }
138 139
  backends::CompilationInfoDumper::DumpLoweredFuncByGroupIndex(
      pcompiler->result_.lowered_funcs[group_id].front(), group_id);
140 141 142
}

void ParallelCompiler::Task::CodegenAndJit() {
143 144
  VLOG(2) << "Start Codegen and JIT on Group " << group_id
          << " at thread: " << std::this_thread::get_id();
145
  // build module
146 147 148
  ir::Module::Builder builder(common::UniqName("module"), context->target);
  for (auto& func : pcompiler->result_.lowered_funcs[group_id]) {
    builder.AddFunction(func);
149 150 151
  }

  auto ir_module = builder.Build();
152
  if (context->target == common::DefaultNVGPUTarget()) {
153 154
#ifdef CINN_WITH_CUDA
    auto splited_module = backends::SplitCudaAndHostModule(ir_module);
155 156
    auto hmodule = std::get<0>(splited_module);
    auto dmodule = std::get<1>(splited_module);
157

158 159 160 161 162 163 164 165 166 167
    VLOG(4) << "Host Code:\n" << hmodule;
    VLOG(4) << "Device Code:\n" << dmodule;
    std::string cuda_c;
    if (context->attached_source_code.empty()) {
      backends::CodeGenCUDA_Dev codegen(context->target);
      cuda_c = codegen.Compile(dmodule);
    } else {
      VLOG(4) << "Codegen and jit with attached source code.";
      cuda_c = context->attached_source_code;
    }
168 169
    CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
                           << dmodule;
170 171
    backends::CompilationInfoDumper::DumpSourceCodeByGroupIndex(cuda_c,
                                                                group_id);
172
    pcompiler->result_.source_codes[group_id] = cuda_c;
173 174 175 176 177 178 179

    cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);

    using runtime::cuda::CUDAModule;
    backends::nvrtc::Compiler compiler;
    auto ptx = compiler(cuda_c);
    CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
180
    backends::CompilationInfoDumper::DumpPtxCodeByGroupIndex(ptx, group_id);
181
    pcompiler->result_.source_ptxs[group_id] = ptx;
182
    // load cumodule
183 184 185 186
    cumodule = std::make_unique<CUDAModule>(ptx,
                                            compiler.compile_to_cubin()
                                                ? CUDAModule::Kind::CUBIN
                                                : CUDAModule::Kind::PTX);
187

188 189 190 191 192 193 194
    // register kernel
    backends::RuntimeSymbols symbols;
    for (auto& fn : dmodule.functions()) {
      auto cufunc = cumodule->GetFunction(0, fn->name);
      CHECK(cufunc);
      symbols.RegisterVar(fn->name + "_ptr_", reinterpret_cast<void*>(cufunc));
    }
195 196
    engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(),
                                               std::move(symbols));
197 198 199 200 201 202 203 204 205 206
    engine->Link<backends::CodeGenCUDA_Host>(hmodule);
#endif
  } else {
    engine = backends::ExecutionEngine::Create(backends::ExecutionOptions());
    engine->Link<backends::CodeGenX86>(ir_module);
  }
}

void ParallelCompiler::Task::BuildInstruction() {
  // create instruction.
207 208 209 210 211 212 213 214 215 216 217 218 219 220 221
  VLOG(4) << "Start BuildInstruction of Group " << group_id
          << " at thread: " << std::this_thread::get_id();
  auto& group = context->graph->fusion_groups[group_id];
  CHECK(!group->input_names.empty() || !group->output_names.empty());
  auto instr = std::make_unique<Instruction>(context->target,
                                             context->scope.get(),
                                             group->input_names,
                                             group->output_names,
                                             group->GetFuncName());

  auto fn_ptr = engine->Lookup(group->GetFuncName());
  CHECK(fn_ptr) << "Can't find jit function : " << group->GetFuncName();
  instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr), group->GetFuncName());

  instr->Finalize();
222
  backends::CompilationInfoDumper::DumpInstructionByGroupIndex(instr, group_id);
223 224 225 226 227 228 229 230 231
  pcompiler->result_.instructions[group_id] = std::move(instr);
}

int ParallelCompiler::GetTaskIdx() {
  std::lock_guard<std::mutex> lock(mtx_);
  if (task_idx_ < tasks_.size()) {
    return task_idx_++;
  } else {
    return -1;
232 233 234 235 236 237
  }
}

}  // namespace framework
}  // namespace hlir
}  // namespace cinn