prepared_operator.cc 31.5 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/prepared_operator.h"
16

J
Jiabin Yang 已提交
17
#include "paddle/fluid/eager/eager_tensor.h"
18
#include "paddle/fluid/framework/data_type_transform.h"
19
#include "paddle/fluid/framework/details/nan_inf_utils.h"
20
#include "paddle/fluid/imperative/infer_shape_context.h"
21
#include "paddle/fluid/imperative/tracer.h"
22
#include "paddle/phi/common/int_array.h"
23
#include "paddle/phi/common/scalar.h"
24
#include "paddle/utils/small_vector.h"
Q
QingshuChen 已提交
25
#ifdef PADDLE_WITH_XPU
26
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
Q
QingshuChen 已提交
27
#endif
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
31 32 33
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/platform/device/gpu/gpu_dnn.h"
#endif
L
Liu-xiandong 已提交
34
#include "paddle/fluid/framework/library_type.h"
35
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
C
chenjian 已提交
36
#include "paddle/fluid/platform/profiler/event_tracing.h"
C
chenjian 已提交
37
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
38

39
DECLARE_bool(check_nan_inf);
40
DECLARE_bool(benchmark);
F
Feng Xing 已提交
41
DECLARE_bool(run_kp_kernel);
42

J
Jiabin Yang 已提交
43 44 45
namespace paddle {
namespace imperative {

46
static const phi::Kernel empty_kernel;
47 48
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
49

50 51 52 53 54 55 56
const phi::KernelFactory& PreparedOp::phi_kernel_factory =
    phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
    phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
    phi::DefaultKernelSignatureMap::Instance();

57 58 59 60 61 62 63 64 65 66
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<paddle::imperative::VarBase>& var) {
  return var->SharedVar();
}

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<VariableWrapper>& var) {
  return var;
}

67
const phi::DenseTensor* GetTensorFromVar(const framework::Variable& var) {
68 69
  if (var.IsType<phi::DenseTensor>()) {
    return &(var.Get<phi::DenseTensor>());
70 71
  } else if (var.IsType<phi::SelectedRows>()) {
    return &(var.Get<phi::SelectedRows>().value());
J
Jiabin Yang 已提交
72 73 74 75 76
  } else {
    return nullptr;
  }
}

77
template <typename VarType>
J
Jiabin Yang 已提交
78
void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
  for (auto& pair : outs) {
    for (auto& var : pair.second) {
      if (var == nullptr) {
        continue;
      }
      if (var->ForwardDataType() ==
          static_cast<framework::proto::VarType::Type>(-1)) {
        VLOG(6) << "Var (" << var->Name()
                << ")'s forward data type is not set.";
        continue;
      }
      if (!framework::IsComplexType(var->DataType()) ||
          framework::IsComplexType(var->ForwardDataType())) {
        continue;
      }
      const auto* tensor = GetTensorFromVar(var->Var());
J
Jiabin Yang 已提交
95
      if (tensor && tensor->IsInitialized()) {
96 97 98 99
        VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType())
                << " var `" << var->Name() << "` to "
                << framework::DataTypeToString(var->ForwardDataType())
                << " real var in dynamic graph.";
100
        phi::DenseTensor out;
101 102
        framework::TransComplexToReal(
            var->ForwardDataType(), var->DataType(), *tensor, &out);
103
        SetTensorToVariable(var->Var(), out, var->MutableVar());
J
Jiabin Yang 已提交
104 105 106 107 108
      }
    }
  }
}

J
Jiabin Yang 已提交
109
template <>
110 111
void HandleComplexGradToRealGrad<egr::EagerVariable>(
    const NameVarMap<egr::EagerVariable>& outs) {
J
Jiabin Yang 已提交
112 113 114
  // TODO(jiabin): Support Complex here.
}

