“1010e39bdf738029fcb78b0d388a91dfdebdda2f”上不存在“paddle/phi/kernels/one_hot_kernel.h”
operator.cc 29.8 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14 15 16
#define GLOG_NO_ABBREVIATED_SEVERITIES
#define GOOGLE_GLOG_DLL_DECL

17 18
#include <gflags/gflags.h>
#include <glog/logging.h>
19

20
#include <algorithm>
21

Y
Yi Wang 已提交
22 23
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
24
#include "paddle/fluid/framework/lod_tensor.h"
25
#include "paddle/fluid/framework/operator.h"
Y
Yi Wang 已提交
26 27
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
28
#include "paddle/fluid/platform/profiler.h"
Q
Qiao Longfei 已提交
29

D
dzhwinter 已提交
30
DECLARE_bool(benchmark);
C
chengduoZH 已提交
31 32 33
DEFINE_bool(check_nan_inf, false,
            "Checking whether operator produce NAN/INF or not. It will be "
            "extremely slow so please use this flag wisely.");
D
dzhwinter 已提交
34

Q
Qiao Longfei 已提交
35 36 37
namespace paddle {
namespace framework {

38 39 40 41 42
// Combine two hash values to a single hash.
inline size_t CombineHash(size_t seed, size_t a) {
  return (seed ^ a) + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

43 44 45 46 47 48
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN),
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain),
    std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN),
    std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
D
dzhwinter 已提交
49

Q
qiaolongfei 已提交
50 51 52 53 54 55 56 57 58 59 60
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
  if (var->IsType<framework::LoDTensor>()) {
    return framework::ToDataType(var->Get<framework::LoDTensor>().type());
  } else if (var->IsType<framework::SelectedRows>()) {
    return framework::ToDataType(
        var->Get<framework::SelectedRows>().value().type());
  } else {
    PADDLE_THROW("Var should be LoDTensor or SelectedRows");
  }
}

61 62
static DDim GetDims(const Scope& scope, const std::string& name,
                    bool get_actual_dim = false) {
63
  Variable* var = scope.FindVar(name);
Q
qiaolongfei 已提交
64 65
  if (var == nullptr) {
    return DDim({-1});
Q
Qiao Longfei 已提交
66 67
  }

M
minqiyang 已提交
68 69
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
M
minqiyang 已提交
70
    if (UNLIKELY(!tensor.IsInitialized())) {
71
      return DDim({-1});
72
    }
M
minqiyang 已提交
73 74 75 76 77 78 79
    return tensor.dims();
  } else if (var->IsType<SelectedRows>()) {
    if (get_actual_dim) {
      return var->Get<SelectedRows>().value().dims();
    } else {
      return var->Get<SelectedRows>().GetCompleteDims();
    }
80 81 82 83 84
  } else {
    return DDim({-1});
  }
}

Q
Qiao Longfei 已提交
85 86 87 88 89 90
static bool VarInited(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) return false;
  return var->IsInitialized();
}

D
dzhwinter 已提交
91 92 93 94 95
static std::string GetDtype(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) {
    return "";
  }
96

M
minqiyang 已提交
97 98 99
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
    if (UNLIKELY(!tensor.IsInitialized())) {
100 101
      return "";
    }
M
minqiyang 已提交
102 103
    return DataTypeToString(ToDataType(tensor.type()));
  } else if (var->IsType<SelectedRows>()) {
Q
Qiao Longfei 已提交
104 105 106 107 108 109
    auto tensor = var->Get<SelectedRows>().value();
    if (UNLIKELY(!tensor.IsInitialized())) {
      return "uninited";
    } else {
      return DataTypeToString(ToDataType(tensor.type()));
    }
D
dzhwinter 已提交
110 111 112 113 114
  } else {
    return "";
  }
}

115 116 117 118 119 120
static int GetRowSize(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) {
    return -1;
  }

M
minqiyang 已提交
121 122
  if (var->IsType<SelectedRows>()) {
    return var->Get<SelectedRows>().rows().size();
123 124 125 126 127
  }

  return -1;
}

