tracer.cc 23.3 KB
Newer Older
J
Jiabin Yang 已提交
1
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
2 3 4 5 6 7 8 9 10 11 12 13 14
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/imperative/tracer.h"
15

16
#include <map>
H
hong 已提交
17
#include <set>
M
minqiyang 已提交
18
#include <unordered_set>
19
#include <utility>
20

W
wanghuancoder 已提交
21
#include "paddle/fluid/eager/api/utils/global_utils.h"
22
#include "paddle/fluid/framework/op_registry.h"
23
#include "paddle/fluid/imperative/amp_auto_cast.h"
24
#include "paddle/fluid/imperative/execution_context.h"
25
#include "paddle/fluid/imperative/layout_autotune.h"
26
#include "paddle/fluid/imperative/op_base.h"
27
#include "paddle/fluid/operators/ops_extra_info.h"
28
#include "paddle/fluid/platform/denormal.h"
29
#include "paddle/fluid/platform/device/device_wrapper.h"
C
chengduo 已提交
30
#include "paddle/fluid/platform/profiler.h"
31
#include "paddle/fluid/platform/profiler/event_tracing.h"
32
#include "paddle/fluid/string/string_helper.h"
W
wanghuancoder 已提交
33
#include "paddle/phi/api/lib/api_gen_utils.h"
34
#include "paddle/phi/common/place.h"
W
wanghuancoder 已提交
35
#include "paddle/phi/core/dense_tensor.h"
36
#include "paddle/phi/core/flags.h"
37

38 39 40
PHI_DECLARE_bool(use_mkldnn);
PHI_DECLARE_string(tracer_mkldnn_ops_on);
PHI_DECLARE_string(tracer_mkldnn_ops_off);
W
wanghuancoder 已提交
41
DECLARE_bool(use_stride_kernel);
42

43
namespace paddle {
M
minqiyang 已提交
44
namespace imperative {
45
thread_local std::string Tracer::python_stack_ = "";
M
minqiyang 已提交
46

47 48
thread_local bool Tracer::enable_program_desc_tracing_ = false;

Z
Zeng Jinle 已提交
49 50
thread_local bool Tracer::has_grad_ = true;

51 52
thread_local bool Tracer::use_promote_ = true;

53 54
thread_local bool Tracer::use_layout_autotune_ = false;

55 56
thread_local AmpLevel Tracer::amp_level_ = AmpLevel::O0;

57
thread_local phi::DataType Tracer::amp_dtype_ = phi::DataType::FLOAT32;
58

59 60 61 62 63 64 65 66 67
static std::shared_ptr<Tracer> g_current_tracer(nullptr);

const std::shared_ptr<Tracer>& GetCurrentTracer() { return g_current_tracer; }

void SetCurrentTracer(const std::shared_ptr<Tracer>& tracer) {
  g_current_tracer = tracer;
  VLOG(6) << "Set current tracer: " << g_current_tracer;
}

68
void PassStopGradient(const NameVarBaseMap& outs, bool generate_grad) {
69 70 71 72 73 74 75 76 77 78 79 80
  for (const auto& pair : outs) {
    for (const auto& var : pair.second) {
      // NOTE(zhiqiu): this happends when None output are passed from python
      // side. For example, fake_quantize_dequantize_moving_average_abs_max may
      // pass None OutAccum in eval mode.
      // It can be refined by generate several different pybind interface for
      // one operator with different function signature.
      if (var == nullptr) {
        VLOG(4) << pair.first << " is NULL";
        continue;
      }
      VLOG(6) << "Set output: " << var->Name() << "'s OverridedStopGradient as "
81
              << generate_grad;
82
      var->InnerSetOverridedStopGradient(generate_grad);
83 84 85 86
    }
  }
}

87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112
void IncreaseVarbaseReferenceCountUntilCopyComplete(
    const std::shared_ptr<imperative::VarBase>& var,
    const platform::Place& place) {
  // Note(zhiqiu): Follow the logic of TensorCopy to determine the place that we
  // need to add callback, see tensor_utils.cc:245
  auto place_ = platform::is_gpu_place(place) ? place : var->Place();

  auto tracer = imperative::GetCurrentTracer();
  auto gc = tracer->MutableGarbageCollectorIfNotExists(place_);

  // Note(zhiqiu): This is an empty callback, the only way is to "reference"
  // var, so it will not be destructed until the kernels launched at current
  // stream of given place is finished.
  auto callback = [var, place_]() {
    VLOG(4) << "Run callback of var:" << var->Name() << " at place " << place_;
  };

  gc->DirectClearCallback(callback);
}

paddle::framework::GarbageCollector* Tracer::MutableGarbageCollectorIfNotExists(
    const platform::Place& place) {
  // if not exists, create a new GarbageCollector at given place
  if (gcs_.count(place) == 0) {
    std::unique_ptr<framework::GarbageCollector> gc;
    if (platform::is_gpu_place(place)) {
Z
zhulei 已提交
113
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
114
      gc = std::make_unique<framework::DefaultStreamGarbageCollector>(place, 0);
115 116 117 118 119 120 121 122

      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CUDA device since it's not compiled with CUDA,"
          "Please recompile or reinstall Paddle with GPU support."));
#endif
    } else if (platform::is_cuda_pinned_place(place)) {
Z
zhulei 已提交
123
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
124
      gc = std::make_unique<framework::CUDAPinnedGarbageCollector>(place, 0);
125 126 127 128 129 130 131 132 133 134

      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CUDAPinned device since it's not compiled with "
          "CUDA,"
          "Please recompile or reinstall Paddle with GPU support."));
#endif
    } else if (platform::is_xpu_place(place)) {
#if defined(PADDLE_WITH_XPU)
135
      gc = std::make_unique<framework::XPUGarbageCollector>(place, 0);
136 137 138 139 140 141 142
      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use XPU device since it's not compiled with XPU,"
          "Please recompile or reinstall Paddle with XPU support."));
