operator.cc 19.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
D
dzhwinter 已提交
14
#include <gflags/gflags.h>
D
dzhwinter 已提交
15
#include <glog/logging.h>
Q
Qiao Longfei 已提交
16

17
#include <algorithm>
D
dzhwinter 已提交
18

Y
Yi Wang 已提交
19 20 21 22 23
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/shape_inference.h"
#include "paddle/fluid/framework/var_type.h"
24
#include "paddle/fluid/platform/profiler.h"
Q
Qiao Longfei 已提交
25

D
dzhwinter 已提交
26
DECLARE_bool(benchmark);
D
dzhwinter 已提交
27

Q
Qiao Longfei 已提交
28 29 30
namespace paddle {
namespace framework {

31 32 33 34 35 36
std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority = {
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kCUDNN),
    std::make_tuple(platform::CUDAPlace(0), LibraryType::kPlain),
    std::make_tuple(platform::CPUPlace(), LibraryType::kMKLDNN),
    std::make_tuple(platform::CPUPlace(), LibraryType::kPlain),
};
D
dzhwinter 已提交
37

38 39
static DDim GetDims(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
Q
qiaolongfei 已提交
40 41
  if (var == nullptr) {
    return DDim({-1});
Q
Qiao Longfei 已提交
42 43 44
  }

  if (var->IsType<LoDTensor>()) {
45 46 47 48 49 50 51 52
    return var->Get<LoDTensor>().dims();
  } else if (var->IsType<SelectedRows>()) {
    return var->Get<SelectedRows>().GetCompleteDims();
  } else {
    return DDim({-1});
  }
}

Q
Qiao Longfei 已提交
53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
static LoD GetLoD(const Scope& scope, const std::string& name) {
  Variable* var = scope.FindVar(name);
  auto default_lod = LoD({{}});

  if (var == nullptr) {
    return default_lod;
  }

  if (var->IsType<LoDTensor>()) {
    return var->Get<LoDTensor>().lod();
  } else {
    return default_lod;
  }
}

68 69 70 71 72 73 74 75 76
void OperatorBase::Run(const Scope& scope, const platform::Place& place) {
  if (platform::is_gpu_place(place)) {
#ifndef PADDLE_WITH_CUDA
    PADDLE_THROW("Cannot run operator on place %s", place);
#else
    auto dev_id = boost::get<platform::CUDAPlace>(place).device;
    platform::SetDeviceId(dev_id);
#endif
  }
77 78 79
  // profile
  auto* dev_ctx = platform::DeviceContextPool::Instance().Get(place);
  platform::RecordEvent record_event(Type(), dev_ctx);
80 81 82
  RunImpl(scope, place);
}

83
std::string OperatorBase::Input(const std::string& name) const {
Y
Yu Yang 已提交
84
  auto& ins = Inputs(name);
Y
Yu Yang 已提交
85
  PADDLE_ENFORCE_LE(ins.size(), 1UL,
86 87
                    "Operator %s's input %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
88
  return ins.empty() ? kEmptyVarName : ins[0];
Y
Yan Chunwei 已提交
89 90
}

Y
Yu Yang 已提交
91 92
const std::vector<std::string>& OperatorBase::Inputs(
    const std::string& name) const {
Y
Yu Yang 已提交
93
  auto it = inputs_.find(name);
94 95
  PADDLE_ENFORCE(it != inputs_.end(), "Operator %s does not have the input %s.",
                 type_, name);
Y
Yu Yang 已提交
96
  return it->second;
Y
Yan Chunwei 已提交
97 98
}

99
std::string OperatorBase::Output(const std::string& name) const {
Y
Yu Yang 已提交
100
  auto& outs = Outputs(name);
Y
Yu Yang 已提交
101
  PADDLE_ENFORCE_LE(outs.size(), 1UL,
102 103
                    "Operator %s's output %s should contain only one variable.",
                    type_, name);
Y
Yu Yang 已提交
104
  return outs.empty() ? kEmptyVarName : outs[0];
Y
Yan Chunwei 已提交
105 106
}

Y
Yu Yang 已提交
107 108
const std::vector<std::string>& OperatorBase::Outputs(
    const std::string& name) const {
Y
Yu Yang 已提交
109
  auto it = outputs_.find(name);
110 111
  PADDLE_ENFORCE(it != outputs_.end(),
                 "Operator %s does not have an output called %s.", type_, name);
Y
Yu Yang 已提交
112
  return it->second;
Y
Yan Chunwei 已提交
113 114
}

115
std::string OperatorBase::DebugStringEx(const Scope* scope) const {
Q
Qiao Longfei 已提交
116
  std::stringstream ss;
Y
Yu Yang 已提交
117
  ss << "Op(" << type_ << "), inputs:{";
Y
Yu Yang 已提交
118 119
  for (auto it = inputs_.begin(); it != inputs_.end();) {
    auto& input = *it;
Y
Yu Yang 已提交
120 121 122
    ss << input.first << "[";
    for (size_t i = 0; i < input.second.size(); ++i) {
      ss << input.second[i];
123
      if (scope) {
Q
Qiao Longfei 已提交
124 125
        ss << "[" << GetDims(*scope, input.second[i]) << "]";
        ss << "(" << GetLoD(*scope, input.second[i]) << ")";
126
      }
Y
Yu Yang 已提交
127 128 129
      if (i != input.second.size() - 1) {
        ss << ", ";
      }
130
    }
Y
Yu Yang 已提交
131
    ss << "]";
Y
Yu Yang 已提交
132 133
    ++it;
    if (it != inputs_.end()) {
134 135
      ss << ", ";
    }
Q
Qiao Longfei 已提交
136
  }
Y
Yu Yang 已提交
137
  ss << "}, outputs:{";
Y
Yu Yang 已提交
138 139
  for (auto it = outputs_.begin(); it != outputs_.end();) {
    auto& output = *it;
Y
Yu Yang 已提交
140 141 142
    ss << output.first << "[";
    for (size_t i = 0; i < output.second.size(); ++i) {
      ss << output.second[i];
143
      if (scope) {
Q
Qiao Longfei 已提交
144 145
        ss << "[" << GetDims(*scope, output.second[i]) << "]";
        ss << "(" << GetLoD(*scope, output.second[i]) << ")";
146
      }
Y
Yu Yang 已提交
147 148 149
      if (i != output.second.size() - 1) {
        ss << ", ";
      }
150
    }
Y
Yu Yang 已提交
151
    ss << "]";
Y
Yu Yang 已提交
152 153
    ++it;
    if (it != outputs_.end()) {
154 155
      ss << ", ";
    }
Q
Qiao Longfei 已提交
156
  }
Y
Yu Yang 已提交
157
  ss << "}.";
Q
Qiao Longfei 已提交
158 159 160
  return ss.str();
}

D
dongzhihong 已提交
161 162
void OperatorBase::Rename(const std::string& old_name,
                          const std::string& new_name) {
Y
Yu Yang 已提交
163 164 165 166 167 168 169
  for (auto& input : inputs_) {
    std::replace(input.second.begin(), input.second.end(), old_name, new_name);
  }
  for (auto& output : outputs_) {
    std::replace(output.second.begin(), output.second.end(), old_name,
                 new_name);
  }
D
dongzhihong 已提交
170 171
}

Y
Yu Yang 已提交
172
OperatorBase::OperatorBase(const std::string& type,
Y
Yu Yang 已提交
173 174
                           const VariableNameMap& inputs,
                           const VariableNameMap& outputs,
Y
Yu Yang 已提交
175 176
                           const AttributeMap& attrs)
    : type_(type), inputs_(inputs), outputs_(outputs), attrs_(attrs) {
177 178
  GenerateTemporaryNames();
  CheckAllInputOutputSet();
Y
Yu Yang 已提交
179
}
180

Q
qijun 已提交
181 182
std::vector<std::string> OperatorBase::InputVars() const {
  std::vector<std::string> ret_val;
Y
Yu Yang 已提交
183
  for (auto& o : inputs_) {
Q
qijun 已提交
184 185 186 187 188 189
    ret_val.reserve(ret_val.size() + o.second.size());
    ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
  }
  return ret_val;
}

Y
Yu Yang 已提交
190 191 192 193 194 195 196 197 198 199
std::vector<std::string> OperatorBase::OutputVars(bool has_intermediate) const {
  std::vector<std::string> ret_val;
  if (has_intermediate) {
    // push all outputs into ret_val
    for (auto& o : outputs_) {
      ret_val.reserve(ret_val.size() + o.second.size());
      ret_val.insert(ret_val.end(), o.second.begin(), o.second.end());
    }
    return ret_val;
  }
Y
Yu Yang 已提交
200
  auto& info = OpInfoMap::Instance().Get(Type());
Y
Yu Yang 已提交
201 202

  // get all OpProto::Var for outputs
Y
Yu Yang 已提交
203
  for (auto& o : info.Proto().outputs()) {
Y
Yu Yang 已提交
204 205 206 207 208 209 210 211 212
    // ignore all intermediate output
    if (o.intermediate()) continue;
    auto out = outputs_.find(o.name());
    if (out != outputs_.end()) {
      ret_val.reserve(ret_val.size() + out->second.size());
      ret_val.insert(ret_val.end(), out->second.begin(), out->second.end());
    }
  }
  return ret_val;
D
dongzhihong 已提交
213 214
}

215 216 217
void OperatorBase::CheckAllInputOutputSet() const {
  auto& info_map = OpInfoMap::Instance();
  auto* op_info = info_map.GetNullable(Type());
Y
Yu Yang 已提交
218
  if (op_info == nullptr || op_info->proto_ == nullptr) return;
219 220 221

  for (auto& in : op_info->Proto().inputs()) {
    PADDLE_ENFORCE(inputs_.find(in.name()) != inputs_.end(),
Y
Yu Yang 已提交
222
                   "Type %s's input %s is not set", Type(), in.name());
223 224 225 226
  }

  for (auto& out : op_info->Proto().outputs()) {
    PADDLE_ENFORCE(outputs_.find(out.name()) != outputs_.end(),
Y
Yu Yang 已提交
227
                   "Type %s's output %s is not set", Type(), out.name());
228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243
  }
}

void OperatorBase::GenerateTemporaryNames() {
  static std::atomic<size_t> gUniqId(0UL);
  for (auto& output : outputs_) {
    for (auto& output_name : output.second) {
      if (output_name == kTempVarName) {
        output_name += type_;
        output_name += "@";
        output_name += std::to_string(gUniqId.fetch_add(1));
      }
    }
  }
}

244 245 246 247
static bool VarIsTensor(const Variable* var) {
  return var->IsType<LoDTensor>() || var->IsType<SelectedRows>();
}

248
static const Tensor* GetTensorFromVar(Variable* var) {
Q
QI JUN 已提交
249
  if (var->IsType<LoDTensor>()) {
250
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
251
  } else if (var->IsType<SelectedRows>()) {
252
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
253
  } else {
Y
Yang Yang 已提交
254 255
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
                 var->Type().name());
Q
QI JUN 已提交
256 257 258 259 260
  }
}

static Tensor* GetMutableTensorFromVar(Variable* var) {
  if (var->IsType<LoDTensor>()) {
261
    return var->GetMutable<LoDTensor>();
Q
QI JUN 已提交
262
  } else if (var->IsType<SelectedRows>()) {
263
    return var->GetMutable<SelectedRows>()->mutable_value();
Q
QI JUN 已提交
264
  } else {
Y
Yang Yang 已提交
265 266
    PADDLE_THROW("Variable type_id %s, expect LoDTensor/SelectedRows.",
                 var->Type().name());
Q
QI JUN 已提交
267 268 269
  }
}

270
template <>
271
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const {
272
  auto* var = InputVar(name);
273 274
  return var == nullptr ? nullptr
                        : GetTensorFromVar(const_cast<Variable*>(var));
275 276 277
}

template <>
278
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
279 280 281 282
    const std::string& name) const {
  auto names = op().Inputs(name);
  std::vector<const Tensor*> res;
  res.reserve(names.size());
283 284 285 286 287
  std::transform(names.begin(), names.end(), std::back_inserter(res),
                 [&](const std::string& sub_name) {
                   auto var = scope_.FindVar(sub_name);
                   return var == nullptr ? nullptr : GetTensorFromVar(var);
                 });
288 289 290 291
  return res;
}