115 116 117 118 119
void TestHandleComplexGradToRealGradEager(
    const NameVarMap<egr::EagerVariable>& outs) {
  HandleComplexGradToRealGrad<egr::EagerVariable>(outs);
}

J
Jiabin Yang 已提交
120 121
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
122
                       const framework::OpKernelType& kernel_type,
123
                       const framework::OperatorWithKernel::OpKernelFunc& func,
124 125
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
126
                       platform::DeviceContext* dev_ctx)
127 128 129 130
    : op_(op),
      ctx_(ctx),
      kernel_type_(kernel_type),
      func_(func),
131
      dev_ctx_(dev_ctx),
132 133 134
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      phi_kernel_(empty_kernel) {}
135

136 137 138
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
                       const framework::OpKernelType& kernel_type,
139 140 141 142
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
                       phi::KernelSignature&& kernel_signature,
                       const phi::Kernel& phi_kernel,
143 144 145 146 147 148
                       platform::DeviceContext* dev_ctx)
    : op_(op),
      ctx_(ctx),
      kernel_type_(kernel_type),
      func_(nullptr),
      dev_ctx_(dev_ctx),
149
      run_phi_kernel_(true),
150 151 152 153
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      kernel_signature_(std::move(kernel_signature)),
      phi_kernel_(phi_kernel) {}
154

155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196
#ifdef PADDLE_WITH_MLU

static void tokenize(const std::string& ops,
                     char delim,
                     std::unordered_set<std::string>* op_set) {
  std::string::size_type beg = 0;
  for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
       ++end) {
    op_set->insert(ops.substr(beg, end - beg));
    beg = end + 1;
  }

  op_set->insert(ops.substr(beg));
}

static bool is_in_mlu_black_list(const std::string& op_name) {
  static bool inited = false;
  static std::unordered_set<std::string> mlu_black_list;
  static std::mutex s_mtx;
  if (!inited) {
    std::lock_guard<std::mutex> guard(s_mtx);
    if (!inited) {
      if (std::getenv("MLU_BLACK_LIST") != nullptr) {
        std::string ops(std::getenv("MLU_BLACK_LIST"));
        tokenize(ops, ',', &mlu_black_list);
      }
      inited = true;
      VLOG(3) << "MLU Black List: ";
      for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end();
           ++iter) {
        VLOG(3) << *iter << " ";
      }
    }
  }
  if (mlu_black_list.find(op_name) != mlu_black_list.end()) {
    return true;
  }
  return false;
}

#endif

197
template <typename VarType>
198
PreparedOp PrepareImpl(
199 200 201 202
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::OperatorWithKernel& op,
    const platform::Place& place,
203 204 205 206 207
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
    const phi::KernelFactory& phi_kernel_factory,
    const phi::OpUtilsMap& phi_op_utils_map,
    const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
208
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
209
  auto* dev_ctx = pool.Get(place);
210

211 212 213 214 215 216
#ifdef PADDLE_WITH_MKLDNN
  // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
  // GetKernelType functions, so we need to copy the attributes there.
  // Const qualifier of Attrs had to be discarded to overwrite it.
  if (FLAGS_use_mkldnn) {
    auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
217 218 219 220
    mutable_op_attrs = default_attrs;
    for (auto& attr : attrs) {
      mutable_op_attrs[attr.first] = attr.second;
    }
221 222
  }
#endif
223 224
  // NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
  // choose is:
225
  // phi npu kernel > fluid npu kernel > phi cpu kernel > fluid cpu kernel
J
Jiabin Yang 已提交
226

227
  // 1. get expected kernel key
228
  auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
229
      op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
230
  auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
231

232 233
  const phi::KernelSignature* default_kernel_signature = nullptr;
  phi::KernelSignature kernel_signature;
234 235
  phi::KernelKey phi_kernel_key;
  std::string phi_kernel_name;
236 237 238 239

// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
// LibraryType::kMKLDNN and data_layout_ = DataLayout::kMKLDNN. But the default
// values are kPlain, so we need to modify the library_type and data_layout_
240 241 242 243
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
244
#ifdef PADDLE_WITH_MKLDNN
245
  if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) &&
