prepared_operator.cc 32.0 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/prepared_operator.h"
16

J
Jiabin Yang 已提交
17
#include "paddle/fluid/eager/eager_tensor.h"
18
#include "paddle/fluid/framework/data_type_transform.h"
19
#include "paddle/fluid/framework/details/nan_inf_utils.h"
20
#include "paddle/fluid/imperative/infer_shape_context.h"
21
#include "paddle/fluid/imperative/tracer.h"
22
#include "paddle/phi/common/int_array.h"
23
#include "paddle/phi/common/scalar.h"
24
#include "paddle/utils/small_vector.h"
Q
QingshuChen 已提交
25
#ifdef PADDLE_WITH_XPU
26
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
Q
QingshuChen 已提交
27
#endif
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
L
Liu-xiandong 已提交
31
#include "paddle/fluid/framework/library_type.h"
32
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
33
#include "paddle/fluid/platform/place.h"
C
chenjian 已提交
34
#include "paddle/fluid/platform/profiler/event_tracing.h"
C
chenjian 已提交
35
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
36

37
DECLARE_bool(check_nan_inf);
38
DECLARE_bool(benchmark);
F
Feng Xing 已提交
39
DECLARE_bool(run_kp_kernel);
40

J
Jiabin Yang 已提交
41 42 43
namespace paddle {
namespace imperative {

44
static const phi::Kernel empty_kernel;
45 46
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
47

48 49 50 51 52 53 54
const phi::KernelFactory& PreparedOp::phi_kernel_factory =
    phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
    phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
    phi::DefaultKernelSignatureMap::Instance();

55 56 57 58 59 60 61 62 63 64
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<paddle::imperative::VarBase>& var) {
  return var->SharedVar();
}

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<VariableWrapper>& var) {
  return var;
}

65
const phi::DenseTensor* GetTensorFromVar(const framework::Variable& var) {
66 67
  if (var.IsType<phi::DenseTensor>()) {
    return &(var.Get<phi::DenseTensor>());
68 69
  } else if (var.IsType<phi::SelectedRows>()) {
    return &(var.Get<phi::SelectedRows>().value());
J
Jiabin Yang 已提交
70 71 72 73 74
  } else {
    return nullptr;
  }
}

75
template <typename VarType>
J
Jiabin Yang 已提交
76
void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  for (auto& pair : outs) {
    for (auto& var : pair.second) {
      if (var == nullptr) {
        continue;
      }
      if (var->ForwardDataType() ==
          static_cast<framework::proto::VarType::Type>(-1)) {
        VLOG(6) << "Var (" << var->Name()
                << ")'s forward data type is not set.";
        continue;
      }
      if (!framework::IsComplexType(var->DataType()) ||
          framework::IsComplexType(var->ForwardDataType())) {
        continue;
      }
      const auto* tensor = GetTensorFromVar(var->Var());
J
Jiabin Yang 已提交
93
      if (tensor && tensor->IsInitialized()) {
94 95 96 97
        VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType())
                << " var `" << var->Name() << "` to "
                << framework::DataTypeToString(var->ForwardDataType())
                << " real var in dynamic graph.";
98
        phi::DenseTensor out;
99 100
        framework::TransComplexToReal(
            var->ForwardDataType(), var->DataType(), *tensor, &out);
101
        SetTensorToVariable(var->Var(), out, var->MutableVar());
J
Jiabin Yang 已提交
102 103 104 105 106
      }
    }
  }
}

J
Jiabin Yang 已提交
107
template <>
108 109
void HandleComplexGradToRealGrad<egr::EagerVariable>(
    const NameVarMap<egr::EagerVariable>& outs) {
J
Jiabin Yang 已提交
110 111 112
  // TODO(jiabin): Support Complex here.
}

113 114 115 116 117
void TestHandleComplexGradToRealGradEager(
    const NameVarMap<egr::EagerVariable>& outs) {
  HandleComplexGradToRealGrad<egr::EagerVariable>(outs);
}

