operator.h 35.2 KB
Newer Older
1
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Q
Qiao Longfei 已提交
2 3 4 5 6 7 8 9 10 11 12 13
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
    http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
14
#include <algorithm>
15
#include <atomic>
L
luotao1 已提交
16
#include <memory>
17
#include <mutex>  // NOLINT
Q
Qiao Longfei 已提交
18
#include <string>
D
dzhwinter 已提交
19
#include <tuple>
Q
Qiao Longfei 已提交
20
#include <unordered_map>
L
luotao1 已提交
21
#include <utility>
Q
Qiao Longfei 已提交
22 23
#include <vector>

Y
Yu Yang 已提交
24
#include "glog/logging.h"  // For VLOG
Y
Yi Wang 已提交
25 26 27 28 29 30 31 32
#include "paddle/fluid/framework/attribute.h"
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_info.h"
#include "paddle/fluid/framework/op_kernel_type.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/selected_rows.h"
#include "paddle/fluid/framework/tensor.h"
33
#include "paddle/fluid/framework/unused_var_check.h"
34
#include "paddle/fluid/memory/malloc.h"
Y
Yi Wang 已提交
35 36
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/variant.h"
P
phlrain 已提交
37
#include "paddle/fluid/framework/operator.h"
Q
Qiao Longfei 已提交
38

W
wanghuancoder 已提交
39 40 41 42 43 44 45 46 47
namespace paddle {
namespace framework {
class InferShapeContext;
class OpInfo;
class Scope;
class Variable;
}  // namespace framework
}  // namespace paddle

Q
Qiao Longfei 已提交
48 49
DECLARE_int32(inner_op_parallelism);