#endif
    } else if (platform::is_cpu_place(place)) {
143
      gc = std::make_unique<framework::CPUGarbageCollector>(place, 0);
144
      VLOG(10) << "Created GarbageCollector at " << place;
145 146
    } else if (platform::is_ipu_place(place)) {
#if defined(PADDLE_WITH_IPU)
147
      gc = std::make_unique<framework::IPUGarbageCollector>(place, 0);
148 149 150 151 152
      VLOG(10) << "Created GarbageCollector at " << place;
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use IPU device since it's not compiled with IPU,"
          "Please recompile or reinstall Paddle with IPU support."));
153 154 155
#endif
    } else if (platform::is_custom_place(place)) {
#if defined(PADDLE_WITH_CUSTOM_DEVICE)
156
      if (framework::IsFastEagerDeletionModeEnabled()) {
157 158 159
        gc =
            std::make_unique<framework::CustomDeviceUnsafeFastGarbageCollector>(
                place, 0);
160 161
        VLOG(10) << "Created UnsafeFastGarbageCollector at " << place;
      } else {
162 163
        gc = std::make_unique<framework::CustomDefaultStreamGarbageCollector>(
            place, 0);
164 165
        VLOG(10) << "Created GarbageCollector at " << place;
      }
166 167 168 169 170 171
#else
      PADDLE_THROW(platform::errors::PermissionDenied(
          "Paddle can't use CustomDevice since it's not compiled with "
          "CustomDevice,"
          "Please recompile or reinstall Paddle with CustomDevice "
          "support."));
172
#endif
173 174 175 176 177 178 179 180 181 182
    } else {
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "Unsupported place for garbage collection"));
    }
    gcs_.emplace(place, std::move(gc));
  }

  return gcs_.at(place).get();
}

J
Jiabin Yang 已提交
183
template <typename VarType>
184 185
void Tracer::TraceOp(const std::string& type,
                     const NameVarMap<VarType>& ins,
J
Jiabin Yang 已提交
186 187
                     const NameVarMap<VarType>& outs,
                     framework::AttributeMap attrs,
188 189
                     const platform::Place& place,
                     bool trace_backward,
J
Jiabin Yang 已提交
190 191
                     const std::map<std::string, std::string>& inplace_map,
                     paddle::framework::AttributeMap* passed_default_attrs_,
192
                     bool use_default_attr_map) {
193 194 195 196 197 198 199 200
  TraceOpImpl<VarType>(type,
                       ins,
                       outs,
                       attrs,
                       place,
                       trace_backward,
                       inplace_map,
                       passed_default_attrs_,
W
wanghuancoder 已提交
201 202 203 204 205 206 207 208
                       use_default_attr_map);
}