J
Jiabin Yang 已提交
118 119
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
120
                       const phi::KernelKey& kernel_key,
121
                       const framework::OperatorWithKernel::OpKernelFunc& func,
122 123
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
124
                       platform::DeviceContext* dev_ctx)
125 126
    : op_(op),
      ctx_(ctx),
127
      kernel_key_(kernel_key),
128
      func_(func),
129
      dev_ctx_(dev_ctx),
130 131 132
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      phi_kernel_(empty_kernel) {}
133

134 135
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
136
                       const phi::KernelKey& kernel_key,
137 138 139 140
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
                       phi::KernelSignature&& kernel_signature,
                       const phi::Kernel& phi_kernel,
141 142 143
                       platform::DeviceContext* dev_ctx)
    : op_(op),
      ctx_(ctx),
144
      kernel_key_(kernel_key),
145 146
      func_(nullptr),
      dev_ctx_(dev_ctx),
147
      run_phi_kernel_(true),
148 149 150 151
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      kernel_signature_(std::move(kernel_signature)),
      phi_kernel_(phi_kernel) {}
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
#ifdef PADDLE_WITH_MLU

static void tokenize(const std::string& ops,
                     char delim,
                     std::unordered_set<std::string>* op_set) {
  std::string::size_type beg = 0;
  for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
       ++end) {
    op_set->insert(ops.substr(beg, end - beg));
    beg = end + 1;
  }

  op_set->insert(ops.substr(beg));
}

static bool is_in_mlu_black_list(const std::string& op_name) {
  static bool inited = false;
  static std::unordered_set<std::string> mlu_black_list;
  static std::mutex s_mtx;
  if (!inited) {
    std::lock_guard<std::mutex> guard(s_mtx);
    if (!inited) {
      if (std::getenv("MLU_BLACK_LIST") != nullptr) {
        std::string ops(std::getenv("MLU_BLACK_LIST"));
        tokenize(ops, ',', &mlu_black_list);
      }
      inited = true;
      VLOG(3) << "MLU Black List: ";
      for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end();
           ++iter) {
        VLOG(3) << *iter << " ";
      }
    }
  }
  if (mlu_black_list.find(op_name) != mlu_black_list.end()) {
    return true;
  }
  return false;
}

#endif

195
template <typename VarType>
196
PreparedOp PrepareImpl(
197 198 199 200
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::OperatorWithKernel& op,
    const platform::Place& place,
201 202 203 204 205
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
    const phi::KernelFactory& phi_kernel_factory,
    const phi::OpUtilsMap& phi_op_utils_map,
    const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
206
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
207
  auto* dev_ctx = pool.Get(place);
208

209 210 211 212 213 214
#ifdef PADDLE_WITH_MKLDNN
  // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
  // GetKernelType functions, so we need to copy the attributes there.
  // Const qualifier of Attrs had to be discarded to overwrite it.
  if (FLAGS_use_mkldnn) {
    auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
215 216 217 218
    mutable_op_attrs = default_attrs;
    for (auto& attr : attrs) {
      mutable_op_attrs[attr.first] = attr.second;
    }
219 220
  }
#endif
221 222
  // NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
  // choose is:
223
  // phi npu kernel > fluid npu kernel > phi cpu kernel > fluid cpu kernel
J
Jiabin Yang 已提交
224

225
  // 1. get expected kernel key
226
  auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
227
      op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
228
  auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
229

230 231
  const phi::KernelSignature* default_kernel_signature = nullptr;
  phi::KernelSignature kernel_signature;
232
  std::string phi_kernel_name;
233 234

// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
235
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
236
// values are kPlain, so we need to modify the library_type and data_layout_
237 238 239 240
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
241
#ifdef PADDLE_WITH_MKLDNN
242
  if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) &&
243 244 245
      op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) {
    expected_kernel_key.set_backend(phi::Backend::ONEDNN);
    expected_kernel_key.set_layout(phi::DataLayout::ONEDNN);
246 247 248
  }
#endif