Q
Qiao Longfei 已提交
50 51 52
namespace paddle {
namespace framework {

53
/// If a variable is a empty variable, that name will be used.
54
constexpr char kEmptyVarName[] = "@EMPTY@";
55 56 57

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
58
constexpr char kTempVarName[] = "@TEMP@";
59 60

/// If a variable's name has a certain suffix, it means that the
T
tianshuo78520a 已提交
61 62
/// variable is the gradient of another variable.
/// e.g. Variable "x@GRAD" is the gradient of variable "x".
63
constexpr char kGradVarSuffix[] = "@GRAD";
64

M
minqiyang 已提交
65 66
constexpr size_t kGradVarSuffixSize = 5U;

67
/// Variables with this suffix are supposed to be filled up with zeros.
68
constexpr char kZeroVarSuffix[] = "@ZERO";
69

C
chengduo 已提交
70 71 72
/// Variables with this suffix are the new Gradient.
constexpr char kNewGradSuffix[] = "@NEWGRAD@";

L
luotao1 已提交
73 74 75 76 77 78 79 80
/// RuntimeContext is used to relate input/output names of Operator with
/// the corresponding variables in name scope.
/// If an Op has attribute kEnableCacheRuntimeContext, it means that in a same
/// name scope, since the input/output names of this Op do not change in the
/// execution, RuntimeContext could be created only at the first iteration of
/// this Op's execution to save the elapsed time.
constexpr char kEnableCacheRuntimeContext[] = "@ENABLE_CACHE_RUNTIME_CONTEXT@";

L
luotao1 已提交
81 82 83 84 85 86 87 88 89
/// If an Op has this attribute, all its kernels should calculate output
/// variable's shape in the corresponding Compute() function. And
/// OperatorWithKernel::RunImpl() would skip call this Op's InferShape()
/// function in its runtime for speedup.
/// TODO(luotao): Note that this temporal attribute would be deleted after all
/// ops contain it.
constexpr char kAllKernelsMustComputeRuntimeShape[] =
    "@ALL_KERNELS_MUST_COMPUTE_RUNTIME_SHAPE@";

D
dzhwinter 已提交
90
// define some kernel priority
91
/* Define multiple kernel type fallback order*/
D
dzhwinter 已提交
92 93
extern std::vector<std::tuple<platform::Place, LibraryType>> kKernelPriority;

94
inline std::string GradVarName(const std::string& var_name) {
M
minqiyang 已提交
95 96 97 98 99
  std::string result;
  result.reserve(var_name.size() + kGradVarSuffixSize);
  result += var_name;
  result += kGradVarSuffix;
  return result;
100 101
}

M
minqiyang 已提交
102
inline std::string GradOriginalVarName(const std::string& grad_var_name) {
M
minqiyang 已提交
103
  std::size_t pos = grad_var_name.rfind(kGradVarSuffix);
104 105 106 107 108
  if (pos == std::string::npos) {
    return grad_var_name;
  } else {
    return grad_var_name.substr(0, pos);
  }
109 110
}

C
chengduo 已提交
111 112
const Tensor* GetLoDTensorOrSelectedRowsValueFromVar(const Variable& var);
Tensor* GetMutableLoDTensorOrSelectedRowsValueFromVar(Variable* var);
Q
qiaolongfei 已提交
113

114
class ExecutionContext;
W
wanghuancoder 已提交
115
class OperatorBase;
116

X
Xin Pan 已提交
117 118
class RuntimeContext {
 public:
X
Xin Pan 已提交
119 120
  RuntimeContext(const VariableNameMap& innames,
                 const VariableNameMap& outnames, const Scope& scope);
X
Xin Pan 已提交
121

X
Xin Pan 已提交
122 123 124 125
  RuntimeContext(const VariableValueMap& invars,
                 const VariableValueMap& outvars)
      : inputs(invars), outputs(outvars) {}

X
Xin Pan 已提交
126 127 128 129
  VariableValueMap inputs;
  VariableValueMap outputs;
};

Q
Qiao Longfei 已提交
130
/**
X
Xin Pan 已提交
131
 * OperatorBase has the basic elements that Net will call to do computation.
Q
Qiao Longfei 已提交
132 133 134 135 136 137
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
Y
Yu Yang 已提交
138 139
  OperatorBase(const std::string& type, const VariableNameMap& inputs,
               const VariableNameMap& outputs, const AttributeMap& attrs);
140

Q
Qiao Longfei 已提交
141 142
  virtual ~OperatorBase() {}

143
  /// Executor will call this interface function to Run an op.
144 145
  //  The implementation should be written at RunImpl
  void Run(const Scope& scope, const platform::Place& place);
Y
Yu Yang 已提交
146

T
typhoonzero 已提交
147 148 149
  // FIXME(typhoonzero): this is only used for recv_op to stop event_loop.
  virtual void Stop() {}

150 151 152
  /// if scope is not null, also show dimensions of arguments
  virtual std::string DebugStringEx(const Scope* scope) const;
  std::string DebugString() const { return DebugStringEx(nullptr); }
Y
Yu Yang 已提交
153

154
  virtual bool SupportGPU() const { return false; }
B
Baibaifan 已提交
155
  virtual bool SupportNPU() const { return false; }
156

157 158
  const std::string& Type() const { return type_; }

M
Michal Gallus 已提交
159
  bool HasAttr(const std::string& name) const { return attrs_.count(name); }
160 161
  template <typename T>
  inline const T& Attr(const std::string& name) const {
162 163 164
    PADDLE_ENFORCE_NE(
        attrs_.find(name), attrs_.end(),
        platform::errors::NotFound("(%s) is not found in AttributeMap.", name));
165
    return BOOST_GET_CONST(T, attrs_.at(name));
166
  }
167 168 169 170 171 172 173 174
  void SetAttr(const std::string& name, const Attribute& v) {
    PADDLE_ENFORCE_EQ(
        HasAttr(name), true,
        platform::errors::NotFound(
            "The attribute %s is not found in operator %s", name, Type()));

    attrs_[name] = v;
  }
175
  const AttributeMap& Attrs() const { return attrs_; }
D
dongzhihong 已提交
176

Y
Yu Yang 已提交
177 178
  const VariableNameMap& Inputs() const { return inputs_; }
  const VariableNameMap& Outputs() const { return outputs_; }
179

S
sneaxiy 已提交
180
  const OpInfo& Info() const {
181 182 183
    PADDLE_ENFORCE_NOT_NULL(
        info_, platform::errors::NotFound(
                   "OpInfo of operator (%s) is not found.", type_));
S
sneaxiy 已提交
184 185 186
    return *info_;
  }

187
  bool HasInputs(const std::string& name) const;
Y
Yu Yang 已提交
188
  //! Get a input with argument's name described in `op_proto`
189
  std::string Input(const std::string& name) const;
Y
Yu Yang 已提交
190
  //! Get a input which has multiple variables.
Y
Yu Yang 已提交
191
  const std::vector<std::string>& Inputs(const std::string& name) const;
192
  //! Get all inputs variable names
Q
qijun 已提交
193 194
  std::vector<std::string> InputVars() const;

195
  bool HasOutputs(const std::string& name) const;
Y
Yu Yang 已提交
196
  //! Get a output with argument's name described in `op_proto`
197
  std::string Output(const std::string& name) const;
Y
Yu Yang 已提交
198 199
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yu Yang 已提交
200
  const std::vector<std::string>& Outputs(const std::string& name) const;
201
  //! Get all outputs variable names
Y
Yu Yang 已提交
202
  virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
203

204
  void SetIsCalledByExecutor(bool x) { run_by_executor_ = x; }
205

B
baojun-nervana 已提交
206
  virtual void RuntimeInferShape(const Scope& scope,
X
Xin Pan 已提交
207 208
                                 const platform::Place& place,
                                 const RuntimeContext& ctx) const {}
209

Z
Zhang Ting 已提交
210 211 212 213 214
  virtual platform::Place GetExecutionPlace(
      const platform::Place& place) const {
    return place;
  }

Q
qiaolongfei 已提交
215
 protected:
Q
Qiao Longfei 已提交
216
  std::string type_;
D
dongzhihong 已提交
217
  // NOTE: in case of OpGrad, inputs_ contains:
218
  // I (Inputs)
D
dongzhihong 已提交
219 220
  // O (Outputs)
  // OG (Output Gradients)
Y
Yu Yang 已提交
221
  VariableNameMap inputs_;
Y
Yu Yang 已提交
222

D
dongzhihong 已提交
223 224
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Y
Yu Yang 已提交
225
  VariableNameMap outputs_;
Q
Qiao Longfei 已提交
226
  AttributeMap attrs_;
S
sneaxiy 已提交
227 228 229 230

  // OpInfo
  const OpInfo* info_;

231 232
  // Whether this operator executes in an Executor.
  bool run_by_executor_{true};
233 234 235 236

 private:
  void GenerateTemporaryNames();
  void CheckAllInputOutputSet() const;
237 238
  virtual void RunImpl(const Scope& scope,
                       const platform::Place& place) const = 0;
Y
Yan Chunwei 已提交
239 240
};

241
class ExecutionContext {
Y
Yan Chunwei 已提交
242
 public:
243
  ExecutionContext(const OperatorBase& op, const Scope& scope,
X
Xin Pan 已提交
244
                   const platform::DeviceContext& device_context,
245 246
                   const RuntimeContext& ctx)
      : op_(op), scope_(scope), device_context_(device_context), ctx_(ctx) {}
H
hong 已提交
247
  virtual ~ExecutionContext() {}
248

H
hong 已提交
249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268
  virtual std::string InputName(const std::string& name) const {
    return op_.Input(name);
  }
  virtual std::vector<std::string> InputNames(const std::string& name) const {
    return op_.Inputs(name);
  }
  virtual std::string OutputName(const std::string& name) const {
    return op_.Output(name);
  }

  virtual std::vector<std::string> OutputNames(const std::string& name) const {
    return op_.Outputs(name);
  }

  virtual bool HasAttr(const std::string& name) const {
    return op_.HasAttr(name);
  }
  virtual const AttributeMap& Attrs() const { return op_.Attrs(); }

  const std::string& Type() const { return op_.Type(); }
Q
qiaolongfei 已提交
269 270 271

  const Scope& scope() const { return scope_; }

Q
qiaolongfei 已提交
272
  template <typename T>
Y
Yu Yang 已提交
273
  inline const T& Attr(const std::string& name) const {
274
    return BOOST_GET_CONST(T, GetAttr(name));
Q
qiaolongfei 已提交
275 276
  }

H
hong 已提交
277 278 279
  virtual const Attribute& GetAttr(const std::string& name) const {
    return op_.Attrs().at(name);
  }
280

H
hong 已提交
281
  virtual bool HasInput(const std::string& name) const;
282

H
hong 已提交
283
  virtual bool HasOutput(const std::string& name) const;
284

H
hong 已提交
285
  virtual size_t InputSize(const std::string& name) const {
Y
Yu Yang 已提交
286
    return op_.Inputs(name).size();
Y
Yan Chunwei 已提交
287 288
  }

H
hong 已提交
289
  virtual size_t OutputSize(const std::string& name) const {
Y
Yu Yang 已提交
290
    return op_.Outputs(name).size();
Y
Yan Chunwei 已提交
291 292
  }

H
hong 已提交
293
  virtual const Variable* InputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
294

H
hong 已提交
295
  virtual Variable* OutputVar(const std::string& name) const;
Y
Yan Chunwei 已提交
296

H
hong 已提交
297
  virtual const std::vector<Variable*> MultiInputVar(
298
      const std::string& name) const {
299 300
    LogVarUsageIfUnusedVarCheckEnabled(name);

X
Xin Pan 已提交
301 302 303 304
    auto it = ctx_.inputs.find(name);
    if (it == ctx_.inputs.end()) {
      return {};
    }
G
Gabor Buella 已提交
305
    return {it->second.begin(), it->second.end()};
X
Xin Pan 已提交
306 307
  }

H
hong 已提交
308
  virtual std::vector<Variable*> MultiOutputVar(const std::string& name) const {
X
Xin Pan 已提交
309 310 311 312 313 314 315
    auto it = ctx_.outputs.find(name);
    if (it == ctx_.outputs.end()) {
      return {};
    }
    return it->second;
  }

H
hong 已提交
316 317 318 319 320 321 322 323 324 325 326
  virtual std::vector<std::string> InNameList() const {
    std::vector<std::string> vec_temp;
    vec_temp.reserve(ctx_.inputs.size());

    for (auto& input : ctx_.inputs) {
      vec_temp.push_back(input.first);
    }

    return vec_temp;
  }

327 328
  template <typename T>
  const T* Input(const std::string& name) const {
Y
Yu Yang 已提交
329
    auto* var = InputVar(name);
330
    return var == nullptr ? nullptr : &var->Get<T>();
331 332 333 334
  }

  template <typename T>
  T* Output(const std::string& name) const {
335
    auto var = OutputVar(name);
336
    return var == nullptr ? nullptr : var->GetMutable<T>();
337 338 339 340
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
341 342
    LogVarUsageIfUnusedVarCheckEnabled(name);

H
hong 已提交
343 344
    auto vars = MultiInputVar(name);
    if (vars.size() == 0) {
X
Xin Pan 已提交
345 346 347 348 349
      return {};
    }
    std::vector<const T*> res;
    res.reserve(vars.size());
    std::transform(vars.begin(), vars.end(), std::back_inserter(res),
H
hong 已提交
350
                   [&](const Variable* var) -> const T* {
X
Xin Pan 已提交
351 352 353 354 355 356 357
                     return var == nullptr ? nullptr : &var->Get<T>();
                   });
    return res;
  }

  template <typename T>
  std::vector<T*> MultiOutput(const std::string& name) const {
H
hong 已提交
358 359 360
    auto vars = MultiOutputVar(name);

    if (vars.size() == 0) {
X
Xin Pan 已提交
361 362
      return {};
    }
H
hong 已提交
363

X
Xin Pan 已提交
364 365 366 367 368 369
    std::vector<T*> res;
    res.reserve(vars.size());
    std::transform(vars.begin(), vars.end(), std::back_inserter(res),
                   [&](Variable* var) -> T* {
                     return var == nullptr ? nullptr : var->GetMutable<T>();
                   });
H
hong 已提交
370

X
Xin Pan 已提交
371 372 373
    return res;
  }

374
  platform::Place GetPlace() const { return device_context_.GetPlace(); }
Q
qijun 已提交
375

Q
QI JUN 已提交
376 377 378 379 380
  template <typename DeviceContextType>
  const DeviceContextType& device_context() const {
    return *reinterpret_cast<const DeviceContextType*>(&device_context_);
  }

381
  const platform::DeviceContext& device_context() const {
Q
qijun 已提交
382
    return device_context_;
Q
qijun 已提交
383
  }
Q
qijun 已提交
384

385
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
Q
QI JUN 已提交
386
  const inline platform::CUDADeviceContext& cuda_device_context() const {
387 388 389
    PADDLE_ENFORCE_EQ(platform::is_gpu_place(device_context_.GetPlace()), true,
                      platform::errors::PreconditionNotMet(
                          "Current device context place is not GPUPlace."));
Q
QI JUN 已提交
390 391 392 393 394
    return *reinterpret_cast<const platform::CUDADeviceContext*>(
        &device_context_);
  }
#endif

X
Xin Pan 已提交
395 396 397
  template <typename T, typename DevContext>
  Tensor AllocateTmpTensor(const framework::DDim& dim,
                           const DevContext& dev_ctx) const {
398
    auto tmp_allocation_ptr = memory::Alloc(dev_ctx, product(dim) * sizeof(T));
X
Xin Pan 已提交
399 400 401 402 403
    auto& deleter = tmp_allocation_ptr.get_deleter();
    auto* allocation_ptr = tmp_allocation_ptr.release();
    auto shared_allocation = std::shared_ptr<memory::allocation::Allocation>(
        allocation_ptr, deleter);

404 405 406 407 408 409
    PADDLE_ENFORCE_GE(
        allocation_ptr->size(), framework::product(dim) * sizeof(T),
        platform::errors::PreconditionNotMet(
            "The data memory size(%d) is less than the tensor needed memory "
            "size(%d).",
            allocation_ptr->size(), framework::product(dim) * sizeof(T)));
X
Xin Pan 已提交
410 411 412 413 414 415 416 417

    paddle::framework::Tensor temp_tensor(
        framework::ToDataType(std::type_index(typeid(T))));
    temp_tensor.Resize(dim);
    temp_tensor.ResetHolder(std::move(shared_allocation));
    return temp_tensor;
  }

H
hong 已提交
418 419 420
  const RuntimeContext Context() const { return ctx_; }

  std::string DebugString() const { return op_.DebugString(); }
421
  const OperatorBase& GetOp() const { return op_; }
H
hong 已提交
422

423
 private:
424 425
  const OperatorBase& op_;
  const Scope& scope_;
426
  const platform::DeviceContext& device_context_;
X
Xin Pan 已提交
427
  const RuntimeContext& ctx_;
Q
Qiao Longfei 已提交
428 429
};

430 431 432 433 434 435 436 437 438 439 440 441 442 443
template <>
const Tensor* ExecutionContext::Input<Tensor>(const std::string& name) const;

template <>
const std::vector<const Tensor*> ExecutionContext::MultiInput<Tensor>(
    const std::string& name) const;

template <>
Tensor* ExecutionContext::Output<Tensor>(const std::string& name) const;

template <>
std::vector<Tensor*> ExecutionContext::MultiOutput<Tensor>(
    const std::string& name) const;

Y
Yu Yang 已提交
444
class OpKernelBase {
Q
qijun 已提交
445
 public:
Q
qijun 已提交
446
  /**
447
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
448 449
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
450
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
451 452
   */

453
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
454

Y
Yu Yang 已提交
455 456 457 458 459 460 461
  virtual ~OpKernelBase() = default;
};

template <typename T>
class OpKernel : public OpKernelBase {
 public:
  using ELEMENT_TYPE = T;
Y
Yu Yang 已提交
462 463
};

Y
Yu Yang 已提交
464 465
class OperatorWithKernel : public OperatorBase {
 public:
Y
yuyang18 已提交
466
  using OpKernelFunc = std::function<void(const ExecutionContext&)>;
Y
Yu Yang 已提交
467
  using OpKernelMap =
Y
yuyang18 已提交
468
      std::unordered_map<OpKernelType, OpKernelFunc, OpKernelType::Hash>;
Q
Qiao Longfei 已提交
469

Y
Yu Yang 已提交
470 471
  OperatorWithKernel(const std::string& type, const VariableNameMap& inputs,
                     const VariableNameMap& outputs, const AttributeMap& attrs)
Y
Yu Yang 已提交
472 473
      : OperatorBase(type, inputs, outputs, attrs) {}

Y
Yu Yang 已提交
474 475 476 477
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
478
  }
Y
Yan Chunwei 已提交
479

480 481 482 483 484
  bool IsMKLDNNType() const {
    return ((this->kernel_type_) && (this->kernel_type_->data_layout_ ==
                                     framework::DataLayout::kMKLDNN));
  }

485
  bool SupportGPU() const override {
Y
Yu Yang 已提交
486 487 488 489 490
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_gpu_place(kern_pair.first.place_);
                       });
