operator.cc 20.4 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#include <gflags/gflags.h>
D
dzhwinter 已提交
15
#include <glog/logging.h>
Q
Qiao Longfei 已提交
16

17
#include <algorithm>
D
dzhwinter 已提交
18

Y
Yi Wang 已提交
19 20 21 22 23
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
24
#include "paddle/fluid/platform/profiler.h"
Q
Qiao Longfei 已提交
25

D
dzhwinter 已提交
26
DECLARE_bool(benchmark);
D
dzhwinter 已提交
27

Q
Qiao Longfei 已提交
28 29 30
namespace paddle {
namespace framework {

31 32 33 34 35 36
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN),
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain),
    std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN),
    std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
D
dzhwinter 已提交
37

Q
qiaolongfei 已提交
38 39 40 41 42 43 44 45 46 47 48
proto::VarType::Type GetDataTypeOfVar(const Variable* var) {
  if (var->IsType<framework::LoDTensor>()) {
    return framework::ToDataType(var->Get<framework::LoDTensor>().type());
  } else if (var->IsType<framework::SelectedRows>()) {
    return framework::ToDataType(
        var->Get<framework::SelectedRows>().value().type());
  } else {
    PADDLE_THROW("Var should be LoDTensor or SelectedRows");
  }
}

49 50
static DDim GetDims(const Scope& scope, const std::string& name,
                    bool get_actual_dim = false) {
51
  Variable* var = scope.FindVar(name);
Q
qiaolongfei 已提交
52 53
  if (var == nullptr) {
    return DDim({-1});
Q
Qiao Longfei 已提交
54 55 56
  }

  if (var->IsType<LoDTensor>()) {
57 58
    return var->Get<LoDTensor>().dims();
  } else if (var->IsType<SelectedRows>()) {
59 60 61 62 63
    if (get_actual_dim) {
      return var->Get<SelectedRows>().value().dims();
    } else {
      return var->Get<SelectedRows>().GetCompleteDims();
    }
64 65 66 67 68
  } else {
    return DDim({-1});
  }
}

Q
Qiao Longfei 已提交
69 70 71 72 73 74 75 76 77 78 79 80 81 82 83
static LoD GetLoD(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  auto default_lod = LoD({{}});

  if (var == nullptr) {
    return default_lod;
  }

  if (var->IsType<LoDTensor>()) {
    return var->Get<LoDTensor>().lod();
  } else {
    return default_lod;
  }
}

84 85 86 87 88 89 90 91 92 93 94 95
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
  if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
    PADDLE_THROW("Cannot run operator on place %s", place);
#else
    auto dev_id = boost::get<platform::CUDAPlace>(place).device;
    platform::SetDeviceId(dev_id);
#endif
  }
  RunImpl(scope, place);
}

96 97 98 99 100 101 102 103
bool OperatorBase::HasInputs(const std::string& name) const {
  if (inputs_.find(name) != inputs_.end()) {
    return true;
  } else {
    return false;
  }
}

104
std::string OperatorBase::Input(const std::string& name) const {
Y
Yu Yang 已提交
105
  auto& ins = Inputs(name);
Y
Yu Yang 已提交
106
  PADDLE_ENFORCE_LE(ins.size(), 1UL,
107 108
                    "Operator %s's input %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
109
  return ins.empty() ? kEmptyVarName : ins[0];
Y
Yan Chunwei 已提交
110 111
}

Y
Yu Yang 已提交
112 113
const std::vector<std::string>& OperatorBase::Inputs(
    const std::string& name) const {
Y
Yu Yang 已提交
114
  auto it = inputs_.find(name);
115 116
  PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.",
                 type_, name);
Y
Yu Yang 已提交
117
  return it->second;
Y
Yan Chunwei 已提交
118 119
}

120 121 122 123 124 125 126 127
bool OperatorBase::HasOutputs(const std::string& name) const {
  if (outputs_.find(name) != outputs_.end()) {
    return true;
  } else {
    return false;
  }
}