249
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
250 251
  if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) {
    expected_kernel_key.set_backend(phi::Backend::GPUDNN);
252 253 254
  }
#endif

L
Liu-xiandong 已提交
255
#if defined(PADDLE_WITH_XPU)
256 257 258
  bool is_xpu_unsupport = expected_kernel_key.backend() == phi::Backend::XPU &&
                          !paddle::platform::is_xpu_support_op(
                              op.Type(), expected_kernel_key.dtype());
259
#endif
260

261 262
#ifdef PADDLE_WITH_MLU
  if (is_in_mlu_black_list(op.Type())) {
263
    expected_kernel_key.set_backend(phi::Backend::CPU);
264 265 266
  }
#endif

267 268
  bool has_phi_kernel = false;

269 270
  const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());

271 272
  if (arg_map_fn) {
    has_phi_kernel = true;
273
    kernel_signature = (*arg_map_fn)(
274 275
        framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
  } else {
276
    if (phi::KernelFactory::Instance().HasStructuredKernel(op.Type())) {
277
      has_phi_kernel = true;
278 279 280 281 282 283 284 285
      kernel_signature = phi::KernelSignature(op.Type().c_str());
    } else {
      default_kernel_signature =
          default_phi_kernel_sig_map.GetNullable(op.Type());
      if (default_kernel_signature) {
        has_phi_kernel = true;
        kernel_signature = *default_kernel_signature;
      }
286 287
    }
  }
288

289
  if (has_phi_kernel) {
290
    VLOG(6) << kernel_signature;
291
    phi_kernel_name = kernel_signature.name;
292

293 294 295
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
L
Liu-xiandong 已提交
296
#ifdef PADDLE_WITH_XPU_KP
297
    if (expected_kernel_key.backend() == phi::Backend::XPU) {
L
Liu-xiandong 已提交
298
      bool use_xpu_kp_kernel_rt =
299 300
          FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op(
                                     op.Type(), expected_kernel_key.dtype());
L
Liu-xiandong 已提交
301 302 303 304 305 306 307 308 309 310 311
      bool use_xpu_kp_kernel_debug =
          paddle::platform::is_in_xpu_kpwhite_list(op.Type());
      if (use_xpu_kp_kernel_rt) {
        VLOG(3) << "phi xpu_kp using rt mode ";
      }
      if (use_xpu_kp_kernel_debug) {
        VLOG(3) << "phi xpu_kp using debug mode ";
      }
      bool is_xpu_kp_support =
          (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
      if (is_xpu_kp_support) {
312 313
        auto expected_kernel_key_backend = expected_kernel_key.backend();
        expected_kernel_key.set_backend(phi::Backend::KPS);
314
        VLOG(3) << "modifing XPU KP kernel: " << phi_kernel_name
L
Liu-xiandong 已提交
315
                << ", using_kernel_key:" << expected_kernel_key;
316

317
        if (!phi_kernel_factory.HasKernel(phi_kernel_name,
318 319
                                          expected_kernel_key)) {
          expected_kernel_key.set_backend(expected_kernel_key_backend);
320
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
321 322
                  << " in dynamic graph is failed " << expected_kernel_key;
        } else {
323
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
324
                  << " in dynamic graph is succeed " << expected_kernel_key;
325
        }
L
Liu-xiandong 已提交
326 327 328
      }
    }
#endif
329

330
    auto& phi_kernel =
331
        phi_kernel_factory.SelectKernel(phi_kernel_name, expected_kernel_key);
332

333
    if (phi_kernel.IsValid()
L
Liu-xiandong 已提交
334
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
335 336
        && !is_xpu_unsupport
#endif
337
    ) {
338
      VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
339
              << " | kernel key: " << expected_kernel_key
340
              << " | kernel: " << phi_kernel;
341

342 343 344 345
      if (!framework::backends_are_same_class(
              expected_kernel_key.backend(),
              phi::TransToPhiBackend(dev_ctx->GetPlace()))) {
        dev_ctx = pool.Get(phi::TransToPhiPlace(expected_kernel_key.backend()));
W
Wilber 已提交
346
      }
347 348 349 350 351 352 353 354
      return PreparedOp(op,
                        empty_ctx,
                        expected_kernel_key,
                        arg_map_fn,
                        default_kernel_signature,
                        std::move(kernel_signature),
                        phi_kernel,
                        dev_ctx);
355
    } else {
356
      VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << phi_kernel_name
357 358 359 360
              << "` not found.";
    }
  }

361
  // 2. check if op[type] has kernel registered.
J
Jiabin Yang 已提交
362 363
  auto& all_op_kernels = op.AllOpKernels();
  auto kernels_iter = all_op_kernels.find(op.Type());
364

365 366 367
// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi,
// we need to select the heterogeneous kernel in fluid, but the kernel
// registered in KP use library_type[KP], we need to modify it.
368 369
#ifdef PADDLE_WITH_XPU_KP
  bool use_xpu_kp_kernel_rt =
370
      expected_kernel_key.backend() == phi::Backend::XPU &&
371
      FLAGS_run_kp_kernel &&
372 373
      paddle::platform::is_xpu_support_op(op.Type(),
                                          expected_kernel_key.dtype());
374
  bool use_xpu_kp_kernel_debug =
375
      expected_kernel_key.backend() == phi::Backend::XPU &&
376 377 378
      paddle::platform::is_in_xpu_kpwhite_list(op.Type());
  bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
  if (is_xpu_kp_support) {
379
    expected_kernel_key.set_backend(phi::Backend::KPS);
380 381 382
  }
#endif

383 384
  paddle::framework::OpKernelType fluid_kernel_type =
      paddle::framework::TransPhiKernelKeyToOpKernelType(expected_kernel_key);
385
  if ((kernels_iter == all_op_kernels.end() ||
386
       kernels_iter->second.find(fluid_kernel_type) ==
387
           kernels_iter->second.end())
L
Liu-xiandong 已提交
388
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
389
      || is_xpu_unsupport
390
#endif
391 392 393
#if defined(PADDLE_WITH_XPU_KP)
      || (is_xpu_unsupport && !is_xpu_kp_support)
#endif
394
  ) {
395
    if (has_phi_kernel) {
396
      auto phi_cpu_kernel_key = FallBackToCpu(expected_kernel_key, op);
397 398 399 400 401 402
      auto& phi_cpu_kernel =
          phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key);
      if (phi_cpu_kernel.IsValid()) {
        VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
                << " | kernel key: " << phi_cpu_kernel_key
                << " | kernel: " << phi_cpu_kernel;
403
        auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
404 405 406 407 408 409 410 411
        return PreparedOp(op,
                          empty_ctx,
                          phi_cpu_kernel_key,
                          arg_map_fn,
                          default_kernel_signature,
                          std::move(kernel_signature),
                          phi_cpu_kernel,
                          cpu_ctx);
412 413 414 415
      }
    }
  }

416
  PADDLE_ENFORCE_NE(
417 418
      kernels_iter,
      all_op_kernels.end(),
419 420 421
      platform::errors::NotFound(
          "There are no kernels which are registered in the %s operator.",
          op.Type()));
422

J
Jiabin Yang 已提交
423
  auto& kernels = kernels_iter->second;
424
  auto kernel_iter = kernels.find(fluid_kernel_type);
425

426
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
427
  if (paddle::platform::is_xpu_place(fluid_kernel_type.place_) &&
428
      (kernel_iter == kernels.end() || is_xpu_unsupport)) {
429
    VLOG(3) << "fluid missing XPU kernel: " << op.Type()
430
            << ", expected_kernel_key:" << fluid_kernel_type
431
            << ", fallbacking to CPU one!";
432 433
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
434
  }
435
#endif
L
Liu-xiandong 已提交
436 437

#ifdef PADDLE_WITH_XPU_KP
438
  if (paddle::platform::is_xpu_place(fluid_kernel_type.place_)) {
439
    if (use_xpu_kp_kernel_rt) {
440
      VLOG(3) << "fluid xpu_kp using rt mode ";
441 442
    }
    if (use_xpu_kp_kernel_debug) {
443
      VLOG(3) << "fluid xpu_kp using debug mode ";
444 445
    }
    if (is_xpu_kp_support) {
446 447
      fluid_kernel_type.library_type_ = paddle::framework::LibraryType::kKP;
      kernel_iter = kernels.find(fluid_kernel_type);
448
      VLOG(3) << "using fluid XPU KP kernel: " << op.Type()
449
              << ", using_kernel_key:" << fluid_kernel_type;
450 451 452
    }
    if (!is_xpu_kp_support &&
        (kernel_iter == kernels.end() || is_xpu_unsupport)) {
453
      VLOG(3) << "fluid missing XPU kernel: " << op.Type()
454
              << ", expected_kernel_key:" << fluid_kernel_type
455
              << ", fallbacking to CPU one!";
456 457
      fluid_kernel_type.place_ = platform::CPUPlace();
      kernel_iter = kernels.find(fluid_kernel_type);
458
    }
L
Liu-xiandong 已提交
459 460 461
  }
#endif

462 463
#ifdef PADDLE_WITH_ASCEND_CL
  if (kernel_iter == kernels.end() &&
464
      paddle::platform::is_npu_place(fluid_kernel_type.place_)) {
465
    VLOG(3) << "missing NPU kernel: " << op.Type()
466
            << ", expected_kernel_key:" << fluid_kernel_type
467
            << ", fallbacking to CPU one!";
468 469
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
470 471 472 473
  }
#endif
#ifdef PADDLE_WITH_IPU
  if (kernel_iter == kernels.end() &&
474
      paddle::platform::is_ipu_place(fluid_kernel_type.place_)) {
475
    VLOG(3) << "missing IPU kernel: " << op.Type()
476
            << ", expected_kernel_key:" << fluid_kernel_type
477
            << ", fallbacking to CPU one!";
478 479
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
480
  }
481 482 483
#endif
#ifdef PADDLE_WITH_MLU
  if (kernel_iter == kernels.end() &&
484
      paddle::platform::is_mlu_place(fluid_kernel_type.place_)) {
485
    VLOG(3) << "missing MLU kernel: " << op.Type()
486
            << ", expected_kernel_key:" << fluid_kernel_type
487
            << ", fallbacking to CPU one!";
488 489
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
490
  }
491 492 493
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (kernel_iter == kernels.end() &&
494
      paddle::platform::is_custom_place(fluid_kernel_type.place_)) {
495 496 497
    VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
498 499
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
500
  }
501
#endif
502 503
  // TODO(jiabin): Add operator.cc's line 1000 part back when we need that
  // case
504 505 506 507 508
  PADDLE_ENFORCE_NE(
      kernel_iter,
      kernels.end(),
      platform::errors::NotFound("Operator %s does not have kernel for %s.",
                                 op.Type(),
509 510 511 512 513 514 515 516 517 518 519 520 521 522
                                 KernelTypeToString(fluid_kernel_type)));

  if (!platform::places_are_same_class(fluid_kernel_type.place_,
                                       dev_ctx->GetPlace())) {
    dev_ctx = pool.Get(fluid_kernel_type.place_);
  }
  return PreparedOp(
      op,
      empty_ctx,
      framework::TransOpKernelTypeToPhiKernelKey(fluid_kernel_type),
      kernel_iter->second,
      arg_map_fn,
      default_kernel_signature,
      dev_ctx);
523 524
}

