prepared_operator.cc 31.3 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "paddle/fluid/imperative/prepared_operator.h"
16

J
Jiabin Yang 已提交
17
#include "paddle/fluid/eager/eager_tensor.h"
18
#include "paddle/fluid/framework/data_type_transform.h"
19
#include "paddle/fluid/framework/details/nan_inf_utils.h"
20
#include "paddle/fluid/imperative/infer_shape_context.h"
21
#include "paddle/fluid/imperative/tracer.h"
22
#include "paddle/phi/common/int_array.h"
23
#include "paddle/phi/common/scalar.h"
24
#include "paddle/utils/small_vector.h"
Q
QingshuChen 已提交
25
#ifdef PADDLE_WITH_XPU
26
#include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
Q
QingshuChen 已提交
27
#endif
28 29 30
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_op_list.h"
#endif
L
Liu-xiandong 已提交
31
#include "paddle/fluid/framework/library_type.h"
32
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
33
#include "paddle/fluid/platform/place.h"
C
chenjian 已提交
34
#include "paddle/fluid/platform/profiler/event_tracing.h"
C
chenjian 已提交
35
#include "paddle/fluid/platform/profiler/supplement_tracing.h"
36

37
DECLARE_bool(check_nan_inf);
38
DECLARE_bool(benchmark);
F
Feng Xing 已提交
39
DECLARE_bool(run_kp_kernel);
40

J
Jiabin Yang 已提交
41 42 43
namespace paddle {
namespace imperative {

44
static const phi::Kernel empty_kernel;
45 46
static const framework::RuntimeContext empty_ctx({}, {});
static const framework::Scope empty_scope;
47

48 49 50 51 52 53 54
const phi::KernelFactory& PreparedOp::phi_kernel_factory =
    phi::KernelFactory::Instance();
const phi::OpUtilsMap& PreparedOp::phi_op_utils_map =
    phi::OpUtilsMap::Instance();
const phi::DefaultKernelSignatureMap& PreparedOp::default_phi_kernel_sig_map =
    phi::DefaultKernelSignatureMap::Instance();

55 56 57 58 59 60 61 62 63 64
const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<paddle::imperative::VarBase>& var) {
  return var->SharedVar();
}

const std::shared_ptr<VariableWrapper>& GetVariableWrapper(
    const std::shared_ptr<VariableWrapper>& var) {
  return var;
}

65
const phi::DenseTensor* GetTensorFromVar(const framework::Variable& var) {
66 67
  if (var.IsType<phi::DenseTensor>()) {
    return &(var.Get<phi::DenseTensor>());
68 69
  } else if (var.IsType<phi::SelectedRows>()) {
    return &(var.Get<phi::SelectedRows>().value());
J
Jiabin Yang 已提交
70 71 72 73 74
  } else {
    return nullptr;
  }
}

75
template <typename VarType>
J
Jiabin Yang 已提交
76
void HandleComplexGradToRealGrad(const NameVarMap<VarType>& outs) {
77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92
  for (auto& pair : outs) {
    for (auto& var : pair.second) {
      if (var == nullptr) {
        continue;
      }
      if (var->ForwardDataType() ==
          static_cast<framework::proto::VarType::Type>(-1)) {
        VLOG(6) << "Var (" << var->Name()
                << ")'s forward data type is not set.";
        continue;
      }
      if (!framework::IsComplexType(var->DataType()) ||
          framework::IsComplexType(var->ForwardDataType())) {
        continue;
      }
      const auto* tensor = GetTensorFromVar(var->Var());
J
Jiabin Yang 已提交
93
      if (tensor && tensor->IsInitialized()) {
94 95 96 97
        VLOG(6) << "Transform " << framework::DataTypeToString(var->DataType())
                << " var `" << var->Name() << "` to "
                << framework::DataTypeToString(var->ForwardDataType())
                << " real var in dynamic graph.";
98
        phi::DenseTensor out;
99 100
        framework::TransComplexToReal(
            var->ForwardDataType(), var->DataType(), *tensor, &out);
101
        SetTensorToVariable(var->Var(), out, var->MutableVar());
J
Jiabin Yang 已提交
102 103 104 105 106
      }
    }
  }
}

J
Jiabin Yang 已提交
107
template <>
108 109
void HandleComplexGradToRealGrad<egr::EagerVariable>(
    const NameVarMap<egr::EagerVariable>& outs) {
J
Jiabin Yang 已提交
110 111 112
  // TODO(jiabin): Support Complex here.
}

113 114 115 116 117
void TestHandleComplexGradToRealGradEager(
    const NameVarMap<egr::EagerVariable>& outs) {
  HandleComplexGradToRealGrad<egr::EagerVariable>(outs);
}

J
Jiabin Yang 已提交
118 119
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
120
                       const phi::KernelKey& kernel_key,
121
                       const framework::OperatorWithKernel::OpKernelFunc& func,
122 123
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
124
                       platform::DeviceContext* dev_ctx)