template <typename VarType>
void Tracer::TraceOpImpl(const std::string& type,
                         const NameVarMap<VarType>& ins,
                         const NameVarMap<VarType>& outs,
                         framework::AttributeMap& attrs,
209 210
                         const platform::Place& place,
                         bool trace_backward,
W
wanghuancoder 已提交
211 212 213
                         const std::map<std::string, std::string>& inplace_map,
                         paddle::framework::AttributeMap* passed_default_attrs_,
                         bool use_default_attr_map) {
214
  platform::RecordEvent op_type_record_event(
215
      type, platform::TracerEventType::Operator, 1);
216
  platform::ScopedFlushDenormal flush;
L
Leo Chen 已提交
217
  VLOG(4) << "Trace Op: " << type;
218
  if (FLAGS_use_mkldnn) {
219 220 221 222 223 224 225 226 227 228 229
    // if both lists are empty all ops are enabled (default for
    // FLAGS_use_mkldnn=1)
    // if ops_on list is not empty only ops from that list are enabled
    if (!FLAGS_tracer_mkldnn_ops_on.empty()) {
      auto is_on = FLAGS_tracer_mkldnn_ops_on.find(type) != std::string::npos;
      attrs["use_mkldnn"] = is_on;
    } else {
      // if ops_on list is empty all ops are enabled except types from off_list
      auto is_off = FLAGS_tracer_mkldnn_ops_off.find(type) != std::string::npos;
      attrs["use_mkldnn"] = !is_off;
    }
230
  }
W
wanghuancoder 已提交
231

232 233 234 235
  auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
  const auto& op_info = op->Info();
  auto* attr_checker = op_info.Checker();
  if (attr_checker) {
236
    attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
237
  }
W
wanghuancoder 已提交
238

239 240 241 242 243
  const auto& extra_attr_checkers =
      operators::ExtraInfoUtils::Instance().GetExtraAttrsChecker(type);
  for (const auto& checker : extra_attr_checkers) {
    checker(&attrs, true);
  }
244

245 246 247 248 249
  static paddle::framework::AttributeMap empty_attrs_map = {};
  const paddle::framework::AttributeMap& default_attrs =
      attr_checker == nullptr ? empty_attrs_map
                              : attr_checker->GetDefaultAttrMap();

Z
zyfncg 已提交
250
  std::unique_ptr<NameVarMap<VarType>> ins_amp = nullptr;
L
Leo Chen 已提交
251
  if (amp_level_ == AmpLevel::O1) {
252
    if (amp_dtype_ == phi::DataType::FLOAT16) {
253
      VLOG(5) << "Float16 Auto Mixed Precision O1 run operator: " << type;
Z
zyfncg 已提交
254
      ins_amp = std::make_unique<NameVarMap<VarType>>(
255
          AutoCastInputs<VarType>(type, ins));
256
    } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
257
      VLOG(5) << "BFloat16 Auto Mixed Precision O1 run operator: " << type;
Z
zyfncg 已提交
258 259
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          AutoCastBF16Inputs<VarType>(type, ins));
260
    }
L
Leo Chen 已提交
261
  } else if (amp_level_ == AmpLevel::O2) {
262
    if (amp_dtype_ == phi::DataType::FLOAT16) {
263
      VLOG(5) << "Float16 Auto Mixed Precision O2 run operator: " << type;
264 265
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          CastPureFp16Inputs<VarType>(type, ins));
266
    } else if (amp_dtype_ == phi::DataType::BFLOAT16) {
267
      VLOG(5) << "BFloat16 Auto Mixed Precision O2 run operator: " << type;
Z
zyfncg 已提交
268 269
      ins_amp = std::make_unique<NameVarMap<VarType>>(
          CastPureBf16Inputs<VarType>(type, ins));
270
    }
271
  }
272 273 274 275 276

  if (platform::is_gpu_place(place)) {
    const auto& new_tmp = ins_amp == nullptr ? ins : *ins_amp;
    const auto& tracer = imperative::GetCurrentTracer();
    ins_amp = std::make_unique<NameVarMap<VarType>>(
277 278
        imperative::AutoTuneLayout<VarType>(
            type, new_tmp, outs, &attrs, tracer));
279 280
  }

Z
zyfncg 已提交
281
  const auto& new_ins = ins_amp == nullptr ? ins : *ins_amp;
282