491
  }
B
Baibaifan 已提交
492 493 494 495 496 497 498
  bool SupportNPU() const override {
    auto& op_kernels = OperatorWithKernel::AllOpKernels().at(type_);
    return std::any_of(op_kernels.begin(), op_kernels.end(),
                       [](OpKernelMap::const_reference kern_pair) {
                         return platform::is_npu_place(kern_pair.first.place_);
                       });
  }
499
  bool SupportsMKLDNN(proto::VarType::Type data_type) const;
500

501 502
  bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx,
                       proto::VarType::Type data_type) const;
503

504
  virtual void InferShape(InferShapeContext* ctx) const = 0;
Y
Yu Yang 已提交
505

X
Xin Pan 已提交
506 507
  void RuntimeInferShape(const Scope& scope, const platform::Place& place,
                         const RuntimeContext& ctx) const override;
B
baojun-nervana 已提交
508

509 510 511
  proto::VarType::Type IndicateVarDataType(const ExecutionContext& ctx,
                                           const std::string& name) const;

512 513 514 515
  proto::VarType::Type IndicateOrPromoteVarDataTypes(
      const ExecutionContext& ctx, const std::string& name1,
      const std::string& name2) const;

516
  virtual OpKernelType GetExpectedKernelType(const ExecutionContext& ctx) const;