Q
Qiao Longfei 已提交
128 129 130 131 132 133 134 135
static LoD GetLoD(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  auto default_lod = LoD({{}});

  if (var == nullptr) {
    return default_lod;
  }

M
minqiyang 已提交
136 137 138
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
    if (UNLIKELY(!tensor.IsInitialized())) {
139 140
      return default_lod;
    }
M
minqiyang 已提交
141
    return tensor.lod();
Q
Qiao Longfei 已提交
142 143 144 145 146
  } else {
    return default_lod;
  }
}

147
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
148
  VLOG(40) << place << " " << DebugStringEx(&scope);
149
  if (platform::is_gpu_place(place)) {
150
#ifndef PADDLE_WITH_CUDA
151
    PADDLE_THROW("Cannot run operator on place %s", place);
152
#else
153 154
    auto dev_id = boost::get<platform::CUDAPlace>(place).device;
    platform::SetDeviceId(dev_id);
155 156
#endif
  }
157

P
peizhilin 已提交
158 159 160
// The profile has a process-wide mutex, results in serious performance issue
// in concurrency scenerio. Here use an `if` to fix this issue.
// Please not remove the `if`, ask @Superjomn if there are any concern.
P
peizhilin 已提交
161
#ifndef _WIN32
162 163 164 165
  if (platform::IsProfileEnabled()) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::RecordEvent record_event(Type(), pool.Get(place));
    RunImpl(scope, place);
P
peizhilin 已提交
166 167 168
  } else
#endif
  {
169 170
    RunImpl(scope, place);
  }
171
  VLOG(30) << place << " " << DebugStringEx(&scope);
172 173
}

174 175 176 177 178 179 180 181
bool OperatorBase::HasInputs(const std::string& name) const {
  if (inputs_.find(name) != inputs_.end()) {
    return true;
  } else {
    return false;
  }
}

182
std::string OperatorBase::Input(const std::string& name) const {
Y
Yu Yang 已提交
183
  auto& ins = Inputs(name);
Y
Yu Yang 已提交
184
  PADDLE_ENFORCE_LE(ins.size(), 1UL,
185 186
                    "Operator %s's input %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
187
  return ins.empty() ? kEmptyVarName : ins[0];
Y
Yan Chunwei 已提交
188 189
}

Y
Yu Yang 已提交
190 191
const std::vector<std::string>& OperatorBase::Inputs(
    const std::string& name) const {
Y
Yu Yang 已提交
192
  auto it = inputs_.find(name);
193 194
  PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.",
                 type_, name);
Y
Yu Yang 已提交
195
  return it->second;
Y
Yan Chunwei 已提交
196 197
}

198
bool OperatorBase::HasOutputs(const std::string& name) const {
199
  if (outputs_.find(name) != outputs_.end()) {
200 201 202 203 204 205
    return true;
  } else {
    return false;
  }
}

