cinn_launch_context.cc 18.5 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16
#include <algorithm>
17
#include <functional>
18
#include <utility>
19
#include <vector>
20 21
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/instruction.h"
22 23 24
#include "cinn/hlir/framework/scope.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/runtime/cinn_runtime.h"
25 26
#include "cinn/runtime/intrinsic.h"
#include "paddle/fluid/framework/convert_utils.h"
27 28 29 30 31 32
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
33
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
34 35 36
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
37
#include "paddle/fluid/platform/device_context.h"
38
#include "paddle/fluid/platform/place.h"
39
#include "paddle/fluid/string/printf.h"
40
#include "paddle/phi/core/ddim.h"
41

42
namespace paddle {
43 44
namespace operators::details {

45 46 47 48 49 50 51
using framework::Scope;
using framework::LoDTensor;
using framework::ParallelExecutor;
using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnRuntimeProgram = ::cinn::hlir::framework::Program;
using framework::paddle2cinn::Name2VarInfoMap;
using framework::paddle2cinn::kMemOptVarInfoFromMainGraph;
52

53 54 55 56
CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
                                     const CinnCompiledObject& compiled_obj)
    : cinn_scope_(compiled_obj.scope) {
  // collect all names of the CINN execution arguments
57
  auto var_names = cinn_scope_->var_names();
58
  cinn_argument_names_.reserve(var_names.size());
59
  std::transform(
60
      var_names.begin(), var_names.end(),
61
      std::inserter(cinn_argument_names_, cinn_argument_names_.end()),
62
      [](const auto& name_view) { return std::string(name_view.data()); });
63
  // build name map between the original variables and compiled ones
64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96
  BuildVarNameMap(compiled_obj.paddle2cinn_varmap, cinn_argument_names_);

  const auto& input_var_names =
      graph.Get<std::vector<std::string>>(framework::paddle2cinn::kInputVars);
  const auto& output_var_names =
      graph.Get<std::vector<std::string>>(framework::paddle2cinn::kOutputVars);
  internal_var_names_ =
      ExtractInternalVarNames(input_var_names, output_var_names);
  // check completeness of output variables in compiled result
  for (auto&& var_name : output_var_names) {
    PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                      platform::errors::PreconditionNotMet(
                          "Variable(%s) not applied in CINN", var_name));
  }

  // initialize all execution arguments
  InitializeArguments();
  // DEPRECATED(CtfGo): following callback assignment will be deprecated soon
  for (auto&& var_name : input_var_names) {
    if (IsVariableUsed(var_name)) {
      AssignExternalVariable(var_name);
    }
  }
  for (auto&& var_name : output_var_names) {
    AssignExternalVariable(var_name);
  }
  for (auto&& var_name : internal_var_names_) {
    AssignInternalVariable(var_name);
  }

  // Convert the CINN runtime program to a Paddle graph
  runtime_graph_ = std::make_unique<framework::ir::Graph>(
      BuildCompiledProgram(graph, compiled_obj));
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120
  auto& outer_varinfo = graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
  runtime_graph_->SetNotOwned<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph,
                                               &outer_varinfo);
  // collect skip_eager_vars
  skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size());
  auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) {
    // if a var exists at outer_varinfo map,
    // that means it can be erased after graph execution
    if (!outer_varinfo.count(var_name)) {
      skip_eager_vars_.emplace_back(var_name);
    }
  };
  std::for_each(input_var_names.begin(), input_var_names.end(),
                add_skip_var_fn);
  std::for_each(output_var_names.begin(), output_var_names.end(),
                add_skip_var_fn);
  VLOG(4) << string::Sprintf(
      "Distribution of variables in the graph compiled:"
      "input[%lu],internal[%lu],output[%lu],"
      "outer_eager_deletion[%lu],skip_eager_deletion[%lu],"
      "initialized_beforehand[%lu]",
      input_var_names.size(), internal_var_names_.size(),
      output_var_names.size(), outer_varinfo.size(), skip_eager_vars_.size(),
      initialized_beforehand_vars_.size());
121 122 123 124 125 126 127 128 129 130 131 132 133
}

