cinn_launch_context.cc 21.8 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include "paddle/fluid/operators/cinn/cinn_launch_context.h"
16

17
#include <algorithm>
18
#include <functional>
19
#include <utility>
20
#include <vector>
21

22
#include "cinn/frontend/op_mapper_registry.h"
23 24
#include "cinn/hlir/framework/graph_compiler.h"
#include "cinn/hlir/framework/instruction.h"
25 26 27
#include "cinn/hlir/framework/scope.h"
#include "cinn/hlir/framework/tensor.h"
#include "cinn/runtime/cinn_runtime.h"
28 29
#include "cinn/runtime/intrinsic.h"
#include "paddle/fluid/framework/convert_utils.h"
30 31 32 33 34 35
#include "paddle/fluid/framework/details/build_strategy.h"
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/paddle2cinn/build_cinn_pass.h"
#include "paddle/fluid/framework/paddle2cinn/cinn_compiler.h"
36
#include "paddle/fluid/framework/paddle2cinn/transform_type.h"
37 38
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h"
39
#include "paddle/fluid/framework/variable_helper.h"
40
#include "paddle/fluid/operators/cinn/cinn_op_helper.h"
41
#include "paddle/fluid/platform/device_context.h"
42
#include "paddle/fluid/platform/place.h"
43
#include "paddle/fluid/string/printf.h"
44
#include "paddle/phi/core/ddim.h"
45
#include "paddle/utils/string/string_helper.h"
46

47
namespace paddle {
48 49
namespace operators::details {

50
using framework::ParallelExecutor;
51
using framework::Scope;
52 53
using CinnInstruction = ::cinn::hlir::framework::Instruction;
using CinnRuntimeProgram = ::cinn::hlir::framework::Program;
54 55
using ::cinn::frontend::paddle::InplaceOutSuffix;
using framework::paddle2cinn::kInplaceVarNames;
56
using framework::paddle2cinn::kMemOptVarInfoFromMainGraph;
57
using framework::paddle2cinn::kSkipGcVarNames;
58
using framework::paddle2cinn::Name2VarInfoMap;
59

60 61 62 63
CinnLaunchContext::CinnLaunchContext(const framework::ir::Graph& graph,
                                     const CinnCompiledObject& compiled_obj)
    : cinn_scope_(compiled_obj.scope) {
  // collect all names of the CINN execution arguments
64
  auto var_names = cinn_scope_->var_names();
65
  cinn_argument_names_.reserve(var_names.size());
66
  std::transform(
67 68
      var_names.begin(),
      var_names.end(),
69
      std::inserter(cinn_argument_names_, cinn_argument_names_.end()),
70
      [](const auto& name_view) { return std::string(name_view.data()); });
71
  // build name map between the original variables and compiled ones
72 73 74 75 76 77
  BuildVarNameMap(compiled_obj.paddle2cinn_varmap, cinn_argument_names_);

  const auto& input_var_names =
      graph.Get<std::vector<std::string>>(framework::paddle2cinn::kInputVars);
  const auto& output_var_names =
      graph.Get<std::vector<std::string>>(framework::paddle2cinn::kOutputVars);
78 79
  inplace_var_names_ =
      graph.Get<std::unordered_set<std::string>>(kInplaceVarNames);
80 81 82 83 84 85 86 87 88 89 90
  internal_var_names_ =
      ExtractInternalVarNames(input_var_names, output_var_names);
  // initialize all execution arguments
  InitializeArguments();
  // DEPRECATED(CtfGo): following callback assignment will be deprecated soon
  for (auto&& var_name : input_var_names) {
    if (IsVariableUsed(var_name)) {
      AssignExternalVariable(var_name);
    }
  }
  for (auto&& var_name : output_var_names) {
91 92 93 94 95 96 97
    if (inplace_var_names_.count(var_name)) {
      VLOG(4) << "Inplaced variable:" << var_name << " -> "
              << var_name + InplaceOutSuffix << " as paddle2cinn varmap key";
      AssignExternalVariable(var_name + InplaceOutSuffix);
    } else {
      AssignExternalVariable(var_name);
    }
98 99 100 101 102 103
  }
  for (auto&& var_name : internal_var_names_) {
    AssignInternalVariable(var_name);
  }

  // Convert the CINN runtime program to a Paddle graph
104 105 106
  runtime_program_desc_ = BuildCompiledProgram(graph, compiled_obj);
  runtime_graph_ =
      std::make_unique<framework::ir::Graph>(*runtime_program_desc_.get());
107 108 109
  auto& outer_varinfo = graph.Get<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph);
  runtime_graph_->SetNotOwned<Name2VarInfoMap>(kMemOptVarInfoFromMainGraph,
                                               &outer_varinfo);
110 111 112 113 114 115 116 117 118 119
  // use kSkipGcVarNames attr of graph to initialize skip_gc_vars_
  if (graph.Has(kSkipGcVarNames)) {
    const auto& skip_gc_vars =
        graph.Get<std::unordered_set<std::string>>(kSkipGcVarNames);
    skip_gc_vars_.insert(skip_gc_vars.begin(), skip_gc_vars.end());
    VLOG(4) << "Append skip_gc_vars:["
            << string::join_strings(skip_gc_vars, ',') << "]";
  }