246 247 248 249 250 251
      op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.data_type_)) {
    expected_kernel_key.library_type_ = framework::LibraryType::kMKLDNN;
    expected_kernel_key.data_layout_ = framework::DataLayout::kMKLDNN;
  }
#endif

252 253 254 255 256 257
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
  if (!op.DnnFallback() && paddle::platform::CanCUDNNBeUsed(dygraph_exe_ctx)) {
    expected_kernel_key.library_type_ = framework::LibraryType::kCUDNN;
  }
#endif

L
Liu-xiandong 已提交
258
#if defined(PADDLE_WITH_XPU)
259 260 261 262 263 264
  bool is_xpu_unsupport =
      paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
          !paddle::platform::is_xpu_support_op(op.Type(),
                                               expected_kernel_key) ||
      paddle::platform::is_in_xpu_black_list(op.Type());
#endif
265

266 267 268 269 270 271
#ifdef PADDLE_WITH_MLU
  if (is_in_mlu_black_list(op.Type())) {
    expected_kernel_key.place_ = platform::CPUPlace();
  }
#endif

272 273
  bool has_phi_kernel = false;

274 275
  const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());

276 277
  if (arg_map_fn) {
    has_phi_kernel = true;
278
    kernel_signature = (*arg_map_fn)(
279 280
        framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
  } else {
281
    default_kernel_signature =
282
        default_phi_kernel_sig_map.GetNullable(op.Type());
283
    if (default_kernel_signature) {
284
      has_phi_kernel = true;
285
      kernel_signature = *default_kernel_signature;
286 287
    }
  }
288

289
  if (has_phi_kernel) {
290
    VLOG(6) << kernel_signature;
291
    phi_kernel_name = kernel_signature.name;
292 293 294
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
L
Liu-xiandong 已提交
295 296 297
#ifdef PADDLE_WITH_XPU_KP
    if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
      bool use_xpu_kp_kernel_rt =
298 299
          FLAGS_run_kp_kernel && paddle::platform::is_xpu_kp_support_op(
                                     op.Type(), expected_kernel_key);
L
Liu-xiandong 已提交
300 301 302 303 304 305 306 307 308 309 310
      bool use_xpu_kp_kernel_debug =
          paddle::platform::is_in_xpu_kpwhite_list(op.Type());
      if (use_xpu_kp_kernel_rt) {
        VLOG(3) << "phi xpu_kp using rt mode ";
      }
      if (use_xpu_kp_kernel_debug) {
        VLOG(3) << "phi xpu_kp using debug mode ";
      }
      bool is_xpu_kp_support =
          (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
      if (is_xpu_kp_support) {
311 312
        auto expected_kernel_key_library_type =
            expected_kernel_key.library_type_;
L
Liu-xiandong 已提交
313
        expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
314
        VLOG(3) << "modifing XPU KP kernel: " << phi_kernel_name
L
Liu-xiandong 已提交
315
                << ", using_kernel_key:" << expected_kernel_key;
316

317
        phi::KernelKey try_phi_kernel_key =
318
            TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
319 320
        if (!phi_kernel_factory.HasKernel(phi_kernel_name,
                                          try_phi_kernel_key)) {
321
          expected_kernel_key.library_type_ = expected_kernel_key_library_type;
322
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
323 324
                  << " in dynamic graph is failed " << expected_kernel_key;
        } else {
325
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
326
                  << " in dynamic graph is succeed " << expected_kernel_key;
327
        }
L
Liu-xiandong 已提交
328 329 330
      }
    }
#endif
331

332
    phi_kernel_key = TransOpKernelTypeToPhiKernelKey(expected_kernel_key);
333
    auto& phi_kernel =
334
        phi_kernel_factory.SelectKernel(phi_kernel_name, phi_kernel_key);
335

336
    if (phi_kernel.IsValid()
L
Liu-xiandong 已提交
337
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
338 339
        && !is_xpu_unsupport
#endif
340
    ) {
341 342
      VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
              << " | kernel key: " << phi_kernel_key
343
              << " | kernel: " << phi_kernel;
344

F
From00 已提交
345 346
      if (expected_kernel_key.place_ != place) {
        dev_ctx = pool.Get(expected_kernel_key.place_);
W
Wilber 已提交
347
      }
F
From00 已提交
348

349 350 351 352 353 354 355 356
      return PreparedOp(op,
                        empty_ctx,
                        expected_kernel_key,
                        arg_map_fn,
                        default_kernel_signature,
                        std::move(kernel_signature),
                        phi_kernel,
                        dev_ctx);
357
    } else {
358
      VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << phi_kernel_name
359 360 361 362
              << "` not found.";
    }
  }

363
  // 2. check if op[type] has kernel registered.
J
Jiabin Yang 已提交
364 365
  auto& all_op_kernels = op.AllOpKernels();
  auto kernels_iter = all_op_kernels.find(op.Type());
366

367 368 369
// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi,
// we need to select the heterogeneous kernel in fluid, but the kernel
// registered in KP use library_type[KP], we need to modify it.
370 371 372 373 374 375 376 377 378 379 380 381 382 383
#ifdef PADDLE_WITH_XPU_KP
  bool use_xpu_kp_kernel_rt =
      paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
      FLAGS_run_kp_kernel &&
      paddle::platform::is_xpu_kp_support_op(op.Type(), expected_kernel_key);
  bool use_xpu_kp_kernel_debug =
      paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
      paddle::platform::is_in_xpu_kpwhite_list(op.Type());
  bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
  if (is_xpu_kp_support) {
    expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
  }
#endif

384 385 386
  if ((kernels_iter == all_op_kernels.end() ||
       kernels_iter->second.find(expected_kernel_key) ==
           kernels_iter->second.end())
L
Liu-xiandong 已提交
387
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
388
      || is_xpu_unsupport
389
#endif
390 391 392
#if defined(PADDLE_WITH_XPU_KP)
      || (is_xpu_unsupport && !is_xpu_kp_support)
#endif
393
  ) {
394
    if (has_phi_kernel) {
395 396 397 398 399 400 401 402
      auto phi_cpu_kernel_key =
          FallBackToCpu(expected_kernel_key, phi_kernel_key, op);
      auto& phi_cpu_kernel =
          phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key);
      if (phi_cpu_kernel.IsValid()) {
        VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
                << " | kernel key: " << phi_cpu_kernel_key
                << " | kernel: " << phi_cpu_kernel;
403
        auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
404
        return PreparedOp(
405 406
            op,
            empty_ctx,
407
            framework::TransPhiKernelKeyToOpKernelType(phi_cpu_kernel_key),
408 409 410
            arg_map_fn,
            default_kernel_signature,
            std::move(kernel_signature),
411
            phi_cpu_kernel,
412
            cpu_ctx);
413 414 415 416
      }
    }
  }