void CinnLaunchContext::BuildVarNameMap(
    const std::unordered_map<std::string, std::string>& compiled_varmap,
    const std::unordered_set<std::string>& argument_names) {
  for (const auto& x : compiled_varmap) {
    if (!argument_names.count(x.second)) {
      // exclude variables not used
      continue;
    }
    // copy to local paddle2cinn map
    paddle2cinn_varmap_.emplace(x.first, x.second);
    // add an entry to local cinn2paddle map reversely
134 135 136 137 138 139 140
    auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
    PADDLE_ENFORCE_EQ(
        res.second, true,
        platform::errors::InvalidArgument(
            "Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
            x.second, res.first->second, x.first));
  }
141 142 143 144
  // supplement the relations of the remain variables
  // not appearing in above map, which are internal variables
  // and here we use the names from cinn compiled.
  for (const auto& var_name : argument_names) {
145 146 147 148 149
    if (!cinn2paddle_varmap_.count(var_name)) {
      cinn2paddle_varmap_.emplace(var_name, var_name);
      paddle2cinn_varmap_.emplace(var_name, var_name);
    }
  }
150 151 152 153 154 155

  PADDLE_ENFORCE_EQ(
      paddle2cinn_varmap_.size(), cinn2paddle_varmap_.size(),
      platform::errors::PreconditionNotMet(
          "Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
          paddle2cinn_varmap_.size(), cinn2paddle_varmap_.size()));
156 157
}

158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

174 175
bool CinnLaunchContext::IsVariableUsed(const std::string& var_name) const {
  return paddle2cinn_varmap_.count(var_name) > 0;
176 177
}

178 179 180 181 182
CinnTensor CinnLaunchContext::GetCinnTensorOfVar(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(
      IsVariableUsed(var_name), true,
      platform::errors::NotFound("Variable(%s) not applied in CINN", var_name));
  const auto& arg_name = paddle2cinn_varmap_.at(var_name);
183
  return cinn_scope_->GetTensor(arg_name);
184 185
}

186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203
std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames(
    const std::vector<std::string>& input_var_names,
    const std::vector<std::string>& output_var_names) {
  std::unordered_set<std::string> remain_var_names;
  remain_var_names.reserve(paddle2cinn_varmap_.size());
  std::transform(paddle2cinn_varmap_.begin(), paddle2cinn_varmap_.end(),
                 std::inserter(remain_var_names, remain_var_names.end()),
                 [](const auto& name_pair) { return name_pair.first; });

  // exclude the input variables and output variables
  auto exclude_names_fn = [&remain_var_names](const std::string& var_name) {
    remain_var_names.erase(var_name);
  };
  std::for_each(input_var_names.begin(), input_var_names.end(),
                exclude_names_fn);
  std::for_each(output_var_names.begin(), output_var_names.end(),
                exclude_names_fn);
  return remain_var_names;
204 205
}

206 207 208 209 210
void CinnLaunchContext::CheckTensorEquivalent(
    const std::string& var_name, const framework::LoDTensor& paddle_tensor) {
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
211
  // check dimension
212
  auto cinn_tensor = GetCinnTensorOfVar(var_name);
213
  auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
214 215 216
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(), cinn_dims,
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
217
                        "paddle is = [%s], but cinn is = [%s].",
218
                        var_name, paddle_tensor.dims(), cinn_dims));
219

220 221 222 223 224 225 226
  auto cinn_dtype =
      framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
  PADDLE_ENFORCE_EQ(paddle_tensor.dtype(), cinn_dtype,
                    platform::errors::PreconditionNotMet(
                        "Tensors' dtype in variable(%s) are not equivalent, "
                        "paddle is = [%s], but cinn is = [%s].",
                        var_name, paddle_tensor.dtype(), cinn_dtype));
227 228
}

