tracer.cc 22.6 KB
Newer Older
J
Jiabin Yang 已提交
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/tracer.h"
15

16
#include <map>
H
hong 已提交
17
#include <set>
M
minqiyang 已提交
18
#include <unordered_set>
19
#include <utility>
20

W
wanghuancoder 已提交
21
#include "paddle/fluid/eager/api/utils/global_utils.h"
22
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/imperative/amp_auto_cast.h"
24
#include "paddle/fluid/imperative/execution_context.h"
25
#include "paddle/fluid/imperative/layout_autotune.h"
26
#include "paddle/fluid/imperative/op_base.h"
27
#include "paddle/fluid/operators/ops_extra_info.h"
28
#include "paddle/fluid/platform/denormal.h"
29
#include "paddle/fluid/platform/device/device_wrapper.h"
C
chengduo 已提交
30
#include "paddle/fluid/platform/profiler.h"
31
#include "paddle/fluid/platform/profiler/event_tracing.h"
32
#include "paddle/fluid/string/string_helper.h"
W
wanghuancoder 已提交
33
#include "paddle/phi/api/lib/api_gen_utils.h"
34
#include "paddle/phi/common/place.h"
W
wanghuancoder 已提交
35
#include "paddle/phi/core/dense_tensor.h"
36
#include "paddle/phi/core/flags.h"
37

38 39 40
PHI_DECLARE_bool(use_mkldnn);
PHI_DECLARE_string(tracer_mkldnn_ops_on);
PHI_DECLARE_string(tracer_mkldnn_ops_off);
W
wanghuancoder 已提交
41
DECLARE_bool(use_stride_kernel);
42

43
namespace paddle {
M
minqiyang 已提交
44
namespace imperative {
45
thread_local std::string Tracer::python_stack_ = "";
M
minqiyang 已提交
46

47 48
thread_local bool Tracer::enable_program_desc_tracing_ = false;

Z
Zeng Jinle 已提交
49 50
thread_local bool Tracer::has_grad_ = true;

51 52
thread_local bool Tracer::use_promote_ = true;

53 54
thread_local bool Tracer::use_layout_autotune_ = false;

55 56
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;

57
thread_local phi::DataType Tracer::amp_dtype_ = phi::DataType::FLOAT32;
58

59 60 61 62 63 64 65 66 67
static std::shared_ptr<Tracer> g_current_tracer(nullptr);

const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }

void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
  g_current_tracer = tracer;
  VLOG(6) << "Set current tracer: " << g_current_tracer;
}

68
void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
69 70 71 72 73 74 75 76 77 78 79 80
  for (const auto& pair : outs) {
    for (const auto& var : pair.second) {
      // NOTE(zhiqiu): this happends when None output are passed from python
      // side. For example, fake_quantize_dequantize_moving_average_abs_max may
      // pass None OutAccum in eval mode.
      // It can be refined by generate several different pybind interface for
      // one operator with different function signature.
      if (var == nullptr) {
        VLOG(4) << pair.first << " is NULL";
        continue;
      }
      VLOG(6) << "Set output: " << var->Name() << "'s OverridedStopGradient as "
81
              << generate_grad;
82
      var->InnerSetOverridedStopGradient(generate_grad);
83 84 85 86
    }
  }
}

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
void IncreaseVarbaseReferenceCountUntilCopyComplete(
    const std::shared_ptr<imperative::VarBase>& var,
    const platform::Place& place) {
  // Note(zhiqiu): Follow the logic of TensorCopy to determine the place that we
  // need to add callback, see tensor_utils.cc:245
  auto place_ = platform::is_gpu_place(place) ? place : var->Place();

  auto tracer = imperative::GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(zhiqiu): This is an empty callback, the only way is to "reference"
  // var, so it will not be destructed until the kernels launched at current
  // stream of given place is finished.
  auto callback = [var, place_]() {
    VLOG(4) << "Run callback of var:" << var->Name() << " at place " << place_;
  };

  gc->DirectClearCallback(callback);
}

paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
    const platform::Place& place) {
  // if not exists, create a new GarbageCollector at given place
  if (gcs_.count(place) == 0) {
    std::unique_ptr<framework::GarbageCollector> gc;
    if (platform::is_gpu_place(place)) {
Z
zhulei 已提交
113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
114
      gc.reset(new framework::DefaultStreamGarbageCollector(place, 0));
115 116 117 118 119 120 121 122

      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CUDA device since it's not compiled with CUDA,"
          "Please recompile or reinstall Paddle with GPU support."));
#endif
    } else if (platform::is_cuda_pinned_place(place)) {
Z
zhulei 已提交
123
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
124
      gc.reset(new framework::CUDAPinnedGarbageCollector(place, 0));
125 126 127 128 129 130 131 132 133 134

      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CUDAPinned device since it's not compiled with "
          "CUDA,"
          "Please recompile or reinstall Paddle with GPU support."));
#endif
    } else if (platform::is_xpu_place(place)) {
#if defined(PADDLE_WITH_XPU)
135
      gc.reset(new framework::XPUGarbageCollector(place, 0));
136 137 138 139 140 141 142
      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use XPU device since it's not compiled with XPU,"
          "Please recompile or reinstall Paddle with XPU support."));
#endif
    } else if (platform::is_cpu_place(place)) {
143
      gc.reset(new framework::CPUGarbageCollector(place, 0));
144
      VLOG(10) << "Created GarbageCollector at " << place;
145 146 147 148 149 150 151 152
    } else if (platform::is_ipu_place(place)) {
#if defined(PADDLE_WITH_IPU)
      gc.reset(new framework::IPUGarbageCollector(place, 0));
      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use IPU device since it's not compiled with IPU,"
          "Please recompile or reinstall Paddle with IPU support."));
153 154 155
#endif
    } else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
156 157 158 159 160 161 162 163
      if (framework::IsFastEagerDeletionModeEnabled()) {
        gc.reset(
            new framework::CustomDeviceUnsafeFastGarbageCollector(place, 0));
        VLOG(10) << "Created UnsafeFastGarbageCollector at " << place;
      } else {
        gc.reset(new framework::CustomDefaultStreamGarbageCollector(place, 0));
        VLOG(10) << "Created GarbageCollector at " << place;
      }
164 165 166 167 168 169
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CustomDevice since it's not compiled with "
          "CustomDevice,"
          "Please recompile or reinstall Paddle with CustomDevice "
          "support."));
170
#endif
171 172 173 174 175 176 177 178 179 180
    } else {
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "Unsupported place for garbage collection"));
    }
    gcs_.emplace(place, std::move(gc));
  }

  return gcs_.at(place).get();
}

J
Jiabin Yang 已提交
181
template <typename VarType>
182 183
void Tracer::TraceOp(const std::string& type,
                     const NameVarMap<VarType>& ins,
J
Jiabin Yang 已提交
184 185
                     const NameVarMap<VarType>& outs,
                     framework::AttributeMap attrs,
186 187
                     const platform::Place& place,
                     bool trace_backward,
J
Jiabin Yang 已提交
188 189
                     const std::map<std::string, std::string>& inplace_map,
                     paddle::framework::AttributeMap* passed_default_attrs_,
190
                     bool use_default_attr_map) {
191 192 193 194 195 196 197 198
  TraceOpImpl<VarType>(type,
                       ins,
                       outs,
                       attrs,
                       place,
                       trace_backward,
                       inplace_map,
                       passed_default_attrs_,
W
wanghuancoder 已提交
199 200 201 202 203 204 205 206
                       use_default_attr_map);
}