283
  try {
284 285
    if (platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
286
      platform::SetDeviceId(place.device);
287 288 289 290 291 292
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with GPU if use CUDAPlace."));
#endif
    } else if (platform::is_xpu_place(place)) {
#ifdef PADDLE_WITH_XPU
293
      platform::SetXPUDeviceId(place.device);
294 295 296
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with XPU if use XPUPlace."));
H
houj04 已提交
297
#endif
298 299
    } else if (platform::is_custom_place(place)) {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
300
      phi::DeviceManager::SetDevice(place);
301 302 303 304
#else
      PADDLE_THROW(platform::errors::PreconditionNotMet(
          "PaddlePaddle should compile with CustomDevice if use "
          "CustomPlace."));
305 306
#endif
    }
W
wanghuancoder 已提交
307

308
    if (!use_default_attr_map) {
J
Jiabin Yang 已提交
309 310 311 312 313 314 315 316 317 318 319 320 321
      PADDLE_ENFORCE_NOT_NULL(passed_default_attrs_,
                              paddle::platform::errors::PermissionDenied(
                                  "Detected default_attrs = nullptr."));
      VLOG(6) << "Use passed in default attrs";
      OpBase::Run(*op, new_ins, outs, attrs, (*passed_default_attrs_), place);
    } else {
      VLOG(6) << "Use Checker's default attrs";
      if (passed_default_attrs_) {
        // TODO(jiabin): Update this without copy
        *passed_default_attrs_ = default_attrs;
      }
      OpBase::Run(*op, new_ins, outs, attrs, default_attrs, place);
    }
322 323 324 325
  } catch (platform::EnforceNotMet& exception) {
    framework::AppendErrorOpHint(type, &exception);
    throw std::move(exception);
  } catch (std::exception& ex) {
326 327 328 329 330 331
    PADDLE_THROW(
        platform::errors::Fatal("Operator %s raises an %s exception.\n"
                                "The exception content is\n:%s.",
                                type,
                                platform::demangle(typeid(ex).name()),
                                ex.what()));
332 333 334 335 336 337 338
  } catch (...) {
    // NOTE: this branch represents a very serious bug with
    // low probability of occurrence, and we can't get its
    // exception content here.
    PADDLE_THROW(platform::errors::Fatal(
        "Operator %s raises an unknown exception.", type));
  }
J
Jiabin Yang 已提交
339

340 341
  if (enable_program_desc_tracing_) {
    VLOG(5) << "Trace op " << type << " into ProgramDesc";
342
    program_desc_tracer_->InsertOp(type, new_ins, outs, attrs);
343 344
  }

345 346
  {
    platform::RecordEvent node_creation_record_event(
347
        "grad_node_creation", platform::TracerEventType::OperatorInner, 1);
348 349 350

    if (ComputeRequiredGrad(new_ins, outs, trace_backward)) {
      PADDLE_ENFORCE_EQ(
351 352
          passed_default_attrs_,
          nullptr,
353 354 355 356
          paddle::platform::errors::PermissionDenied(
              "We expect passed_default_attrs_ is nullptr while "
              "use_default_attr_map is true, however we got not null "
              "passed_default_attrs_. Please check your usage of trace_op. "));
357 358
      CreateGradOpNode(
          *op, new_ins, outs, attrs, default_attrs, place, inplace_map);
359 360 361 362
    } else {
      VLOG(3) << "No Grad to track for Op: " << type;
    }
    VLOG(6) << "Finish Trace Op: " << type;
363
  }
M
minqiyang 已提交
364 365
}

J
Jiabin Yang 已提交
366
template void Tracer::TraceOp<VarBase>(
367 368 369 370 371 372
    const std::string& type,
    const NameVarMap<VarBase>& ins,
    const NameVarMap<VarBase>& outs,
    framework::AttributeMap attrs,
    const platform::Place& place,
    bool trace_backward,
J
Jiabin Yang 已提交
373
    const std::map<std::string, std::string>& inplace_map,
374 375
    paddle::framework::AttributeMap* default_attrs,
    bool use_default_attr_map);
J
Jiabin Yang 已提交
376

377
template void Tracer::TraceOp<egr::EagerVariable>(
378 379 380 381 382 383
    const std::string& type,
    const NameVarMap<egr::EagerVariable>& ins,
    const NameVarMap<egr::EagerVariable>& outs,
    framework::AttributeMap attrs,
    const platform::Place& place,
    bool trace_backward,
J
Jiabin Yang 已提交
384
    const std::map<std::string, std::string>& inplace_map_,
385 386
    paddle::framework::AttributeMap* default_attrs,
    bool use_default_attr_map);
J
Jiabin Yang 已提交
387