template <>
292
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const {
293
  auto var = OutputVar(name);
Q
QI JUN 已提交
294
  return var == nullptr ? nullptr : GetMutableTensorFromVar(var);
295 296 297
}

template <>
298
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
299 300 301 302
    const std::string& name) const {
  auto names = op().Outputs(name);
  std::vector<Tensor*> res;
  res.reserve(names.size());
303 304
  std::transform(names.begin(), names.end(), std::back_inserter(res),
                 [&](const std::string& sub_name) {
305 306
                   auto var = scope_.FindVar(sub_name);
                   return var == nullptr ? nullptr
Q
QI JUN 已提交
307
                                         : GetMutableTensorFromVar(var);
308
                 });
309 310 311
  return res;
}

Y
Yu Yang 已提交
312 313 314 315 316
bool OpSupportGPU(const std::string& op_type) {
  auto& all_kernels = OperatorWithKernel::AllOpKernels();
  auto it = all_kernels.find(op_type);
  if (it == all_kernels.end()) {
    // All control operator must support GPU
317

Y
Yu Yang 已提交
318 319 320 321 322 323 324 325 326 327
    return true;
  }
  for (auto& kern_pair : it->second) {
    if (platform::is_gpu_place(kern_pair.first.place_)) {
      return true;
    }
  }
  return false;
}

