operator.cc 36.7 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14

15 16
#include <gflags/gflags.h>
#include <glog/logging.h>
17

18
#include <algorithm>
Y
Yi Wang 已提交
19 20
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
21
#include "paddle/fluid/framework/lod_tensor.h"
22
#include "paddle/fluid/framework/op_proto_maker.h"
23
#include "paddle/fluid/framework/operator.h"
Y
Yi Wang 已提交
24
#include "paddle/fluid/framework/shape_inference.h"
25
#include "paddle/fluid/framework/transfer_scope_cache.h"
Y
Yi Wang 已提交
26
#include "paddle/fluid/framework/var_type.h"
27
#include "paddle/fluid/platform/debug_support.h"
28
#include "paddle/fluid/platform/profiler.h"
Q
Qiao Longfei 已提交
29

D
dzhwinter 已提交
30
DECLARE_bool(benchmark);
C
chengduoZH 已提交
31 32 33
DEFINE_bool(check_nan_inf, false,
            "Checking whether operator produce NAN/INF or not. It will be "
            "extremely slow so please use this flag wisely.");
D
dzhwinter 已提交
34

Q
Qiao Longfei 已提交
35 36 37
namespace paddle {
namespace framework {

38 39 40 41 42 43
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN),
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain),
    std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN),
    std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
D
dzhwinter 已提交
44

Q
qiaolongfei 已提交
45 46
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
  if (var->IsType<framework::LoDTensor>()) {
Y
Yu Yang 已提交
47
    return var->Get<framework::LoDTensor>().type();
Q
qiaolongfei 已提交
48
  } else if (var->IsType<framework::SelectedRows>()) {
Y
Yu Yang 已提交
49
    return var->Get<framework::SelectedRows>().value().type();
Q
qiaolongfei 已提交
50 51 52 53 54
  } else {
    PADDLE_THROW("Var should be LoDTensor or SelectedRows");
  }
}

55 56
static DDim GetDims(const Scope& scope, const std::string& name,
                    bool get_actual_dim = false) {
57
  Variable* var = scope.FindVar(name);
Q
qiaolongfei 已提交
58 59
  if (var == nullptr) {
    return DDim({-1});
Q
Qiao Longfei 已提交
60 61
  }

M
minqiyang 已提交
62 63
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
M
minqiyang 已提交
64
    if (UNLIKELY(!tensor.IsInitialized())) {
65
      return DDim({-1});
66
    }
M
minqiyang 已提交
67 68 69 70 71 72 73
    return tensor.dims();
  } else if (var->IsType<SelectedRows>()) {
    if (get_actual_dim) {
      return var->Get<SelectedRows>().value().dims();
    } else {
      return var->Get<SelectedRows>().GetCompleteDims();
    }
74 75 76 77 78
  } else {
    return DDim({-1});
  }
}

Q
Qiao Longfei 已提交
79 80 81 82 83 84
static bool VarInited(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) return false;
  return var->IsInitialized();
}

D
dzhwinter 已提交
85 86 87 88 89
static std::string GetDtype(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) {
    return "";
  }
90

M
minqiyang 已提交
91 92 93
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
    if (UNLIKELY(!tensor.IsInitialized())) {
94 95
      return "";
    }
Y
Yu Yang 已提交
96
    return DataTypeToString(tensor.type());
M
minqiyang 已提交
97
  } else if (var->IsType<SelectedRows>()) {
Q
Qiao Longfei 已提交
98 99 100 101
    auto tensor = var->Get<SelectedRows>().value();
    if (UNLIKELY(!tensor.IsInitialized())) {
      return "uninited";
    } else {
Y
Yu Yang 已提交
102
      return DataTypeToString(tensor.type());
Q
Qiao Longfei 已提交
103
    }
D
dzhwinter 已提交
104 105 106 107 108
  } else {
    return "";
  }
}

109 110 111 112 113 114
static int GetRowSize(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  if (var == nullptr) {
    return -1;
  }

M
minqiyang 已提交
115 116
  if (var->IsType<SelectedRows>()) {
    return var->Get<SelectedRows>().rows().size();
117 118 119 120 121
  }

  return -1;
}

Q
Qiao Longfei 已提交
122 123 124 125 126 127 128 129
static LoD GetLoD(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  auto default_lod = LoD({{}});

  if (var == nullptr) {
    return default_lod;
  }

M
minqiyang 已提交
130 131 132
  if (var->IsType<LoDTensor>()) {
    const LoDTensor& tensor = var->Get<LoDTensor>();
    if (UNLIKELY(!tensor.IsInitialized())) {
133 134
      return default_lod;
    }
M
minqiyang 已提交
135
    return tensor.lod();
Q
Qiao Longfei 已提交
136 137 138 139 140
  } else {
    return default_lod;
  }
}

X
Xin Pan 已提交
141 142 143 144 145
RuntimeContext::RuntimeContext(const VariableNameMap& innames,
                               const VariableNameMap& outnames,
                               const Scope& scope) {
  for (auto& var_name_item : innames) {
    std::vector<Variable*>& input_vars = inputs[var_name_item.first];
X
Xin Pan 已提交
146
    input_vars.reserve(var_name_item.second.size());
X
Xin Pan 已提交
147 148 149 150 151 152
    for (auto& var_name : var_name_item.second) {
      input_vars.push_back(scope.FindVar(var_name));
    }
  }
  for (auto& var_name_item : outnames) {
    std::vector<Variable*>& output_vars = outputs[var_name_item.first];
X
Xin Pan 已提交
153
    output_vars.reserve(var_name_item.second.size());
X
Xin Pan 已提交
154 155 156 157 158 159
    for (auto& var_name : var_name_item.second) {
      output_vars.push_back(scope.FindVar(var_name));
    }
  }
}