206
std::string OperatorBase::Output(const std::string& name) const {
Y
Yu Yang 已提交
207
  auto& outs = Outputs(name);
Y
Yu Yang 已提交
208
  PADDLE_ENFORCE_LE(outs.size(), 1UL,
209 210
                    "Operator %s's output %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
211
  return outs.empty() ? kEmptyVarName : outs[0];
Y
Yan Chunwei 已提交
212 213
}

Y
Yu Yang 已提交
214 215
const std::vector<std::string>& OperatorBase::Outputs(
    const std::string& name) const {
Y
Yu Yang 已提交
216
  auto it = outputs_.find(name);
217 218
  PADDLE_ENFORCE(it != outputs_.end(),
                 "Operator %s does not have an output called %s.", type_, name);
Y
Yu Yang 已提交
219
  return it->second;
Y
Yan Chunwei 已提交
220 221
}

222
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
Q
Qiao Longfei 已提交
223
  std::stringstream ss;
Y
Yu Yang 已提交
224
  ss << "Op(" << type_ << "), inputs:{";
Y
Yu Yang 已提交
225 226
  for (auto it = inputs_.begin(); it != inputs_.end();) {
    auto& input = *it;
Y
Yu Yang 已提交
227 228
    ss << input.first << "[";
    for (size_t i = 0; i < input.second.size(); ++i) {
Q
Qiao Longfei 已提交
229 230
      auto var_name = input.second[i];
      ss << var_name;
231
      if (scope) {
Q
Qiao Longfei 已提交
232 233 234 235 236 237 238 239 240 241 242
        if (!VarInited(*scope, var_name)) {
          ss << "[uninited]";
        } else {
          int row_size = GetRowSize(*scope, var_name);
          if (row_size >= 0) {
            ss << "[row_size=" << row_size << "]";
          }
          std::string dtype = GetDtype(*scope, var_name);
          ss << ":" << dtype;
          ss << "[" << GetDims(*scope, var_name, true) << "]";
          ss << "(" << GetLoD(*scope, var_name) << ")";
243
        }
244
      }
Y
Yu Yang 已提交
245 246 247
      if (i != input.second.size() - 1) {
        ss << ", ";
      }
248
    }
Y
Yu Yang 已提交
249
    ss << "]";
Y
Yu Yang 已提交
250 251
    ++it;
    if (it != inputs_.end()) {
252 253
      ss << ", ";
    }
Q
Qiao Longfei 已提交
254
  }
Y
Yu Yang 已提交
255
  ss << "}, outputs:{";
Y
Yu Yang 已提交
256 257
  for (auto it = outputs_.begin(); it != outputs_.end();) {
    auto& output = *it;
Y
Yu Yang 已提交
258 259
    ss << output.first << "[";
    for (size_t i = 0; i < output.second.size(); ++i) {
Q
Qiao Longfei 已提交
260 261
      auto var_name = output.second[i];
      ss << var_name;
262
      if (scope) {
Q
Qiao Longfei 已提交
263 264 265 266 267 268 269
        if (!VarInited(*scope, var_name)) {
          ss << "[uninited]";
        } else {
          int row_size = GetRowSize(*scope, output.second[i]);
          if (row_size >= 0) {
            ss << "[row_size=" << row_size << "]";
          }
C
chengduo 已提交
270 271
          std::string dtype = GetDtype(*scope, output.second[i]);
          ss << ":" << dtype;
Q
Qiao Longfei 已提交
272 273
          ss << "[" << GetDims(*scope, var_name, true) << "]";
          ss << "(" << GetLoD(*scope, var_name) << ")";
274
        }
275
      }
Y
Yu Yang 已提交
276 277 278
      if (i != output.second.size() - 1) {
        ss << ", ";
      }
279
    }
Y
Yu Yang 已提交
280
    ss << "]";
Y
Yu Yang 已提交
281 282
    ++it;
    if (it != outputs_.end()) {
283 284
      ss << ", ";
    }
Q
Qiao Longfei 已提交
285
  }
Y
Yu Yang 已提交
286
  ss << "}.";
Q
Qiao Longfei 已提交
287 288 289
  return ss.str();
}

Y
Yu Yang 已提交
290
OperatorBase::OperatorBase(const std::string& type,
Y
Yu Yang 已提交
291 292
                           const VariableNameMap& inputs,
                           const VariableNameMap& outputs,
Y
Yu Yang 已提交
293 294
                           const AttributeMap& attrs)
    : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
295 296
  GenerateTemporaryNames();
  CheckAllInputOutputSet();
Y
Yu Yang 已提交
297
}
298

Q
qijun 已提交
299 300
std::vector<std::string> OperatorBase::InputVars() const {
  std::vector<std::string> ret_val;
Y
Yu Yang 已提交
301
  for (auto& o : inputs_) {
Q
qijun 已提交
302 303 304 305 306 307
    ret_val.reserve(ret_val.size() + o.second.size());
    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
  }
  return ret_val;
}

Y
Yu Yang 已提交
308 309 310 311 312 313 314 315 316 317
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
  std::vector<std::string> ret_val;
  if (has_intermediate) {
    // push all outputs into ret_val
    for (auto& o : outputs_) {
      ret_val.reserve(ret_val.size() + o.second.size());
      ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
    }
    return ret_val;
  }
Y
Yu Yang 已提交
318
  auto& info = OpInfoMap::Instance().Get(Type());
Y
Yu Yang 已提交
319 320

  // get all OpProto::Var for outputs
Y
Yu Yang 已提交
321
  for (auto& o : info.Proto().outputs()) {
Y
Yu Yang 已提交
322 323 324 325 326 327 328 329 330
    // ignore all intermediate output
    if (o.intermediate()) continue;
    auto out = outputs_.find(o.name());
    if (out != outputs_.end()) {
      ret_val.reserve(ret_val.size() + out->second.size());
      ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
    }
  }
  return ret_val;