  // collect variables name list to be skipped in GC
120 121
  skip_eager_vars_.reserve(input_var_names.size() + output_var_names.size());
  auto add_skip_var_fn = [&outer_varinfo, this](const std::string& var_name) {
122 123
    // if a var exists at the outer_varinfo map, that means it will be
    // erased by the following eager_deletion_op of current cinn_launch op
124 125
    if (!outer_varinfo.count(var_name)) {
      skip_eager_vars_.emplace_back(var_name);
126
      skip_gc_vars_.insert(var_name);
127
      VLOG(4) << "Append a skip_gc_var:" << var_name;
128 129
    }
  };
130 131 132 133
  std::for_each(
      input_var_names.begin(), input_var_names.end(), add_skip_var_fn);
  std::for_each(
      output_var_names.begin(), output_var_names.end(), add_skip_var_fn);
134 135 136 137
  VLOG(4) << string::Sprintf(
      "Distribution of variables in the graph compiled:"
      "input[%lu],internal[%lu],output[%lu],"
      "outer_eager_deletion[%lu],skip_eager_deletion[%lu],"
138
      "skip_gc_vars_[%lu]",
139 140 141 142 143
      input_var_names.size(),
      internal_var_names_.size(),
      output_var_names.size(),
      outer_varinfo.size(),
      skip_eager_vars_.size(),
144
      skip_gc_vars_.size());
145 146 147 148 149 150 151 152 153 154 155 156 157
}

void CinnLaunchContext::BuildVarNameMap(
    const std::unordered_map<std::string, std::string>& compiled_varmap,
    const std::unordered_set<std::string>& argument_names) {
  for (const auto& x : compiled_varmap) {
    if (!argument_names.count(x.second)) {
      // exclude variables not used
      continue;
    }
    // copy to local paddle2cinn map
    paddle2cinn_varmap_.emplace(x.first, x.second);
    // add an entry to local cinn2paddle map reversely
158 159
    auto res = cinn2paddle_varmap_.emplace(x.second, x.first);
    PADDLE_ENFORCE_EQ(
160 161
        res.second,
        true,
162 163
        platform::errors::InvalidArgument(
            "Cinn variable(%s) maps to more than one paddle variable(%s,%s)",
164 165 166
            x.second,
            res.first->second,
            x.first));
167
  }
168 169 170 171
  // supplement the relations of the remain variables
  // not appearing in above map, which are internal variables
  // and here we use the names from cinn compiled.
  for (const auto& var_name : argument_names) {
172 173 174 175 176
    if (!cinn2paddle_varmap_.count(var_name)) {
      cinn2paddle_varmap_.emplace(var_name, var_name);
      paddle2cinn_varmap_.emplace(var_name, var_name);
    }
  }
177 178

  PADDLE_ENFORCE_EQ(
179 180
      paddle2cinn_varmap_.size(),
      cinn2paddle_varmap_.size(),
181 182
      platform::errors::PreconditionNotMet(
          "Size of variables is not euqal, paddle[%ld] vs cinn[%ld]",
183 184
          paddle2cinn_varmap_.size(),
          cinn2paddle_varmap_.size()));
185 186
}

187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202
void CinnLaunchContext::UpdateCapturedEnv(const framework::Scope& scope,
                                          const platform::Place& place) {
  if (std::addressof(scope) == cached_scope_ &&
      std::addressof(place) == cached_place_) {
    VLOG(4) << "Captured scope:" << cached_scope_ << ", place:" << cached_place_
            << " are not changed";
    return;
  }
  cached_scope_ = std::addressof(scope);
  cached_place_ = std::addressof(place);
  cached_temp_scope_ = scope.NewTmpScope();
  VLOG(4) << "Captured env is update, scope:" << cached_scope_ << "->"
          << std::addressof(scope) << ", place:" << cached_place_ << "->"
          << std::addressof(place);
}

203 204
bool CinnLaunchContext::IsVariableUsed(const std::string& var_name) const {
  return paddle2cinn_varmap_.count(var_name) > 0;
205 206
}

207 208
CinnTensor CinnLaunchContext::GetCinnTensorOfVar(const std::string& var_name) {
  PADDLE_ENFORCE_EQ(
209 210
      IsVariableUsed(var_name),
      true,
211 212
      platform::errors::NotFound("Variable(%s) not applied in CINN", var_name));
  const auto& arg_name = paddle2cinn_varmap_.at(var_name);
213
  return cinn_scope_->GetTensor(arg_name);
214 215
}

216 217 218 219 220
std::unordered_set<std::string> CinnLaunchContext::ExtractInternalVarNames(
    const std::vector<std::string>& input_var_names,
    const std::vector<std::string>& output_var_names) {
  std::unordered_set<std::string> remain_var_names;
  remain_var_names.reserve(paddle2cinn_varmap_.size());
221 222
  std::transform(paddle2cinn_varmap_.begin(),
                 paddle2cinn_varmap_.end(),
223 224 225 226
                 std::inserter(remain_var_names, remain_var_names.end()),
                 [](const auto& name_pair) { return name_pair.first; });

  // exclude the input variables and output variables
227 228
  auto exclude_names_fn = [this,
                           &remain_var_names](const std::string& var_name) {
229
    remain_var_names.erase(var_name);
230 231 232
    if (inplace_var_names_.count(var_name)) {
      remain_var_names.erase(var_name + InplaceOutSuffix);
    }
233
  };
234 235 236 237
  std::for_each(
      input_var_names.begin(), input_var_names.end(), exclude_names_fn);
  std::for_each(
      output_var_names.begin(), output_var_names.end(), exclude_names_fn);
238
  return remain_var_names;
239 240
}

241
void CinnLaunchContext::CheckTensorEquivalent(
242
    const std::string& var_name, const phi::DenseTensor& paddle_tensor) {
243 244
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name),
                    true,
245 246
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
247
  // check dimension
248
  auto cinn_tensor = GetCinnTensorOfVar(var_name);
249
  auto cinn_dims = phi::make_ddim(cinn_tensor->shape().data());
250 251
  PADDLE_ENFORCE_EQ(paddle_tensor.dims(),
                    cinn_dims,
252 253
                    platform::errors::PreconditionNotMet(
                        "Tensors' shape in variable(%s) are not equivalent, "
254
                        "paddle is = [%s], but cinn is = [%s].",
255 256 257
                        var_name,
                        paddle_tensor.dims(),
                        cinn_dims));
258

259 260
  auto cinn_dtype =
      framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type());