160 161 162 163 164 165 166 167
void OperatorBase::PreHook() {
  auto attrName = OpProtoAndCheckerMaker::OpCreationCallstackAttrName();
  if (HasAttr(attrName)) {
    auto& callstack = Attr<std::vector<std::string>>(attrName);
    platform::PythonDebugSupport::GetInstance()->SetInformation(callstack);
  }
}

168
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
169 170 171
  VLOG(4) << "Call the prehook ... ";
  PreHook();

P
peizhilin 已提交
172 173
  VLOG(4) << place << " " << DebugStringEx(&scope);
  if (platform::is_gpu_place(place)) {
174
#ifndef PADDLE_WITH_CUDA
P
peizhilin 已提交
175
    PADDLE_THROW("Cannot run operator on place %s", place);
176
#else
P
peizhilin 已提交
177 178
    auto dev_id = boost::get<platform::CUDAPlace>(place).device;
    platform::SetDeviceId(dev_id);
179
#endif
P
peizhilin 已提交
180
  }
P
peizhilin 已提交
181

P
peizhilin 已提交
182 183 184 185 186 187 188 189 190
  // The profile has a process-wide mutex, results in serious performance issue
  // in concurrency scenerio. Here use an `if` to fix this issue.
  // Please not remove the `if`, ask @Superjomn if there are any concern.
  if (platform::IsProfileEnabled()) {
    platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
    platform::RecordEvent record_event(Type(), pool.Get(place));
    RunImpl(scope, place);
  } else {
    RunImpl(scope, place);
191
  }
P
peizhilin 已提交
192
  VLOG(3) << place << " " << DebugStringEx(&scope);
193 194 195 196 197 198 199

  VLOG(4) << "Call the posthook ... ";
  PostHook();
}

void OperatorBase::PostHook() {
  // do nothing here
200 201
}

202
bool OperatorBase::HasInputs(const std::string& name) const {
M
minqiyang 已提交
203
  return inputs_.find(name) != inputs_.end();
204 205
}

206
std::string OperatorBase::Input(const std::string& name) const {
Y
Yu Yang 已提交
207
  auto& ins = Inputs(name);
Y
Yu Yang 已提交
208
  PADDLE_ENFORCE_LE(ins.size(), 1UL,
209 210
                    "Operator %s's input %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
211
  return ins.empty() ? kEmptyVarName : ins[0];
Y
Yan Chunwei 已提交
212 213
}

Y
Yu Yang 已提交
214 215
const std::vector<std::string>& OperatorBase::Inputs(
    const std::string& name) const {
Y
Yu Yang 已提交
216
  auto it = inputs_.find(name);
217 218
  PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.",
                 type_, name);
Y
Yu Yang 已提交
219
  return it->second;
Y
Yan Chunwei 已提交
220 221
}

222
bool OperatorBase::HasOutputs(const std::string& name) const {
223
  if (outputs_.find(name) != outputs_.end()) {
224 225 226 227 228 229
    return true;
  } else {
    return false;
  }
}

230
std::string OperatorBase::Output(const std::string& name) const {
Y
Yu Yang 已提交
231
  auto& outs = Outputs(name);
Y
Yu Yang 已提交
232
  PADDLE_ENFORCE_LE(outs.size(), 1UL,
233 234
                    "Operator %s's output %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
235
  return outs.empty() ? kEmptyVarName : outs[0];
Y
Yan Chunwei 已提交
236 237
}

Y
Yu Yang 已提交
238 239
const std::vector<std::string>& OperatorBase::Outputs(
    const std::string& name) const {
Y
Yu Yang 已提交
240
  auto it = outputs_.find(name);
241 242
  PADDLE_ENFORCE(it != outputs_.end(),
                 "Operator %s does not have an output called %s.", type_, name);
Y
Yu Yang 已提交
243
  return it->second;
Y
Yan Chunwei 已提交
244 245
}

246
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
Q
Qiao Longfei 已提交
247
  std::stringstream ss;
Y
Yu Yang 已提交
248
  ss << "Op(" << type_ << "), inputs:{";
Y
Yu Yang 已提交
249 250
  for (auto it = inputs_.begin(); it != inputs_.end();) {
    auto& input = *it;
Y
Yu Yang 已提交
251 252
    ss << input.first << "[";
    for (size_t i = 0; i < input.second.size(); ++i) {
Q
Qiao Longfei 已提交
253 254
      auto var_name = input.second[i];
      ss << var_name;
255
      if (scope) {
Q
Qiao Longfei 已提交
256 257 258 259 260 261 262 263 264 265 266
        if (!VarInited(*scope, var_name)) {
          ss << "[uninited]";
        } else {
          int row_size = GetRowSize(*scope, var_name);
          if (row_size >= 0) {
            ss << "[row_size=" << row_size << "]";
          }
          std::string dtype = GetDtype(*scope, var_name);
          ss << ":" << dtype;
          ss << "[" << GetDims(*scope, var_name, true) << "]";
          ss << "(" << GetLoD(*scope, var_name) << ")";
267
        }
268
      }
Y
Yu Yang 已提交
269 270 271
      if (i != input.second.size() - 1) {
        ss << ", ";
      }
272
    }
Y
Yu Yang 已提交
273
    ss << "]";
Y
Yu Yang 已提交
274 275
    ++it;
    if (it != inputs_.end()) {
276 277
      ss << ", ";
    }
Q
Qiao Longfei 已提交
278
  }
Y
Yu Yang 已提交
279
  ss << "}, outputs:{";
Y
Yu Yang 已提交
280 281
  for (auto it = outputs_.begin(); it != outputs_.end();) {
    auto& output = *it;
Y
Yu Yang 已提交
282 283
    ss << output.first << "[";
    for (size_t i = 0; i < output.second.size(); ++i) {
Q
Qiao Longfei 已提交
284 285
      auto var_name = output.second[i];
      ss << var_name;
286
      if (scope) {
Q
Qiao Longfei 已提交
287 288 289 290 291 292 293
        if (!VarInited(*scope, var_name)) {
          ss << "[uninited]";
        } else {
          int row_size = GetRowSize(*scope, output.second[i]);
          if (row_size >= 0) {
            ss << "[row_size=" << row_size << "]";
          }
C
chengduo 已提交
294 295
          std::string dtype = GetDtype(*scope, output.second[i]);
          ss << ":" << dtype;
Q
Qiao Longfei 已提交
296 297
          ss << "[" << GetDims(*scope, var_name, true) << "]";
          ss << "(" << GetLoD(*scope, var_name) << ")";
298
        }
299
      }
Y
Yu Yang 已提交
300 301 302
      if (i != output.second.size() - 1) {
        ss << ", ";
      }
303
    }
Y
Yu Yang 已提交
304
    ss << "]";
Y
Yu Yang 已提交
305 306
    ++it;
    if (it != outputs_.end()) {
307 308
      ss << ", ";
    }
Q
Qiao Longfei 已提交
309
  }
Y
Yu Yang 已提交
310
  ss << "}.";
Q
Qiao Longfei 已提交
311 312 313
  return ss.str();
}

