operator.h 10.3 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
Q
Qiao Longfei 已提交
18 19 20 21 22
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

Q
qijun 已提交
23 24
#include "paddle/framework/attr_checker.h"
#include "paddle/framework/op_desc.pb.h"
Y
Yan Chunwei 已提交
25
#include "paddle/framework/op_proto.pb.h"
Q
qijun 已提交
26 27 28 29 30
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
31 32 33 34 35

namespace paddle {
namespace framework {

class OperatorBase;
36 37
class InferShapeContext;
class ExecutionContext;
Q
Qiao Longfei 已提交
38 39 40 41 42 43 44 45
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
46 47 48 49 50 51 52
  /// If a variable is a empty variable, that name will be used.
  static std::string EMPTY_VAR_NAME() { return "@EMPTY@"; }

  /// If a variable is a temporary variable, that name will be set in Python,
  /// but it will be convert to a unique name in scope after OpCreator.
  static std::string TMP_VAR_NAME() { return "@TEMP@"; }

F
fengjiayi 已提交
53 54 55 56 57
  /// If a variable's name has a certain suffix, it means that the
  /// variable is the gradient of another varibale.
  /// e.g. Variable "x@GRAD" is the gradient of varibale "x".
  static std::string GRAD_VAR_SUFFIX() { return "@GRAD"; }

Q
Qiao Longfei 已提交
58 59 60 61
  static std::string GRAD_VAR_NAME(const std::string& name) {
    return name + GRAD_VAR_SUFFIX();
  }

62 63 64
  /// Variables with this suffix are supposed to be filled up with zeros.
  static std::string ZERO_VAR_SUFFIX() { return "@ZERO"; }

Q
Qiao Longfei 已提交
65 66 67 68 69 70 71 72 73
  virtual ~OperatorBase() {}

  template <typename T>
  inline const T& GetAttr(const std::string& name) const {
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

74
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
75

Q
Qiao Longfei 已提交
76 77 78 79
  /// Init will be called after CreateOperator, you can put some initialization
  /// logic here.
  virtual void Init() {}

Q
Qiao Longfei 已提交
80 81
  /// InferShape infer the size of Variables used by this Operator with
  /// information inside scope
Y
Yu Yang 已提交
82
  virtual void InferShape(const Scope& scope) const = 0;
Q
Qiao Longfei 已提交
83 84

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
85
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
86 87
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
88 89
  virtual bool IsNetOp() const { return false; }

D
dongzhihong 已提交
90 91 92
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
93
  //! Get a input with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
94
  const std::string& Input(const std::string& name) const;
Y
Yu Yang 已提交
95

Y
Yu Yang 已提交
96 97
  //! Get a input which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
98
  std::vector<std::string> Inputs(const std::string& name) const;
Y
Yu Yang 已提交
99
  //! Get a output with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
100
  const std::string& Output(const std::string& name) const;
Y
Yu Yang 已提交
101 102
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
103 104
  std::vector<std::string> Outputs(const std::string& name) const;

Q
Qiao Longfei 已提交
105
 public:
Q
Qiao Longfei 已提交
106
  std::string type_;
D
dongzhihong 已提交
107 108 109 110
  // NOTE: in case of OpGrad, inputs_ contains:
  // I (Inputs)
  // O (Outputs)
  // OG (Output Gradients)
Q
Qiao Longfei 已提交
111
  std::vector<std::string> inputs_;
D
dongzhihong 已提交
112 113
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Q
Qiao Longfei 已提交
114 115
  std::vector<std::string> outputs_;
  AttributeMap attrs_;
Y
Yan Chunwei 已提交
116
  // store the arguments' offset described in op_desc.
Y
Yu Yang 已提交
117
  std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
Y
Yan Chunwei 已提交
118 119
};

120
class OperatorContext {
Y
Yan Chunwei 已提交
121
 public:
122
  OperatorContext(const OperatorBase* op, const Scope& scope)
123 124 125
      : op_(*op), scope_(scope) {}

  size_t InputSize() const { return op_.inputs_.size(); }
Y
Yan Chunwei 已提交
126

127 128
  size_t OutputSize() const { return op_.outputs_.size(); }

129
  const Variable* InputVar(const size_t index) const {
130
    return scope_.FindVar(op_.inputs_.at(index));
Y
Yan Chunwei 已提交
131 132
  }

133
  Variable* OutputVar(const size_t index) const {
134
    return scope_.FindVar(op_.outputs_.at(index));
Y
Yan Chunwei 已提交
135 136
  }

137
  const Variable* InputVar(const std::string& name) const {
Y
Yu Yang 已提交
138
    return scope_.FindVar(op_.Input(name));
Y
Yan Chunwei 已提交
139 140
  }

141
  Variable* OutputVar(const std::string& name) const {
Y
Yu Yang 已提交
142
    return scope_.FindVar(op_.Output(name));
Y
Yan Chunwei 已提交
143 144
  }

145 146
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
147 148
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
149
    res.reserve(names.size());
Y
Yan Chunwei 已提交
150
    std::transform(
151
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
152
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
153 154 155
    return res;
  }

156
  std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
157 158
    auto names = op_.Outputs(name);
    std::vector<const Variable*> res;
159
    res.reserve(names.size());
Y
Yan Chunwei 已提交
160
    std::transform(
161
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
162
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
163 164 165
    return res;
  }

166
  template <typename T>
167 168 169 170
  const T* Input(const size_t index) const {
    auto var = InputVar(index);
    PADDLE_ENFORCE(var != nullptr, "Input(%d) should not be nullptr", index);
    return &var->Get<T>();
171 172 173
  }