261 262
  PADDLE_ENFORCE_EQ(paddle_tensor.dtype(),
                    cinn_dtype,
263 264 265
                    platform::errors::PreconditionNotMet(
                        "Tensors' dtype in variable(%s) are not equivalent, "
                        "paddle is = [%s], but cinn is = [%s].",
266 267 268
                        var_name,
                        paddle_tensor.dtype(),
                        cinn_dtype));
269 270
}

271 272 273 274 275 276 277
void CinnLaunchContext::InitializeArguments() {
  for (auto&& arg : cinn_argument_names_) {
    auto cinn_buffer = std::make_unique<cinn_buffer_t>();
    auto cinn_tensor = GetCinnTensorOfVar(cinn2paddle_varmap_.at(arg));
    // assign dimensions with corresponding compiled tensor
    cinn_buffer->resize(cinn_tensor->shape().data().data(),
                        cinn_tensor->shape().data().size());
278
    cinn_buffer->type = cinn::runtime::ToRuntimeType(cinn_tensor->type());
279
    VLOG(4) << string::Sprintf(
280 281
        "Append an argument:name(%s),dims(%s),type(%s)",
        arg,
282
        framework::DDim(cinn_buffer->dims, cinn_buffer->dimensions).to_str(),
283
        cinn_tensor->type());
284
    name2argument_.emplace(arg, cinn_buffer.get());
285 286
    auto pdvar2cinnbuf_ = cinn2paddle_varmap_.at(arg);
    paddle2argument_.emplace(pdvar2cinnbuf_, cinn_buffer.get());
287 288
    hold_buffers_.emplace_back(std::move(cinn_buffer));
  }
289
  VLOG(4) << "Total argument size:" << name2argument_.size();
290 291
}