Y
Yu Yang 已提交
314
OperatorBase::OperatorBase(const std::string& type,
Y
Yu Yang 已提交
315 316
                           const VariableNameMap& inputs,
                           const VariableNameMap& outputs,
Y
Yu Yang 已提交
317 318
                           const AttributeMap& attrs)
    : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
319 320
  GenerateTemporaryNames();
  CheckAllInputOutputSet();
Y
Yu Yang 已提交
321
}
322

Q
qijun 已提交
323 324
std::vector<std::string> OperatorBase::InputVars() const {
  std::vector<std::string> ret_val;
Y
Yu Yang 已提交
325
  for (auto& o : inputs_) {
Q
qijun 已提交
326 327 328 329 330 331
    ret_val.reserve(ret_val.size() + o.second.size());
    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
  }
  return ret_val;
}

Y
Yu Yang 已提交
332 333 334 335 336 337 338 339 340 341
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
  std::vector<std::string> ret_val;
  if (has_intermediate) {
    // push all outputs into ret_val
    for (auto& o : outputs_) {
      ret_val.reserve(ret_val.size() + o.second.size());
      ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
    }
    return ret_val;
  }
Y
Yu Yang 已提交
342
  auto& info = OpInfoMap::Instance().Get(Type());
Y
Yu Yang 已提交
343 344

  // get all OpProto::Var for outputs
Y
Yu Yang 已提交
345
  for (auto& o : info.Proto().outputs()) {
Y
Yu Yang 已提交
346 347 348 349 350 351 352 353 354
    // ignore all intermediate output
    if (o.intermediate()) continue;
    auto out = outputs_.find(o.name());
    if (out != outputs_.end()) {
      ret_val.reserve(ret_val.size() + out->second.size());
      ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
    }
  }
  return ret_val;
D
dongzhihong 已提交
355 356
}

357 358 359
void OperatorBase::CheckAllInputOutputSet() const {
  auto& info_map = OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(Type());
Y
Yu Yang 已提交
360
  if (op_info == nullptr || op_info->proto_ == nullptr) return;
361 362

  for (auto& in : op_info->Proto().inputs()) {
363 364 365 366
    if (!in.dispensable()) {
      PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
                     "Operator %s's input, %s, is not set", Type(), in.name());
    }
367 368 369
  }

  for (auto& out : op_info->Proto().outputs()) {
370 371 372 373 374
    if (!out.dispensable()) {
      PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
                     "Operator %s's output, %s, is not set", Type(),
                     out.name());
    }
375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390
  }
}

void OperatorBase::GenerateTemporaryNames() {
  static std::atomic<size_t> gUniqId(0UL);
  for (auto& output : outputs_) {
    for (auto& output_name : output.second) {
      if (output_name == kTempVarName) {
        output_name += type_;
        output_name += "@";
        output_name += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
}

B
baojun-nervana 已提交
391
static bool VarIsTensor(const Variable& var) {
C
chengduo 已提交
392
  return var.IsType<LoDTensor>() || var.IsType<SelectedRows>();
393 394
}

C
chengduo 已提交
395
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var) {
C
chengduo 已提交
396 397 398 399
  if (var.IsType<LoDTensor>()) {
    return static_cast<const Tensor*>(&(var.Get<LoDTensor>()));
  } else if (var.IsType<SelectedRows>()) {
    return &(var.Get<SelectedRows>().value());
Q
QI JUN 已提交
400
  } else {
Y
Yang Yang 已提交
401
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
S
sneaxiy 已提交
402
                 ToTypeName(var.Type()));
Q
QI JUN 已提交
403 404 405
  }
}

C
chengduo 已提交
406
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var) {
Q
QI JUN 已提交
407
  if (var->IsType<LoDTensor>()) {
408
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
409
  } else if (var->IsType<SelectedRows>()) {
410
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
411
  } else {
Y
Yang Yang 已提交
412
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
S
sneaxiy 已提交
413
                 ToTypeName(var->Type()));
Q
QI JUN 已提交
414 415 416
  }
}

417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448
bool ExecutionContext::HasInput(const std::string& name) const {
  if (!op_.HasInputs(name)) {
    return false;
  }
  auto& ins = Inputs(name);
  size_t length = ins.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Input %s should not have more than one inputs", name);
  auto arg = ins[0];
  auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
  return var != nullptr;
}