328 329 330 331 332 333 334 335 336 337 338
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}

  bool HasInput(const std::string& name) const override {
    auto& ins = Inputs(name);
    size_t length = ins.size();
    if (length == 0) {
      return false;
    }
F
fengjiayi 已提交
339 340
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Input %s should not have more than one inputs", name);
341 342 343 344 345 346 347 348 349 350 351
    auto ipt = ins[0];
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

  bool HasOutput(const std::string& name) const override {
    auto& outs = Outputs(name);
    size_t length = outs.size();
    if (length == 0) {
      return false;
    }
F
fengjiayi 已提交
352 353
    PADDLE_ENFORCE_EQ(length, 1UL,
                      "Output %s should not have more than one inputs", name);
354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396
    auto ipt = outs[0];
    auto* var = ipt == kEmptyVarName ? nullptr : scope_.FindVar(ipt);
    return var != nullptr;
  }

  bool HasInputs(const std::string& name) const override {
    auto inputs = op_.Inputs(name);
    if (inputs.empty()) {
      return false;
    }
    for (auto& input : inputs) {
      if (scope_.FindVar(input) == nullptr) {
        return false;
      }
    }
    return true;
  }

  bool HasOutputs(const std::string& name) const override {
    auto outputs = op_.Outputs(name);
    if (outputs.empty()) {
      return false;
    }
    for (auto& output : outputs) {
      if (scope_.FindVar(output) == nullptr) {
        return false;
      }
    }
    return true;
  }

  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }

  const std::vector<std::string>& Inputs(
      const std::string& name) const override {
    return op_.Inputs(name);
  }

  const std::vector<std::string>& Outputs(
      const std::string& name) const override {
    return op_.Outputs(name);
  }