125 126
    : op_(op),
      ctx_(ctx),
127
      kernel_key_(kernel_key),
128
      func_(func),
129
      dev_ctx_(dev_ctx),
130 131 132
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      phi_kernel_(empty_kernel) {}
133

134 135
PreparedOp::PreparedOp(const framework::OperatorBase& op,
                       const framework::RuntimeContext& ctx,
136
                       const phi::KernelKey& kernel_key,
137 138 139 140
                       const phi::ArgumentMappingFn* arg_map_fn,
                       const phi::KernelSignature* default_kernel_signature,
                       phi::KernelSignature&& kernel_signature,
                       const phi::Kernel& phi_kernel,
141 142 143
                       platform::DeviceContext* dev_ctx)
    : op_(op),
      ctx_(ctx),
144
      kernel_key_(kernel_key),
145 146
      func_(nullptr),
      dev_ctx_(dev_ctx),
147
      run_phi_kernel_(true),
148 149 150 151
      arg_map_fn_(arg_map_fn),
      default_kernel_signature_(default_kernel_signature),
      kernel_signature_(std::move(kernel_signature)),
      phi_kernel_(phi_kernel) {}
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
#ifdef PADDLE_WITH_MLU

static void tokenize(const std::string& ops,
                     char delim,
                     std::unordered_set<std::string>* op_set) {
  std::string::size_type beg = 0;
  for (uint64_t end = 0; (end = ops.find(delim, end)) != std::string::npos;
       ++end) {
    op_set->insert(ops.substr(beg, end - beg));
    beg = end + 1;
  }

  op_set->insert(ops.substr(beg));
}

static bool is_in_mlu_black_list(const std::string& op_name) {
  static bool inited = false;
  static std::unordered_set<std::string> mlu_black_list;
  static std::mutex s_mtx;
  if (!inited) {
    std::lock_guard<std::mutex> guard(s_mtx);
    if (!inited) {
      if (std::getenv("MLU_BLACK_LIST") != nullptr) {
        std::string ops(std::getenv("MLU_BLACK_LIST"));
        tokenize(ops, ',', &mlu_black_list);
      }
      inited = true;
      VLOG(3) << "MLU Black List: ";
      for (auto iter = mlu_black_list.begin(); iter != mlu_black_list.end();
           ++iter) {
        VLOG(3) << *iter << " ";
      }
    }
  }
  if (mlu_black_list.find(op_name) != mlu_black_list.end()) {
    return true;
  }
  return false;
}

#endif