bool ExecutionContext::HasOutput(const std::string& name) const {
  if (!op_.HasOutputs(name)) {
    return false;
  }
  auto& outs = Outputs(name);
  size_t length = outs.size();
  if (length == 0) {
    return false;
  }
  PADDLE_ENFORCE_EQ(length, 1UL,
                    "Output %s should not have more than one inputs", name);
  auto arg = outs[0];
  auto* var = arg == kEmptyVarName ? nullptr : scope_.FindVar(arg);
  return var != nullptr;
}

X
Xin Pan 已提交
449 450 451 452 453 454 455 456 457 458
const Variable* ExecutionContext::InputVar(const std::string& name) const {
  auto it = ctx_.inputs.find(name);
  if (it == ctx_.inputs.end()) return nullptr;

  PADDLE_ENFORCE_LE(it->second.size(), 1UL,
                    "Operator %s's input %s should contain only one variable.",
                    op_.Type(), name);
  return it->second.empty() ? nullptr : it->second[0];
}

X
clean  
Xin Pan 已提交
459 460 461 462
const Variable* ExecutionContext::LegacyInputVar(
    const std::string& name) const {
  auto ipt = op_.Input(name);
  return ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
X
Xin Pan 已提交
463 464
}

X
clean  
Xin Pan 已提交
465
Variable* ExecutionContext::OutputVar(const std::string& name) const {
X
Xin Pan 已提交
466 467 468 469 470 471 472 473 474
  auto it = ctx_.outputs.find(name);
  if (it == ctx_.outputs.end()) return nullptr;

  PADDLE_ENFORCE_LE(it->second.size(), 1UL,
                    "Operator %s's output %s should contain only one variable.",
                    op_.Type(), name);
  return it->second.empty() ? nullptr : it->second[0];
}

X
clean  
Xin Pan 已提交
475 476 477 478 479
Variable* ExecutionContext::LegacyOutputVar(const std::string& name) const {
  auto opt = op_.Output(name);
  return opt == kEmptyVarName ? nullptr : scope_.FindVar(opt);
}

480
template <>
481
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
C
chengduo 已提交
482
  return Input<LoDTensor>(name);
483 484
}

X
Xin Pan 已提交
485
template <>
X
clean  
Xin Pan 已提交
486
const Tensor* ExecutionContext::LegacyInput<Tensor>(
X
Xin Pan 已提交
487
    const std::string& name) const {
X
clean  
Xin Pan 已提交
488
  return LegacyInput<LoDTensor>(name);
X
Xin Pan 已提交
489 490
}

491
template <>
492
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
493
    const std::string& name) const {
X
Xin Pan 已提交
494 495 496 497 498 499 500 501 502 503 504 505 506
  auto it = ctx_.inputs.find(name);
  if (it == ctx_.inputs.end()) {
    return {};
  }
  const std::vector<Variable*>& vars = it->second;
  std::vector<const Tensor*> res;
  res.reserve(vars.size());
  std::transform(vars.begin(), vars.end(), std::back_inserter(res),
                 [&](Variable* var) -> const Tensor* {
                   if (var == nullptr) return nullptr;
                   PADDLE_ENFORCE(
                       var->IsType<LoDTensor>(),
                       "should be LoDTensor, but the received type is %s",
S
sneaxiy 已提交
507
                       ToTypeName(var->Type()));
X
Xin Pan 已提交
508 509 510 511 512 513 514 515
                   return &(var->Get<LoDTensor>());
                 });
  return res;
}

template <>
const std::vector<const Tensor*> ExecutionContext::LegacyMultiInput<Tensor>(
    const std::string& name) const {
516 517 518
  auto names = op().Inputs(name);
  std::vector<const Tensor*> res;
  res.reserve(names.size());
519
  std::transform(names.begin(), names.end(), std::back_inserter(res),
C
chengduo 已提交
520
                 [&](const std::string& sub_name) -> const Tensor* {
521
                   auto var = scope_.FindVar(sub_name);
C
chengduo 已提交
522 523 524 525
                   if (var == nullptr) return nullptr;
                   PADDLE_ENFORCE(
                       var->IsType<LoDTensor>(),
                       "%s should be LoDTensor, but the received type is %s",
S
sneaxiy 已提交
526
                       sub_name, ToTypeName(var->Type()));
C
chengduo 已提交
527
                   return &(var->Get<LoDTensor>());
528
                 });
529 530 531 532
  return res;
}

template <>
533
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
C
chengduo 已提交
534
  return Output<LoDTensor>(name);
535 536
}

X
Xin Pan 已提交
537
template <>
X
clean  
Xin Pan 已提交
538 539
Tensor* ExecutionContext::LegacyOutput<Tensor>(const std::string& name) const {
  return LegacyOutput<LoDTensor>(name);
X
Xin Pan 已提交
540 541
}

542
template <>
543
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
544 545 546 547
    const std::string& name) const {
  auto names = op().Outputs(name);
  std::vector<Tensor*> res;
  res.reserve(names.size());
548
  std::transform(names.begin(), names.end(), std::back_inserter(res),
C
chengduo 已提交
549
                 [&](const std::string& sub_name) -> Tensor* {
550
                   auto var = scope_.FindVar(sub_name);
C
chengduo 已提交
551 552 553 554
                   if (var == nullptr) return nullptr;
                   PADDLE_ENFORCE(
                       var->IsType<LoDTensor>(),
                       "%s should be LoDTensor, but the received type is %s",
S
sneaxiy 已提交
555
                       sub_name, ToTypeName(var->Type()));
C
chengduo 已提交
556
                   return var->GetMutable<LoDTensor>();
557
                 });
558 559 560
  return res;
}

Y
Yu Yang 已提交
561 562 563 564 565 566 567 568 569 570 571 572 573 574 575
bool OpSupportGPU(const std::string& op_type) {
  auto& all_kernels = OperatorWithKernel::AllOpKernels();
  auto it = all_kernels.find(op_type);
  if (it == all_kernels.end()) {
    // All control operator must support GPU
    return true;
  }
  for (auto& kern_pair : it->second) {
    if (platform::is_gpu_place(kern_pair.first.place_)) {
      return true;
    }
  }
  return false;
}