292
void CinnLaunchContext::AssignExternalVariable(const std::string& var_name) {
293 294
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name),
                    true,
295 296
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
297
  auto* cinn_buffer = GetCinnBufferOfVar(var_name);
298
  std::string revise_var_name = RedirectVarName(var_name);
299
  // assign external malloc/free callbacks of cinn_buffer_t
300
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
301 302 303
      [this, revise_var_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor = cached_scope_->GetVar(revise_var_name)
                           ->GetMutable<phi::DenseTensor>();
304
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
305 306 307
        buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
            *cached_place_,
            framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
308 309 310 311 312 313 314 315 316
        return 0;
      });

  // external variables will be recycled by global gc, so do nothing here
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
      [](void* ctx, cinn_buffer_t* buffer) {
        // Do nothing
        return 0;
      });
317
}
318

319
void CinnLaunchContext::AssignInternalVariable(const std::string& var_name) {
320 321
  PADDLE_ENFORCE_EQ(IsVariableUsed(var_name),
                    true,
322 323
                    platform::errors::InvalidArgument(
                        "Variable(%s) not applied in cinn", var_name));
324
  auto* cinn_buffer = GetCinnBufferOfVar(var_name);
325
  std::string revise_var_name = RedirectVarName(var_name);
326
  // assign external malloc/free callbacks of cinn_buffer_t
327
  cinn_buffer->external_malloc = new std::function<int(void*, cinn_buffer_t*)>(
328 329 330
      [this, revise_var_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor = cached_temp_scope_->Var(revise_var_name)
                           ->GetMutable<phi::DenseTensor>();
331
        tensor->Resize(framework::DDim(buffer->dims, buffer->dimensions));
332 333 334
        buffer->memory = reinterpret_cast<uint8_t*>(tensor->mutable_data(
            *cached_place_,
            framework::paddle2cinn::TransToPaddleDataType(buffer->type)));
335 336 337
        return 0;
      });

338 339
  // internal variables should release its buffer immediately
  // if no instruction use it
340
  cinn_buffer->external_free = new std::function<int(void*, cinn_buffer_t*)>(
341 342
      [this, revise_var_name](void* ctx, cinn_buffer_t* buffer) {
        auto* tensor = cached_temp_scope_->GetVar(revise_var_name)
343
                           ->GetMutable<phi::DenseTensor>();
344
        tensor->clear();
345 346
        return 0;
      });
347 348
}