D
dongzhihong 已提交
331 332
}

333 334 335
void OperatorBase::CheckAllInputOutputSet() const {
  auto& info_map = OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(Type());
Y
Yu Yang 已提交
336
  if (op_info == nullptr || op_info->proto_ == nullptr) return;
337 338

  for (auto& in : op_info->Proto().inputs()) {
339 340 341 342
    if (!in.dispensable()) {
      PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
                     "Operator %s's input, %s, is not set", Type(), in.name());
    }
343 344 345
  }

  for (auto& out : op_info->Proto().outputs()) {
346 347 348 349 350
    if (!out.dispensable()) {
      PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
                     "Operator %s's output, %s, is not set", Type(),
                     out.name());
    }
351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366
  }
}

void OperatorBase::GenerateTemporaryNames() {
  static std::atomic<size_t> gUniqId(0UL);
  for (auto& output : outputs_) {
    for (auto& output_name : output.second) {
      if (output_name == kTempVarName) {
        output_name += type_;
        output_name += "@";
        output_name += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
}

C
chengduo 已提交
367 368
static bool VarIsTensor(const Variable& var) {
  return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
369 370
}

C
chengduo 已提交
371
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
C
chengduo 已提交
372 373 374 375
  if (var.IsType<LoDTensor>()) {
    return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
  } else if (var.IsType<SelectedRows>()) {
    return &(var.Get<SelectedRows>().value());
Q
QI JUN 已提交
376
  } else {
Y
Yang Yang 已提交
377
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
C
chengduo 已提交
378
                 var.Type().name());
Q
QI JUN 已提交
379 380 381
  }
}

C
chengduo 已提交
382
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
Q
QI JUN 已提交
383
  if (var->IsType<LoDTensor>()) {
384
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
385
  } else if (var->IsType<SelectedRows>()) {
386
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
387
  } else {
Y
Yang Yang 已提交
388 389
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
                 var->Type().name());
Q
QI JUN 已提交
390 391 392
  }
}

393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424
bool ExecutionContext::HasInput(const std::string& name) const {
  if (!op_.HasInputs(name)) {
    return false;
  }
  auto& ins = Inputs(name);
  size_t length = ins.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Input %s should not have more than one inputs", name);
  auto arg = ins[0];
  auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
  return var != nullptr;
}

bool ExecutionContext::HasOutput(const std::string& name) const {
  if (!op_.HasOutputs(name)) {
    return false;
  }
  auto& outs = Outputs(name);
  size_t length = outs.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Output %s should not have more than one inputs", name);
  auto arg = outs[0];
  auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
  return var != nullptr;
}

425
template <>
426
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
C
chengduo 已提交
427
  return Input<LoDTensor>(name);
428 429 430
}

template <>
431
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
432 433 434 435
    const std::string& name) const {
  auto names = op().Inputs(name);
  std::vector<const Tensor*> res;
  res.reserve(names.size());
436
  std::transform(names.begin(), names.end(), std::back_inserter(res),
C
chengduo 已提交
437
                 [&](const std::string& sub_name) -> const Tensor* {
438
                   auto var = scope_.FindVar(sub_name);
C
chengduo 已提交
439 440 441 442 443 444
                   if (var == nullptr) return nullptr;
                   PADDLE_ENFORCE(
                       var->IsType<LoDTensor>(),
                       "%s should be LoDTensor, but the received type is %s",
                       sub_name, var->Type().name());
                   return &(var->Get<LoDTensor>());
445
                 });
446 447 448 449
  return res;
}

template <>
450
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
C
chengduo 已提交
451
  return Output<LoDTensor>(name);
452 453 454
}

template <>
455
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
456 457 458 459
    const std::string& name) const {
  auto names = op().Outputs(name);
  std::vector<Tensor*> res;
  res.reserve(names.size());
460
  std::transform(names.begin(), names.end(), std::back_inserter(res),
C
chengduo 已提交
461
                 [&](const std::string& sub_name) -> Tensor* {
462
                   auto var = scope_.FindVar(sub_name);
C
chengduo 已提交
463 464 465 466 467 468
                   if (var == nullptr) return nullptr;
                   PADDLE_ENFORCE(
                       var->IsType<LoDTensor>(),
                       "%s should be LoDTensor, but the received type is %s",
                       sub_name, var->Type().name());
                   return var->GetMutable<LoDTensor>();
469
                 });
470 471 472
  return res;
}