576 577
class RuntimeInferShapeContext : public InferShapeContext {
 public:
X
Xin Pan 已提交
578 579 580
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope,
                           const RuntimeContext& ctx)
      : op_(op), scope_(scope), ctx_(ctx) {}
581 582

  bool HasInput(const std::string& name) const override {
583
    // has only one input
X
Xin Pan 已提交
584
    const auto& ins = ctx_.inputs;
585 586
    auto it = ins.find(name);
    if (it == ins.end()) {
587 588
      return false;
    }
589
    const auto& in = it->second;
X
Xin Pan 已提交
590
    if (in.size() == 0) return false;
T
tensor-tang 已提交
591
    PADDLE_ENFORCE_EQ(in.size(), 1UL,
F
fengjiayi 已提交
592
                      "Input %s should not have more than one inputs", name);
X
Xin Pan 已提交
593
    return in[0] != nullptr;
594 595 596
  }

  bool HasOutput(const std::string& name) const override {
597
    // has only one output
X
Xin Pan 已提交
598
    const auto& outs = ctx_.outputs;
599 600
    auto it = outs.find(name);
    if (it == outs.end()) {
601 602
      return false;
    }
603
    const auto& out = it->second;
X
Xin Pan 已提交
604
    if (out.size() == 0) {
605 606
      return false;
    }
T
tensor-tang 已提交
607 608
    PADDLE_ENFORCE_EQ(out.size(), 1UL,
                      "Output %s should not have more than one outputs", name);
X
Xin Pan 已提交
609
    return out[0] != nullptr;
610 611 612
  }

  bool HasInputs(const std::string& name) const override {
X
Xin Pan 已提交
613 614
    const auto& ins = ctx_.inputs;
    auto it = ins.find(name);
X
fix  
Xin Pan 已提交
615
    if (it == ins.end() || it->second.empty()) {
616 617
      return false;
    }
X
Xin Pan 已提交
618 619
    for (auto& input : it->second) {
      if (input == nullptr) {
620 621 622 623 624 625 626
        return false;
      }
    }
    return true;
  }

  bool HasOutputs(const std::string& name) const override {
X
Xin Pan 已提交
627 628
    const auto& outs = ctx_.outputs;
    auto it = outs.find(name);
X
fix  
Xin Pan 已提交
629
    if (it == outs.end() || it->second.empty()) {
630 631
      return false;
    }
X
Xin Pan 已提交
632 633
    for (auto& output : it->second) {
      if (output == nullptr) {
634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651
        return false;
      }
    }
    return true;
  }

  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
    return op_.Inputs(name);
  }

  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
    return op_.Outputs(name);
  }

652 653
  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) override {
X
Xin Pan 已提交
654 655 656 657 658 659 660 661 662
    auto in_it = ctx_.inputs.find(in);
    auto out_it = ctx_.outputs.find(out);
    PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
                   "Inputs %s should have %llu argument", in, i);
    PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
                   "Outputs %s should have %llu argument", out, j);

    Variable* in_var = in_it->second[i];
    Variable* out_var = out_it->second[j];
663 664

    PADDLE_ENFORCE(in_var->Type() == out_var->Type(),
X
fix  
Xin Pan 已提交
665
                   "The type of %s and %s is not the same.", in, out);
666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683

    if (in_var->IsType<framework::SelectedRows>()) {
      auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
      auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
      out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
      out_sele_rows->set_rows(in_sele_rows.rows());
      out_sele_rows->set_height(in_sele_rows.height());
    } else if (in_var->IsType<framework::LoDTensor>()) {
      auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
      auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
      out_lod_tensor->Resize(in_lod_tensor.dims());
    } else {
      PADDLE_THROW(
          "Currently, the input type of ShareDim only can be LoDTensor "
          "or SelectedRows.");
    }
  }

Q
Qiao Longfei 已提交
684 685
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
X
Xin Pan 已提交
686 687 688 689 690 691 692 693
    auto in_it = ctx_.inputs.find(in);
    auto out_it = ctx_.outputs.find(out);
    PADDLE_ENFORCE(in_it != ctx_.inputs.end() && in_it->second.size() > i,
                   "Inputs %s should have %llu argument", in, i);
    PADDLE_ENFORCE(out_it != ctx_.outputs.end() && out_it->second.size() > j,
                   "Outputs %s should have %llu argument", out, j);

    Variable* in_var = in_it->second.at(i);
Q
Qiao Longfei 已提交
694
    if (!in_var->IsType<LoDTensor>()) return;
X
Xin Pan 已提交
695
    Variable* out_var = out_it->second.at(j);
Q
Qiao Longfei 已提交
696 697 698 699 700
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
D
dzhwinter 已提交
701

M
mozga-intel 已提交
702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
    // Fix me: ugly workaround below
    // Correct solution:
    //    set_layout() should NOT be called here (i.e. ShareLoD). Instead,
    //    layout of output tensor should be set "manually" in Compute()
    //    of each OPKernel. The reason layout should NOT be shared between
    //    input and output "automatically" (now by InferShape()->ShareLoD())
    //    is that layout transform may occur after InferShape().
    // Workaround:
    //    Skip set_layout() when input layout is kMKLDNN
    //    This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
    //    OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
    //    in Compute()
    if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
      out_tensor->set_layout(in_tensor.layout());
D
dzhwinter 已提交
721 722
  }

C
chengduo 已提交
723 724 725 726 727
  void DecreaseLoDLevel(const std::string& in, const std::string& out,
                        size_t i = 0, size_t j = 0) const override {
    PADDLE_THROW("DecreaseLoDLevel is only used in compile time.");
  }

