operator.h 10.6 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
Q
Qiao Longfei 已提交
18 19 20 21 22
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

Y
Yi Wang 已提交
23
#include "paddle/framework/attribute.h"
Q
qijun 已提交
24
#include "paddle/framework/op_desc.pb.h"
Y
Yan Chunwei 已提交
25
#include "paddle/framework/op_proto.pb.h"
Q
qijun 已提交
26 27 28 29 30
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
31 32 33 34

namespace paddle {
namespace framework {

35 36 37 38 39 40 41 42 43 44 45 46 47
/// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@";

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@";

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD";

/// Variables with this suffix are supposed to be filled up with zeros.
Y
Yi Wang 已提交
48
const std::string kZeroVarSuffix = "@ZERO";
49 50 51 52 53

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
54
class OperatorBase;
55 56
class InferShapeContext;
class ExecutionContext;
57

Q
Qiao Longfei 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
  virtual ~OperatorBase() {}

  template <typename T>
  inline const T& GetAttr(const std::string& name) const {
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

75
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
76

Q
Qiao Longfei 已提交
77 78 79 80
  /// Init will be called after CreateOperator, you can put some initialization
  /// logic here.
  virtual void Init() {}

Q
Qiao Longfei 已提交
81 82
  /// InferShape infer the size of Variables used by this Operator with
  /// information inside scope
Y
Yu Yang 已提交
83
  virtual void InferShape(const Scope& scope) const = 0;
Q
Qiao Longfei 已提交
84 85

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
86
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
87 88
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
89 90
  virtual bool IsNetOp() const { return false; }

91 92
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
93 94 95
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
96
  //! Get a input with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
97
  const std::string& Input(const std::string& name) const;
Y
Yu Yang 已提交
98

Y
Yu Yang 已提交
99 100
  //! Get a input which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
101
  std::vector<std::string> Inputs(const std::string& name) const;
Y
Yu Yang 已提交
102
  //! Get a output with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
103
  const std::string& Output(const std::string& name) const;
Y
Yu Yang 已提交
104 105
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
106 107
  std::vector<std::string> Outputs(const std::string& name) const;

Q
Qiao Longfei 已提交
108
 public:
Q
Qiao Longfei 已提交
109
  std::string type_;
D
dongzhihong 已提交
110 111 112 113
  // NOTE: in case of OpGrad, inputs_ contains:
  // I (Inputs)
  // O (Outputs)
  // OG (Output Gradients)
Q
Qiao Longfei 已提交
114
  std::vector<std::string> inputs_;
D
dongzhihong 已提交
115 116
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Q
Qiao Longfei 已提交
117 118
  std::vector<std::string> outputs_;
  AttributeMap attrs_;
Y
Yan Chunwei 已提交
119
  // store the arguments' offset described in op_desc.
Y
Yu Yang 已提交
120
  std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
Y
Yan Chunwei 已提交
121 122
};

123
class OperatorContext {
Y
Yan Chunwei 已提交
124
 public:
125
  OperatorContext(const OperatorBase* op, const Scope& scope)
126 127 128
      : op_(*op), scope_(scope) {}

  size_t InputSize() const { return op_.inputs_.size(); }
Y
Yan Chunwei 已提交
129

130 131
  size_t OutputSize() const { return op_.outputs_.size(); }

132
  const Variable* InputVar(const size_t index) const {
133
    return scope_.FindVar(op_.inputs_.at(index));
Y
Yan Chunwei 已提交
134 135
  }

136
  Variable* OutputVar(const size_t index) const {
137
    return scope_.FindVar(op_.outputs_.at(index));
Y
Yan Chunwei 已提交
138 139
  }

140
  const Variable* InputVar(const std::string& name) const {
Y
Yu Yang 已提交
141
    return scope_.FindVar(op_.Input(name));
Y
Yan Chunwei 已提交
142 143
  }

144
  Variable* OutputVar(const std::string& name) const {
Y
Yu Yang 已提交
145
    return scope_.FindVar(op_.Output(name));
Y
Yan Chunwei 已提交
146 147
  }

148 149
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
150 151
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
152
    res.reserve(names.size());
Y
Yan Chunwei 已提交
153
    std::transform(
154
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
155
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
156 157 158
    return res;
  }

159
  std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
160 161
    auto names = op_.Outputs(name);
    std::vector<const Variable*> res;
162
    res.reserve(names.size());
Y
Yan Chunwei 已提交
163
    std::transform(
164
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
165
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
166 167 168
    return res;
  }

169
  template <typename T>
170 171
  const T* Input(const size_t index) const {
    auto var = InputVar(index);
Y
Yan Chunwei 已提交
172
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
173
    return &var->Get<T>();
174 175 176
  }