229 230 231 232 233 234 235
void CinnLaunchContext::InitializeArguments() {
  for (auto&& arg : cinn_argument_names_) {
    auto cinn_buffer = std::make_unique<cinn_buffer_t>();
    auto cinn_tensor = GetCinnTensorOfVar(cinn2paddle_varmap_.at(arg));
    // assign dimensions with corresponding compiled tensor
    cinn_buffer->resize(cinn_tensor->shape().data().data(),
                        cinn_tensor->shape().data().size());
236
    cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type());
237
    VLOG(4) << string::Sprintf(
238
        "Append an argument:name(%s),dims(%s),type(%s)",
239
        framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
240
        cinn_tensor->type());
241 242 243
    name2argument_.emplace(arg, cinn_buffer.get());
    hold_buffers_.emplace_back(std::move(cinn_buffer));
  }
244
  VLOG(4) << "Total argument size:" << name2argument_.size();
245 246
}

247 248 249 250
void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
251 252
  auto* cinn_buffer = GetCinnBufferOfVar(var_name);
  // assign external malloc/free callbacks of cinn_buffer_t
253
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
254 255
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor = cached_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
256 257 258 259 260 261 262 263 264 265 266 267
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });
268
}
269

270 271 272 273
void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name), true,
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
274 275
  auto* cinn_buffer = GetCinnBufferOfVar(var_name);
  // assign external malloc/free callbacks of cinn_buffer_t
276
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
277
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
278
        auto* tensor =
279
            cached_temp_scope_->Var(var_name)->GetMutable<LoDTensor>();
280 281 282
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
        buffer->memory = reinterpret_cast<uint8_t*>(
            tensor->mutable_data<float>(*cached_place_));
283 284 285
        return 0;
      });

286 287
  // internal variables should release its buffer immediately
  // if no instruction use it
288
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
289
      [this, var_name](void* ctx, cinn_buffer_t* buffer) {
290
        auto* tensor =
291
            cached_temp_scope_->GetVar(var_name)->GetMutable<LoDTensor>();
292
        tensor->clear();
293 294
        return 0;
      });
295 296
}

297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323
framework::ProgramDesc CinnLaunchContext::BuildCompiledProgram(
    const framework::ir::Graph& graph, const CinnCompiledObject& compiled_obj) {
  CinnRuntimeProgram* runtime_program = compiled_obj.runtime_program.get();
  // Step 0: Create an empty program_desc, there will be only one block
  framework::ProgramDesc program_desc;
  auto* block = program_desc.MutableBlock(0);
  const std::vector<std::unique_ptr<CinnInstruction>>& instructions =
      runtime_program->GetRunInstructions();

  // build a map that links the name of a Paddle variable to its VarDesc
  const std::unordered_set<framework::ir::Node*>& nodes = graph.Nodes();
  std::unordered_map<std::string, framework::VarDesc*> original_vardescs;
  for (auto* node : nodes) {
    if (node->IsVar() && node->Var()) {
      original_vardescs.emplace(node->Name(), node->Var());
    }
  }

  // Step 1: Create a VarDesc for each execution argument:
  //   (1) For those variables that are input or output variables of the
  //   original subgraph, there must exist an original VarDesc, so
  //   we copy some useful info(such as IsParameter,Persistable)
  //   to the new VarDesc.
  //   (2) For all variables, the shape, data type of their VarDescs
  //   are set by values of the corresponding compiled tensors,
  //   including the in/out variables where the equiality between their tensors
  //   and the CINN compiled ones is verified in corresponding cinn_launch_op.
324
  std::unordered_set<std::string> has_refer_vars;
325 326 327 328 329 330 331 332 333 334
  for (auto&& arg : cinn_argument_names_) {
    const std::string& var_name = cinn2paddle_varmap_.at(arg);
    framework::VarDesc* var_desc = block->Var(var_name);
    var_desc->SetType(framework::proto::VarType::LOD_TENSOR);

    auto res = original_vardescs.find(var_name);
    if (res != original_vardescs.end()) {
      auto* ori_desc = res->second;
      var_desc->SetPersistable(ori_desc->Persistable());
      var_desc->SetIsParameter(ori_desc->IsParameter());
335
      has_refer_vars.insert(var_name);
336 337 338
    }

    auto cinn_tensor = GetCinnTensorOfVar(var_name);
339 340
    var_desc->SetDataType(framework::TransToProtoVarType(
        framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type())));
341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367
    var_desc->SetShape(std::vector<int64_t>(cinn_tensor->shape().data().begin(),
                                            cinn_tensor->shape().data().end()));
  }

  // transform names of the input or output arguments of a CINN instruction
  // to the corresponding Paddle variable names, and repack them as one vector
  auto trans_and_pack_args_fn =
      [this](const std::vector<std::vector<std::string>>& cinn_args_array) {
        std::vector<std::string> var_names;
        for (auto&& cinn_args : cinn_args_array) {
          for (auto&& arg : cinn_args) {
            auto res = cinn2paddle_varmap_.find(arg);
            PADDLE_ENFORCE_NE(
                res, cinn2paddle_varmap_.end(),
                platform::errors::NotFound("Argument(%s) not found", arg));
            var_names.emplace_back(res->second);
          }
        }
        return var_names;
      };

  // Step 2: create a VarDesc of cinn_instruction_run op for
  //         each CINN instruction and append it to the main block
  for (auto ins_idx = 0; ins_idx < instructions.size(); ++ins_idx) {
    auto* ins = instructions.at(ins_idx).get();
    auto in_args = trans_and_pack_args_fn(ins->GetInArgs());
    auto out_args = trans_and_pack_args_fn(ins->GetOutArgs());
368 369 370 371 372 373
    for (auto&& var_name : in_args) {
      if (!has_refer_vars.count(var_name)) {
        initialized_beforehand_vars_.emplace_back(var_name);
      }
    }
    has_refer_vars.insert(out_args.begin(), out_args.end());
374 375 376 377 378 379 380 381 382 383 384

    auto* op_desc = block->AppendOp();
    op_desc->SetType("cinn_instruction_run");
    op_desc->SetInput(kX, in_args);
    op_desc->SetOutput(kOutputs, out_args);
    op_desc->SetAttr(kCachedIndex,
                     {static_cast<int64_t>(compiled_obj.cached_index)});
    op_desc->SetAttr(kInstructionIndex, {static_cast<int64_t>(ins_idx)});
  }

  return program_desc;