Y
Yu Yang 已提交
473 474 475 476 477 478 479 480 481 482 483 484 485 486 487
bool OpSupportGPU(const std::string& op_type) {
  auto& all_kernels = OperatorWithKernel::AllOpKernels();
  auto it = all_kernels.find(op_type);
  if (it == all_kernels.end()) {
    // All control operator must support GPU
    return true;
  }
  for (auto& kern_pair : it->second) {
    if (platform::is_gpu_place(kern_pair.first.place_)) {
      return true;
    }
  }
  return false;
}

488 489 490 491 492 493
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

  bool HasInput(const std::string& name) const override {
494 495 496 497
    // has only one input
    const auto& ins = op_.Inputs();
    auto it = ins.find(name);
    if (it == ins.end()) {
498 499
      return false;
    }
500
    const auto& in = it->second;
T
tensor-tang 已提交
501
    if (in.size() == 0 || in[0] == kEmptyVarName) {
502 503
      return false;
    }
T
tensor-tang 已提交
504
    PADDLE_ENFORCE_EQ(in.size(), 1UL,
F
fengjiayi 已提交
505
                      "Input %s should not have more than one inputs", name);
506
    return scope_.FindVar(in[0]) != nullptr;
507 508 509
  }

  bool HasOutput(const std::string& name) const override {
510 511 512 513
    // has only one output
    const auto& outs = op_.Outputs();
    auto it = outs.find(name);
    if (it == outs.end()) {
514 515
      return false;
    }
516
    const auto& out = it->second;
T
tensor-tang 已提交
517
    if (out.size() == 0 || out[0] == kEmptyVarName) {
518 519
      return false;
    }
T
tensor-tang 已提交
520 521
    PADDLE_ENFORCE_EQ(out.size(), 1UL,
                      "Output %s should not have more than one outputs", name);
522
    return scope_.FindVar(out[0]) != nullptr;
523 524 525
  }

  bool HasInputs(const std::string& name) const override {
526 527 528
    if (!op_.HasInputs(name)) {
      return false;
    }
529 530 531 532 533 534 535 536 537 538 539 540 541
    auto inputs = op_.Inputs(name);
    if (inputs.empty()) {
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

  bool HasOutputs(const std::string& name) const override {
542 543 544
    if (!op_.HasOutputs(name)) {
      return false;
    }
545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568
    auto outputs = op_.Outputs(name);
    if (outputs.empty()) {
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
    return op_.Inputs(name);
  }

  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
    return op_.Outputs(name);
  }

569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598
  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    const std::string& input_n = Inputs(in)[i];
    const std::string& output_n = Outputs(out)[j];

    Variable* in_var = scope_.FindVar(input_n);
    Variable* out_var = scope_.FindVar(output_n);
    PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
                   "The type of %s and %s is not the same.", output_n,
                   GetDim(input_n));

    if (in_var->IsType<framework::SelectedRows>()) {
      auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
      auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
      out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
      out_sele_rows->set_rows(in_sele_rows.rows());
      out_sele_rows->set_height(in_sele_rows.height());
    } else if (in_var->IsType<framework::LoDTensor>()) {
      auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
      auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
      out_lod_tensor->Resize(in_lod_tensor.dims());
    } else {
      PADDLE_THROW(
          "Currently, the input type of ShareDim only can be LoDTensor "
          "or SelectedRows.");
    }
  }

Q
Qiao Longfei 已提交
599 600
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
X
Xin Pan 已提交
601 602 603 604 605
    const std::vector<std::string>& inputs = Inputs(in);
    const std::vector<std::string>& outputs = Outputs(out);
    PADDLE_ENFORCE_LT(i, inputs.size());
    PADDLE_ENFORCE_LT(j, outputs.size());
    Variable* in_var = scope_.FindVar(inputs.at(i));
Q
Qiao Longfei 已提交
606
    if (!in_var->IsType<LoDTensor>()) return;
X
Xin Pan 已提交
607
    Variable* out_var = scope_.FindVar(outputs.at(j));