  template <typename T>
177 178
  T* Output(const size_t index) const {
    auto var = OutputVar(index);
Y
Yan Chunwei 已提交
179 180
    PADDLE_ENFORCE_NOT_NULL(
        var,
Y
Yan Chunwei 已提交
181 182 183
        "Output(%d) not be nullptr, which means variable [%s] does not "
        "exist in scope",
        index, op_.outputs_[index]);
184
    return var->GetMutable<T>();
185 186 187 188
  }

  template <typename T>
  const T* Input(const std::string& name) const {
189
    auto var = InputVar(name);
Y
Yan Chunwei 已提交
190
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
191
    return &var->Get<T>();
192 193 194 195
  }

  template <typename T>
  T* Output(const std::string& name) const {
196
    auto var = OutputVar(name);
Y
Yan Chunwei 已提交
197
    PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
198
    return var->GetMutable<T>();
199 200 201 202 203 204 205 206
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
207
                   [&](const std::string& sub_name) {
208
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
209 210 211
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiInput(%s:%s) should not be nullptr", name,
                         sub_name);
212
                     return &var->Get<T>();
213 214 215 216 217 218 219 220 221 222
                   });
    return res;
  }

  template <typename T>
  std::vector<const T*> MultiOutput(const std::string& name) const {
    auto names = op_.Outputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
223
                   [&](const std::string& sub_name) {
224
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
225 226 227
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiOutput(%s:%s) should not be nullptr", name,
                         sub_name);
228
                     return var->GetMutable<T>();
229 230 231 232 233
                   });
    return res;
  }

  const OperatorBase& op_;
234
  const Scope& scope_;
235 236 237 238
};

class InferShapeContext : public OperatorContext {
 public:
239
  InferShapeContext(const OperatorBase* op, const Scope& scope)
240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259
      : OperatorContext(op, scope) {}
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
  using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
  using EigenDeviceType = Eigen::GpuDevice;
};
#endif

class ExecutionContext : public OperatorContext {
 public:
260
  ExecutionContext(const OperatorBase* op, const Scope& scope,
D
dongzhihong 已提交
261
                   const platform::DeviceContext* device_context)
262 263
      : OperatorContext(op, scope), device_context_(device_context) {}

Q
qijun 已提交
264 265 266
  template <typename PlaceType,
            typename DeviceType =
                typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
267
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
268

D
dongzhihong 已提交
269
  platform::Place GetPlace() const { return device_context_->GetPlace(); }
Q
qijun 已提交
270

D
dongzhihong 已提交
271
  const platform::DeviceContext* device_context_;
Q
Qiao Longfei 已提交
272 273
};

Q
qijun 已提交
274 275
class OpKernel {
 public:
Q
qijun 已提交
276
  /**
277
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
278 279
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
280
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
281 282
   */

283
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
284 285 286 287

  virtual ~OpKernel() {}
};

Q
Qiao Longfei 已提交
288 289
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
290 291
  struct OpKernelKey {
    platform::Place place_;
Q
Qiao Longfei 已提交
292

Y
Yu Yang 已提交
293
    OpKernelKey() = default;
L
liaogang 已提交
294
    explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
Y
Yu Yang 已提交
295 296 297
      place_ = dev_ctx.GetPlace();
    }

Q
qijun 已提交
298 299 300
    bool operator==(const OpKernelKey& o) const {
      return platform::places_are_same_class(place_, o.place_);
    }
Y
Yu Yang 已提交
301 302 303 304 305 306 307 308 309 310 311
  };

  struct OpKernelHash {
    std::hash<bool> hash_;
    size_t operator()(const OpKernelKey& key) const {
      return hash_(platform::is_gpu_place(key.place_));
    }
  };

  using OpKernelMap =
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
Q
Qiao Longfei 已提交
312

313
  void InferShape(const Scope& scope) const override {
314 315 316
    InferShape(InferShapeContext(this, scope));
  }

Y
Yu Yang 已提交
317
  void Run(const Scope& scope,
Y
Yu Yang 已提交
318
           const platform::DeviceContext& dev_ctx) const final {
Q
Qiao Longfei 已提交
319
    auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
D
dongzhihong 已提交
320
    opKernel->Compute(ExecutionContext(this, scope, &dev_ctx));
Q
Qiao Longfei 已提交
321 322
  }

Y
Yu Yang 已提交
323 324 325 326
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
327
  }
Y
Yan Chunwei 已提交
328

329 330 331 332 333 334
  bool SupportGPU() const override {
    OperatorWithKernel::OpKernelKey key;
    key.place_ = platform::GPUPlace();
    return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
  }

Y
Yu Yang 已提交
335
 protected:
336
  virtual void InferShape(const InferShapeContext& ctx) const = 0;
Q
Qiao Longfei 已提交
337 338 339 340
};

}  // namespace framework
}  // namespace paddle