195
template <typename VarType>
196
PreparedOp PrepareImpl(
197 198 199 200
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::OperatorWithKernel& op,
    const platform::Place& place,
201 202 203 204 205
    const framework::AttributeMap& attrs,
    const framework::AttributeMap& default_attrs,
    const phi::KernelFactory& phi_kernel_factory,
    const phi::OpUtilsMap& phi_op_utils_map,
    const phi::DefaultKernelSignatureMap& default_phi_kernel_sig_map) {
206
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
207
  auto* dev_ctx = pool.Get(place);
208

209 210 211 212 213 214
#ifdef PADDLE_WITH_MKLDNN
  // MKLDNN variant of code reads attributes in some of GetKernelTypeForVar and
  // GetKernelType functions, so we need to copy the attributes there.
  // Const qualifier of Attrs had to be discarded to overwrite it.
  if (FLAGS_use_mkldnn) {
    auto& mutable_op_attrs = const_cast<framework::AttributeMap&>(op.Attrs());
215 216 217 218
    mutable_op_attrs = default_attrs;
    for (auto& attr : attrs) {
      mutable_op_attrs[attr.first] = attr.second;
    }
219 220
  }
#endif
221 222
  // NOTE(zhiqiu): for kernels on given device, for example NPU, the order to
  // choose is:
223
  // phi npu kernel > fluid npu kernel > phi cpu kernel > fluid cpu kernel
J
Jiabin Yang 已提交
224

225
  // 1. get expected kernel key
226
  auto dygraph_exe_ctx = DygraphExecutionContext<VarType>(
227
      op, empty_scope, *dev_ctx, empty_ctx, ins, outs, attrs, default_attrs);
228
  auto expected_kernel_key = op.GetExpectedKernelType(dygraph_exe_ctx);
229

230 231
  const phi::KernelSignature* default_kernel_signature = nullptr;
  phi::KernelSignature kernel_signature;
232
  std::string phi_kernel_name;
233 234

// NOTE(jiahongyu): The registered MKLDNN kernel have library_type =
235
// LibraryType::kMKLDNN and data_layout_ = DataLayout::ONEDNN. But the default
236
// values are kPlain, so we need to modify the library_type and data_layout_
237 238 239 240
// here. There are three statements in if condition:
// 1. Whether mkldnn kernel fallbacks to plain kernel;
// 2. Whether this op has specific implementation;
// 3. Whether mkldnn kernel can be used.
241
#ifdef PADDLE_WITH_MKLDNN
242
  if (!op.DnnFallback() && !paddle::platform::in_mkldnn_white_list(op.Type()) &&
243 244 245
      op.CanMKLDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) {
    expected_kernel_key.set_backend(phi::Backend::ONEDNN);
    expected_kernel_key.set_layout(phi::DataLayout::ONEDNN);
246 247 248
  }
#endif

249
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
250 251
  if (op.CanCUDNNBeUsed(dygraph_exe_ctx, expected_kernel_key.dtype())) {
    expected_kernel_key.set_backend(phi::Backend::GPUDNN);
252 253 254
  }
#endif

L
Liu-xiandong 已提交
255
#if defined(PADDLE_WITH_XPU)
256 257 258
  bool is_xpu_unsupport = expected_kernel_key.backend() == phi::Backend::XPU &&
                          !paddle::platform::is_xpu_support_op(
                              op.Type(), expected_kernel_key.dtype());
259
#endif
260

261 262
#ifdef PADDLE_WITH_MLU
  if (is_in_mlu_black_list(op.Type())) {
263
    expected_kernel_key.set_backend(phi::Backend::CPU);
264 265 266
  }
#endif

267 268
  bool has_phi_kernel = false;

269 270
  const auto* arg_map_fn = phi_op_utils_map.GetArgumentMappingFn(op.Type());

271 272
  if (arg_map_fn) {
    has_phi_kernel = true;
273
    kernel_signature = (*arg_map_fn)(
274 275
        framework::ExecutionArgumentMappingContext(dygraph_exe_ctx));
  } else {
276
    default_kernel_signature =
277
        default_phi_kernel_sig_map.GetNullable(op.Type());
278
    if (default_kernel_signature) {
279
      has_phi_kernel = true;
280
      kernel_signature = *default_kernel_signature;
281 282
    }
  }
283

284
  if (has_phi_kernel) {
285
    VLOG(6) << kernel_signature;
286
    phi_kernel_name = kernel_signature.name;
287 288 289
// NOTE(Liu-xiandong): The register kernel used KP have library_type[KP],
// But the default library_type is Plain, so we need to modify the
// library_type here, otherwise it can't work.
L
Liu-xiandong 已提交
290
#ifdef PADDLE_WITH_XPU_KP
291
    if (expected_kernel_key.backend() == phi::Backend::XPU) {
L
Liu-xiandong 已提交
292
      bool use_xpu_kp_kernel_rt =
293 294
          FLAGS_run_kp_kernel && paddle::platform::is_xpu_support_op(
                                     op.Type(), expected_kernel_key.dtype());
L
Liu-xiandong 已提交
295 296 297 298 299 300 301 302 303 304 305
      bool use_xpu_kp_kernel_debug =
          paddle::platform::is_in_xpu_kpwhite_list(op.Type());
      if (use_xpu_kp_kernel_rt) {
        VLOG(3) << "phi xpu_kp using rt mode ";
      }
      if (use_xpu_kp_kernel_debug) {
        VLOG(3) << "phi xpu_kp using debug mode ";
      }
      bool is_xpu_kp_support =
          (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
      if (is_xpu_kp_support) {
306 307
        auto expected_kernel_key_backend = expected_kernel_key.backend();
        expected_kernel_key.set_backend(phi::Backend::KPS);
308
        VLOG(3) << "modifing XPU KP kernel: " << phi_kernel_name
L
Liu-xiandong 已提交
309
                << ", using_kernel_key:" << expected_kernel_key;
310

311
        if (!phi_kernel_factory.HasKernel(phi_kernel_name,
312 313
                                          expected_kernel_key)) {
          expected_kernel_key.set_backend(expected_kernel_key_backend);
314
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
315 316
                  << " in dynamic graph is failed " << expected_kernel_key;
        } else {
317
          VLOG(3) << "modify XPU KP kernel: " << phi_kernel_name
318
                  << " in dynamic graph is succeed " << expected_kernel_key;
319
        }
L
Liu-xiandong 已提交
320 321 322
      }
    }
#endif
323

324
    auto& phi_kernel =
325
        phi_kernel_factory.SelectKernel(phi_kernel_name, expected_kernel_key);
326

327
    if (phi_kernel.IsValid()
L
Liu-xiandong 已提交
328
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
329 330
        && !is_xpu_unsupport
#endif
331
    ) {
332
      VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
333
              << " | kernel key: " << expected_kernel_key
334
              << " | kernel: " << phi_kernel;
335

336 337 338 339
      if (!framework::backends_are_same_class(
              expected_kernel_key.backend(),
              phi::TransToPhiBackend(dev_ctx->GetPlace()))) {
        dev_ctx = pool.Get(phi::TransToPhiPlace(expected_kernel_key.backend()));
W
Wilber 已提交
340
      }
341 342 343 344 345 346 347 348
      return PreparedOp(op,
                        empty_ctx,
                        expected_kernel_key,
                        arg_map_fn,
                        default_kernel_signature,
                        std::move(kernel_signature),
                        phi_kernel,
                        dev_ctx);
349
    } else {
350
      VLOG(6) << "Dynamic mode ChoosePhiKernel - kernel `" << phi_kernel_name
351 352 353 354
              << "` not found.";
    }
  }

355
  // 2. check if op[type] has kernel registered.
J
Jiabin Yang 已提交
356 357
  auto& all_op_kernels = op.AllOpKernels();
  auto kernels_iter = all_op_kernels.find(op.Type());
358

359 360 361
// NOTE(Liu-xiandong): If we can't find heterogeneous kernel in phi,
// we need to select the heterogeneous kernel in fluid, but the kernel
// registered in KP use library_type[KP], we need to modify it.
362 363
#ifdef PADDLE_WITH_XPU_KP
  bool use_xpu_kp_kernel_rt =
364
      expected_kernel_key.backend() == phi::Backend::XPU &&
365
      FLAGS_run_kp_kernel &&
366 367
      paddle::platform::is_xpu_support_op(op.Type(),
                                          expected_kernel_key.dtype());
368
  bool use_xpu_kp_kernel_debug =
369
      expected_kernel_key.backend() == phi::Backend::XPU &&
370 371 372
      paddle::platform::is_in_xpu_kpwhite_list(op.Type());
  bool is_xpu_kp_support = (use_xpu_kp_kernel_rt || use_xpu_kp_kernel_debug);
  if (is_xpu_kp_support) {
373
    expected_kernel_key.set_backend(phi::Backend::KPS);
374 375 376
  }
#endif

377 378
  paddle::framework::OpKernelType fluid_kernel_type =
      paddle::framework::TransPhiKernelKeyToOpKernelType(expected_kernel_key);
379
  if ((kernels_iter == all_op_kernels.end() ||
380
       kernels_iter->second.find(fluid_kernel_type) ==
381
           kernels_iter->second.end())
L
Liu-xiandong 已提交
382
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
383
      || is_xpu_unsupport
384
#endif
385 386 387
#if defined(PADDLE_WITH_XPU_KP)
      || (is_xpu_unsupport && !is_xpu_kp_support)
#endif
388
  ) {
389
    if (has_phi_kernel) {
390
      auto phi_cpu_kernel_key = FallBackToCpu(expected_kernel_key, op);
391 392 393 394 395 396
      auto& phi_cpu_kernel =
          phi_kernel_factory.SelectKernel(phi_kernel_name, phi_cpu_kernel_key);
      if (phi_cpu_kernel.IsValid()) {
        VLOG(6) << "Dynamic mode PrepareImpl - kernel name: " << phi_kernel_name
                << " | kernel key: " << phi_cpu_kernel_key
                << " | kernel: " << phi_cpu_kernel;
397
        auto* cpu_ctx = pool.Get(paddle::platform::CPUPlace());
398 399 400 401 402 403 404 405
        return PreparedOp(op,
                          empty_ctx,
                          phi_cpu_kernel_key,
                          arg_map_fn,
                          default_kernel_signature,
                          std::move(kernel_signature),
                          phi_cpu_kernel,
                          cpu_ctx);
406 407 408 409
      }
    }
  }

410
  PADDLE_ENFORCE_NE(
411 412
      kernels_iter,
      all_op_kernels.end(),
413 414 415
      platform::errors::NotFound(
          "There are no kernels which are registered in the %s operator.",
          op.Type()));
416

J
Jiabin Yang 已提交
417
  auto& kernels = kernels_iter->second;
418
  auto kernel_iter = kernels.find(fluid_kernel_type);
419

420
#if defined(PADDLE_WITH_XPU) && !defined(PADDLE_WITH_XPU_KP)
421
  if (paddle::platform::is_xpu_place(fluid_kernel_type.place_) &&
422
      (kernel_iter == kernels.end() || is_xpu_unsupport)) {
423
    VLOG(3) << "fluid missing XPU kernel: " << op.Type()
424
            << ", expected_kernel_key:" << fluid_kernel_type
425
            << ", fallbacking to CPU one!";
426 427
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
428
  }
429
#endif
L
Liu-xiandong 已提交
430 431

#ifdef PADDLE_WITH_XPU_KP
432
  if (paddle::platform::is_xpu_place(fluid_kernel_type.place_)) {
433
    if (use_xpu_kp_kernel_rt) {
434
      VLOG(3) << "fluid xpu_kp using rt mode ";
435 436
    }
    if (use_xpu_kp_kernel_debug) {
437
      VLOG(3) << "fluid xpu_kp using debug mode ";
438 439
    }
    if (is_xpu_kp_support) {
440 441
      fluid_kernel_type.library_type_ = paddle::framework::LibraryType::kKP;
      kernel_iter = kernels.find(fluid_kernel_type);
442
      VLOG(3) << "using fluid XPU KP kernel: " << op.Type()
443
              << ", using_kernel_key:" << fluid_kernel_type;
444 445 446
    }
    if (!is_xpu_kp_support &&
        (kernel_iter == kernels.end() || is_xpu_unsupport)) {
447
      VLOG(3) << "fluid missing XPU kernel: " << op.Type()
448
              << ", expected_kernel_key:" << fluid_kernel_type
449
              << ", fallbacking to CPU one!";
450 451
      fluid_kernel_type.place_ = platform::CPUPlace();
      kernel_iter = kernels.find(fluid_kernel_type);
452
    }
L
Liu-xiandong 已提交
453 454 455
  }
#endif

456 457
#ifdef PADDLE_WITH_ASCEND_CL
  if (kernel_iter == kernels.end() &&
458
      paddle::platform::is_npu_place(fluid_kernel_type.place_)) {
459
    VLOG(3) << "missing NPU kernel: " << op.Type()
460
            << ", expected_kernel_key:" << fluid_kernel_type
461
            << ", fallbacking to CPU one!";
462 463
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
464 465 466 467
  }
#endif
#ifdef PADDLE_WITH_IPU
  if (kernel_iter == kernels.end() &&
468
      paddle::platform::is_ipu_place(fluid_kernel_type.place_)) {
469
    VLOG(3) << "missing IPU kernel: " << op.Type()
470
            << ", expected_kernel_key:" << fluid_kernel_type
471
            << ", fallbacking to CPU one!";
472 473
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
474
  }
475 476 477
#endif
#ifdef PADDLE_WITH_MLU
  if (kernel_iter == kernels.end() &&
478
      paddle::platform::is_mlu_place(fluid_kernel_type.place_)) {
479
    VLOG(3) << "missing MLU kernel: " << op.Type()
480
            << ", expected_kernel_key:" << fluid_kernel_type
481
            << ", fallbacking to CPU one!";
482 483
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
484
  }
485 486 487
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
  if (kernel_iter == kernels.end() &&
488
      paddle::platform::is_custom_place(fluid_kernel_type.place_)) {
489 490 491
    VLOG(3) << "missing " << place.GetDeviceType() << " kernel: " << op.Type()
            << ", expected_kernel_key:" << expected_kernel_key
            << ", fallbacking to CPU one!";
492 493
    fluid_kernel_type.place_ = platform::CPUPlace();
    kernel_iter = kernels.find(fluid_kernel_type);
494
  }
495
#endif
496 497
  // TODO(jiabin): Add operator.cc's line 1000 part back when we need that
  // case
498 499 500 501 502
  PADDLE_ENFORCE_NE(
      kernel_iter,
      kernels.end(),
      platform::errors::NotFound("Operator %s does not have kernel for %s.",
                                 op.Type(),
503 504 505 506 507 508 509 510 511 512 513 514 515 516
                                 KernelTypeToString(fluid_kernel_type)));

  if (!platform::places_are_same_class(fluid_kernel_type.place_,
                                       dev_ctx->GetPlace())) {
    dev_ctx = pool.Get(fluid_kernel_type.place_);
  }
  return PreparedOp(
      op,
      empty_ctx,
      framework::TransOpKernelTypeToPhiKernelKey(fluid_kernel_type),
      kernel_iter->second,
      arg_map_fn,
      default_kernel_signature,
      dev_ctx);
517 518
}

519 520 521 522
PreparedOp PreparedOp::Prepare(const NameVarMap<VarBase>& ins,
                               const NameVarMap<VarBase>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
523
                               const framework::AttributeMap& attrs,
524
                               const framework::AttributeMap& default_attrs) {
525 526 527 528 529 530 531 532
  return PrepareImpl<VarBase>(ins,
                              outs,
                              op,
                              place,
                              attrs,
                              default_attrs,
                              phi_kernel_factory,
                              phi_op_utils_map,
533
                              default_phi_kernel_sig_map);
534 535 536 537 538 539
}

PreparedOp PreparedOp::Prepare(const NameVarMap<VariableWrapper>& ins,
                               const NameVarMap<VariableWrapper>& outs,
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
540
                               const framework::AttributeMap& attrs,
541
                               const framework::AttributeMap& default_attrs) {
542 543 544 545 546 547 548 549 550
  return PrepareImpl<VariableWrapper>(ins,
                                      outs,
                                      op,
                                      place,
                                      attrs,
                                      default_attrs,
                                      phi_kernel_factory,
                                      phi_op_utils_map,
                                      default_phi_kernel_sig_map);
551 552
}

553 554
PreparedOp PreparedOp::Prepare(const NameVarMap<egr::EagerVariable>& ins,
                               const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
555 556 557 558
                               const framework::OperatorWithKernel& op,
                               const platform::Place& place,
                               const framework::AttributeMap& attrs,
                               const framework::AttributeMap& default_attrs) {
559 560 561 562 563 564 565 566 567
  return PrepareImpl<egr::EagerVariable>(ins,
                                         outs,
                                         op,
                                         place,
                                         attrs,
                                         default_attrs,
                                         phi_kernel_factory,
                                         phi_op_utils_map,
                                         default_phi_kernel_sig_map);
J
Jiabin Yang 已提交
568
}
569 570
template <typename VarType>
static void PreparedOpRunImpl(
571 572
    const framework::OperatorBase& op,
    const framework::RuntimeContext& ctx,
573
    const phi::KernelKey& kernel_key,
574
    const framework::OperatorWithKernel::OpKernelFunc& func,
575 576
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
577 578 579 580
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
581
    const framework::AttributeMap& default_attrs) {
J
Jiabin Yang 已提交
582
  // TODO(zjl): remove scope in dygraph
H
hong 已提交
583

584
  {
585
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
586
                                       platform::TracerEventType::OperatorInner,
587 588 589 590 591 592 593
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
594
                                                      &kernel_key,
595 596
                                                      arg_map_fn,
                                                      default_kernel_signature);
597
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
598 599
    record_event.End();
    platform::RecordOpInfoSupplement(
600
        op.Type(), op.Attrs(), infer_shape_ctx, ctx, op.Id());
601 602 603
  }

  {
604
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
605
                                       platform::TracerEventType::OperatorInner,
606 607
                                       1,
                                       platform::EventRole::kInnerOp);
H
hong 已提交
608

609 610
    func(DygraphExecutionContext<VarType>(
        op, empty_scope, *dev_ctx, ctx, ins, outs, attrs, default_attrs));
611
  }