728 729
  bool IsRuntime() const override { return true; }

730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748
  // TODO(paddle-dev): Can this be template?
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
      const std::string& name) override {
    const std::vector<Variable*>& vars = InputVars(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(vars.size());
    res.insert(res.begin(), vars.begin(), vars.end());
    return res;
  }

  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
      const std::string& name) override {
    const std::vector<Variable*>& vars = OutputVars(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(vars.size());
    res.insert(res.begin(), vars.begin(), vars.end());
    return res;
  }

X
Xin Pan 已提交
749 750 751 752 753 754 755 756 757 758 759 760 761
  DDim GetInputDim(const std::string& name) const override {
    const std::vector<Variable*>& vars = InputVars(name);
    PADDLE_ENFORCE_EQ(vars.size(), 1UL,
                      "Input(%s) should hold one element, but now it holds %d",
                      name, vars.size());
    return this->GetDim(vars[0]);
  }

  std::vector<DDim> GetInputsDim(const std::string& name) const override {
    const std::vector<Variable*>& vars = InputVars(name);
    return GetDims(vars);
  }

X
Xin Pan 已提交
762 763 764 765 766 767 768 769 770 771
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override {
    return GetVarTypes(InputVars(name));
  }

  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override {
    return GetVarTypes(OutputVars(name));
  }

X
Xin Pan 已提交
772 773 774 775 776 777 778 779 780 781 782 783 784 785
  void SetOutputDim(const std::string& name, const DDim& dim) override {
    auto& vars = OutputVars(name);
    PADDLE_ENFORCE_EQ(vars.size(), 1UL,
                      "Output(%s) should hold one element, but now it holds %d",
                      name, vars.size());
    SetDim(vars[0], dim);
  }

  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override {
    auto& vars = OutputVars(name);
    SetDims(vars, dims);
  }

786
 protected:
X
Xin Pan 已提交
787
  DDim GetDim(Variable* var) const {
F
fengjiayi 已提交
788
    PADDLE_ENFORCE_NOT_NULL(var);
789 790 791 792 793
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
    } else {
F
fengjiayi 已提交
794
      PADDLE_THROW(
X
Xin Pan 已提交
795
          "Only LoDTensor/SelectedRows support 'GetDim', but Variables "
F
fengjiayi 已提交
796
          "type_id is %s.",
S
sneaxiy 已提交
797
          ToTypeName(var->Type()));
F
fengjiayi 已提交
798 799 800
    }
  }

X
Xin Pan 已提交
801 802 803 804 805 806 807 808
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
    std::vector<DDim> ret;
    ret.reserve(vars.size());
    std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
                   [this](Variable* var) { return this->GetDim(var); });
    return ret;
  }

F
fengjiayi 已提交
809
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Y
Yu Yang 已提交
810
    PADDLE_THROW("Only compile time support this method");
811 812
  }

X
Xin Pan 已提交
813
  void SetDim(Variable* var, const DDim& dim) {
814 815 816 817 818
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
X
Xin Pan 已提交
819
      PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
S
sneaxiy 已提交
820
                   ToTypeName(var->Type()));
X
Xin Pan 已提交
821 822 823 824 825 826 827 828 829 830 831 832
    }
  }

  void SetDims(const std::vector<Variable*>& vars,
               const std::vector<DDim>& dims) {
    size_t length = vars.size();
    PADDLE_ENFORCE_EQ(length, dims.size());
    for (size_t i = 0; i < length; ++i) {
      if (vars[i] == nullptr) {
        continue;
      }
      SetDim(vars[i], dims[i]);
833 834 835
    }
  }

F
fengjiayi 已提交
836 837
  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
Y
Yu Yang 已提交
838
    PADDLE_THROW("Only compile time support this method");
F
fengjiayi 已提交
839 840
  }

X
Xin Pan 已提交
841 842 843 844 845 846 847 848 849 850 851
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<Variable*>& vars) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(vars.size());
    std::transform(vars.begin(), vars.end(), retv.begin(),
                   std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
                             this, std::placeholders::_1));
    return retv;
  }

  proto::VarType::Type GetVarType(Variable* var) const {
852 853 854
    return ToVarType(var->Type());
  }

855 856 857 858 859 860 861 862 863 864 865 866 867 868
 private:
  const std::vector<Variable*>& InputVars(const std::string& name) const {
    auto it = ctx_.inputs.find(name);
    PADDLE_ENFORCE(it != ctx_.inputs.end(),
                   "Operator %s does not have the input %s.", op_.Type(), name);
    return it->second;
  }

  const std::vector<Variable*>& OutputVars(const std::string& name) const {
    auto it = ctx_.outputs.find(name);
    PADDLE_ENFORCE(it != ctx_.outputs.end(),
                   "Operator %s does not have the outputs %s.", op_.Type(),
                   name);
    return it->second;
F
fengjiayi 已提交
869 870
  }

871 872
  const OperatorBase& op_;
  const Scope& scope_;
X
Xin Pan 已提交
873
  const RuntimeContext& ctx_;
874 875
};

C
chengduoZH 已提交
876 877 878 879 880
static void CheckTensorNANOrInf(const std::string& name,
                                const framework::Tensor& tensor) {
  if (tensor.memory_size() == 0) {
    return;
  }
Y
Yu Yang 已提交
881 882
  if (tensor.type() != proto::VarType::FP32 &&
      tensor.type() != proto::VarType::FP64) {
C
chengduoZH 已提交
883 884 885 886 887 888 889 890
    return;
  }
  PADDLE_ENFORCE(!framework::TensorContainsInf(tensor),
                 "Tensor %s contains Inf", name);
  PADDLE_ENFORCE(!framework::TensorContainsNAN(tensor),
                 "Tensor %s contains NAN", name);
}

