operator.h 10.7 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
Q
Qiao Longfei 已提交
18 19 20 21
#include <string>
#include <unordered_map>
#include <vector>

Y
Yi Wang 已提交
22
#include "paddle/framework/attribute.h"
Q
qijun 已提交
23
#include "paddle/framework/op_desc.pb.h"
Y
Yan Chunwei 已提交
24
#include "paddle/framework/op_proto.pb.h"
Q
qijun 已提交
25 26 27 28
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
Y
Yu Yang 已提交
29
#include "paddle/platform/variant.h"
Q
qijun 已提交
30
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
31 32 33 34

namespace paddle {
namespace framework {

35
/// If a variable is a empty variable, that name will be used.
36
constexpr char kEmptyVarName[] = "@EMPTY@";
37 38 39

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
40
constexpr char kTempVarName[] = "@TEMP@";
41 42 43 44

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
45
constexpr char kGradVarSuffix[] = "@GRAD";
46 47

/// Variables with this suffix are supposed to be filled up with zeros.
48
constexpr char kZeroVarSuffix[] = "@ZERO";
49 50 51 52 53

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
54
class OperatorBase;
55 56
class InferShapeContext;
class ExecutionContext;
57

Q
Qiao Longfei 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
  virtual ~OperatorBase() {}

  template <typename T>
  inline const T& GetAttr(const std::string& name) const {
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

75
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
76

Q
Qiao Longfei 已提交
77 78 79 80
  /// Init will be called after CreateOperator, you can put some initialization
  /// logic here.
  virtual void Init() {}

Q
Qiao Longfei 已提交
81 82
  /// InferShape infer the size of Variables used by this Operator with
  /// information inside scope
Y
Yu Yang 已提交
83
  virtual void InferShape(const Scope& scope) const = 0;
Q
Qiao Longfei 已提交
84 85

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
86
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
87 88
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
89 90
  virtual bool IsNetOp() const { return false; }

91 92
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
93 94 95
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
96
  //! Get a input with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
97
  const std::string& Input(const std::string& name) const;
Y
Yu Yang 已提交
98 99
  //! Get a input which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
100
  std::vector<std::string> Inputs(const std::string& name) const;
Y
Yi Wang 已提交
101

Y
Yu Yang 已提交
102
  //! Get a output with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
103
  const std::string& Output(const std::string& name) const;
Y
Yu Yang 已提交
104 105
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
106 107
  std::vector<std::string> Outputs(const std::string& name) const;

Y
Yi Wang 已提交
108 109 110 111 112
  const std::string Type() const { return type_; }
  const std::vector<std::string> Inputs() const { return inputs_; }
  const std::vector<std::string> Outputs() const { return outputs_; }
  const AttributeMap& Attrs() const { return attrs_; }

Q
Qiao Longfei 已提交
113
 public:
Q
Qiao Longfei 已提交
114
  std::string type_;
D
dongzhihong 已提交
115 116 117 118
  // NOTE: in case of OpGrad, inputs_ contains:
  // I (Inputs)
  // O (Outputs)
  // OG (Output Gradients)
Q
Qiao Longfei 已提交
119
  std::vector<std::string> inputs_;
D
dongzhihong 已提交
120 121
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Q
Qiao Longfei 已提交
122 123
  std::vector<std::string> outputs_;
  AttributeMap attrs_;
Y
Yan Chunwei 已提交
124
  // store the arguments' offset described in op_desc.
Y
Yu Yang 已提交
125
  std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
Y
Yan Chunwei 已提交
126 127
};

128
class InferShapeContext {
Y
Yan Chunwei 已提交
129
 public:
130 131
  InferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}
132 133

  size_t InputSize() const { return op_.inputs_.size(); }
Y
Yan Chunwei 已提交
134

135 136
  size_t OutputSize() const { return op_.outputs_.size(); }

137
  const Variable* InputVar(const size_t index) const {
138
    return scope_.FindVar(op_.inputs_.at(index));
Y
Yan Chunwei 已提交
139 140
  }

141
  Variable* OutputVar(const size_t index) const {
142
    return scope_.FindVar(op_.outputs_.at(index));
Y
Yan Chunwei 已提交
143 144
  }

145
  const Variable* InputVar(const std::string& name) const {
Y
Yu Yang 已提交
146
    return scope_.FindVar(op_.Input(name));
Y
Yan Chunwei 已提交
147 148
  }

149
  Variable* OutputVar(const std::string& name) const {
Y
Yu Yang 已提交
150
    return scope_.FindVar(op_.Output(name));
Y
Yan Chunwei 已提交
151 152
  }

153 154
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
155 156
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
157
    res.reserve(names.size());
Y
Yan Chunwei 已提交
158
    std::transform(
159
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
160
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
161 162 163
    return res;
  }

164
  std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
165 166
    auto names = op_.Outputs(name);
    std::vector<const Variable*> res;
167
    res.reserve(names.size());
Y
Yan Chunwei 已提交
168
    std::transform(
169
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
170
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
171 172 173
    return res;
  }

174
  template <typename T>
175 176
  const T* Input(const size_t index) const {
    auto var = InputVar(index);
Y
Yan Chunwei 已提交
177
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
178
    return &var->Get<T>();
179 180 181
  }