Q
Qiao Longfei 已提交
397 398 399 400 401 402 403 404 405 406 407 408
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    Variable* in_var = scope_.FindVar(Inputs(in)[i]);
    Variable* out_var = scope_.FindVar(Outputs(out)[j]);
    if (!in_var->IsType<LoDTensor>()) return;
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
D
dzhwinter 已提交
409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427

    // TODO(dzhwinter) : reuse ShareLoD in most operators.
    // Need to call ShareLayout explicitly in sequence related ops.
    // Shall we have a better method to shared info between in/out Tensor?
    out_tensor->set_layout(in_tensor.layout());
  }

  void ShareLayout(const std::string& in, const std::string& out, size_t i = 0,
                   size_t j = 0) const {
    PADDLE_ENFORCE_LT(i, Inputs(in).size());
    PADDLE_ENFORCE_LT(j, Outputs(out).size());
    Variable* in_var = scope_.FindVar(Inputs(in)[i]);
    Variable* out_var = scope_.FindVar(Outputs(out)[j]);
    if (!in_var->IsType<LoDTensor>()) return;
    PADDLE_ENFORCE(out_var->IsType<LoDTensor>(),
                   "The %d-th output of Output(%s) must be LoDTensor.", j, out);
    auto in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_layout(in_tensor.layout());
Q
Qiao Longfei 已提交
428 429
  }