Q
Qiao Longfei 已提交
608 609 610 611 612
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
D
dzhwinter 已提交
613

M
mozga-intel 已提交
614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
    // Fix me: ugly workaround below
    // Correct solution:
    //    set_layout() should NOT be called here (i.e. ShareLoD). Instead,
    //    layout of output tensor should be set "manually" in Compute()
    //    of each OPKernel. The reason layout should NOT be shared between
    //    input and output "automatically" (now by InferShape()->ShareLoD())
    //    is that layout transform may occur after InferShape().
    // Workaround:
    //    Skip set_layout() when input layout is kMKLDNN
    //    This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
    //    OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
    //    in Compute()
    if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
      out_tensor->set_layout(in_tensor.layout());
D
dzhwinter 已提交
633 634
  }

635 636 637
  bool IsRuntime() const override { return true; }

 protected:
638 639
  DDim GetDim(const std::string& name) const override {
    Variable* var = scope_.FindVar(name);
F
fengjiayi 已提交
640
    PADDLE_ENFORCE_NOT_NULL(var);
641 642 643 644 645
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
    } else {
F
fengjiayi 已提交
646 647 648 649 650 651 652
      PADDLE_THROW(
          "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
          "type_id is %s.",
          name, var->Type().name());
    }
  }

F
fengjiayi 已提交
653
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Y
Yu Yang 已提交
654
    PADDLE_THROW("Only compile time support this method");
655 656 657 658 659 660 661 662 663
  }

  void SetDim(const std::string& name, const DDim& dim) override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
Y
Yang Yang 已提交
664 665
      PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
                   name, var->Type().name());
666 667 668
    }
  }

F
fengjiayi 已提交
669 670
  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
Y
Yu Yang 已提交
671
    PADDLE_THROW("Only compile time support this method");
F
fengjiayi 已提交
672 673
  }

674
  proto::VarType::Type GetVarType(const std::string& name) const override {
675 676 677 678
    auto* var = scope_.FindVar(name);
    return ToVarType(var->Type());
  }

F
fengjiayi 已提交
679 680 681 682
  InferShapeVarPtr GetVarPtr(const std::string& name) override {
    return scope_.FindVar(name);
  }

683
 private:
684 685 686 687
  const OperatorBase& op_;
  const Scope& scope_;
};

C
chengduoZH 已提交
688 689 690 691 692
static void CheckTensorNANOrInf(const std::string& name,
                                const framework::Tensor& tensor) {
  if (tensor.memory_size() == 0) {
    return;
  }
S
sneaxiy 已提交
693
  if (!IsType<float>(tensor.type()) && !IsType<double>(tensor.type())) {
C
chengduoZH 已提交
694 695 696 697 698 699 700 701
    return;
  }
  PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
                 "Tensor %s contains Inf", name);
  PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
                 "Tensor %s contains NAN", name);
}

702 703
void OperatorWithKernel::RunImpl(const Scope& scope,
                                 const platform::Place& place) const {
704 705
  RuntimeInferShapeContext infer_shape_ctx(*this, scope);
  this->InferShape(&infer_shape_ctx);
Y
Yu Yang 已提交
706
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
707
  auto* dev_ctx = pool.Get(place);
708

709 710 711 712
  // check if op[type] has kernel registered.
  auto& all_op_kernels = AllOpKernels();
  auto kernels_iter = all_op_kernels.find(type_);
  if (kernels_iter == all_op_kernels.end()) {
Y
Yu Yang 已提交
713 714
    PADDLE_THROW(
        "There are no kernels which are registered in the %s operator.", type_);
715 716
  }

Q
qiaolongfei 已提交
717 718
  OpKernelMap& kernels = kernels_iter->second;

719 720
  // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
  // transform functions are ready.
Q
qiaolongfei 已提交
721

722 723 724 725
  // for (auto& candidate : kKernelPriority) {
  //   Do selection
  // }

Y
yuyang18 已提交
726 727
  auto expected_kernel_key =
      this->GetExpectedKernelType(ExecutionContext(*this, scope, *dev_ctx));
728
  VLOG(30) << "expected_kernel_key:" << expected_kernel_key;
Q
qiaolongfei 已提交
729

730
  auto kernel_iter = kernels.find(expected_kernel_key);
731
#ifdef PADDLE_WITH_MKLDNN
P
Paweł Żelazko 已提交
732
  // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
733 734
  if (kernel_iter == kernels.end() &&
      expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
735
    VLOG(30) << "missing MKLDNN kernel: fallbacking to PLAIN one";
736 737 738 739 740
    expected_kernel_key.library_type_ = LibraryType::kPlain;
    expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
    kernel_iter = kernels.find(expected_kernel_key);
  }
#endif
741 742 743 744 745
  if (kernel_iter == kernels.end()) {
    PADDLE_THROW("op %s does not have kernel for %s", type_,
                 KernelTypeToString(expected_kernel_key));
  }

Y
yuyang18 已提交
746 747 748 749
  // do data transformScope &transfer_scope;
  std::vector<std::string> transfered_inplace_vars;
  auto* transfer_scope =
      TryTransferData(scope, expected_kernel_key, &transfered_inplace_vars);
750

Y
yuyang18 已提交
751 752 753 754 755 756
  // exec scope is the scope that kernel actually executed on.
  const Scope& exec_scope =
      (transfer_scope == nullptr ? scope : *transfer_scope);

  if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
    dev_ctx = pool.Get(expected_kernel_key.place_);
757
  }