525 526 527 528
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
                               const NameVarMap<VarBase>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
529
                               const framework::AttributeMap& attrs,
530
                               const framework::AttributeMap& default_attrs) {
531 532 533 534 535 536 537 538
  return PrepareImpl<VarBase>(ins,
                              outs,
                              op,
                              place,
                              attrs,
                              default_attrs,
                              phi_kernel_factory,
                              phi_op_utils_map,
539
                              default_phi_kernel_sig_map);
540 541 542 543 544 545
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
                               const NameVarMap<VariableWrapper>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
546
                               const framework::AttributeMap& attrs,
547
                               const framework::AttributeMap& default_attrs) {
548 549 550 551 552 553 554 555 556
  return PrepareImpl<VariableWrapper>(ins,
                                      outs,
                                      op,
                                      place,
                                      attrs,
                                      default_attrs,
                                      phi_kernel_factory,
                                      phi_op_utils_map,
                                      default_phi_kernel_sig_map);
557 558
}

559 560
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
                               const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
561 562 563 564
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
                               const framework::AttributeMap& attrs,
                               const framework::AttributeMap& default_attrs) {
565 566 567 568 569 570 571 572 573
  return PrepareImpl<egr::EagerVariable>(ins,
                                         outs,
                                         op,
                                         place,
                                         attrs,
                                         default_attrs,
                                         phi_kernel_factory,
                                         phi_op_utils_map,
                                         default_phi_kernel_sig_map);
J
Jiabin Yang 已提交
574
}
575 576
template <typename VarType>
static void PreparedOpRunImpl(
577 578
    const framework::OperatorBase& op,
    const framework::RuntimeContext& ctx,
579
    const phi::KernelKey& kernel_key,
580
    const framework::OperatorWithKernel::OpKernelFunc& func,
581 582
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
583 584 585 586
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
587
    const framework::AttributeMap& default_attrs) {
J
Jiabin Yang 已提交
588
  // TODO(zjl): remove scope in dygraph
H
hong 已提交
589

590
  {
591
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
592
                                       platform::TracerEventType::OperatorInner,
593 594 595 596 597 598 599
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
600
                                                      &kernel_key,
601 602
                                                      arg_map_fn,
                                                      default_kernel_signature);
603
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
604 605
    record_event.End();
    platform::RecordOpInfoSupplement(
606
        op.Type(), op.Attrs(), infer_shape_ctx, ctx, op.Id());
607 608 609
  }

  {
610
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
611
                                       platform::TracerEventType::OperatorInner,
612 613
                                       1,
                                       platform::EventRole::kInnerOp);