430 431 432
  bool IsRuntime() const override { return true; }

 protected:
433 434 435 436 437 438 439
  DDim GetDim(const std::string& name) const override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
    } else {
F
fengjiayi 已提交
440 441 442 443 444 445 446
      PADDLE_THROW(
          "Only LoDTensor/SelectedRows support 'GetDim', but Variable %s's "
          "type_id is %s.",
          name, var->Type().name());
    }
  }

F
fengjiayi 已提交
447
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
Y
Yu Yang 已提交
448
    PADDLE_THROW("Only compile time support this method");
449 450 451 452 453 454 455 456 457
  }

  void SetDim(const std::string& name, const DDim& dim) override {
    Variable* var = scope_.FindVar(name);
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
Y
Yang Yang 已提交
458 459
      PADDLE_THROW("Variable %s type_id %s, expect LoDTensor/SelectedRows.",
                   name, var->Type().name());
460 461 462
    }
  }

F
fengjiayi 已提交
463 464
  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
Y
Yu Yang 已提交
465
    PADDLE_THROW("Only compile time support this method");
F
fengjiayi 已提交
466 467
  }

468
  proto::VarType::Type GetVarType(const std::string& name) const override {
469 470 471 472
    auto* var = scope_.FindVar(name);
    return ToVarType(var->Type());
  }

F
fengjiayi 已提交
473 474 475 476
  InferShapeVarPtr GetVarPtr(const std::string& name) override {
    return scope_.FindVar(name);
  }

477
 private:
478 479 480 481
  const OperatorBase& op_;
  const Scope& scope_;
};