349
std::unique_ptr<framework::ProgramDesc> CinnLaunchContext::BuildCompiledProgram(
350 351 352
    const framework::ir::Graph& graph, const CinnCompiledObject& compiled_obj) {
  CinnRuntimeProgram* runtime_program = compiled_obj.runtime_program.get();
  // Step 0: Create an empty program_desc, there will be only one block
353 354 355 356
  // framework::ProgramDesc program_desc;
  std::unique_ptr<framework::ProgramDesc> program_desc(
      new framework::ProgramDesc());
  auto* block = program_desc->MutableBlock(0);
357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
  const std::vector<std::unique_ptr<CinnInstruction>>& instructions =
      runtime_program->GetRunInstructions();

  // build a map that links the name of a Paddle variable to its VarDesc
  const std::unordered_set<framework::ir::Node*>& nodes = graph.Nodes();
  std::unordered_map<std::string, framework::VarDesc*> original_vardescs;
  for (auto* node : nodes) {
    if (node->IsVar() && node->Var()) {
      original_vardescs.emplace(node->Name(), node->Var());
    }
  }

  // Step 1: Create a VarDesc for each execution argument:
  //   (1) For those variables that are input or output variables of the
  //   original subgraph, there must exist an original VarDesc, so
  //   we copy some useful info(such as IsParameter,Persistable)
  //   to the new VarDesc.
  //   (2) For all variables, the shape, data type of their VarDescs
  //   are set by values of the corresponding compiled tensors,
  //   including the in/out variables where the equiality between their tensors
  //   and the CINN compiled ones is verified in corresponding cinn_launch_op.
  for (auto&& arg : cinn_argument_names_) {
    const std::string& var_name = cinn2paddle_varmap_.at(arg);
    framework::VarDesc* var_desc = block->Var(var_name);
    var_desc->SetType(framework::proto::VarType::LOD_TENSOR);

    auto res = original_vardescs.find(var_name);
    if (res != original_vardescs.end()) {
      auto* ori_desc = res->second;
      var_desc->SetPersistable(ori_desc->Persistable());
      var_desc->SetIsParameter(ori_desc->IsParameter());
    }

    auto cinn_tensor = GetCinnTensorOfVar(var_name);
391 392
    var_desc->SetDataType(framework::TransToProtoVarType(
        framework::paddle2cinn::TransToPaddleDataType(cinn_tensor->type())));
393 394 395 396 397 398 399 400 401 402 403 404 405
    var_desc->SetShape(std::vector<int64_t>(cinn_tensor->shape().data().begin(),
                                            cinn_tensor->shape().data().end()));
  }

  // transform names of the input or output arguments of a CINN instruction
  // to the corresponding Paddle variable names, and repack them as one vector
  auto trans_and_pack_args_fn =
      [this](const std::vector<std::vector<std::string>>& cinn_args_array) {
        std::vector<std::string> var_names;
        for (auto&& cinn_args : cinn_args_array) {
          for (auto&& arg : cinn_args) {
            auto res = cinn2paddle_varmap_.find(arg);
            PADDLE_ENFORCE_NE(
406 407
                res,
                cinn2paddle_varmap_.end(),
408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
                platform::errors::NotFound("Argument(%s) not found", arg));
            var_names.emplace_back(res->second);
          }
        }
        return var_names;
      };

  // Step 2: create a VarDesc of cinn_instruction_run op for
  //         each CINN instruction and append it to the main block
  for (auto ins_idx = 0; ins_idx < instructions.size(); ++ins_idx) {
    auto* ins = instructions.at(ins_idx).get();
    auto in_args = trans_and_pack_args_fn(ins->GetInArgs());
    auto out_args = trans_and_pack_args_fn(ins->GetOutArgs());
    auto* op_desc = block->AppendOp();
    op_desc->SetType("cinn_instruction_run");
    op_desc->SetInput(kX, in_args);
    op_desc->SetOutput(kOutputs, out_args);
    op_desc->SetAttr(kCachedIndex,
                     {static_cast<int64_t>(compiled_obj.cached_index)});
    op_desc->SetAttr(kInstructionIndex, {static_cast<int64_t>(ins_idx)});
  }

  return program_desc;
431 432
}