Q
QI JUN 已提交
758

Y
yuyang18 已提交
759
  kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx));
D
dzhwinter 已提交
760

Y
yuyang18 已提交
761 762 763
  if (!transfered_inplace_vars.empty()) {
    // there is inplace variable has been transfered.
    TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
764 765
  }

D
dzhwinter 已提交
766
  /*For profiling/benchmark only*/
D
dzhwinter 已提交
767
  if (FLAGS_benchmark) {
Y
yuyang18 已提交
768
    dev_ctx->Wait();
D
dzhwinter 已提交
769
  }
C
chengduoZH 已提交
770 771 772

  if (FLAGS_check_nan_inf) {
    for (auto& vname : OutputVars(true)) {
Y
yuyang18 已提交
773
      auto* var = exec_scope.FindVar(vname);
C
chengduoZH 已提交
774 775 776
      if (var == nullptr) continue;
      if (var->IsType<framework::LoDTensor>()) {
        CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
777 778
      } else if (var->IsType<framework::SelectedRows>()) {
        CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
C
chengduoZH 已提交
779 780 781
      }
    }
  }
Q
Qiao Longfei 已提交
782
}
Y
yuyang18 已提交
783 784 785 786
void OperatorWithKernel::TransferInplaceVarsBack(
    const Scope& scope, const std::vector<std::string>& inplace_vars,
    const Scope& transfer_scope) const {
  for (auto& var_name : inplace_vars) {
787 788
    VLOG(30) << "share inplace var " + var_name +
                    " back to it's original scope";
C
chengduo 已提交
789 790
    auto* original_tensor =
        GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
C
chengduo 已提交
791 792 793
    auto* var = transfer_scope.FindVar(var_name);
    PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
                   var_name);
C
chengduo 已提交
794
    auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
Y
yuyang18 已提交
795 796 797 798 799 800 801
    original_tensor->ShareDataWith(*transformed_tensor);
  }
}