H
hong 已提交
614

615 616
    func(DygraphExecutionContext<VarType>(
        op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs));
617
  }
618

619 620 621 622 623
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

L
Leo Chen 已提交
624 625 626 627 628 629 630 631
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

632 633 634 635 636 637 638 639 640 641 642 643
  /**
   * [ Why need handle complex gradient to real gradient? ]
   *
   * After the introduction of complex number calculations, Ops that support
   * complex number calculations generally support type promotion, such as
   * x(float32) + y(complex64) = out(complex64), then the type of the grad
   * tensor should be dout(complex64), dx(float32), dy (complex64).
   *
   * But because the dout is complex64, the dx is also complex64 after
   * grad op kernel executed, we need to recognize this situation and
   * convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
   */
644
  if (framework::IsComplexType(kernel_key.dtype())) {
645 646
    HandleComplexGradToRealGrad<VarType>(outs);
  }
647
}
H
hong 已提交
648

649 650 651
template <typename VarType>
static void PreparedOpRunPtImpl(
    const framework::OperatorBase& op,
652
    const phi::KernelKey& kernel_key,
653 654
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
655 656
    const phi::KernelSignature& kernel_signature,
    const phi::Kernel& phi_kernel,
657
    const framework::RuntimeContext& ctx,
658 659 660 661
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
662
    const framework::AttributeMap& default_attrs) {
663
  {
664
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
665
                                       platform::TracerEventType::OperatorInner,
666 667 668 669 670 671 672
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
673
                                                      &kernel_key,
674 675
                                                      arg_map_fn,
                                                      default_kernel_signature);
676
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
677 678 679
    record_event.End();
    platform::RecordOpInfoSupplement(
        op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
680 681 682
  }

  {
683
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
684
                                       platform::TracerEventType::OperatorInner,
685 686
                                       1,
                                       platform::EventRole::kInnerOp);
687

688 689 690 691 692 693 694 695 696 697 698 699
    if (phi_kernel.GetKernelRegisteredType() ==
        phi::KernelRegisteredType::FUNCTION) {
      PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
      phi::KernelContext phi_kernel_context;
      BuildDygraphPhiKernelContext<VarType>(kernel_signature,
                                            phi_kernel,
                                            ins,
                                            outs,
                                            attrs,
                                            default_attrs,
                                            dev_ctx,
                                            &phi_kernel_context);
700

701 702 703 704 705 706
      phi_kernel(&phi_kernel_context);
    } else {
      DygraphExecutionContext<VarType> exe_ctx(
          op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs);
      phi_kernel(&exe_ctx);
    }
707
  }