template <typename VarType>
void Tracer::TraceOpImpl(const std::string& type,
                         const NameVarMap<VarType>& ins,
                         const NameVarMap<VarType>& outs,
                         framework::AttributeMap& attrs,
207 208
                         const platform::Place& place,
                         bool trace_backward,
W
wanghuancoder 已提交
209 210 211
                         const std::map<std::string, std::string>& inplace_map,
                         paddle::framework::AttributeMap* passed_default_attrs_,
                         bool use_default_attr_map) {
212
  platform::RecordEvent op_type_record_event(
213
      type, platform::TracerEventType::Operator, 1);
214
  platform::ScopedFlushDenormal flush;
L
Leo Chen 已提交
215
  VLOG(4) << "Trace Op: " << type;
216
  if (FLAGS_use_mkldnn) {
217 218 219 220 221 222 223 224 225 226 227
    // if both lists are empty all ops are enabled (default for
    // FLAGS_use_mkldnn=1)
    // if ops_on list is not empty only ops from that list are enabled
    if (!FLAGS_tracer_mkldnn_ops_on.empty()) {
      auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos;
      attrs["use_mkldnn"] = is_on;
    } else {
      // if ops_on list is empty all ops are enabled except types from off_list
      auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos;
      attrs["use_mkldnn"] = !is_off;
    }
228
  }
W
wanghuancoder 已提交
229

230 231 232 233
  auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
  const auto& op_info = op->Info();
  auto* attr_checker = op_info.Checker();
  if (attr_checker) {
234
    attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
235
  }
W
wanghuancoder 已提交
236

237 238 239 240 241
  const auto& extra_attr_checkers =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
  for (const auto& checker : extra_attr_checkers) {
    checker(&attrs, true);
  }
242

243 244 245 246 247
  static paddle::framework::AttributeMap empty_attrs_map = {};
  const paddle::framework::AttributeMap& default_attrs =
      attr_checker == nullptr ? empty_attrs_map
                              : attr_checker->GetDefaultAttrMap();

Z
zyfncg 已提交
248
  std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
L
Leo Chen 已提交
249
  if (amp_level_ == AmpLevel::O1) {
250
    if (amp_dtype_ == phi::DataType::FLOAT16) {
251
      VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
Z
zyfncg 已提交
252
      ins_amp = std::make_unique<NameVarMap<VarType>>(
253
          AutoCastInputs<VarType>(type, ins));
254
    } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
255
      VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
Z
zyfncg 已提交
256 257
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          AutoCastBF16Inputs<VarType>(type, ins));
258
    }
L
Leo Chen 已提交
259
  } else if (amp_level_ == AmpLevel::O2) {
260
    if (amp_dtype_ == phi::DataType::FLOAT16) {
261
      VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
262 263
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          CastPureFp16Inputs<VarType>(type, ins));
264
    } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
265
      VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
Z
zyfncg 已提交
266 267
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          CastPureBf16Inputs<VarType>(type, ins));
268
    }
269
  }
270 271 272 273 274

  if (platform::is_gpu_place(place)) {
    const auto& new_tmp = ins_amp == nullptr ? ins : *ins_amp;
    const auto& tracer = imperative::GetCurrentTracer();
    ins_amp = std::make_unique<NameVarMap<VarType>>(
275 276
        imperative::AutoTuneLayout<VarType>(
            type, new_tmp, outs, &attrs, tracer));
277 278
  }

Z
zyfncg 已提交
279
  const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;
280

281
  try {
282 283
    if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
284
      platform::SetDeviceId(place.device);
285 286 287 288 289 290
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
    } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
291
      platform::SetXPUDeviceId(place.device);
292 293 294
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with XPU if use XPUPlace."));
H
houj04 已提交
295
#endif
296 297
    } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
298
      phi::DeviceManager::SetDevice(place);
299 300 301 302
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with CustomDevice if use "
          "CustomPlace."));
303 304
#endif
    }
W
wanghuancoder 已提交
305

306
    if (!use_default_attr_map) {
J
Jiabin Yang 已提交
307 308 309 310 311 312 313 314 315 316 317 318 319
      PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_,
                              paddle::platform::errors::PermissionDenied(
                                  "Detected default_attrs = nullptr."));
      VLOG(6) << "Use passed in default attrs";
      OpBase::Run(*op, new_ins, outs, attrs, (*passed_default_attrs_), place);
    } else {
      VLOG(6) << "Use Checker's default attrs";
      if (passed_default_attrs_) {
        // TODO(jiabin): Update this without copy
        *passed_default_attrs_ = default_attrs;
      }
      OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
    }
320 321 322 323
  } catch (platform::EnforceNotMet& exception) {
    framework::AppendErrorOpHint(type, &exception);
    throw std::move(exception);
  } catch (std::exception& ex) {
324 325 326 327 328 329
    PADDLE_THROW(
        platform::errors::Fatal("Operator %s raises an %s exception.\n"
                                "The exception content is\n:%s.",
                                type,
                                platform::demangle(typeid(ex).name()),
                                ex.what()));
330 331 332 333 334 335 336
  } catch (...) {
    // NOTE: this branch represents a very serious bug with
    // low probability of occurrence, and we can't get its
    // exception content here.
    PADDLE_THROW(platform::errors::Fatal(
        "Operator %s raises an unknown exception.", type));
  }
J
Jiabin Yang 已提交
337

338 339
  if (enable_program_desc_tracing_) {
    VLOG(5) << "Trace op " << type << " into ProgramDesc";
340
    program_desc_tracer_->InsertOp(type, new_ins, outs, attrs);
341 342
  }