128
std::string OperatorBase::Output(const std::string& name) const {
Y
Yu Yang 已提交
129
  auto& outs = Outputs(name);
Y
Yu Yang 已提交
130
  PADDLE_ENFORCE_LE(outs.size(), 1UL,
131 132
                    "Operator %s's output %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
133
  return outs.empty() ? kEmptyVarName : outs[0];
Y
Yan Chunwei 已提交
134 135
}

Y
Yu Yang 已提交
136 137
const std::vector<std::string>& OperatorBase::Outputs(
    const std::string& name) const {
Y
Yu Yang 已提交
138
  auto it = outputs_.find(name);
139 140
  PADDLE_ENFORCE(it != outputs_.end(),
                 "Operator %s does not have an output called %s.", type_, name);
Y
Yu Yang 已提交
141
  return it->second;
Y
Yan Chunwei 已提交
142 143
}

144
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
Q
Qiao Longfei 已提交
145
  std::stringstream ss;
Y
Yu Yang 已提交
146
  ss << "Op(" << type_ << "), inputs:{";
Y
Yu Yang 已提交
147 148
  for (auto it = inputs_.begin(); it != inputs_.end();) {
    auto& input = *it;
Y
Yu Yang 已提交
149 150 151
    ss << input.first << "[";
    for (size_t i = 0; i < input.second.size(); ++i) {
      ss << input.second[i];
152
      if (scope) {
153
        ss << "[" << GetDims(*scope, input.second[i], true) << "]";
Q
Qiao Longfei 已提交
154
        ss << "(" << GetLoD(*scope, input.second[i]) << ")";
155
      }
Y
Yu Yang 已提交
156 157 158
      if (i != input.second.size() - 1) {
        ss << ", ";
      }
159
    }
Y
Yu Yang 已提交
160
    ss << "]";
Y
Yu Yang 已提交
161 162
    ++it;
    if (it != inputs_.end()) {
163 164
      ss << ", ";
    }
Q
Qiao Longfei 已提交
165
  }
Y
Yu Yang 已提交
166
  ss << "}, outputs:{";
Y
Yu Yang 已提交
167 168
  for (auto it = outputs_.begin(); it != outputs_.end();) {
    auto& output = *it;
Y
Yu Yang 已提交
169 170 171
    ss << output.first << "[";
    for (size_t i = 0; i < output.second.size(); ++i) {
      ss << output.second[i];
172
      if (scope) {
173
        ss << "[" << GetDims(*scope, output.second[i], true) << "]";
Q
Qiao Longfei 已提交
174
        ss << "(" << GetLoD(*scope, output.second[i]) << ")";
175
      }
Y
Yu Yang 已提交
176 177 178
      if (i != output.second.size() - 1) {
        ss << ", ";
      }
179
    }
Y
Yu Yang 已提交
180
    ss << "]";
Y
Yu Yang 已提交
181 182
    ++it;
    if (it != outputs_.end()) {
183 184
      ss << ", ";
    }
Q
Qiao Longfei 已提交
185
  }
Y
Yu Yang 已提交
186
  ss << "}.";
Q
Qiao Longfei 已提交
187 188 189
  return ss.str();
}

Y
Yu Yang 已提交
190
OperatorBase::OperatorBase(const std::string& type,
Y
Yu Yang 已提交
191 192
                           const VariableNameMap& inputs,
                           const VariableNameMap& outputs,
Y
Yu Yang 已提交
193 194
                           const AttributeMap& attrs)
    : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
195 196
  GenerateTemporaryNames();
  CheckAllInputOutputSet();
Y
Yu Yang 已提交
197
}
198

Q
qijun 已提交
199 200
std::vector<std::string> OperatorBase::InputVars() const {
  std::vector<std::string> ret_val;
Y
Yu Yang 已提交
201
  for (auto& o : inputs_) {
Q
qijun 已提交
202 203 204 205 206 207
    ret_val.reserve(ret_val.size() + o.second.size());
    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
  }
  return ret_val;
}

Y
Yu Yang 已提交
208 209 210 211 212 213 214 215 216 217
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
  std::vector<std::string> ret_val;
  if (has_intermediate) {
    // push all outputs into ret_val
    for (auto& o : outputs_) {
      ret_val.reserve(ret_val.size() + o.second.size());
      ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
    }
    return ret_val;
  }
Y
Yu Yang 已提交
218
  auto& info = OpInfoMap::Instance().Get(Type());
Y
Yu Yang 已提交
219 220

  // get all OpProto::Var for outputs
Y
Yu Yang 已提交
221
  for (auto& o : info.Proto().outputs()) {
Y
Yu Yang 已提交
222 223 224 225 226 227 228 229 230
    // ignore all intermediate output
    if (o.intermediate()) continue;
    auto out = outputs_.find(o.name());
    if (out != outputs_.end()) {
      ret_val.reserve(ret_val.size() + out->second.size());
      ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
    }
  }
  return ret_val;