Scope* OperatorWithKernel::TryTransferData(
    const Scope& scope, const OpKernelType& expected_kernel_key,
    std::vector<std::string>* transfered_inplace_vars) const {
802 803 804 805 806 807 808 809 810 811 812
// In the inference scenerio, the scopes will be reused across the batches, so
// the `new_scope` here will result in GPU memroy explosion over the running of
// operators.
// We use a thread_local cache to fix that issue, the key in the cache is the
// combination of the `scope` argument, from_kernel_type, target_kernel_type.
// Have a discussion with @Superjomn or the inference developers if some changes
// on this logic for this macro might not tested on the other scenerios.
#ifdef PADDLE_ON_INFERENCE
  thread_local std::unordered_map<size_t, Scope*> infer_transfer_scope_cache;
#endif

Y
yuyang18 已提交
813 814 815 816 817
  Scope* new_scope = nullptr;
  for (auto& var_name_item : Inputs()) {
    for (auto& var_name : var_name_item.second) {
      auto* var = scope.FindVar(var_name);
      // Only tensor can be tranfer to another device.
C
chengduo 已提交
818
      if (var == nullptr || !VarIsTensor(*var)) {
Y
yuyang18 已提交
819 820 821
        continue;
      }

C
chengduo 已提交
822
      auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
Y
yuyang18 已提交
823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839
      if (!tensor_in->IsInitialized()) {
        continue;
      }

      auto kernel_type_for_var = GetKernelTypeForVar(
          var_name_item.first, *tensor_in, expected_kernel_key);

      if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
        continue;
      }

      auto out_var_names = OutputVars(true);
      if (std::find(out_var_names.begin(), out_var_names.end(), var_name) !=
          out_var_names.end()) {
        transfered_inplace_vars->emplace_back(var_name);
      }

840 841
      VLOG(30) << "Transform Variable " << var_name << " from "
               << kernel_type_for_var << " to " << expected_kernel_key;
Y
yuyang18 已提交
842

843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858
#ifdef PADDLE_ON_INFERENCE
      size_t infer_cache_key =
          CombineHash(OpKernelType::Hash()(kernel_type_for_var),
                      OpKernelType::Hash()(expected_kernel_key));
      infer_cache_key =
          CombineHash(infer_cache_key, std::hash<const Scope*>()(&scope));

      auto it = infer_transfer_scope_cache.find(infer_cache_key);
      if (it != infer_transfer_scope_cache.end()) {
        new_scope = infer_transfer_scope_cache[infer_cache_key];
      } else {
        new_scope = &scope.NewScope();
        infer_transfer_scope_cache[infer_cache_key] = new_scope;
      }
#endif

Y
yuyang18 已提交
859 860 861 862 863
      if (new_scope == nullptr) {
        new_scope = &scope.NewScope();
      }

      auto* trans_var = new_scope->Var(var_name);
864

Y
yuyang18 已提交
865
      Tensor out;
Y
yuyang18 已提交
866
      TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
Y
yuyang18 已提交
867 868 869 870 871 872
      SetTensorToVariable(*var, out, trans_var);
    }
  }

  return new_scope;
}
Q
Qiao Longfei 已提交
873

874
proto::VarType::Type OperatorWithKernel::IndicateDataType(
Y
Yu Yang 已提交
875 876 877
    const ExecutionContext& ctx) const {
  auto& scope = ctx.scope();
  int data_type = -1;
878
  std::string last_input_name;
Y
Yu Yang 已提交
879 880 881 882 883 884 885 886 887 888 889 890 891 892
  for (auto& input : this->inputs_) {
    for (auto& ipt_name : input.second) {
      auto* var = scope.FindVar(ipt_name);
      if (var != nullptr) {
        const Tensor* t = nullptr;
        if (var->IsType<Tensor>()) {
          t = &var->Get<Tensor>();
        } else if (var->IsType<LoDTensor>()) {
          t = &var->Get<LoDTensor>();
        } else if (var->IsType<SelectedRows>()) {
          t = &(var->Get<SelectedRows>().value());
        }
        if (t != nullptr) {
          int tmp = static_cast<int>(ToDataType(t->type()));
893 894
          PADDLE_ENFORCE(
              tmp == data_type || data_type == -1,
895 896
              "DataType of Paddle Op %s must be the same. Get %s(%d) != %s(%d)",
              Type(), last_input_name, data_type, ipt_name, tmp);
Y
Yu Yang 已提交
897
          data_type = tmp;
898
          last_input_name = ipt_name;
Y
Yu Yang 已提交
899 900 901 902 903
        }
      }
    }
  }
  PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
904
  return static_cast<proto::VarType::Type>(data_type);
Y
Yu Yang 已提交
905
}
906

907 908 909 910 911 912 913 914
OpKernelType OperatorWithKernel::GetExpectedKernelType(
    const ExecutionContext& ctx) const {
  return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetKernelTypeForVar(
    const std::string& var_name, const Tensor& tensor,
    const OpKernelType& expected_kernel_type) const {
M
mozga-intel 已提交
915 916
  return OpKernelType(expected_kernel_type.data_type_, tensor.place(),
                      tensor.layout());
917 918
}

Q
Qiao Longfei 已提交
919
}  // namespace framework
L
liaogang 已提交
920
}  // namespace paddle