  template <typename T>
174 175 176 177
  T* Output(const size_t index) const {
    auto var = OutputVar(index);
    PADDLE_ENFORCE(var != nullptr, "Output(%d) should not be nullptr", index);
    return var->GetMutable<T>();
178 179 180 181
  }

  template <typename T>
  const T* Input(const std::string& name) const {
182 183 184
    auto var = InputVar(name);
    PADDLE_ENFORCE(var != nullptr, "Input(%s) should not be nullptr", name);
    return &var->Get<T>();
185 186 187 188
  }

  template <typename T>
  T* Output(const std::string& name) const {
189 190 191
    auto var = OutputVar(name);
    PADDLE_ENFORCE(var != nullptr, "Output(%s) should not be nullptr", name);
    return var->GetMutable<T>();
192 193 194 195 196 197 198 199
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
200
                   [&](const std::string& sub_name) {
201
                     auto var = scope_.FindVar(sub_name);
202 203 204 205
                     PADDLE_ENFORCE(var != nullptr,
                                    "MultiInput(%s:%s) should not be nullptr",
                                    name, sub_name);
                     return &var->Get<T>();
206 207 208 209 210 211 212 213 214 215
                   });
    return res;
  }

  template <typename T>
  std::vector<const T*> MultiOutput(const std::string& name) const {
    auto names = op_.Outputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
216
                   [&](const std::string& sub_name) {
217
                     auto var = scope_.FindVar(sub_name);
218 219 220 221
                     PADDLE_ENFORCE(var != nullptr,
                                    "MultiOutput(%s:%s) should not be nullptr",
                                    name, sub_name);
                     return var->GetMutable<T>();
222 223 224 225 226
                   });
    return res;
  }

  const OperatorBase& op_;
227
  const Scope& scope_;
228 229 230 231
};

class InferShapeContext : public OperatorContext {
 public:
232
  InferShapeContext(const OperatorBase* op, const Scope& scope)
233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252
      : OperatorContext(op, scope) {}
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
  using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
  using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class ExecutionContext : public OperatorContext {
 public:
253
  ExecutionContext(const OperatorBase* op, const Scope& scope,
254 255 256
                   const platform::DeviceContext& device_context)
      : OperatorContext(op, scope), device_context_(device_context) {}

Q
qijun 已提交
257 258 259
  template <typename PlaceType,
            typename DeviceType =
                typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
260
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
261 262 263

  platform::Place GetPlace() const { return device_context_.GetPlace(); }

Y
Yan Chunwei 已提交
264
  const platform::DeviceContext& device_context_;
Q
Qiao Longfei 已提交
265 266
};

Q
qijun 已提交
267 268
class OpKernel {
 public:
Q
qijun 已提交
269
  /**
270
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
271 272
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
273
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
274 275
   */

276
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
277 278 279 280

  virtual ~OpKernel() {}
};

Q
Qiao Longfei 已提交
281 282
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
283 284
  struct OpKernelKey {
    platform::Place place_;
Q
Qiao Longfei 已提交
285

Y
Yu Yang 已提交
286
    OpKernelKey() = default;
L
liaogang 已提交
287
    explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
Y
Yu Yang 已提交
288 289 290
      place_ = dev_ctx.GetPlace();
    }

Q
qijun 已提交
291 292 293
    bool operator==(const OpKernelKey& o) const {
      return platform::places_are_same_class(place_, o.place_);
    }
Y
Yu Yang 已提交
294 295 296 297 298 299 300 301 302 303 304
  };

  struct OpKernelHash {
    std::hash<bool> hash_;
    size_t operator()(const OpKernelKey& key) const {
      return hash_(platform::is_gpu_place(key.place_));
    }
  };

  using OpKernelMap =
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
Q
Qiao Longfei 已提交
305

306
  void InferShape(const Scope& scope) const {
307 308 309
    InferShape(InferShapeContext(this, scope));
  }

Y
Yu Yang 已提交
310
  void Run(const Scope& scope,
Y
Yu Yang 已提交
311
           const platform::DeviceContext& dev_ctx) const final {
Q
Qiao Longfei 已提交
312
    auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
313
    opKernel->Compute(ExecutionContext(this, scope, dev_ctx));
Q
Qiao Longfei 已提交
314 315
  }

Y
Yu Yang 已提交
316 317 318 319
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
320
  }
Y
Yan Chunwei 已提交
321

Y
Yu Yang 已提交
322
 protected:
323
  virtual void InferShape(const InferShapeContext& ctx) const = 0;
Q
Qiao Longfei 已提交
324 325 326 327
};

}  // namespace framework
}  // namespace paddle