343 344
  {
    platform::RecordEvent node_creation_record_event(
345
        "grad_node_creation", platform::TracerEventType::OperatorInner, 1);
346 347 348

    if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
      PADDLE_ENFORCE_EQ(
349 350
          passed_default_attrs_,
          nullptr,
351 352 353 354
          paddle::platform::errors::PermissionDenied(
              "We expect passed_default_attrs_ is nullptr while "
              "use_default_attr_map is true, however we got not null "
              "passed_default_attrs_. Please check your usage of trace_op. "));
355 356
      CreateGradOpNode(
          *op, new_ins, outs, attrs, default_attrs, place, inplace_map);
357 358 359 360
    } else {
      VLOG(3) << "No Grad to track for Op: " << type;
    }
    VLOG(6) << "Finish Trace Op: " << type;
361
  }
M
minqiyang 已提交
362 363
}

J
Jiabin Yang 已提交
364
template void Tracer::TraceOp<VarBase>(
365 366 367 368 369 370
    const std::string& type,
    const NameVarMap<VarBase>& ins,
    const NameVarMap<VarBase>& outs,
    framework::AttributeMap attrs,
    const platform::Place& place,
    bool trace_backward,
J
Jiabin Yang 已提交
371
    const std::map<std::string, std::string>& inplace_map,
372 373
    paddle::framework::AttributeMap* default_attrs,
    bool use_default_attr_map);
J
Jiabin Yang 已提交
374

375
template void Tracer::TraceOp<egr::EagerVariable>(
376 377 378 379 380 381
    const std::string& type,
    const NameVarMap<egr::EagerVariable>& ins,
    const NameVarMap<egr::EagerVariable>& outs,
    framework::AttributeMap attrs,
    const platform::Place& place,
    bool trace_backward,
J
Jiabin Yang 已提交
382
    const std::map<std::string, std::string>& inplace_map_,
383 384
    paddle::framework::AttributeMap* default_attrs,
    bool use_default_attr_map);
J
Jiabin Yang 已提交
385

386 387 388 389
void Tracer::TraceOp(const std::string& type,
                     const NameVarBaseMap& ins,
                     const NameVarBaseMap& outs,
                     framework::AttributeMap attrs,
390
                     const std::map<std::string, std::string>& inplace_map) {
391 392 393 394 395 396 397
  TraceOp<VarBase>(type,
                   ins,
                   outs,
                   std::move(attrs),
                   expected_place_,
                   has_grad_,
                   inplace_map);
J
Jiabin Yang 已提交
398 399
}

400 401
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
J
Jiabin Yang 已提交
402
                     const NameTensorMap& outs,
W
wanghuancoder 已提交
403
                     paddle::framework::AttributeMap& attrs,
J
Jiabin Yang 已提交
404 405
                     const paddle::platform::Place& place,
                     paddle::framework::AttributeMap* default_attrs,
406
                     bool use_default_attr_map,
J
Jiabin Yang 已提交
407
                     const std::map<std::string, std::string>& inplace_map) {
408 409
  VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
          << use_default_attr_map;
W
wanghuancoder 已提交
410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456
  std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;
  if (FLAGS_use_stride_kernel) {
    for (auto& iter : inplace_map) {
      auto inputs_iter = ins.find(iter.first);
      for (size_t i = 0; i < inputs_iter->second.size(); i++) {
        auto var = inputs_iter->second[i]->MutableVar();
        if (var->IsType<phi::DenseTensor>()) {
          auto dense_tensor = var->GetMutable<phi::DenseTensor>();
          if (!dense_tensor->meta().is_contiguous()) {
            NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
            auto outputs_iter = tmp_out->find(iter.second);
            outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
                egr::Controller::Instance().GenerateUniqueName());
            need_backup_inputs2outputs[dense_tensor] =
                outputs_iter->second[i]
                    ->MutableVar()
                    ->GetMutable<phi::DenseTensor>();
          }
        }
      }
    }

    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    place,
                                    false,
                                    {},
                                    default_attrs,
                                    use_default_attr_map);

    auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
    for (auto& iter : need_backup_inputs2outputs) {
      paddle::experimental::TransStride(dev_ctx, iter.second, iter.first);
    }
  } else {
    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    place,
                                    false,
                                    inplace_map,
                                    default_attrs,
                                    use_default_attr_map);
  }
W
wanghuancoder 已提交
457 458
}

