operator.h 10.9 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
Q
Qiao Longfei 已提交
18 19 20 21
#include <string>
#include <unordered_map>
#include <vector>

Y
Yi Wang 已提交
22
#include "paddle/framework/attribute.h"
Q
qijun 已提交
23
#include "paddle/framework/op_desc.pb.h"
Y
Yan Chunwei 已提交
24
#include "paddle/framework/op_proto.pb.h"
Q
qijun 已提交
25 26 27 28
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
29
#include "paddle/platform/variant.h"
Q
qijun 已提交
30
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
31 32 33 34

namespace paddle {
namespace framework {

35
/// If a variable is a empty variable, that name will be used.
36
constexpr char kEmptyVarName[] = "@EMPTY@";
37 38 39

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
40
constexpr char kTempVarName[] = "@TEMP@";
41 42 43 44

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
45
constexpr char kGradVarSuffix[] = "@GRAD";
46 47

/// Variables with this suffix are supposed to be filled up with zeros.
48
constexpr char kZeroVarSuffix[] = "@ZERO";
49 50 51 52 53

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
54
class OperatorBase;
55 56
class InferShapeContext;
class ExecutionContext;
57

Q
Qiao Longfei 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
  virtual ~OperatorBase() {}

  template <typename T>
  inline const T& GetAttr(const std::string& name) const {
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

75
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
76

Q
Qiao Longfei 已提交
77 78 79 80
  /// Init will be called after CreateOperator, you can put some initialization
  /// logic here.
  virtual void Init() {}

Q
Qiao Longfei 已提交
81 82
  /// InferShape infer the size of Variables used by this Operator with
  /// information inside scope
Y
Yu Yang 已提交
83
  virtual void InferShape(const Scope& scope) const = 0;
Q
Qiao Longfei 已提交
84 85

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
86
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
87 88
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
89 90
  virtual bool IsNetOp() const { return false; }

91 92
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
93 94 95
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
96
  //! Get a input with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
97
  const std::string& Input(const std::string& name) const;
Y
Yu Yang 已提交
98 99
  //! Get a input which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
100
  std::vector<std::string> Inputs(const std::string& name) const;
Y
Yi Wang 已提交
101

Y
Yu Yang 已提交
102
  //! Get a output with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
103
  const std::string& Output(const std::string& name) const;
Y
Yu Yang 已提交
104 105
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
106 107
  std::vector<std::string> Outputs(const std::string& name) const;

Y
Yi Wang 已提交
108 109 110 111 112
  const std::string Type() const { return type_; }
  const std::vector<std::string> Inputs() const { return inputs_; }
  const std::vector<std::string> Outputs() const { return outputs_; }
  const AttributeMap& Attrs() const { return attrs_; }

Q
Qiao Longfei 已提交
113
 public:
Q
Qiao Longfei 已提交
114
  std::string type_;
D
dongzhihong 已提交
115 116 117 118
  // NOTE: in case of OpGrad, inputs_ contains:
  // I (Inputs)
  // O (Outputs)
  // OG (Output Gradients)
Q
Qiao Longfei 已提交
119
  std::vector<std::string> inputs_;
D
dongzhihong 已提交
120 121
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Q
Qiao Longfei 已提交
122 123
  std::vector<std::string> outputs_;
  AttributeMap attrs_;
Y
Yan Chunwei 已提交
124
  // store the arguments' offset described in op_desc.
Y
Yu Yang 已提交
125
  std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
Y
Yan Chunwei 已提交
126 127
};

128 129 130 131 132 133 134
class NOP : public OperatorBase {
 public:
  void InferShape(const Scope& scope) const override {}
  void Run(const Scope& scope,
           const platform::DeviceContext& dev_ctx) const override {}
};

135
class InferShapeContext {
Y
Yan Chunwei 已提交
136
 public:
137 138
  InferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}
139 140

  size_t InputSize() const { return op_.inputs_.size(); }
Y
Yan Chunwei 已提交
141

142 143
  size_t OutputSize() const { return op_.outputs_.size(); }

144
  const Variable* InputVar(const size_t index) const {
145
    return scope_.FindVar(op_.inputs_.at(index));
Y
Yan Chunwei 已提交
146 147
  }

148
  Variable* OutputVar(const size_t index) const {
149
    return scope_.FindVar(op_.outputs_.at(index));
Y
Yan Chunwei 已提交
150 151
  }

152
  const Variable* InputVar(const std::string& name) const {
Y
Yu Yang 已提交
153
    return scope_.FindVar(op_.Input(name));
Y
Yan Chunwei 已提交
154 155
  }

156
  Variable* OutputVar(const std::string& name) const {
Y
Yu Yang 已提交
157
    return scope_.FindVar(op_.Output(name));
Y
Yan Chunwei 已提交
158 159
  }

160 161
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
162 163
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
164
    res.reserve(names.size());
Y
Yan Chunwei 已提交
165
    std::transform(
166
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
167
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
168 169 170
    return res;
  }

171
  std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
172 173
    auto names = op_.Outputs(name);
    std::vector<const Variable*> res;
174
    res.reserve(names.size());
Y
Yan Chunwei 已提交
175
    std::transform(
176
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
177
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
178 179 180
    return res;
  }

181
  template <typename T>
182 183
  const T* Input(const size_t index) const {
    auto var = InputVar(index);
Y
Yan Chunwei 已提交
184
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
185
    return &var->Get<T>();
186 187 188
  }