D
dongzhihong 已提交
231 232
}

233 234 235
void OperatorBase::CheckAllInputOutputSet() const {
  auto& info_map = OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(Type());
Y
Yu Yang 已提交
236
  if (op_info == nullptr || op_info->proto_ == nullptr) return;
237 238

  for (auto& in : op_info->Proto().inputs()) {
239 240 241 242
    if (!in.dispensable()) {
      PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
                     "Operator %s's input, %s, is not set", Type(), in.name());
    }
243 244 245
  }

  for (auto& out : op_info->Proto().outputs()) {
246 247 248 249 250
    if (!out.dispensable()) {
      PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
                     "Operator %s's output, %s, is not set", Type(),
                     out.name());
    }
251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266
  }
}

void OperatorBase::GenerateTemporaryNames() {
  static std::atomic<size_t> gUniqId(0UL);
  for (auto& output : outputs_) {
    for (auto& output_name : output.second) {
      if (output_name == kTempVarName) {
        output_name += type_;
        output_name += "@";
        output_name += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
}

267 268 269 270
static bool VarIsTensor(const Variable* var) {
  return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}

271
static const Tensor* GetTensorFromVar(Variable* var) {
Q
QI JUN 已提交
272
  if (var->IsType<LoDTensor>()) {
273
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
274
  } else if (var->IsType<SelectedRows>()) {
275
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
276
  } else {
Y
Yang Yang 已提交
277 278
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
                 var->Type().name());
Q
QI JUN 已提交
279 280 281 282 283
  }
}

static Tensor* GetMutableTensorFromVar(Variable* var) {
  if (var->IsType<LoDTensor>()) {
284
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
285
  } else if (var->IsType<SelectedRows>()) {
286
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
287
  } else {
Y
Yang Yang 已提交
288 289
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
                 var->Type().name());
Q
QI JUN 已提交
290 291 292
  }
}

293
template <>
294
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
295
  auto* var = InputVar(name);
296 297
  return var == nullptr ? nullptr
                        : GetTensorFromVar(const_cast<Variable*>(var));
298 299 300
}

template <>
301
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
302 303 304 305
    const std::string& name) const {
  auto names = op().Inputs(name);
  std::vector<const Tensor*> res;
  res.reserve(names.size());
306 307 308 309 310
  std::transform(names.begin(), names.end(), std::back_inserter(res),
                 [&](const std::string& sub_name) {
                   auto var = scope_.FindVar(sub_name);
                   return var == nullptr ? nullptr : GetTensorFromVar(var);
                 });
311 312 313 314
  return res;
}

template <>
315
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
316
  auto var = OutputVar(name);
Q
QI JUN 已提交
317
  return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
318 319 320
}

template <>
321
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
322 323 324 325
    const std::string& name) const {
  auto names = op().Outputs(name);
  std::vector<Tensor*> res;
  res.reserve(names.size());
326 327
  std::transform(names.begin(), names.end(), std::back_inserter(res),
                 [&](const std::string& sub_name) {
328 329
                   auto var = scope_.FindVar(sub_name);
                   return var == nullptr ? nullptr
Q
QI JUN 已提交
330
                                         : GetMutableTensorFromVar(var);
331
                 });
332 333 334
  return res;
}

Y
Yu Yang 已提交
335 336 337 338 339 340 341 342 343 344 345 346 347 348 349
bool OpSupportGPU(const std::string& op_type) {
  auto& all_kernels = OperatorWithKernel::AllOpKernels();
  auto it = all_kernels.find(op_type);
  if (it == all_kernels.end()) {
    // All control operator must support GPU
    return true;
  }
  for (auto& kern_pair : it->second) {
    if (platform::is_gpu_place(kern_pair.first.place_)) {
      return true;
    }
  }
  return false;
}