388 389 390 391
void Tracer::TraceOp(const std::string& type,
                     const NameVarBaseMap& ins,
                     const NameVarBaseMap& outs,
                     framework::AttributeMap attrs,
392
                     const std::map<std::string, std::string>& inplace_map) {
393 394 395 396 397 398 399
  TraceOp<VarBase>(type,
                   ins,
                   outs,
                   std::move(attrs),
                   expected_place_,
                   has_grad_,
                   inplace_map);
J
Jiabin Yang 已提交
400 401
}

402 403
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
J
Jiabin Yang 已提交
404
                     const NameTensorMap& outs,
W
wanghuancoder 已提交
405
                     paddle::framework::AttributeMap& attrs,
J
Jiabin Yang 已提交
406 407
                     const paddle::platform::Place& place,
                     paddle::framework::AttributeMap* default_attrs,
408
                     bool use_default_attr_map,
J
Jiabin Yang 已提交
409
                     const std::map<std::string, std::string>& inplace_map) {
410 411
  VLOG(6) << "Running On Eager TraceOp with use_default_attr_map: "
          << use_default_attr_map;
W
wanghuancoder 已提交
412
  std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;
413 414 415
  std::map<phi::DenseTensor*, std::shared_ptr<phi::Allocation>>
      need_backup_inputs2holder;
  std::map<phi::DenseTensor*, phi::DDim> need_backup_inputs2strides;
W
wanghuancoder 已提交
416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431
  if (FLAGS_use_stride_kernel) {
    for (auto& iter : inplace_map) {
      auto inputs_iter = ins.find(iter.first);
      for (size_t i = 0; i < inputs_iter->second.size(); i++) {
        auto var = inputs_iter->second[i]->MutableVar();
        if (var->IsType<phi::DenseTensor>()) {
          auto dense_tensor = var->GetMutable<phi::DenseTensor>();
          if (!dense_tensor->meta().is_contiguous()) {
            NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
            auto outputs_iter = tmp_out->find(iter.second);
            outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
                egr::Controller::Instance().GenerateUniqueName());
            need_backup_inputs2outputs[dense_tensor] =
                outputs_iter->second[i]
                    ->MutableVar()
                    ->GetMutable<phi::DenseTensor>();
432 433
            need_backup_inputs2holder[dense_tensor] = dense_tensor->Holder();
            need_backup_inputs2strides[dense_tensor] = dense_tensor->strides();
W
wanghuancoder 已提交
434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449
          }
        }
      }
    }
    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    place,
                                    false,
                                    {},
                                    default_attrs,
                                    use_default_attr_map);

    auto dev_ctx = paddle::platform::DeviceContextPool::Instance().Get(place);
    for (auto& iter : need_backup_inputs2outputs) {
450 451 452 453 454
      iter.first->ResetHolder(need_backup_inputs2holder[iter.first]);
      iter.first->set_strides(need_backup_inputs2strides[iter.first]);
      paddle::experimental::TransStrideLegacy(dev_ctx, iter.second, iter.first);
      iter.second->ResetHolder(need_backup_inputs2holder[iter.first]);
      iter.second->set_strides(need_backup_inputs2strides[iter.first]);
W
wanghuancoder 已提交
455 456 457 458 459 460 461 462 463 464 465 466
    }
  } else {
    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    place,
                                    false,
                                    inplace_map,
                                    default_attrs,
                                    use_default_attr_map);
  }
W
wanghuancoder 已提交
467 468
}

469 470
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
W
wanghuancoder 已提交
471 472 473
                     const NameTensorMap& outs,
                     paddle::framework::AttributeMap attrs) {
  VLOG(6) << "Running On Eager TraceOp(4 agrs): ";
474 475
  TraceOpImpl<egr::EagerVariable>(
      type, ins, outs, attrs, expected_place_, false, {}, nullptr, true);
J
Jiabin Yang 已提交
476 477
}