417
  PADDLE_ENFORCE_NE(
418 419
      kernels_iter,
      all_op_kernels.end(),
420 421 422
      platform::errors::NotFound(
          "There are no kernels which are registered in the %s operator.",
          op.Type()));
423

J
Jiabin Yang 已提交
424 425
  auto& kernels = kernels_iter->second;
  auto kernel_iter = kernels.find(expected_kernel_key);
426

427
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
428
  if (paddle::platform::is_xpu_place(expected_kernel_key.place_) &&
429
      (kernel_iter == kernels.end() || is_xpu_unsupport)) {
430
    VLOG(3) << "fluid missing XPU kernel: " << op.Type()
431 432
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
433 434 435
    expected_kernel_key.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(expected_kernel_key);
  }
436
#endif
L
Liu-xiandong 已提交
437 438

#ifdef PADDLE_WITH_XPU_KP
439 440
  if (paddle::platform::is_xpu_place(expected_kernel_key.place_)) {
    if (use_xpu_kp_kernel_rt) {
441
      VLOG(3) << "fluid xpu_kp using rt mode ";
442 443
    }
    if (use_xpu_kp_kernel_debug) {
444
      VLOG(3) << "fluid xpu_kp using debug mode ";
445 446 447 448
    }
    if (is_xpu_kp_support) {
      expected_kernel_key.library_type_ = paddle::framework::LibraryType::kKP;
      kernel_iter = kernels.find(expected_kernel_key);
449
      VLOG(3) << "using fluid XPU KP kernel: " << op.Type()
450 451 452 453
              << ", using_kernel_key:" << expected_kernel_key;
    }
    if (!is_xpu_kp_support &&
        (kernel_iter == kernels.end() || is_xpu_unsupport)) {
454
      VLOG(3) << "fluid missing XPU kernel: " << op.Type()
455 456 457 458 459
              << ", expected_kernel_key:" << expected_kernel_key
              << ", fallbacking to CPU one!";
      expected_kernel_key.place_ = platform::CPUPlace();
      kernel_iter = kernels.find(expected_kernel_key);
    }
L
Liu-xiandong 已提交
460 461 462
  }