B
baojun-nervana 已提交
891
void OperatorWithKernel::RuntimeInferShape(const Scope& scope,
X
Xin Pan 已提交
892 893 894
                                           const platform::Place& place,
                                           const RuntimeContext& ctx) const {
  RuntimeInferShapeContext infer_shape_ctx(*this, scope, ctx);
B
baojun-nervana 已提交
895 896 897
  this->InferShape(&infer_shape_ctx);
}

898 899
void OperatorWithKernel::RunImpl(const Scope& scope,
                                 const platform::Place& place) const {
X
Xin Pan 已提交
900
  RuntimeContext ctx(Inputs(), Outputs(), scope);
Y
Yu Yang 已提交
901
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
902
  auto* dev_ctx = pool.Get(place);
903

904 905 906 907
  // check if op[type] has kernel registered.
  auto& all_op_kernels = AllOpKernels();
  auto kernels_iter = all_op_kernels.find(type_);
  if (kernels_iter == all_op_kernels.end()) {
Y
Yu Yang 已提交
908 909
    PADDLE_THROW(
        "There are no kernels which are registered in the %s operator.", type_);
910 911
  }

Q
qiaolongfei 已提交
912 913
  OpKernelMap& kernels = kernels_iter->second;

X
Xin Pan 已提交
914 915
  auto expected_kernel_key = this->GetExpectedKernelType(
      ExecutionContext(*this, scope, *dev_ctx, ctx));
M
minqiyang 已提交
916
  VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
Q
qiaolongfei 已提交
917

918
  auto kernel_iter = kernels.find(expected_kernel_key);
919
#ifdef PADDLE_WITH_MKLDNN
P
Paweł Żelazko 已提交
920
  // workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
921 922
  if (kernel_iter == kernels.end() &&
      expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
M
minqiyang 已提交
923
    VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
924 925 926 927 928
    expected_kernel_key.library_type_ = LibraryType::kPlain;
    expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
    kernel_iter = kernels.find(expected_kernel_key);
  }
#endif
929 930 931 932 933
  if (kernel_iter == kernels.end()) {
    PADDLE_THROW("op %s does not have kernel for %s", type_,
                 KernelTypeToString(expected_kernel_key));
  }

Y
yuyang18 已提交
934 935 936
  // do data transformScope &transfer_scope;
  std::vector<std::string> transfered_inplace_vars;
  auto* transfer_scope =
X
Xin Pan 已提交
937
      PrepareData(scope, expected_kernel_key, &transfered_inplace_vars, &ctx);
938

Y
yuyang18 已提交
939 940 941 942 943 944
  // exec scope is the scope that kernel actually executed on.
  const Scope& exec_scope =
      (transfer_scope == nullptr ? scope : *transfer_scope);

  if (!(expected_kernel_key.place_ == dev_ctx->GetPlace())) {
    dev_ctx = pool.Get(expected_kernel_key.place_);
945
  }
Q
QI JUN 已提交
946

X
Xin Pan 已提交
947
  RuntimeInferShapeContext infer_shape_ctx(*this, exec_scope, ctx);
X
Xin Pan 已提交
948
  this->InferShape(&infer_shape_ctx);
X
clean  
Xin Pan 已提交
949 950
  // TODO(panyx0718): ExecutionContext should only depend on RuntimeContext
  // not Scope. Imperative mode only pass inputs and get outputs.
X
Xin Pan 已提交
951
  kernel_iter->second(ExecutionContext(*this, exec_scope, *dev_ctx, ctx));
D
dzhwinter 已提交
952

Y
yuyang18 已提交
953 954 955
  if (!transfered_inplace_vars.empty()) {
    // there is inplace variable has been transfered.
    TransferInplaceVarsBack(scope, transfered_inplace_vars, *transfer_scope);
956 957
  }

D
dzhwinter 已提交
958
  /*For profiling/benchmark only*/
D
dzhwinter 已提交
959
  if (FLAGS_benchmark) {
Y
yuyang18 已提交
960
    dev_ctx->Wait();
D
dzhwinter 已提交
961
  }
C
chengduoZH 已提交
962 963 964

  if (FLAGS_check_nan_inf) {
    for (auto& vname : OutputVars(true)) {
Y
yuyang18 已提交
965
      auto* var = exec_scope.FindVar(vname);
C
chengduoZH 已提交
966 967 968
      if (var == nullptr) continue;
      if (var->IsType<framework::LoDTensor>()) {
        CheckTensorNANOrInf(vname, var->Get<framework::LoDTensor>());
969 970
      } else if (var->IsType<framework::SelectedRows>()) {
        CheckTensorNANOrInf(vname, var->Get<framework::SelectedRows>().value());
C
chengduoZH 已提交
971 972 973
      }
    }
  }
Q
Qiao Longfei 已提交
974
}
X
Xin Pan 已提交
975

Y
yuyang18 已提交
976 977 978 979
void OperatorWithKernel::TransferInplaceVarsBack(
    const Scope& scope, const std::vector<std::string>& inplace_vars,
    const Scope& transfer_scope) const {
  for (auto& var_name : inplace_vars) {
M
minqiyang 已提交
980
    VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
C
chengduo 已提交
981 982
    auto* original_tensor =
        GetMutableLoDTensorOrSelectedRowsValueFromVar(scope.FindVar(var_name));
C
chengduo 已提交
983 984 985
    auto* var = transfer_scope.FindVar(var_name);
    PADDLE_ENFORCE(var != nullptr, "The var[%s] should not be nullptr",
                   var_name);
C
chengduo 已提交
986
    auto* transformed_tensor = GetLoDTensorOrSelectedRowsValueFromVar(*var);
Y
yuyang18 已提交
987 988 989 990
    original_tensor->ShareDataWith(*transformed_tensor);
  }
}