433 434 435 436
ParallelExecutor* CinnLaunchContext::InitializePE(const platform::Place& place,
                                                  framework::Scope* scope) {
  if (!parallel_executor_) {
    framework::details::ExecutionStrategy exec_strategy;
437 438
    exec_strategy.num_threads_ = 1;
    exec_strategy.use_device_ = platform::Place2DeviceType(place);
439 440 441 442 443 444
    framework::details::BuildStrategy build_strategy;
    parallel_executor_ = std::make_unique<ParallelExecutor>(
        place, scope, exec_strategy, build_strategy, runtime_graph_.get());
  }

  // update the scope bound to an OpHandle and rebuild temporary variables
445
  VLOG(4) << "Reset scope and initialize temporary variables";
446 447 448
  std::unordered_map<Scope*, Scope*> scope_map = {
      {parallel_executor_->GetLocalScopes().front(), scope}};
  parallel_executor_->ResetOpHandleScopeMapOfGraphs(scope_map);
449 450 451 452 453 454 455 456 457 458 459 460 461 462
  // instead of using the PrepareVariables function of ParallelExecutor to
  // initialize all variables, here we only initialize internal variables
  // because external variables are already included in parent scope.
  for (auto&& var_name : internal_var_names_) {
    auto* var = scope->FindVar(var_name);
    if (var != nullptr) {
      VLOG(5) << "internal variable:" << var_name
              << " has been initialized beforehand in global scope, skipped.";
      continue;
    }
    framework::InitializeVariable(scope->Var(var_name),
                                  framework::proto::VarType::LOD_TENSOR);
  }

463
  return parallel_executor_.get();
464 465
}

466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497
framework::InterpreterCore* CinnLaunchContext::InitializeInterpreterCore(
    const platform::Place& place, framework::Scope* scope) {
  if (!interpreter_core_ || scope != cached_scope_) {
    VLOG(1) << "interpreter_core_ is null or scope != cached_scope_: "
               "interpreter_core_: "
            << interpreter_core_.get() << "; scope: " << scope
            << "; cached_scope_: " << cached_scope_;
    for (auto&& var_name : internal_var_names_) {
      auto* var = scope->FindVar(var_name);
      if (var != nullptr) {
        continue;
      }
      framework::InitializeVariable(scope->Var(var_name),
                                    framework::proto::VarType::LOD_TENSOR);
    }
    if (!interpreter_core_) {
      interpreter_core_ = std::make_unique<framework::InterpreterCore>(
          place,
          runtime_program_desc_->Block(0),
          skip_gc_vars_,
          scope,
          /*used_for_jit*/ false,
          /*used_for_control_flow_op*/ false,
          /*used_for_cinn*/ true);
    } else {
      interpreter_core_->reset_scope(scope);
    }
    UpdateCapturedEnv(*scope, place);
  }
  return interpreter_core_.get();
}

498 499 500 501 502 503 504 505 506 507 508 509 510 511 512
std::string CinnLaunchContext::RedirectVarName(const std::string& var_name) {
  auto pos = var_name.find(InplaceOutSuffix);
  if (pos == std::string::npos) {
    return var_name;
  }
  std::string remove_suffix_name = var_name.substr(0, pos);
  if (!inplace_var_names_.count(remove_suffix_name)) {
    LOG(WARNING) << "Variable:" << remove_suffix_name
                 << " was not marked as inplaced by Paddle, but CINN does";
  }
  VLOG(4) << "Inplaced variable:" << var_name << " redirect to "
          << remove_suffix_name;
  return remove_suffix_name;
}

513
cinn_buffer_t* CinnLaunchContext::GetCinnBufferOfVar(
514
    const std::string& var_name) {
515
  auto res = paddle2argument_.find(var_name);
516
  PADDLE_ENFORCE_NE(
517 518 519 520
      res,
      paddle2argument_.end(),
      platform::errors::NotFound("Variable(%s) not found in compilation result",
                                 var_name));
521
  return static_cast<cinn_buffer_t*>(res->second);
522 523
}

524
}  // namespace operators::details
525
}  // namespace paddle