#endif

463 464
#ifdef PADDLE_WITH_ASCEND_CL
  if (kernel_iter == kernels.end() &&
465
      paddle::platform::is_npu_place(expected_kernel_key.place_)) {
466 467 468
    VLOG(3) << "missing NPU kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
469
    expected_kernel_key.place_ = platform::CPUPlace();
470 471 472 473 474 475 476 477 478 479
    kernel_iter = kernels.find(expected_kernel_key);
  }
#endif
#ifdef PADDLE_WITH_IPU
  if (kernel_iter == kernels.end() &&
      paddle::platform::is_ipu_place(expected_kernel_key.place_)) {
    VLOG(3) << "missing IPU kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
    expected_kernel_key.place_ = platform::CPUPlace();
480 481
    kernel_iter = kernels.find(expected_kernel_key);
  }
482 483 484
#endif
#ifdef PADDLE_WITH_MLU
  if (kernel_iter == kernels.end() &&
485
      paddle::platform::is_mlu_place(expected_kernel_key.place_)) {
486 487 488 489 490 491
    VLOG(3) << "missing MLU kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
    expected_kernel_key.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(expected_kernel_key);
  }
492 493 494 495 496 497 498 499 500 501
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (kernel_iter == kernels.end() &&
      paddle::platform::is_custom_place(expected_kernel_key.place_)) {
    VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
    expected_kernel_key.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(expected_kernel_key);
  }
502
#endif
503 504
  // TODO(jiabin): Add operator.cc's line 1000 part back when we need that
  // case
505 506 507 508 509 510
  PADDLE_ENFORCE_NE(
      kernel_iter,
      kernels.end(),
      platform::errors::NotFound("Operator %s does not have kernel for %s.",
                                 op.Type(),
                                 KernelTypeToString(expected_kernel_key)));
511

512 513 514 515
  if (!(expected_kernel_key.place_ == place)) {
    dev_ctx = pool.Get(expected_kernel_key.place_);
  }

516 517 518 519 520 521 522
  return PreparedOp(op,
                    empty_ctx,
                    expected_kernel_key,
                    kernel_iter->second,
                    arg_map_fn,
                    default_kernel_signature,
                    dev_ctx);
523 524
}

525 526 527 528
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
                               const NameVarMap<VarBase>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
529
                               const framework::AttributeMap& attrs,