482 483
void OperatorWithKernel::RunImpl(const Scope& scope,
                                 const platform::Place& place) const {
484 485
  RuntimeInferShapeContext infer_shape_ctx(*this, scope);
  this->InferShape(&infer_shape_ctx);
Y
Yu Yang 已提交
486
  platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
487
  auto* dev_ctx = pool.Get(place);
488 489 490 491
  // check if op[type] has kernel registered.
  auto& all_op_kernels = AllOpKernels();
  auto kernels_iter = all_op_kernels.find(type_);
  if (kernels_iter == all_op_kernels.end()) {
Y
Yu Yang 已提交
492 493
    PADDLE_THROW(
        "There are no kernels which are registered in the %s operator.", type_);
494 495
  }

D
dzhwinter 已提交
496
  ExecutionContext ctx(*this, scope, *dev_ctx);
497

Q
qiaolongfei 已提交
498 499
  OpKernelMap& kernels = kernels_iter->second;

500 501
  // TODO(dzhwinter) : kernel fallback mechanism will be added when all the
  // transform functions are ready.
Q
qiaolongfei 已提交
502

503 504 505 506 507
  // for (auto& candidate : kKernelPriority) {
  //   Do selection
  // }

  auto expected_kernel_key = this->GetExpectedKernelType(ctx);
Q
qiaolongfei 已提交
508 509
  VLOG(3) << "expected_kernel_key:" << expected_kernel_key;

510 511 512 513 514 515 516
  auto kernel_iter = kernels.find(expected_kernel_key);
  if (kernel_iter == kernels.end()) {
    PADDLE_THROW("op %s does not have kernel for %s", type_,
                 KernelTypeToString(expected_kernel_key));
  }

  // do data transform
517 518 519 520 521 522 523 524 525 526
  Scope& new_scope = scope.NewScope();

  for (auto& var_name_item : this->Inputs()) {
    for (auto& var_name : var_name_item.second) {
      auto* var = scope.FindVar(var_name);
      if (var && VarIsTensor(var)) {
        auto* tensor_in = GetTensorFromVar(var);
        if (tensor_in->IsInitialized()) {
          auto kernel_type_for_var = this->GetKernelTypeForVar(
              var_name_item.first, *tensor_in, expected_kernel_key);
527
          if (TransFromNeeded(kernel_type_for_var, expected_kernel_key)) {
528 529 530 531 532 533 534 535
            auto out_var_names = OutputVars(true);
            if (std::find(out_var_names.begin(), out_var_names.end(),
                          var_name) != out_var_names.end()) {
              PADDLE_THROW(
                  "var %s is both input and output, "
                  "does not support transform",
                  var_name);
            }
536 537
            VLOG(3) << "Transform Variable " << var_name << " from "
                    << kernel_type_for_var << " to " << expected_kernel_key;
538
            auto* trans_var = new_scope.Var(var_name);
539 540 541 542
            std::shared_ptr<Tensor> out(new Tensor);
            DataTransform(expected_kernel_key, kernel_type_for_var, *tensor_in,
                          out.get());
            CopyVariableWithTensor(*var, *(out.get()), *trans_var);
543
          }
Q
QI JUN 已提交
544 545
        }
      }
546 547
    }
  }
Q
QI JUN 已提交
548

D
dzhwinter 已提交
549 550 551 552 553
  auto* new_dev_ctx = pool.Get(expected_kernel_key.place_);
  kernel_iter->second->Compute(
      ExecutionContext(*this, new_scope, *new_dev_ctx));

  /*For profiling/benchmark only*/
D
dzhwinter 已提交
554
  if (FLAGS_benchmark) {
D
dzhwinter 已提交
555 556
    new_dev_ctx->Wait();
  }
Q
Qiao Longfei 已提交
557 558
}

559
proto::VarType::Type OperatorWithKernel::IndicateDataType(
Y
Yu Yang 已提交
560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584
    const ExecutionContext& ctx) const {
  auto& scope = ctx.scope();
  int data_type = -1;
  for (auto& input : this->inputs_) {
    for (auto& ipt_name : input.second) {
      auto* var = scope.FindVar(ipt_name);
      if (var != nullptr) {
        const Tensor* t = nullptr;
        if (var->IsType<Tensor>()) {
          t = &var->Get<Tensor>();
        } else if (var->IsType<LoDTensor>()) {
          t = &var->Get<LoDTensor>();
        } else if (var->IsType<SelectedRows>()) {
          t = &(var->Get<SelectedRows>().value());
        }
        if (t != nullptr) {
          int tmp = static_cast<int>(ToDataType(t->type()));
          PADDLE_ENFORCE(tmp == data_type || data_type == -1,
                         "DataType of Paddle Op %s must be the same.", Type());
          data_type = tmp;
        }
      }
    }
  }
  PADDLE_ENFORCE(data_type != -1, "DataType should be indicated by input");
585
  return static_cast<proto::VarType::Type>(data_type);
Y
Yu Yang 已提交
586
}
587

588 589 590 591 592 593 594 595 596 597 598
OpKernelType OperatorWithKernel::GetExpectedKernelType(
    const ExecutionContext& ctx) const {
  return OpKernelType(IndicateDataType(ctx), ctx.GetPlace());
}

OpKernelType OperatorWithKernel::GetKernelTypeForVar(
    const std::string& var_name, const Tensor& tensor,
    const OpKernelType& expected_kernel_type) const {
  return OpKernelType(expected_kernel_type.data_type_, tensor.place());
}

Q
Qiao Longfei 已提交
599
}  // namespace framework
L
liaogang 已提交
600
}  // namespace paddle