459 460
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
W
wanghuancoder 已提交
461 462 463
                     const NameTensorMap& outs,
                     paddle::framework::AttributeMap attrs) {
  VLOG(6) << "Running On Eager TraceOp(4 agrs): ";
464 465
  TraceOpImpl<egr::EagerVariable>(
      type, ins, outs, attrs, expected_place_, false, {}, nullptr, true);
J
Jiabin Yang 已提交
466 467
}

468 469
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
J
Jiabin Yang 已提交
470
                     const NameTensorMap& outs,
W
wanghuancoder 已提交
471
                     paddle::framework::AttributeMap& attrs,
J
Jiabin Yang 已提交
472 473
                     const std::map<std::string, std::string>& inplace_map) {
  VLOG(6) << "Running On Eager TraceOp(less): ";
W
wanghuancoder 已提交
474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507

  std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;

  if (FLAGS_use_stride_kernel) {
    for (auto& iter : inplace_map) {
      auto inputs_iter = ins.find(iter.first);
      for (size_t i = 0; i < inputs_iter->second.size(); i++) {
        auto var = inputs_iter->second[i]->MutableVar();
        if (var->IsType<phi::DenseTensor>()) {
          auto dense_tensor = var->GetMutable<phi::DenseTensor>();
          if (!dense_tensor->meta().is_contiguous()) {
            NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
            auto outputs_iter = tmp_out->find(iter.second);
            outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
                egr::Controller::Instance().GenerateUniqueName());
            need_backup_inputs2outputs[dense_tensor] =
                outputs_iter->second[i]
                    ->MutableVar()
                    ->GetMutable<phi::DenseTensor>();
          }
        }
      }
    }
  } else {
    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    expected_place_,
                                    false,
                                    inplace_map,
                                    nullptr,
                                    true);
  }
508 509
}

W
WangXi 已提交
510 511 512 513
void Tracer::SetExpectedPlace(platform::Place place) {
  expected_place_ = place;
}

J
Jiabin Yang 已提交
514
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
515
                                 const NameVarBaseMap& outs,
J
Jiabin Yang 已提交
516
                                 bool trace_backward) {
517 518 519 520 521 522 523 524 525 526 527 528 529
  if (!trace_backward) return false;

  for (const auto& name_pair : ins) {
    for (const auto& var_base : name_pair.second) {
      if (!var_base->OverridedStopGradient()) {
        VLOG(6) << "Find out input: " << var_base->Name()
                << "'s GeneratedGrad is True";
        PassStopGradient(outs, var_base->OverridedStopGradient());
        return true;
      }
    }
  }
  return false;
M
minqiyang 已提交
530 531
}

J
Jiabin Yang 已提交
532 533 534 535 536 537
bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
                                 const NameTensorMap& outs,
                                 bool trace_backward) {
  return false;
}

538
phi::KernelSignature Tracer::GetExpectedKernelSignature(
539 540 541 542
    const std::string& type,
    const NameTensorMap& ins,
    const NameTensorMap& outs,
    framework::AttributeMap attrs) const {
543 544 545 546 547 548 549 550 551 552 553 554 555 556
  auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
  framework::RuntimeContext ctx({}, {});
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(phi::CPUPlace());
  const auto& op_info = op->Info();
  auto* attr_checker = op_info.Checker();
  if (attr_checker) {
    attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
  }
  static paddle::framework::AttributeMap empty_attrs_map = {};
  const paddle::framework::AttributeMap& default_attrs =
      attr_checker == nullptr ? empty_attrs_map
                              : attr_checker->GetDefaultAttrMap();
  auto dygraph_exe_ctx =
557
      imperative::DygraphExecutionContext<egr::EagerVariable>(
558 559 560 561 562 563 564
          *op,
          framework::Scope(),
          *dev_ctx,
          ctx,
          ins,
          outs,
          attrs,
565 566 567
          default_attrs);
  auto* opbase_with_kernel =
      dynamic_cast<framework::OperatorWithKernel*>(op.get());
568 569
  PADDLE_ENFORCE_NE(opbase_with_kernel,
                    nullptr,
570 571 572 573
                    platform::errors::InvalidArgument(
                        "This op type:`%s` is not a OperatorWithKernel, only "
                        "OperatorWithKernel can get KernelSignature",
                        type));
574 575 576 577 578 579
  if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
    return phi::KernelSignature(op->Type().c_str());
  } else {
    return phi::KernelSignature(std::move(
        opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
  }
580 581
}

M
minqiyang 已提交
582
}  // namespace imperative
583
}  // namespace paddle