530
                               const framework::AttributeMap& default_attrs) {
531 532 533 534 535 536 537 538
  return PrepareImpl<VarBase>(ins,
                              outs,
                              op,
                              place,
                              attrs,
                              default_attrs,
                              phi_kernel_factory,
                              phi_op_utils_map,
539
                              default_phi_kernel_sig_map);
540 541 542 543 544 545
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
                               const NameVarMap<VariableWrapper>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
546
                               const framework::AttributeMap& attrs,
547
                               const framework::AttributeMap& default_attrs) {
548 549 550 551 552 553 554 555 556
  return PrepareImpl<VariableWrapper>(ins,
                                      outs,
                                      op,
                                      place,
                                      attrs,
                                      default_attrs,
                                      phi_kernel_factory,
                                      phi_op_utils_map,
                                      default_phi_kernel_sig_map);
557 558
}

559 560
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
                               const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
561 562 563 564
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
                               const framework::AttributeMap& attrs,
                               const framework::AttributeMap& default_attrs) {
565 566 567 568 569 570 571 572 573
  return PrepareImpl<egr::EagerVariable>(ins,
                                         outs,
                                         op,
                                         place,
                                         attrs,
                                         default_attrs,
                                         phi_kernel_factory,
                                         phi_op_utils_map,
                                         default_phi_kernel_sig_map);
J
Jiabin Yang 已提交
574
}
575 576
template <typename VarType>
static void PreparedOpRunImpl(
577 578
    const framework::OperatorBase& op,
    const framework::RuntimeContext& ctx,
579
    const framework::OpKernelType& kernel_type,
580
    const framework::OperatorWithKernel::OpKernelFunc& func,
581 582
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
583 584 585 586
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
587
    const framework::AttributeMap& default_attrs) {
J
Jiabin Yang 已提交
588
  // TODO(zjl): remove scope in dygraph
H
hong 已提交
589

590
  {
591
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
592
                                       platform::TracerEventType::OperatorInner,
593 594 595 596 597 598 599 600 601 602
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
                                                      &kernel_type,
                                                      arg_map_fn,
                                                      default_kernel_signature);
603
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
604 605 606
    record_event.End();
    platform::RecordOpInfoSupplement(
        op.Type(), op.Attrs(), infer_shape_ctx, ctx);
607 608 609
  }

  {
610
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
611
                                       platform::TracerEventType::OperatorInner,
612 613
                                       1,
                                       platform::EventRole::kInnerOp);
H
hong 已提交
614

615 616
    func(DygraphExecutionContext<VarType>(
        op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs));
617
  }
618

619 620 621 622 623
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

L
Leo Chen 已提交
624 625 626 627 628 629 630 631
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

632 633 634 635 636 637 638 639 640 641 642 643 644 645 646
  /**
   * [ Why need handle complex gradient to real gradient? ]
   *
   * After the introduction of complex number calculations, Ops that support
   * complex number calculations generally support type promotion, such as
   * x(float32) + y(complex64) = out(complex64), then the type of the grad
   * tensor should be dout(complex64), dx(float32), dy (complex64).
   *
   * But because the dout is complex64, the dx is also complex64 after
   * grad op kernel executed, we need to recognize this situation and
   * convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
   */
  if (framework::IsComplexType(kernel_type.data_type_)) {
    HandleComplexGradToRealGrad<VarType>(outs);
  }
647
}
H
hong 已提交
648