  template <typename T>
189 190
  T* Output(const size_t index) const {
    auto var = OutputVar(index);
Y
Yan Chunwei 已提交
191 192
    PADDLE_ENFORCE_NOT_NULL(
        var,
Y
Yan Chunwei 已提交
193 194 195
        "Output(%d) not be nullptr, which means variable [%s] does not "
        "exist in scope",
        index, op_.outputs_[index]);
196
    return var->GetMutable<T>();
197 198 199 200
  }

  template <typename T>
  const T* Input(const std::string& name) const {
201
    auto var = InputVar(name);
Y
Yan Chunwei 已提交
202
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
203
    return &var->Get<T>();
204 205 206 207
  }

  template <typename T>
  T* Output(const std::string& name) const {
208
    auto var = OutputVar(name);
Y
Yan Chunwei 已提交
209
    PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
210
    return var->GetMutable<T>();
211 212 213 214 215 216 217 218
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
219
                   [&](const std::string& sub_name) {
220
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
221 222 223
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiInput(%s:%s) should not be nullptr", name,
                         sub_name);
224
                     return &var->Get<T>();
225 226 227 228 229 230 231 232 233 234
                   });
    return res;
  }

  template <typename T>
  std::vector<const T*> MultiOutput(const std::string& name) const {
    auto names = op_.Outputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
235
                   [&](const std::string& sub_name) {
236
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
237 238 239
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiOutput(%s:%s) should not be nullptr", name,
                         sub_name);
240
                     return var->GetMutable<T>();
241 242 243 244 245
                   });
    return res;
  }

  const OperatorBase& op_;
246
  const Scope& scope_;
247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
  using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
  using EigenDeviceType = Eigen::GpuDevice;
};
#endif

264
class ExecutionContext : public InferShapeContext {
265
 public:
266
  ExecutionContext(const OperatorBase& op, const Scope& scope,
D
dongzhihong 已提交
267
                   const platform::DeviceContext* device_context)
268
      : InferShapeContext(op, scope), device_context_(device_context) {}
269

Q
qijun 已提交
270 271 272
  template <typename PlaceType,
            typename DeviceType =
                typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
273
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
274

D
dongzhihong 已提交
275
  platform::Place GetPlace() const { return device_context_->GetPlace(); }
Q
qijun 已提交
276

D
dongzhihong 已提交
277
  const platform::DeviceContext* device_context_;
Q
Qiao Longfei 已提交
278 279
};

Q
qijun 已提交
280 281
class OpKernel {
 public:
Q
qijun 已提交
282
  /**
283
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
284 285
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
286
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
287 288
   */

289
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
290 291 292 293

  virtual ~OpKernel() {}
};

Q
Qiao Longfei 已提交
294 295
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
296 297
  struct OpKernelKey {
    platform::Place place_;
Q
Qiao Longfei 已提交
298

Y
Yu Yang 已提交
299
    OpKernelKey() = default;
L
liaogang 已提交
300
    explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
Y
Yu Yang 已提交
301 302 303
      place_ = dev_ctx.GetPlace();
    }

Q
qijun 已提交
304 305 306
    bool operator==(const OpKernelKey& o) const {
      return platform::places_are_same_class(place_, o.place_);
    }
Y
Yu Yang 已提交
307 308 309 310 311 312 313 314 315 316 317
  };

  struct OpKernelHash {
    std::hash<bool> hash_;
    size_t operator()(const OpKernelKey& key) const {
      return hash_(platform::is_gpu_place(key.place_));
    }
  };

  using OpKernelMap =
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
Q
Qiao Longfei 已提交
318

319
  void InferShape(const Scope& scope) const override {
320
    InferShape(InferShapeContext(*this, scope));
321 322
  }

Y
Yu Yang 已提交
323
  void Run(const Scope& scope,
Y
Yu Yang 已提交
324
           const platform::DeviceContext& dev_ctx) const final {
Q
Qiao Longfei 已提交
325
    auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
326
    opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
Q
Qiao Longfei 已提交
327 328
  }

Y
Yu Yang 已提交
329 330 331 332
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
333
  }
Y
Yan Chunwei 已提交
334

335 336 337 338 339 340
  bool SupportGPU() const override {
    OperatorWithKernel::OpKernelKey key;
    key.place_ = platform::GPUPlace();
    return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
  }

Y
Yu Yang 已提交
341
 protected:
342
  virtual void InferShape(const InferShapeContext& ctx) const = 0;
Q
Qiao Longfei 已提交
343 344 345 346
};

}  // namespace framework
}  // namespace paddle