X
Xin Pan 已提交
517

518 519
  // change this to public so that in dygraph mode we can call it to check if we
  // need transform data
520 521 522
  virtual OpKernelType GetKernelTypeForVar(
      const std::string& var_name, const Tensor& tensor,
      const OpKernelType& expected_kernel_type) const;
Y
Yu Yang 已提交
523

524 525
  platform::Place GetExecutionPlace(
      const platform::Place& platform) const override {
Z
Zhang Ting 已提交
526 527 528
    return kernel_type_->place_;
  }

Y
Yu Yang 已提交
529
 private:
530
  void RunImpl(const Scope& scope, const platform::Place& place) const final;
L
luotao1 已提交
531 532
  void RunImpl(const Scope& scope, const platform::Place& place,
               RuntimeContext* runtime_ctx) const;
Y
yuyang18 已提交
533 534

  /**
T
tianshuo78520a 已提交
535 536
   * Transfer data from scope to a transferred scope. If there is no data need
   * to
Y
yuyang18 已提交
537 538 539 540
   * be tranfered, it returns nullptr.
   *
   * * transfered_inplace_vars is a output vector.
   */
X
Xin Pan 已提交
541 542 543 544
  Scope* PrepareData(const Scope& scope,
                     const OpKernelType& expected_kernel_key,
                     std::vector<std::string>* transfered_inplace_vars,
                     RuntimeContext* ctx) const;