649 650 651
template <typename VarType>
static void PreparedOpRunPtImpl(
    const framework::OperatorBase& op,
652
    const framework::OpKernelType& kernel_type,
653 654
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
655 656 657 658 659 660
    const phi::KernelSignature& kernel_signature,
    const phi::Kernel& phi_kernel,
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
661
    const framework::AttributeMap& default_attrs) {
662
  {
663
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
664
                                       platform::TracerEventType::OperatorInner,
665 666 667 668 669 670 671 672 673 674
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
                                                      &kernel_type,
                                                      arg_map_fn,
                                                      default_kernel_signature);
675
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
676 677 678
    record_event.End();
    platform::RecordOpInfoSupplement(
        op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
679 680 681
  }

  {
682
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
683
                                       platform::TracerEventType::OperatorInner,
684 685
                                       1,
                                       platform::EventRole::kInnerOp);
686

687
    PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
688

689
    phi::KernelContext phi_kernel_context;
690 691 692 693 694 695 696
    BuildDygraphPhiKernelContext<VarType>(kernel_signature,
                                          phi_kernel,
                                          ins,
                                          outs,
                                          attrs,
                                          default_attrs,
                                          dev_ctx,
697
                                          &phi_kernel_context);
698

699
    phi_kernel(&phi_kernel_context);
700
  }
701

702 703 704 705 706
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

707 708
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
709 710
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
711 712 713 714
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

715 716 717
  if (framework::IsComplexType(kernel_type.data_type_)) {
    HandleComplexGradToRealGrad<VarType>(outs);
  }
718 719
}

720 721
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
                     const NameVarMap<VarBase>& outs,
722 723
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
724
  if (run_phi_kernel_) {
725 726 727 728 729 730 731 732 733 734
    PreparedOpRunPtImpl<VarBase>(op_,
                                 kernel_type_,
                                 arg_map_fn_,
                                 default_kernel_signature_,
                                 kernel_signature_,
                                 phi_kernel_,
                                 dev_ctx_,
                                 ins,
                                 outs,
                                 attrs,
735
                                 default_attrs);
736
  } else {
737 738 739 740 741 742 743 744 745 746 747
    PreparedOpRunImpl<VarBase>(op_,
                               ctx_,
                               kernel_type_,
                               func_,
                               arg_map_fn_,
                               default_kernel_signature_,
                               dev_ctx_,
                               ins,
                               outs,
                               attrs,
                               default_attrs);
748
  }
749
}
H
hong 已提交
750

751 752
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
                     const NameVarMap<VariableWrapper>& outs,
753 754
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
755
  if (run_phi_kernel_) {
756 757 758 759 760 761 762 763 764 765 766
    PreparedOpRunPtImpl<VariableWrapper>(op_,
                                         kernel_type_,
                                         arg_map_fn_,
                                         default_kernel_signature_,
                                         kernel_signature_,
                                         phi_kernel_,
                                         dev_ctx_,
                                         ins,
                                         outs,
                                         attrs,
                                         default_attrs);
767
  } else {
768 769 770 771 772 773 774 775 776 777 778
    PreparedOpRunImpl<VariableWrapper>(op_,
                                       ctx_,
                                       kernel_type_,
                                       func_,
                                       arg_map_fn_,
                                       default_kernel_signature_,
                                       dev_ctx_,
                                       ins,
                                       outs,
                                       attrs,
                                       default_attrs);
779
  }
J
Jiabin Yang 已提交
780 781
}

782 783
void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
                     const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
784 785
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
786
  if (run_phi_kernel_) {
787 788 789 790 791 792 793 794 795 796 797
    PreparedOpRunPtImpl<egr::EagerVariable>(op_,
                                            kernel_type_,
                                            arg_map_fn_,
                                            default_kernel_signature_,
                                            kernel_signature_,
                                            phi_kernel_,
                                            dev_ctx_,
                                            ins,
                                            outs,
                                            attrs,
                                            default_attrs);
J
Jiabin Yang 已提交
798
  } else {
799 800 801 802 803 804 805 806 807 808 809
    PreparedOpRunImpl<egr::EagerVariable>(op_,
                                          ctx_,
                                          kernel_type_,
                                          func_,
                                          arg_map_fn_,
                                          default_kernel_signature_,
                                          dev_ctx_,
                                          ins,
                                          outs,
                                          attrs,
                                          default_attrs);
J
Jiabin Yang 已提交
810 811 812
  }
}

J
Jiabin Yang 已提交
813 814
}  // namespace imperative
}  // namespace paddle