708

709 710 711 712 713
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

714 715
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
716 717
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
718 719 720 721
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

722
  if (framework::IsComplexType(kernel_key.dtype())) {
723 724
    HandleComplexGradToRealGrad<VarType>(outs);
  }
725 726
}

727 728
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
                     const NameVarMap<VarBase>& outs,
729 730
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
731
  if (run_phi_kernel_) {
732
    PreparedOpRunPtImpl<VarBase>(op_,
733
                                 kernel_key_,
734 735 736 737
                                 arg_map_fn_,
                                 default_kernel_signature_,
                                 kernel_signature_,
                                 phi_kernel_,
738
                                 ctx_,
739 740 741 742
                                 dev_ctx_,
                                 ins,
                                 outs,
                                 attrs,
743
                                 default_attrs);
744
  } else {
745 746
    PreparedOpRunImpl<VarBase>(op_,
                               ctx_,
747
                               kernel_key_,
748 749 750 751 752 753 754 755
                               func_,
                               arg_map_fn_,
                               default_kernel_signature_,
                               dev_ctx_,
                               ins,
                               outs,
                               attrs,
                               default_attrs);
756
  }
757
}
H
hong 已提交
758

759 760
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
                     const NameVarMap<VariableWrapper>& outs,
761 762
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
763
  if (run_phi_kernel_) {
764
    PreparedOpRunPtImpl<VariableWrapper>(op_,
765
                                         kernel_key_,
766 767 768 769
                                         arg_map_fn_,
                                         default_kernel_signature_,
                                         kernel_signature_,
                                         phi_kernel_,
770
                                         ctx_,
771 772 773 774 775
                                         dev_ctx_,
                                         ins,
                                         outs,
                                         attrs,
                                         default_attrs);
776
  } else {
777 778
    PreparedOpRunImpl<VariableWrapper>(op_,
                                       ctx_,
779
                                       kernel_key_,
780 781 782 783 784 785 786 787
                                       func_,
                                       arg_map_fn_,
                                       default_kernel_signature_,
                                       dev_ctx_,
                                       ins,
                                       outs,
                                       attrs,
                                       default_attrs);
788
  }
J
Jiabin Yang 已提交
789 790
}

791 792
void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
                     const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
793 794
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
795
  if (run_phi_kernel_) {
796
    PreparedOpRunPtImpl<egr::EagerVariable>(op_,
797
                                            kernel_key_,
798 799 800 801
                                            arg_map_fn_,
                                            default_kernel_signature_,
                                            kernel_signature_,
                                            phi_kernel_,
802
                                            ctx_,
803 804 805 806 807
                                            dev_ctx_,
                                            ins,
                                            outs,
                                            attrs,
                                            default_attrs);
J
Jiabin Yang 已提交
808
  } else {
809 810
    PreparedOpRunImpl<egr::EagerVariable>(op_,
                                          ctx_,
811
                                          kernel_key_,
812 813 814 815 816 817 818 819
                                          func_,
                                          arg_map_fn_,
                                          default_kernel_signature_,
                                          dev_ctx_,
                                          ins,
                                          outs,
                                          attrs,
                                          default_attrs);
J
Jiabin Yang 已提交
820 821 822
  }
}

J
Jiabin Yang 已提交
823 824
}  // namespace imperative
}  // namespace paddle