Y
yuyang18 已提交
545 546 547 548

  void TransferInplaceVarsBack(const Scope& scope,
                               const std::vector<std::string>& inplace_vars,
                               const Scope& exec_scope) const;
549

L
Liu Yiqun 已提交
550 551 552
  void ChooseKernel(const RuntimeContext& ctx, const Scope& scope,
                    const platform::Place& place) const;

553 554 555
  void HandleComplexGradToRealGrad(const Scope& scope,
                                   RuntimeContext* ctx) const;

556 557 558 559 560 561 562 563 564 565 566
  /* Inner assist methods */
  // indicate kernel DataType by input data.
  // By default all input data must be same.
  proto::VarType::Type IndicateDataType(const ExecutionContext& ctx) const;
  // used for IndicateDataType
  void ParseInputDataType(const ExecutionContext& ctx, const std::string& name,
                          proto::VarType::Type* type) const;
  // used for IndicateOrPromoteVarDataTypes
  Tensor* GetTensorFormInputSafely(const ExecutionContext& ctx,
                                   const std::string& name) const;

567
 protected:
L
Liu Yiqun 已提交
568 569
  mutable std::unique_ptr<OpKernelType> kernel_type_;
  mutable std::unique_ptr<OpKernelFunc> kernel_func_;
L
luotao1 已提交
570 571
  mutable std::unique_ptr<RuntimeContext> runtime_ctx_;
  mutable const Scope* pre_scope_ = nullptr;
572
  mutable bool need_prepare_data_ = true;
573 574
  mutable bool enable_cache_runtime_context_ = false;
  mutable bool all_kernels_must_compute_runtime_shape_ = false;
575
  mutable std::mutex cache_update_mutex_;
576
  mutable bool enable_cache_transfer_scope_ = false;
Q
Qiao Longfei 已提交
577 578
};

Y
Yu Yang 已提交
579 580
extern bool OpSupportGPU(const std::string& op_type);

