parallel_compiler.cc 8.3 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38
// Copyright (c) 2022 CINN Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/cinn/hlir/framework/parallel_compiler.h"

#include <algorithm>
#include <fstream>
#include <thread>

#include "paddle/cinn/backends/codegen_cuda_dev.h"
#include "paddle/cinn/backends/codegen_cuda_host.h"
#include "paddle/cinn/backends/codegen_cuda_util.h"
#include "paddle/cinn/backends/compiler.h"
#include "paddle/cinn/backends/llvm/codegen_x86.h"
#include "paddle/cinn/backends/llvm/runtime_symbol_registry.h"
#include "paddle/cinn/backends/nvrtc/nvrtc_util.h"
#include "paddle/cinn/common/context.h"
#include "paddle/cinn/hlir/framework/pass.h"
#include "paddle/cinn/ir/module.h"
#include "paddle/cinn/runtime/flags.h"

DECLARE_int32(cinn_parallel_compile_thread);

namespace cinn {
namespace hlir {
namespace framework {

39
ParallelCompiler::CompilationResult ParallelCompiler::operator()() {
40 41 42 43 44 45 46 47 48 49 50 51 52
  if (graph_->fusion_groups.size() == 0) {
    hlir::framework::ApplyPasses(graph_.get(), {"BuildNonFusedGroupsPass"});
  }
  // Task Spilt
  SplitTask();
  // launch task
  LaunchTask();
  // merge instruction
  return MergeResult();
}

void ParallelCompiler::SplitTask() {
  CHECK(graph_->fusion_groups.size());
53 54
  CHECK(graph_->fusion_groups.size() == option_.lowered_funcs.size() ||
        option_.lowered_funcs.size() == 0);
55 56 57 58 59 60 61 62 63
  // Assign fusion_group to each task.
  // The maximum number of tasks is determined by the number of threads.
  // Fusion_group is assigned to tasks in order and continuous.
  int fusion_group_size = graph_->fusion_groups.size();
  int thread_size = FLAGS_cinn_parallel_compile_thread > 0
                        ? FLAGS_cinn_parallel_compile_thread
                        : 1;
  int group_per_task =
      (graph_->fusion_groups.size() + thread_size - 1) / thread_size;
64
  for (int idx = 0; idx < graph_->fusion_groups.size(); idx += group_per_task) {
65 66 67 68 69 70
    Task task(this, scope_, graph_, option_, target_);
    task.start_gidx = idx;
    task.stop_gidx =
        (idx + group_per_task > fusion_group_size ? fusion_group_size
                                                  : idx + group_per_task);
    tasks_.emplace_back(std::move(task));
71 72 73 74
  }
  VLOG(2) << "Split task to " << tasks_.size() << " sub-task!";
}

75
void ParallelCompiler::RunTask(ParallelCompiler::Task* task) {
76 77 78 79 80 81 82 83 84 85 86 87 88 89
  VLOG(2) << "Stark run sub-task, Thread Id : " << std::this_thread::get_id();
  VLOG(4) << "Start Lowering";
  task->Lowering();
  VLOG(4) << "Start CodegenAndJit";
  task->CodegenAndJit();
  VLOG(4) << "Start BuildInstruction";
  task->BuildInstruction();
  VLOG(2) << "Finish run sub-task, Thread Id : " << std::this_thread::get_id();
}

void ParallelCompiler::LaunchTask() {
  // start sub-task.
  std::vector<std::thread> threads;
  for (int idx = 1; idx < tasks_.size(); ++idx) {
90
    threads.emplace_back(&ParallelCompiler::RunTask, this, &tasks_[idx]);
91 92 93 94
  }

  RunTask(&tasks_[0]);
  // syncthreads.
95
  for_each(threads.begin(), threads.end(), std::mem_fn(&std::thread::join));
96 97
}

98 99
ParallelCompiler::CompilationResult ParallelCompiler::MergeResult() {
  ParallelCompiler::CompilationResult res;
100
  for (auto& task : tasks_) {
101 102 103 104 105 106 107 108 109 110 111
    for (auto& lowered_func : task.lowered_funcs) {
      res.lowered_funcs.emplace_back(lowered_func);
    }
    for (auto& source_code : task.source_codes) {
      res.source_codes.emplace_back(source_code);
    }
    for (auto& source_ptx : task.source_ptxs) {
      res.source_ptxs.emplace_back(source_ptx);
    }
    for (auto& instruction : task.instructions) {
      res.instructions.emplace_back(std::move(instruction));
112 113
    }
  }
114
  return res;
115 116 117 118 119 120
}

void ParallelCompiler::Task::Lowering() {
  if (options.lowered_funcs.size()) {
    CHECK_EQ(options.lowered_funcs.size(), graph->fusion_groups.size());
  }
121 122 123 124 125 126
  auto& dtype_dict =
      graph->GetMutableAttrs<absl::flat_hash_map<std::string, Type>>(
          "inferdtype");
  auto& shape_dict =
      graph->GetMutableAttrs<absl::flat_hash_map<std::string, shape_t>>(
          "infershape");
127 128

  OpLowerer op_lowerer(dtype_dict, shape_dict, target);
129
  for (int idx = start_gidx; idx < stop_gidx; ++idx) {
130 131 132 133 134
    if (options.lowered_funcs.size()) {
      lowered_funcs.push_back(options.lowered_funcs[idx]);
      continue;
    }
    auto& group = graph->fusion_groups[idx];
135 136
    VLOG(1) << "Start Lowering Group " << idx << " at "
            << std::this_thread::get_id() << " :\n"
137 138
            << "Group " << idx << " {\n"
            << graph->DebugGroupedGraph(group->CollectNodes()) << "}\n";
139 140 141
    auto lowered_group = op_lowerer.Lower(group);
    CHECK_EQ(lowered_group.size(), 1) << "Lowerd Function Is Not Equal 1!";
    lowered_funcs.emplace_back(std::move(lowered_group));
142 143 144 145
  }
}

void ParallelCompiler::Task::CodegenAndJit() {
146 147
  VLOG(2) << "Start Codegen and JIT with Group [" << start_gidx << "-"
          << stop_gidx << ") at thread" << std::this_thread::get_id();
148 149 150 151 152 153 154 155 156 157 158
  // build module
  ir::Module::Builder builder(common::UniqName("module"), target);
  for (auto& func : lowered_funcs) {
    CHECK_EQ(func.size(), 1);
    builder.AddFunction(func[0]);
  }

  auto ir_module = builder.Build();
  if (target == common::DefaultNVGPUTarget()) {
#ifdef CINN_WITH_CUDA
    auto splited_module = backends::SplitCudaAndHostModule(ir_module);
159 160
    auto hmodule = std::get<0>(splited_module);
    auto dmodule = std::get<1>(splited_module);
161 162 163 164 165

    VLOG(3) << "Host Code:\n" << hmodule;
    VLOG(3) << "Device Code:\n" << dmodule;
    backends::CodeGenCUDA_Dev codegen(target);
    auto cuda_c = codegen.Compile(dmodule);
166 167
    CHECK(!cuda_c.empty()) << "Compile CUDA C code failed from device module:\n"
                           << dmodule;
168 169 170 171 172 173 174

    cinn::backends::SourceCodePrint::GetInstance()->write(cuda_c);

    using runtime::cuda::CUDAModule;
    backends::nvrtc::Compiler compiler;
    auto ptx = compiler(cuda_c);
    CHECK(!ptx.empty()) << "Compile PTX failed from source code:\n" << cuda_c;
175

176
    // load cumodule
177 178 179 180
    cumodule.reset(new CUDAModule(ptx,
                                  compiler.compile_to_cubin()
                                      ? CUDAModule::Kind::CUBIN
                                      : CUDAModule::Kind::PTX));
181

182 183 184
    source_codes.emplace_back(std::move(cuda_c));
    source_ptxs.emplace_back(std::move(ptx));

185 186 187 188 189 190 191
    // register kernel
    backends::RuntimeSymbols symbols;
    for (auto& fn : dmodule.functions()) {
      auto cufunc = cumodule->GetFunction(0, fn->name);
      CHECK(cufunc);
      symbols.RegisterVar(fn->name + "_ptr_", reinterpret_cast<void*>(cufunc));
    }
192 193
    engine = backends::ExecutionEngine::Create(backends::ExecutionOptions(),
                                               std::move(symbols));
194 195 196 197 198 199 200 201 202 203
    engine->Link<backends::CodeGenCUDA_Host>(hmodule);
#endif
  } else {
    engine = backends::ExecutionEngine::Create(backends::ExecutionOptions());
    engine->Link<backends::CodeGenX86>(ir_module);
  }
}

void ParallelCompiler::Task::BuildInstruction() {
  // create instruction.
204
  for (int idx = start_gidx; idx < stop_gidx; ++idx) {
205 206
    VLOG(2) << "Start BuildInstruction of Group " << idx << " at "
            << std::this_thread::get_id();
207 208
    auto& group = graph->fusion_groups[idx];
    CHECK(group->input_names.size() > 0 || group->output_names.size() > 0);
209 210 211 212 213 214
    auto instr =
        std::unique_ptr<Instruction>(new Instruction(target,
                                                     scope.get(),
                                                     group->input_names,
                                                     group->output_names,
                                                     group->GetFuncName()));
215 216 217

    auto fn_ptr = engine->Lookup(group->GetFuncName());
    CHECK(fn_ptr) << "Can't find jit function : " << group->GetFuncName();
218 219
    instr->SetLoweredFunc(reinterpret_cast<void*>(fn_ptr),
                          group->GetFuncName());
220 221 222 223 224 225 226 227 228

    instr->Finalize();
    instructions.push_back(std::move(instr));
  }
}

}  // namespace framework
}  // namespace hlir
}  // namespace cinn