350 351 352 353 354 355
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

  bool HasInput(const std::string& name) const override {
356 357 358
    if (!op_.HasInputs(name)) {
      return false;
    }
359 360 361 362 363
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
F
fengjiayi 已提交
364 365
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input %s should not have more than one inputs", name);
366 367 368 369 370 371
    auto ipt = ins[0];
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

  bool HasOutput(const std::string& name) const override {
372 373 374
    if (!op_.HasOutputs(name)) {
      return false;
    }
375 376 377 378 379
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
F
fengjiayi 已提交
380 381
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output %s should not have more than one inputs", name);
382 383 384 385 386 387
    auto ipt = outs[0];
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

  bool HasInputs(const std::string& name) const override {
388 389 390
    if (!op_.HasInputs(name)) {
      return false;
    }
391 392 393 394 395 396 397 398 399 400 401 402 403
    auto inputs = op_.Inputs(name);
    if (inputs.empty()) {
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

  bool HasOutputs(const std::string& name) const override {
404 405 406
    if (!op_.HasOutputs(name)) {
      return false;
    }
407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430
    auto outputs = op_.Outputs(name);
    if (outputs.empty()) {
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
    return op_.Inputs(name);
  }

  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
    return op_.Outputs(name);
  }

Q
Qiao Longfei 已提交
431 432 433 434 435 436 437 438 439 440 441 442
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    Variable* in_var = scope_.FindVar(Inputs(in)[i]);
    Variable* out_var = scope_.FindVar(Outputs(out)[j]);
    if (!in_var->IsType<LoDTensor>()) return;
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
D
dzhwinter 已提交
443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461

    // TODO(dzhwinter) : reuse ShareLoD in most operators.
    // Need to call ShareLayout explicitly in sequence related ops.
    // Shall we have a better method to shared info between in/out Tensor?
    out_tensor->set_layout(in_tensor.layout());
  }

  void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
                   size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    Variable* in_var = scope_.FindVar(Inputs(in)[i]);
    Variable* out_var = scope_.FindVar(Outputs(out)[j]);
    if (!in_var->IsType<LoDTensor>()) return;
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_layout(in_tensor.layout());
Q
Qiao Longfei 已提交
462 463
  }

464 465 466
  bool IsRuntime() const override { return true; }

 protected:
467 468 469 470 471 472 473
  DDim GetDim(const std::string& name) const override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
    } else {
F
fengjiayi 已提交
474 475 476 477 478 479 480
      PADDLE_THROW(
          "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
          "type_id is %s.",
          name, var->Type().name());
    }
  }

F
fengjiayi 已提交
481
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Y
Yu Yang 已提交
482
    PADDLE_THROW("Only compile time support this method");
483 484 485 486 487 488 489 490 491
  }

  void SetDim(const std::string& name, const DDim& dim) override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
Y
Yang Yang 已提交
492 493
      PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
                   name, var->Type().name());
494 495 496
    }
  }

F
fengjiayi 已提交
497 498
  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
Y
Yu Yang 已提交
499
    PADDLE_THROW("Only compile time support this method");
F
fengjiayi 已提交
500 501
  }

502
  proto::VarType::Type GetVarType(const std::string& name) const override {
503 504 505 506
    auto* var = scope_.FindVar(name);
    return ToVarType(var->Type());
  }

F
fengjiayi 已提交
507 508 509 510
  InferShapeVarPtr GetVarPtr(const std::string& name) override {
    return scope_.FindVar(name);
  }

511
 private:
512 513 514 515
  const OperatorBase& op_;
  const Scope& scope_;
};