478 479
void Tracer::TraceOp(const std::string& type,
                     const NameTensorMap& ins,
J
Jiabin Yang 已提交
480
                     const NameTensorMap& outs,
W
wanghuancoder 已提交
481
                     paddle::framework::AttributeMap& attrs,
J
Jiabin Yang 已提交
482 483
                     const std::map<std::string, std::string>& inplace_map) {
  VLOG(6) << "Running On Eager TraceOp(less): ";
W
wanghuancoder 已提交
484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517

  std::map<phi::DenseTensor*, phi::DenseTensor*> need_backup_inputs2outputs;

  if (FLAGS_use_stride_kernel) {
    for (auto& iter : inplace_map) {
      auto inputs_iter = ins.find(iter.first);
      for (size_t i = 0; i < inputs_iter->second.size(); i++) {
        auto var = inputs_iter->second[i]->MutableVar();
        if (var->IsType<phi::DenseTensor>()) {
          auto dense_tensor = var->GetMutable<phi::DenseTensor>();
          if (!dense_tensor->meta().is_contiguous()) {
            NameTensorMap* tmp_out = const_cast<NameTensorMap*>(&outs);
            auto outputs_iter = tmp_out->find(iter.second);
            outputs_iter->second[i] = std::make_shared<egr::EagerVariable>(
                egr::Controller::Instance().GenerateUniqueName());
            need_backup_inputs2outputs[dense_tensor] =
                outputs_iter->second[i]
                    ->MutableVar()
                    ->GetMutable<phi::DenseTensor>();
          }
        }
      }
    }
  } else {
    TraceOpImpl<egr::EagerVariable>(type,
                                    ins,
                                    outs,
                                    attrs,
                                    expected_place_,
                                    false,
                                    inplace_map,
                                    nullptr,
                                    true);
  }
518 519
}

W
WangXi 已提交
520 521 522 523
void Tracer::SetExpectedPlace(platform::Place place) {
  expected_place_ = place;
}

J
Jiabin Yang 已提交
524
bool Tracer::ComputeRequiredGrad(const NameVarBaseMap& ins,
525
                                 const NameVarBaseMap& outs,
J
Jiabin Yang 已提交
526
                                 bool trace_backward) {
527 528 529 530 531 532 533 534 535 536 537 538 539
  if (!trace_backward) return false;

  for (const auto& name_pair : ins) {
    for (const auto& var_base : name_pair.second) {
      if (!var_base->OverridedStopGradient()) {
        VLOG(6) << "Find out input: " << var_base->Name()
                << "'s GeneratedGrad is True";
        PassStopGradient(outs, var_base->OverridedStopGradient());
        return true;
      }
    }
  }
  return false;
M
minqiyang 已提交
540 541
}

J
Jiabin Yang 已提交
542 543 544 545 546 547
bool Tracer::ComputeRequiredGrad(const NameTensorMap& ins,
                                 const NameTensorMap& outs,
                                 bool trace_backward) {
  return false;
}

548
phi::KernelSignature Tracer::GetExpectedKernelSignature(
549 550 551 552
    const std::string& type,
    const NameTensorMap& ins,
    const NameTensorMap& outs,
    framework::AttributeMap attrs) const {
553 554 555 556 557 558 559 560 561 562 563 564 565 566
  auto op = framework::OpRegistry::CreateOp(type, {}, {}, {}, false);
  framework::RuntimeContext ctx({}, {});
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
  auto* dev_ctx = pool.Get(phi::CPUPlace());
  const auto& op_info = op->Info();
  auto* attr_checker = op_info.Checker();
  if (attr_checker) {
    attr_checker->Check(&attrs, true, /*only_check_exist_value=*/true);
  }
  static paddle::framework::AttributeMap empty_attrs_map = {};
  const paddle::framework::AttributeMap& default_attrs =
      attr_checker == nullptr ? empty_attrs_map
                              : attr_checker->GetDefaultAttrMap();
  auto dygraph_exe_ctx =
567
      imperative::DygraphExecutionContext<egr::EagerVariable>(
568 569 570 571 572 573 574
          *op,
          framework::Scope(),
          *dev_ctx,
          ctx,
          ins,
          outs,
          attrs,
575 576 577
          default_attrs);
  auto* opbase_with_kernel =
      dynamic_cast<framework::OperatorWithKernel*>(op.get());
578 579
  PADDLE_ENFORCE_NE(opbase_with_kernel,
                    nullptr,
580 581 582 583
                    platform::errors::InvalidArgument(
                        "This op type:`%s` is not a OperatorWithKernel, only "
                        "OperatorWithKernel can get KernelSignature",
                        type));
584 585 586 587 588 589
  if (phi::KernelFactory::Instance().HasStructuredKernel(type)) {
    return phi::KernelSignature(op->Type().c_str());
  } else {
    return phi::KernelSignature(std::move(
        opbase_with_kernel->GetExpectedPhiKernelArgs(dygraph_exe_ctx)));
  }
590 591
}

M
minqiyang 已提交
592
}  // namespace imperative
593
}  // namespace paddle