385 386
}

387 388 389 390
ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
                                                  framework::Scope* scope) {
  if (!parallel_executor_) {
    framework::details::ExecutionStrategy exec_strategy;
391 392
    exec_strategy.num_threads_ = 1;
    exec_strategy.use_device_ = platform::Place2DeviceType(place);
393 394 395 396 397 398
    framework::details::BuildStrategy build_strategy;
    parallel_executor_ = std::make_unique<ParallelExecutor>(
        place, scope, exec_strategy, build_strategy, runtime_graph_.get());
  }

  // update the scope bound to an OpHandle and rebuild temporary variables
399
  VLOG(4) << "Reset scope and initialize temporary variables";
400 401 402 403
  std::unordered_map<Scope*, Scope*> scope_map = {
      {parallel_executor_->GetLocalScopes().front(), scope}};
  parallel_executor_->ResetOpHandleScopeMapOfGraphs(scope_map);
  parallel_executor_->PrepareVariables(scope);
404 405 406 407 408 409 410
  for (auto&& var_name : initialized_beforehand_vars_) {
    auto* var = scope->GetVar(var_name);
    auto* buffer = GetCinnBufferOfVar(var_name);
    auto dim = framework::DDim(buffer->dims, buffer->dimensions);
    var->GetMutable<LoDTensor>()->Resize(dim);
    var->GetMutable<LoDTensor>()->mutable_data<float>(place);
  }
411
  return parallel_executor_.get();
412 413
}

414
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
415 416
    const std::string& var_name) {
  auto it = paddle2cinn_varmap_.find(var_name);
417
  PADDLE_ENFORCE_NE(
418
      it, paddle2cinn_varmap_.end(),
419
      platform::errors::InvalidArgument(
420 421 422 423 424 425
          "Variable(%s) not found in compilation result", var_name));
  auto res = name2argument_.find(it->second);
  PADDLE_ENFORCE_NE(res, name2argument_.end(),
                    platform::errors::NotFound(
                        "Argument(%s) not be initialized", it->second));
  return static_cast<cinn_buffer_t*>(res->second);
426 427
}

428
}  // namespace operators::details
429
}  // namespace paddle