516 517
void OperatorWithKernel::RunImpl(const Scope& scope,
                                 const platform::Place& place) const {
518 519
  RuntimeInferShapeContext infer_shape_ctx(*this, scope);
  this->InferShape(&infer_shape_ctx);
Y
Yu Yang 已提交
520
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
521
  auto* dev_ctx = pool.Get(place);
522 523 524 525

  // For profiling, don't move out of this function because that will result
  // in the failure of multi-GPU profiling.
  platform::RecordEvent record_event(Type(), dev_ctx);
526 527 528 529
  // check if op[type] has kernel registered.
  auto& all_op_kernels = AllOpKernels();
  auto kernels_iter = all_op_kernels.find(type_);
  if (kernels_iter == all_op_kernels.end()) {
Y
Yu Yang 已提交
530 531
    PADDLE_THROW(
        "There are no kernels which are registered in the %s operator.", type_);
532 533
  }

D
dzhwinter 已提交
534
  ExecutionContext ctx(*this, scope, *dev_ctx);
535

Q
qiaolongfei 已提交
536 537
  OpKernelMap& kernels = kernels_iter->second;

538 539
  // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
  // transform functions are ready.
Q
qiaolongfei 已提交
540

541 542 543 544 545
  // for (auto& candidate : kKernelPriority) {
  //   Do selection
  // }

  auto expected_kernel_key = this->GetExpectedKernelType(ctx);
Q
qiaolongfei 已提交
546 547
  VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

548 549 550 551 552 553 554
  auto kernel_iter = kernels.find(expected_kernel_key);
  if (kernel_iter == kernels.end()) {
    PADDLE_THROW("op %s does not have kernel for %s", type_,
                 KernelTypeToString(expected_kernel_key));
  }

  // do data transform
555 556
  Scope& new_scope = scope.NewScope();

557
  std::vector<std::string> inplace_vars;
558 559 560 561 562 563 564 565
  for (auto& var_name_item : this->Inputs()) {
    for (auto& var_name : var_name_item.second) {
      auto* var = scope.FindVar(var_name);
      if (var && VarIsTensor(var)) {
        auto* tensor_in = GetTensorFromVar(var);
        if (tensor_in->IsInitialized()) {
          auto kernel_type_for_var = this->GetKernelTypeForVar(
              var_name_item.first, *tensor_in, expected_kernel_key);
566
          if (TransFromNeeded(kernel_type_for_var, expected_kernel_key)) {
567 568 569
            auto out_var_names = OutputVars(true);
            if (std::find(out_var_names.begin(), out_var_names.end(),
                          var_name) != out_var_names.end()) {
570
              inplace_vars.push_back(var_name);
571
            }
572 573
            VLOG(3) << "Transform Variable " << var_name << " from "
                    << kernel_type_for_var << " to " << expected_kernel_key;
574
            auto* trans_var = new_scope.Var(var_name);
575 576 577
            std::shared_ptr<Tensor> out(new Tensor);
            DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
                          out.get());
578
            CopyVariableWithTensor(*var, *(out.get()), trans_var);
579
          }
Q
QI JUN 已提交
580 581
        }
      }
582 583
    }
  }
Q
QI JUN 已提交
584

D
dzhwinter 已提交
585 586 587 588
  auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
  kernel_iter->second->Compute(
      ExecutionContext(*this, new_scope, *new_dev_ctx));

589 590 591 592 593 594 595
  for (auto& var_name : inplace_vars) {
    VLOG(3) << "share inplace var " + var_name + " back to it's original scope";
    auto* original_tensor = GetMutableTensorFromVar(scope.FindVar(var_name));
    auto* transformed_tensor = GetTensorFromVar(new_scope.FindVar(var_name));
    original_tensor->ShareDataWith(*transformed_tensor);
  }

D
dzhwinter 已提交
596
  /*For profiling/benchmark only*/
D
dzhwinter 已提交
597
  if (FLAGS_benchmark) {
D
dzhwinter 已提交
598 599
    new_dev_ctx->Wait();
  }
Q
Qiao Longfei 已提交
600 601
}

602
proto::VarType::Type OperatorWithKernel::IndicateDataType(
Y
Yu Yang 已提交
603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627
    const ExecutionContext& ctx) const {
  auto& scope = ctx.scope();
  int data_type = -1;
  for (auto& input : this->inputs_) {
    for (auto& ipt_name : input.second) {
      auto* var = scope.FindVar(ipt_name);
      if (var != nullptr) {
        const Tensor* t = nullptr;
        if (var->IsType<Tensor>()) {
          t = &var->Get<Tensor>();
        } else if (var->IsType<LoDTensor>()) {
          t = &var->Get<LoDTensor>();
        } else if (var->IsType<SelectedRows>()) {
          t = &(var->Get<SelectedRows>().value());
        }
        if (t != nullptr) {
          int tmp = static_cast<int>(ToDataType(t->type()));
          PADDLE_ENFORCE(tmp == data_type || data_type == -1,
                         "DataType of Paddle Op %s must be the same.", Type());
          data_type = tmp;
        }
      }
    }
  }
  PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
628
  return static_cast<proto::VarType::Type>(data_type);
Y
Yu Yang 已提交
629
}
630

631 632 633 634 635 636 637 638 639 640 641
OpKernelType OperatorWithKernel::GetExpectedKernelType(
    const ExecutionContext& ctx) const {
  return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetKernelTypeForVar(
    const std::string& var_name, const Tensor& tensor,
    const OpKernelType& expected_kernel_type) const {
  return OpKernelType(expected_kernel_type.data_type_, tensor.place());
}

Q
Qiao Longfei 已提交
642
}  // namespace framework
L
liaogang 已提交
643
}  // namespace paddle