X
Xin Pan 已提交
991
Scope* OperatorWithKernel::PrepareData(
Y
yuyang18 已提交
992
    const Scope& scope, const OpKernelType& expected_kernel_key,
X
Xin Pan 已提交
993 994
    std::vector<std::string>* transfered_inplace_vars,
    RuntimeContext* ctx) const {
Y
yuyang18 已提交
995 996
  Scope* new_scope = nullptr;
  for (auto& var_name_item : Inputs()) {
X
Xin Pan 已提交
997 998 999 1000
    std::vector<Variable*>& input_vars = ctx->inputs[var_name_item.first];

    for (size_t i = 0; i < var_name_item.second.size(); ++i) {
      auto& var_name = var_name_item.second[i];
X
Xin Pan 已提交
1001
      auto* var = input_vars[i];
X
Xin Pan 已提交
1002

Y
yuyang18 已提交
1003
      // Only tensor can be tranfer to another device.
C
chengduo 已提交
1004
      if (var == nullptr || !VarIsTensor(*var)) {
Y
yuyang18 已提交
1005 1006 1007
        continue;
      }

C
chengduo 已提交
1008
      auto* tensor_in = GetLoDTensorOrSelectedRowsValueFromVar(*var);
Y
yuyang18 已提交
1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025
      if (!tensor_in->IsInitialized()) {
        continue;
      }

      auto kernel_type_for_var = GetKernelTypeForVar(
          var_name_item.first, *tensor_in, expected_kernel_key);

      if (!NeedTransform(kernel_type_for_var, expected_kernel_key)) {
        continue;
      }

      auto out_var_names = OutputVars(true);
      if (std::find(out_var_names.begin(), out_var_names.end(), var_name) !=
          out_var_names.end()) {
        transfered_inplace_vars->emplace_back(var_name);
      }

M
minqiyang 已提交
1026 1027
      VLOG(3) << "Transform Variable " << var_name << " from "
              << kernel_type_for_var << " to " << expected_kernel_key;
Y
yuyang18 已提交
1028

1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043
      // In the inference scenerio, the scopes will be reused across the
      // batches, so the `new_scope` here will result in GPU memroy explosion
      // over the  running of operators.
      // We use a thread_local cache to fix that issue, the key in the cache is
      // the combination of the `scope` argument, from_kernel_type,
      // target_kernel_type.
      // Have a discussion with @Superjomn or the inference developers if some
      // changes on this logic for this macro might not tested on the other
      // scenerios.
      // If this op is not called by an Executor or ParallelExecutor, it should
      // called by a NaiveExecutor, the NaiveExecutor will cache the scopes and
      // variables, that behavior a lot different.
      if (!run_by_executor_) {
        new_scope = TryCreateTransferScope(kernel_type_for_var,
                                           expected_kernel_key, &scope);
1044
      }
1045
      if (!new_scope) {
Y
yuyang18 已提交
1046 1047 1048 1049
        new_scope = &scope.NewScope();
      }

      auto* trans_var = new_scope->Var(var_name);
X
fix  
Xin Pan 已提交
1050
      input_vars[i] = trans_var;
1051

Y
yuyang18 已提交
1052
      Tensor out;
Y
yuyang18 已提交
1053
      TransformData(expected_kernel_key, kernel_type_for_var, *tensor_in, &out);
Y
yuyang18 已提交
1054 1055 1056 1057 1058 1059
      SetTensorToVariable(*var, out, trans_var);
    }
  }

  return new_scope;
}
Q
Qiao Longfei 已提交
1060

1061
proto::VarType::Type OperatorWithKernel::IndicateDataType(
Y
Yu Yang 已提交
1062 1063 1064
    const ExecutionContext& ctx) const {
  int data_type = -1;
  for (auto& input : this->inputs_) {
X
Xin Pan 已提交
1065 1066 1067
    const std::vector<const Variable*> vars = ctx.MultiInputVar(input.first);
    for (size_t i = 0; i < vars.size(); ++i) {
      const Variable* var = vars[i];
Y
Yu Yang 已提交
1068 1069 1070 1071 1072 1073 1074 1075 1076 1077
      if (var != nullptr) {
        const Tensor* t = nullptr;
        if (var->IsType<Tensor>()) {
          t = &var->Get<Tensor>();
        } else if (var->IsType<LoDTensor>()) {
          t = &var->Get<LoDTensor>();
        } else if (var->IsType<SelectedRows>()) {
          t = &(var->Get<SelectedRows>().value());
        }
        if (t != nullptr) {
X
Xin Pan 已提交
1078 1079
          PADDLE_ENFORCE(t->IsInitialized(), "Input %s(%lu)is not initialized",
                         input.first, i);
Y
Yu Yang 已提交
1080
          int tmp = static_cast<int>(t->type());
1081 1082
          PADDLE_ENFORCE(
              tmp == data_type || data_type == -1,
X
Xin Pan 已提交
1083 1084
              "DataType of Paddle Op %s must be the same. Get (%d) != (%d)",
              Type(), data_type, tmp);
Y
Yu Yang 已提交
1085 1086 1087 1088 1089 1090
          data_type = tmp;
        }
      }
    }
  }
  PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
1091
  return static_cast<proto::VarType::Type>(data_type);
Y
Yu Yang 已提交
1092
}
1093

1094 1095 1096 1097 1098 1099 1100 1101
OpKernelType OperatorWithKernel::GetExpectedKernelType(
    const ExecutionContext& ctx) const {
  return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetKernelTypeForVar(
    const std::string& var_name, const Tensor& tensor,
    const OpKernelType& expected_kernel_type) const {
M
mozga-intel 已提交
1102 1103
  return OpKernelType(expected_kernel_type.data_type_, tensor.place(),
                      tensor.layout());
1104 1105
}

Q
Qiao Longfei 已提交
1106
}  // namespace framework
L
liaogang 已提交
1107
}  // namespace paddle