612

613 614 615 616 617
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

L
Leo Chen 已提交
618 619 620 621 622 623 624 625
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

626 627 628 629 630 631 632 633 634 635 636 637
  /**
   * [ Why need handle complex gradient to real gradient? ]
   *
   * After the introduction of complex number calculations, Ops that support
   * complex number calculations generally support type promotion, such as
   * x(float32) + y(complex64) = out(complex64), then the type of the grad
   * tensor should be dout(complex64), dx(float32), dy (complex64).
   *
   * But because the dout is complex64, the dx is also complex64 after
   * grad op kernel executed, we need to recognize this situation and
   * convert dx to float32 type. HandleComplexGradToRealGrad does this thing.
   */
638
  if (framework::IsComplexType(kernel_key.dtype())) {
639 640
    HandleComplexGradToRealGrad<VarType>(outs);
  }
641
}
H
hong 已提交
642

643 644 645
template <typename VarType>
static void PreparedOpRunPtImpl(
    const framework::OperatorBase& op,
646
    const phi::KernelKey& kernel_key,
647 648
    const phi::ArgumentMappingFn* arg_map_fn,
    const phi::KernelSignature* default_kernel_signature,
649 650 651 652 653 654
    const phi::KernelSignature& kernel_signature,
    const phi::Kernel& phi_kernel,
    platform::DeviceContext* dev_ctx,
    const NameVarMap<VarType>& ins,
    const NameVarMap<VarType>& outs,
    const framework::AttributeMap& attrs,
655
    const framework::AttributeMap& default_attrs) {
656
  {
657
    platform::RecordEvent record_event("infer_shape",
C
chenjian 已提交
658
                                       platform::TracerEventType::OperatorInner,
659 660 661 662 663 664 665
                                       1,
                                       platform::EventRole::kInnerOp);
    DygraphInferShapeContext<VarType> infer_shape_ctx(&ins,
                                                      &outs,
                                                      &attrs,
                                                      &default_attrs,
                                                      op.Type(),
666
                                                      &kernel_key,
667 668
                                                      arg_map_fn,
                                                      default_kernel_signature);
669
    op.Info().infer_shape_(&infer_shape_ctx);
C
chenjian 已提交
670 671 672
    record_event.End();
    platform::RecordOpInfoSupplement(
        op.Type(), op.Attrs(), infer_shape_ctx, kernel_signature);
673 674 675
  }

  {
676
    platform::RecordEvent record_event("compute",
C
chenjian 已提交
677
                                       platform::TracerEventType::OperatorInner,
678 679
                                       1,
                                       platform::EventRole::kInnerOp);
680

681
    PreparePhiData<VarType>(phi_kernel, kernel_signature, ins);
682

683
    phi::KernelContext phi_kernel_context;
684 685 686 687 688 689 690
    BuildDygraphPhiKernelContext<VarType>(kernel_signature,
                                          phi_kernel,
                                          ins,
                                          outs,
                                          attrs,
                                          default_attrs,
                                          dev_ctx,
691
                                          &phi_kernel_context);
692

693
    phi_kernel(&phi_kernel_context);
694
  }
695

696 697 698 699 700
  if (FLAGS_check_nan_inf) {
    framework::details::CheckOpHasNanOrInfInDygraph<VarType>(
        op.Type(), outs, dev_ctx->GetPlace());
  }

701 702
  if (FLAGS_benchmark) {
    dev_ctx->Wait();
703 704
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
    PADDLE_ENFORCE_GPU_SUCCESS(platform::GpuGetLastError());
705 706 707 708
    VLOG(4) << "Operator(" << op.Type() << "): context wait and get last error";
#endif
  }

709
  if (framework::IsComplexType(kernel_key.dtype())) {
710 711
    HandleComplexGradToRealGrad<VarType>(outs);
  }
712 713
}