P
phlrain 已提交
581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962
/*
class RuntimeInferShapeContext : public InferShapeContext {
 public:
  RuntimeInferShapeContext(const OperatorBase& op, const RuntimeContext& ctx)
      : op_(op), ctx_(ctx) {}
  bool HasInput(const std::string& name) const override {
    // has only one input
    const auto& ins = ctx_.inputs;
    auto it = ins.find(name);
    if (it == ins.end()) {
      return false;
    }
    const auto& in = it->second;
    if (in.size() == 0) return false;
    PADDLE_ENFORCE_EQ(
        in.size(), 1UL,
        platform::errors::InvalidArgument(
            "Input %s should not contain more than one inputs.", name));
    return in[0] != nullptr;
  }
  bool HasOutput(const std::string& name) const override {
    // has only one output
    const auto& outs = ctx_.outputs;
    auto it = outs.find(name);
    if (it == outs.end()) {
      return false;
    }
    const auto& out = it->second;
    if (out.size() == 0) {
      return false;
    }
    PADDLE_ENFORCE_EQ(
        out.size(), 1UL,
        platform::errors::InvalidArgument(
            "Output %s should not contain more than one outputs.", name));
    return out[0] != nullptr;
  }
  bool HasInputs(const std::string& name) const override {
    const auto& ins = ctx_.inputs;
    auto it = ins.find(name);
    if (it == ins.end() || it->second.empty()) {
      return false;
    }
    for (auto& input : it->second) {
      if (input == nullptr) {
        return false;
      }
    }
    return true;
  }
  bool HasOutputs(const std::string& name) const override {
    const auto& outs = ctx_.outputs;
    auto it = outs.find(name);
    if (it == outs.end() || it->second.empty()) {
      return false;
    }
    for (auto& output : it->second) {
      if (output == nullptr) {
        return false;
      }
    }
    return true;
  }
  AttrReader Attrs() const override { return AttrReader(op_.Attrs()); }
  std::vector<std::string> Inputs(const std::string& name) const override {
    return op_.Inputs(name);
  }
  std::vector<std::string> Outputs(const std::string& name) const override {
    return op_.Outputs(name);
  }
  std::string GetInputNameByIdx(size_t idx) const override {
    auto& op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(idx, op_proto->inputs().size(),
                      platform::errors::OutOfRange(
                          "The index should be less than the size of inputs of "
                          "operator %s, but got index is %d and size is %d",
                          op_.Type(), idx, op_proto->inputs().size()));
    return op_proto->inputs()[idx].name();
  }
  std::string GetOutputNameByIdx(size_t idx) const override {
    auto& op_proto =
        paddle::framework::OpInfoMap::Instance().Get(op_.Type()).proto_;
    PADDLE_ENFORCE_LT(
        idx, op_proto->outputs().size(),
        platform::errors::OutOfRange(
            "The index should be less than the size of outputs of "
            "operator %s, but got index is %d and size is %d",
            op_.Type(), idx, op_proto->outputs().size()));
    return op_proto->outputs()[idx].name();
  }
  void ShareDim(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) override {
    auto in_it = ctx_.inputs.find(in);
    auto out_it = ctx_.outputs.find(out);
    PADDLE_ENFORCE_NE(
        in_it, ctx_.inputs.end(),
        platform::errors::NotFound("Input %s does not exist.", in));
    PADDLE_ENFORCE_NE(
        out_it, ctx_.outputs.end(),
        platform::errors::NotFound("Output %s does not exist.", out));
    PADDLE_ENFORCE_LT(i, in_it->second.size(),
                      platform::errors::InvalidArgument(
                          "The index of input dimension is out of range, "
                          "excepted index less than %zu, but received %zu.",
                          in_it->second.size(), i));
    PADDLE_ENFORCE_LT(j, out_it->second.size(),
                      platform::errors::InvalidArgument(
                          "The index of output dimension is out of range, "
                          "excepted index less than %zu, but received %zu.",
                          out_it->second.size(), j));
    Variable* in_var = in_it->second[i];
    Variable* out_var = out_it->second[j];
    PADDLE_ENFORCE_EQ(
        in_var->Type(), out_var->Type(),
        platform::errors::InvalidArgument(
            "The type of input (%s) and output (%s) are inconsistent.", in,
            out));
    if (in_var->IsType<framework::SelectedRows>()) {
      auto& in_sele_rows = in_var->Get<framework::SelectedRows>();
      auto out_sele_rows = out_var->GetMutable<framework::SelectedRows>();
      out_sele_rows->mutable_value()->Resize(in_sele_rows.value().dims());
      out_sele_rows->set_rows(in_sele_rows.rows());
      out_sele_rows->set_height(in_sele_rows.height());
    } else if (in_var->IsType<framework::LoDTensor>()) {
      auto& in_lod_tensor = in_var->Get<framework::LoDTensor>();
      auto* out_lod_tensor = out_var->GetMutable<framework::LoDTensor>();
      out_lod_tensor->Resize(in_lod_tensor.dims());
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Currently, the input type of ShareDim only can be LoDTensor "
          "or SelectedRows."));
    }
  }
  void ShareAllLoD(const std::string& in,
                   const std::string& out) const override {
    auto in_it = ctx_.inputs.find(in);
    auto out_it = ctx_.outputs.find(out);
    PADDLE_ENFORCE_NE(in_it, ctx_.inputs.end(),
                      platform::errors::NotFound(
                          "Input [%s] found error in Op [%s]", in, op_.Type()));
    PADDLE_ENFORCE_NE(
        out_it, ctx_.outputs.end(),
        platform::errors::NotFound("Output [%s] found error in Op [%s]", out,
                                   op_.Type()));
    auto& in_var_list = in_it->second;
    auto& out_var_list = out_it->second;
    PADDLE_ENFORCE_EQ(
        in_var_list.size(), out_var_list.size(),
        platform::errors::PreconditionNotMet(
            "Op [%s]: Input var size should be equal with output var size",
            op_.Type()));
    auto& out_var_names = op_.Outputs(out);
    for (size_t i = 0; i < in_var_list.size(); ++i) {
      if (out_var_names[i] == framework::kEmptyVarName) {
        continue;
      }
      Variable* in_var = in_var_list[i];
      if (!in_var->IsType<LoDTensor>()) return;
      Variable* out_var = out_var_list[i];
      PADDLE_ENFORCE_EQ(out_var->IsType<LoDTensor>(), true,
                        platform::errors::PreconditionNotMet(
                            "The %d-th output of Output(%s) must be LoDTensor.",
                            i, out_var_names[i]));
      auto& in_tensor = in_var->Get<LoDTensor>();
      auto* out_tensor = out_var->GetMutable<LoDTensor>();
      out_tensor->set_lod(in_tensor.lod());
#ifdef PADDLE_WITH_MKLDNN
      if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
        out_tensor->set_layout(in_tensor.layout());
    }
  }
  void ShareLoD(const std::string& in, const std::string& out, size_t i = 0,
                size_t j = 0) const override {
    auto in_it = ctx_.inputs.find(in);
    auto out_it = ctx_.outputs.find(out);
    PADDLE_ENFORCE_NE(
        in_it, ctx_.inputs.end(),
        platform::errors::NotFound("Input %s does not exist.", in));
    PADDLE_ENFORCE_NE(
        out_it, ctx_.outputs.end(),
        platform::errors::NotFound("Output %s does not exist.", out));
    PADDLE_ENFORCE_LT(i, in_it->second.size(),
                      platform::errors::InvalidArgument(
                          "The index of input dimension is out of range, "
                          "excepted index less than %zu, but received %zu.",
                          in_it->second.size(), i));
    PADDLE_ENFORCE_LT(j, out_it->second.size(),
                      platform::errors::InvalidArgument(
                          "The index of output dimension is out of range, "
                          "excepted index less than %zu, but received %zu.",
                          out_it->second.size(), j));
    Variable* in_var = in_it->second.at(i);
    if (!in_var->IsType<LoDTensor>()) return;
    Variable* out_var = out_it->second.at(j);
    PADDLE_ENFORCE_EQ(
        out_var->IsType<LoDTensor>(), true,
        platform::errors::InvalidArgument(
            "The %zu-th output of Output(%s) must be LoDTensor.", j, out));
    auto& in_tensor = in_var->Get<LoDTensor>();
    auto* out_tensor = out_var->GetMutable<LoDTensor>();
    out_tensor->set_lod(in_tensor.lod());
// TODO(dzhwinter) : reuse ShareLoD in most operators.
// Need to call ShareLayout explicitly in sequence related ops.
// Shall we have a better method to shared info between in/out Tensor?
#ifdef PADDLE_WITH_MKLDNN
    // Fix me: ugly workaround below
    // Correct solution:
    //    set_layout() should NOT be called here (i.e. ShareLoD). Instead,
    //    layout of output tensor should be set "manually" in Compute()
    //    of each OPKernel. The reason layout should NOT be shared between
    //    input and output "automatically" (now by InferShape()->ShareLoD())
    //    is that layout transform may occur after InferShape().
    // Workaround:
    //    Skip set_layout() when input layout is kMKLDNN
    //    This is to avoid kMKLDNN is populated wrongly into a non-MKLDNN
    //    OPKernel. In all MKLDNN OPkernel, set_layout(kMKLDNN) should be called
    //    in Compute()
    if (in_tensor.layout() != DataLayout::kMKLDNN)
#endif
      out_tensor->set_layout(in_tensor.layout());
  }
  int32_t GetLoDLevel(const std::string& in, size_t i = 0) const override {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "GetLoDLevel is only used in compile time. The calculation of "
        "output's actual lod is different among operators so that should be "
        "set in the runtime kernel."));
  }
  void SetLoDLevel(const std::string& out, int32_t lod_level,
                   size_t j = 0) const override {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "SetLoDLevel is only used in compile time. The calculation of "
        "output's actual lod is different among operators so that should be "
        "set in the runtime kernel."));
  }
  bool IsRuntime() const override { return true; }
  // TODO(paddle-dev): Can this be template?
  std::vector<InferShapeVarPtr> GetInputVarPtrs(
      const std::string& name) override {
    const std::vector<Variable*>& vars = InputVars(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(vars.size());
    res.insert(res.begin(), vars.begin(), vars.end());
    return res;
  }
  std::vector<InferShapeVarPtr> GetOutputVarPtrs(
      const std::string& name) override {
    const std::vector<Variable*>& vars = OutputVars(name);
    std::vector<InferShapeVarPtr> res;
    res.reserve(vars.size());
    res.insert(res.begin(), vars.begin(), vars.end());
    return res;
  }
  DDim GetInputDim(const std::string& name) const override {
    const std::vector<Variable*>& vars = InputVars(name);
    PADDLE_ENFORCE_EQ(
        vars.size(), 1UL,
        platform::errors::InvalidArgument(
            "Input(%s) should hold one element, but now it holds %zu elements.",
            name, vars.size()));
    return this->GetDim(vars[0]);
  }
  std::vector<DDim> GetInputsDim(const std::string& name) const override {
    const std::vector<Variable*>& vars = InputVars(name);
    return GetDims(vars);
  }
  std::vector<proto::VarType::Type> GetInputsVarType(
      const std::string& name) const override {
    return GetVarTypes(InputVars(name));
  }
  std::vector<proto::VarType::Type> GetOutputsVarType(
      const std::string& name) const override {
    return GetVarTypes(OutputVars(name));
  }
  void SetOutputDim(const std::string& name, const DDim& dim) override {
    auto& vars = OutputVars(name);
    PADDLE_ENFORCE_EQ(
        vars.size(), 1UL,
        platform::errors::InvalidArgument("Output(%s) should hold one element, "
                                          "but now it holds %zu elements.",
                                          name, vars.size()));
    SetDim(vars[0], dim);
  }
  void SetOutputsDim(const std::string& name,
                     const std::vector<DDim>& dims) override {
    auto& vars = OutputVars(name);
    SetDims(vars, dims);
  }
 protected:
  DDim GetDim(Variable* var) const {
    PADDLE_ENFORCE_NOT_NULL(
        var, platform::errors::InvalidArgument("Input variable is nullptr."));
    if (var->IsType<LoDTensor>()) {
      return var->Get<LoDTensor>().dims();
    } else if (var->IsType<SelectedRows>()) {
      return var->Get<SelectedRows>().GetCompleteDims();
    } else {
      PADDLE_THROW(platform::errors::InvalidArgument(
          "Only LoDTensor or SelectedRows support 'GetDim', but input "
          "Variable's type is %s.",
          ToTypeName(var->Type())));
    }
  }
  std::vector<DDim> GetDims(const std::vector<Variable*>& vars) const {
    std::vector<DDim> ret;
    ret.reserve(vars.size());
    std::transform(vars.begin(), vars.end(), std::back_inserter(ret),
                   [this](Variable* var) { return this->GetDim(var); });
    return ret;
  }
  std::vector<DDim> GetRepeatedDims(const std::string& name) const override {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "GetRepeatedDims method only ban be used in compile time."));
  }
  void SetDim(Variable* var, const DDim& dim) {
    if (var->IsType<LoDTensor>()) {
      var->GetMutable<LoDTensor>()->Resize(dim);
    } else if (var->IsType<SelectedRows>()) {
      var->GetMutable<SelectedRows>()->set_height(dim[0]);
    } else {
      PADDLE_THROW(platform::errors::Unimplemented(
          "Variable type error, expect LoDTensor or SelectedRows, but received "
          "(%s).",
          ToTypeName(var->Type())));
    }
  }
  void SetDims(const std::vector<Variable*>& vars,
               const std::vector<DDim>& dims) {
    size_t length = vars.size();
    PADDLE_ENFORCE_EQ(length, dims.size(),
                      platform::errors::InvalidArgument(
                          "The number of input variables do not match the "
                          "number of input dimensions, the number of variables "
                          "is %zu, the number of dimensions is %zu.",
                          length, dims.size()));
    for (size_t i = 0; i < length; ++i) {
      if (vars[i] == nullptr) {
        continue;
      }
      SetDim(vars[i], dims[i]);
    }
  }
  void SetRepeatedDims(const std::string& name,
                       const std::vector<DDim>& dims) override {
    PADDLE_THROW(platform::errors::PreconditionNotMet(
        "SetRepeatedDims method only can be used in compile time."));
  }
  std::vector<proto::VarType::Type> GetVarTypes(
      const std::vector<Variable*>& vars) const {
    std::vector<proto::VarType::Type> retv;
    retv.resize(vars.size());
    std::transform(vars.begin(), vars.end(), retv.begin(),
                   std::bind(std::mem_fn(&RuntimeInferShapeContext::GetVarType),
                             this, std::placeholders::_1));
    return retv;
  }
  proto::VarType::Type GetVarType(Variable* var) const {
    return ToVarType(var->Type());
  }
 private:
  const std::vector<Variable*>& InputVars(const std::string& name) const {
    auto it = ctx_.inputs.find(name);
    PADDLE_ENFORCE_NE(
        it, ctx_.inputs.end(),
        platform::errors::NotFound(
            "Operator (%s) does not have the input (%s).", op_.Type(), name));
    return it->second;
  }
  const std::vector<Variable*>& OutputVars(const std::string& name) const {
    auto it = ctx_.outputs.find(name);
    PADDLE_ENFORCE_NE(
        it, ctx_.outputs.end(),
        platform::errors::NotFound(
            "Operator (%s) does not have the outputs (%s).", op_.Type(), name));
    return it->second;
  }
  const OperatorBase& op_;
  const RuntimeContext& ctx_;
};
*/

Q
Qiao Longfei 已提交
963
}  // namespace framework
W
wanghuancoder 已提交
964
}  // namespace paddle