  template <typename T>
182 183
  T* Output(const size_t index) const {
    auto var = OutputVar(index);
Y
Yan Chunwei 已提交
184 185
    PADDLE_ENFORCE_NOT_NULL(
        var,
Y
Yan Chunwei 已提交
186 187 188
        "Output(%d) not be nullptr, which means variable [%s] does not "
        "exist in scope",
        index, op_.outputs_[index]);
189
    return var->GetMutable<T>();
190 191 192 193
  }

  template <typename T>
  const T* Input(const std::string& name) const {
194
    auto var = InputVar(name);
Y
Yan Chunwei 已提交
195
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
196
    return &var->Get<T>();
197 198 199 200
  }

  template <typename T>
  T* Output(const std::string& name) const {
201
    auto var = OutputVar(name);
Y
Yan Chunwei 已提交
202
    PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
203
    return var->GetMutable<T>();
204 205 206 207 208 209 210 211
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
212
                   [&](const std::string& sub_name) {
213
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
214 215 216
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiInput(%s:%s) should not be nullptr", name,
                         sub_name);
217
                     return &var->Get<T>();
218 219 220 221 222 223 224 225 226 227
                   });
    return res;
  }

  template <typename T>
  std::vector<const T*> MultiOutput(const std::string& name) const {
    auto names = op_.Outputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
228
                   [&](const std::string& sub_name) {
229
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
230 231 232
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiOutput(%s:%s) should not be nullptr", name,
                         sub_name);
233
                     return var->GetMutable<T>();
234 235 236 237 238
                   });
    return res;
  }

  const OperatorBase& op_;
239
  const Scope& scope_;
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
  using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
  using EigenDeviceType = Eigen::GpuDevice;
};
#endif

257
class ExecutionContext : public InferShapeContext {
258
 public:
259
  ExecutionContext(const OperatorBase& op, const Scope& scope,
D
dongzhihong 已提交
260
                   const platform::DeviceContext* device_context)
261
      : InferShapeContext(op, scope), device_context_(device_context) {}
262

Q
qijun 已提交
263 264 265
  template <typename PlaceType,
            typename DeviceType =
                typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
266
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
267

D
dongzhihong 已提交
268
  platform::Place GetPlace() const { return device_context_->GetPlace(); }
Q
qijun 已提交
269

D
dongzhihong 已提交
270
  const platform::DeviceContext* device_context_;
Q
Qiao Longfei 已提交
271 272
};

Q
qijun 已提交
273 274
class OpKernel {
 public:
Q
qijun 已提交
275
  /**
276
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
277 278
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
279
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
280 281
   */

282
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
283 284 285 286

  virtual ~OpKernel() {}
};

Q
Qiao Longfei 已提交
287 288
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
289 290
  struct OpKernelKey {
    platform::Place place_;
Q
Qiao Longfei 已提交
291

Y
Yu Yang 已提交
292
    OpKernelKey() = default;
L
liaogang 已提交
293
    explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
Y
Yu Yang 已提交
294 295 296
      place_ = dev_ctx.GetPlace();
    }

Q
qijun 已提交
297 298 299
    bool operator==(const OpKernelKey& o) const {
      return platform::places_are_same_class(place_, o.place_);
    }
Y
Yu Yang 已提交
300 301 302 303 304 305 306 307 308 309 310
  };

  struct OpKernelHash {
    std::hash<bool> hash_;
    size_t operator()(const OpKernelKey& key) const {
      return hash_(platform::is_gpu_place(key.place_));
    }
  };

  using OpKernelMap =
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
Q
Qiao Longfei 已提交
311

312
  void InferShape(const Scope& scope) const override {
313
    InferShape(InferShapeContext(*this, scope));
314 315
  }

Y
Yu Yang 已提交
316
  void Run(const Scope& scope,
Y
Yu Yang 已提交
317
           const platform::DeviceContext& dev_ctx) const final {
Q
Qiao Longfei 已提交
318
    auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
319
    opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
Q
Qiao Longfei 已提交
320 321
  }

Y
Yu Yang 已提交
322 323 324 325
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
326
  }
Y
Yan Chunwei 已提交
327

328 329 330 331 332 333
  bool SupportGPU() const override {
    OperatorWithKernel::OpKernelKey key;
    key.place_ = platform::GPUPlace();
    return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
  }

Y
Yu Yang 已提交
334
 protected:
335
  virtual void InferShape(const InferShapeContext& ctx) const = 0;
Q
Qiao Longfei 已提交
336 337 338 339
};

}  // namespace framework
}  // namespace paddle