714 715
void PreparedOp::Run(const NameVarMap<VarBase>& ins,
                     const NameVarMap<VarBase>& outs,
716 717
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
718
  if (run_phi_kernel_) {
719
    PreparedOpRunPtImpl<VarBase>(op_,
720
                                 kernel_key_,
721 722 723 724 725 726 727 728
                                 arg_map_fn_,
                                 default_kernel_signature_,
                                 kernel_signature_,
                                 phi_kernel_,
                                 dev_ctx_,
                                 ins,
                                 outs,
                                 attrs,
729
                                 default_attrs);
730
  } else {
731 732
    PreparedOpRunImpl<VarBase>(op_,
                               ctx_,
733
                               kernel_key_,
734 735 736 737 738 739 740 741
                               func_,
                               arg_map_fn_,
                               default_kernel_signature_,
                               dev_ctx_,
                               ins,
                               outs,
                               attrs,
                               default_attrs);
742
  }
743
}
H
hong 已提交
744

745 746
void PreparedOp::Run(const NameVarMap<VariableWrapper>& ins,
                     const NameVarMap<VariableWrapper>& outs,
747 748
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
749
  if (run_phi_kernel_) {
750
    PreparedOpRunPtImpl<VariableWrapper>(op_,
751
                                         kernel_key_,
752 753 754 755 756 757 758 759 760
                                         arg_map_fn_,
                                         default_kernel_signature_,
                                         kernel_signature_,
                                         phi_kernel_,
                                         dev_ctx_,
                                         ins,
                                         outs,
                                         attrs,
                                         default_attrs);
761
  } else {
762 763
    PreparedOpRunImpl<VariableWrapper>(op_,
                                       ctx_,
764
                                       kernel_key_,
765 766 767 768 769 770 771 772
                                       func_,
                                       arg_map_fn_,
                                       default_kernel_signature_,
                                       dev_ctx_,
                                       ins,
                                       outs,
                                       attrs,
                                       default_attrs);
773
  }
J
Jiabin Yang 已提交
774 775
}

776 777
void PreparedOp::Run(const NameVarMap<egr::EagerVariable>& ins,
                     const NameVarMap<egr::EagerVariable>& outs,
J
Jiabin Yang 已提交
778 779
                     const framework::AttributeMap& attrs,
                     const framework::AttributeMap& default_attrs) {
780
  if (run_phi_kernel_) {
781
    PreparedOpRunPtImpl<egr::EagerVariable>(op_,
782
                                            kernel_key_,
783 784 785 786 787 788 789 790 791
                                            arg_map_fn_,
                                            default_kernel_signature_,
                                            kernel_signature_,
                                            phi_kernel_,
                                            dev_ctx_,
                                            ins,
                                            outs,
                                            attrs,
                                            default_attrs);
J
Jiabin Yang 已提交
792
  } else {
793 794
    PreparedOpRunImpl<egr::EagerVariable>(op_,
                                          ctx_,
795
                                          kernel_key_,
796 797 798 799 800 801 802 803
                                          func_,
                                          arg_map_fn_,
                                          default_kernel_signature_,
                                          dev_ctx_,
                                          ins,
                                          outs,
                                          attrs,
                                          default_attrs);
J
Jiabin Yang 已提交
804 805 806
  }
}

J
Jiabin Yang 已提交
807 808
}  // namespace imperative
}  // namespace paddle