operator.h 10.4 KB
Newer Older
Q
Qiao Longfei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

D
dongzhihong 已提交
17
#include <algorithm>
Q
Qiao Longfei 已提交
18 19 20 21 22
#include <boost/variant.hpp>
#include <string>
#include <unordered_map>
#include <vector>

Y
Yi Wang 已提交
23
#include "paddle/framework/attribute.h"
Q
qijun 已提交
24
#include "paddle/framework/op_desc.pb.h"
Y
Yan Chunwei 已提交
25
#include "paddle/framework/op_proto.pb.h"
Q
qijun 已提交
26 27 28 29 30
#include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h"
#include "paddle/platform/device_context.h"
#include "paddle/platform/place.h"
#include "paddle/utils/Error.h"
Q
Qiao Longfei 已提交
31 32 33 34

namespace paddle {
namespace framework {

35 36 37 38 39 40 41 42 43 44 45 46 47
/// If a variable is a empty variable, that name will be used.
const std::string kEmptyVarName = "@EMPTY@";

/// If a variable is a temporary variable, that name will be set in Python,
/// but it will be convert to a unique name in scope after OpCreator.
const std::string kTempVarName = "@TEMP@";

/// If a variable's name has a certain suffix, it means that the
/// variable is the gradient of another varibale.
/// e.g. Variable "x@GRAD" is the gradient of varibale "x".
const std::string kGradVarSuffix = "@GRAD";

/// Variables with this suffix are supposed to be filled up with zeros.
Y
Yi Wang 已提交
48
const std::string kZeroVarSuffix = "@ZERO";
49 50 51 52 53

inline std::string GradVarName(const std::string& var_name) {
  return var_name + kGradVarSuffix;
}

Q
Qiao Longfei 已提交
54
class OperatorBase;
55 56
class InferShapeContext;
class ExecutionContext;
57

Q
Qiao Longfei 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74
/**
 * OperatorBase has the basic element that Net will call to do computation.
 * Only CreateOperator from OpRegistry will new Operator directly. User
 * should always construct a proto message OpDesc and call
 * OpRegistry::CreateOp(op_desc) to get an Operator instance.
 */
class OperatorBase {
 public:
  virtual ~OperatorBase() {}

  template <typename T>
  inline const T& GetAttr(const std::string& name) const {
    PADDLE_ENFORCE(attrs_.count(name) != 0, "%s should be in AttributeMap",
                   name);
    return boost::get<T>(attrs_.at(name));
  }

75
  virtual std::string DebugString() const;
Q
Qiao Longfei 已提交
76

Q
Qiao Longfei 已提交
77 78 79 80
  /// Init will be called after CreateOperator, you can put some initialization
  /// logic here.
  virtual void Init() {}

Q
Qiao Longfei 已提交
81 82
  /// InferShape infer the size of Variables used by this Operator with
  /// information inside scope
Y
Yu Yang 已提交
83
  virtual void InferShape(const Scope& scope) const = 0;
Q
Qiao Longfei 已提交
84 85

  /// Net will call this function to Run an op.
Y
Yu Yang 已提交
86
  virtual void Run(const Scope& scope,
Y
Yu Yang 已提交
87 88
                   const platform::DeviceContext& dev_ctx) const = 0;

Y
Yu Yang 已提交
89 90
  virtual bool IsNetOp() const { return false; }

91 92
  virtual bool SupportGPU() const { return false; }

D
dongzhihong 已提交
93 94 95
  /// rename inputs outputs name
  void Rename(const std::string& old_name, const std::string& new_name);

Y
Yu Yang 已提交
96
  //! Get a input with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
97
  const std::string& Input(const std::string& name) const;
Y
Yu Yang 已提交
98

Y
Yu Yang 已提交
99 100
  //! Get a input which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
101
  std::vector<std::string> Inputs(const std::string& name) const;
Y
Yu Yang 已提交
102
  //! Get a output with argument's name described in `op_proto`
Y
Yan Chunwei 已提交
103
  const std::string& Output(const std::string& name) const;
Y
Yu Yang 已提交
104 105
  //! Get an output which has multiple variables.
  //! TODO add a vector_view to prevent memory copy.
Y
Yan Chunwei 已提交
106 107
  std::vector<std::string> Outputs(const std::string& name) const;

Q
Qiao Longfei 已提交
108
 public:
Q
Qiao Longfei 已提交
109
  std::string type_;
D
dongzhihong 已提交
110 111 112 113
  // NOTE: in case of OpGrad, inputs_ contains:
  // I (Inputs)
  // O (Outputs)
  // OG (Output Gradients)
Q
Qiao Longfei 已提交
114
  std::vector<std::string> inputs_;
D
dongzhihong 已提交
115 116
  // NOTE: in case of OpGrad, outputs_ contains
  // IG (Inputs Gradients)
Q
Qiao Longfei 已提交
117 118
  std::vector<std::string> outputs_;
  AttributeMap attrs_;
Y
Yan Chunwei 已提交
119
  // store the arguments' offset described in op_desc.
Y
Yu Yang 已提交
120
  std::shared_ptr<std::unordered_map<std::string, int>> in_out_idxs_;
Y
Yan Chunwei 已提交
121 122
};

123
class InferShapeContext {
Y
Yan Chunwei 已提交
124
 public:
125 126
  InferShapeContext(const OperatorBase& op, const Scope& scope)
      : op_(op), scope_(scope) {}
127 128

  size_t InputSize() const { return op_.inputs_.size(); }
Y
Yan Chunwei 已提交
129

130 131
  size_t OutputSize() const { return op_.outputs_.size(); }

132
  const Variable* InputVar(const size_t index) const {
133
    return scope_.FindVar(op_.inputs_.at(index));
Y
Yan Chunwei 已提交
134 135
  }

136
  Variable* OutputVar(const size_t index) const {
137
    return scope_.FindVar(op_.outputs_.at(index));
Y
Yan Chunwei 已提交
138 139
  }

140
  const Variable* InputVar(const std::string& name) const {
Y
Yu Yang 已提交
141
    return scope_.FindVar(op_.Input(name));
Y
Yan Chunwei 已提交
142 143
  }

144
  Variable* OutputVar(const std::string& name) const {
Y
Yu Yang 已提交
145
    return scope_.FindVar(op_.Output(name));
Y
Yan Chunwei 已提交
146 147
  }

148 149
  const std::vector<const Variable*> MultiInputVar(
      const std::string& name) const {
Y
Yan Chunwei 已提交
150 151
    auto names = op_.Inputs(name);
    std::vector<const Variable*> res;
152
    res.reserve(names.size());
Y
Yan Chunwei 已提交
153
    std::transform(
154
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
155
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
156 157 158
    return res;
  }

159
  std::vector<const Variable*> MultiOutputVar(const std::string& name) const {
Y
Yan Chunwei 已提交
160 161
    auto names = op_.Outputs(name);
    std::vector<const Variable*> res;
162
    res.reserve(names.size());
Y
Yan Chunwei 已提交
163
    std::transform(
164
        names.begin(), names.end(), std::back_inserter(res),
Y
Yu Yang 已提交
165
        [this](const std::string& name) { return scope_.FindVar(name); });
Y
Yan Chunwei 已提交
166 167 168
    return res;
  }

169
  template <typename T>
170 171
  const T* Input(const size_t index) const {
    auto var = InputVar(index);
Y
Yan Chunwei 已提交
172
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%d) should not be nullptr", index);
173
    return &var->Get<T>();
174 175 176
  }

  template <typename T>
177 178
  T* Output(const size_t index) const {
    auto var = OutputVar(index);
Y
Yan Chunwei 已提交
179 180
    PADDLE_ENFORCE_NOT_NULL(
        var,
Y
Yan Chunwei 已提交
181 182 183
        "Output(%d) not be nullptr, which means variable [%s] does not "
        "exist in scope",
        index, op_.outputs_[index]);
184
    return var->GetMutable<T>();
185 186 187 188
  }

  template <typename T>
  const T* Input(const std::string& name) const {
189
    auto var = InputVar(name);
Y
Yan Chunwei 已提交
190
    PADDLE_ENFORCE_NOT_NULL(var, "Input(%s) should not be nullptr", name);
191
    return &var->Get<T>();
192 193 194 195
  }

  template <typename T>
  T* Output(const std::string& name) const {
196
    auto var = OutputVar(name);
Y
Yan Chunwei 已提交
197
    PADDLE_ENFORCE_NOT_NULL(var, "Output(%s) should not be nullptr", name);
198
    return var->GetMutable<T>();
199 200 201 202 203 204 205 206
  }

  template <typename T>
  const std::vector<const T*> MultiInput(const std::string& name) const {
    auto names = op_.Inputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
207
                   [&](const std::string& sub_name) {
208
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
209 210 211
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiInput(%s:%s) should not be nullptr", name,
                         sub_name);
212
                     return &var->Get<T>();
213 214 215 216 217 218 219 220 221 222
                   });
    return res;
  }

  template <typename T>
  std::vector<const T*> MultiOutput(const std::string& name) const {
    auto names = op_.Outputs(name);
    std::vector<const T*> res;
    res.reserve(names.size());
    std::transform(names.begin(), names.end(), std::back_inserter(res),
223
                   [&](const std::string& sub_name) {
224
                     auto var = scope_.FindVar(sub_name);
Y
Yan Chunwei 已提交
225 226 227
                     PADDLE_ENFORCE_NOT_NULL(
                         var, "MultiOutput(%s:%s) should not be nullptr", name,
                         sub_name);
228
                     return var->GetMutable<T>();
229 230 231 232 233
                   });
    return res;
  }

  const OperatorBase& op_;
234
  const Scope& scope_;
235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251
};

template <typename T>
struct EigenDeviceConverter;

template <>
struct EigenDeviceConverter<platform::CPUPlace> {
  using EigenDeviceType = Eigen::DefaultDevice;
};

#ifndef PADDLE_ONLY_CPU
template <>
struct EigenDeviceConverter<platform::GPUPlace> {
  using EigenDeviceType = Eigen::GpuDevice;
};
#endif

252
class ExecutionContext : public InferShapeContext {
253
 public:
254
  ExecutionContext(const OperatorBase& op, const Scope& scope,
D
dongzhihong 已提交
255
                   const platform::DeviceContext* device_context)
256
      : InferShapeContext(op, scope), device_context_(device_context) {}
257

Q
qijun 已提交
258 259 260
  template <typename PlaceType,
            typename DeviceType =
                typename EigenDeviceConverter<PlaceType>::EigenDeviceType>
261
  DeviceType& GetEigenDevice() const;
Q
qijun 已提交
262

D
dongzhihong 已提交
263
  platform::Place GetPlace() const { return device_context_->GetPlace(); }
Q
qijun 已提交
264

D
dongzhihong 已提交
265
  const platform::DeviceContext* device_context_;
Q
Qiao Longfei 已提交
266 267
};

Q
qijun 已提交
268 269
class OpKernel {
 public:
Q
qijun 已提交
270
  /**
271
   * ExecutionContext is the only parameter of Kernel Run function.
Q
qijun 已提交
272 273
   * Run will get input/output variables, state such as momentum and
   * device resource such as CUDA stream, cublas handle, etc. from
274
   * ExecutionContext. User should construct it before run the Operator.
Q
qijun 已提交
275 276
   */

277
  virtual void Compute(const ExecutionContext& context) const = 0;
Y
Yu Yang 已提交
278 279 280 281

  virtual ~OpKernel() {}
};

Q
Qiao Longfei 已提交
282 283
class OperatorWithKernel : public OperatorBase {
 public:
Y
Yu Yang 已提交
284 285
  struct OpKernelKey {
    platform::Place place_;
Q
Qiao Longfei 已提交
286

Y
Yu Yang 已提交
287
    OpKernelKey() = default;
L
liaogang 已提交
288
    explicit OpKernelKey(const platform::DeviceContext& dev_ctx) {
Y
Yu Yang 已提交
289 290 291
      place_ = dev_ctx.GetPlace();
    }

Q
qijun 已提交
292 293 294
    bool operator==(const OpKernelKey& o) const {
      return platform::places_are_same_class(place_, o.place_);
    }
Y
Yu Yang 已提交
295 296 297 298 299 300 301 302 303 304 305
  };

  struct OpKernelHash {
    std::hash<bool> hash_;
    size_t operator()(const OpKernelKey& key) const {
      return hash_(platform::is_gpu_place(key.place_));
    }
  };

  using OpKernelMap =
      std::unordered_map<OpKernelKey, std::unique_ptr<OpKernel>, OpKernelHash>;
Q
Qiao Longfei 已提交
306

307
  void InferShape(const Scope& scope) const override {
308
    InferShape(InferShapeContext(*this, scope));
309 310
  }

Y
Yu Yang 已提交
311
  void Run(const Scope& scope,
Y
Yu Yang 已提交
312
           const platform::DeviceContext& dev_ctx) const final {
Q
Qiao Longfei 已提交
313
    auto& opKernel = AllOpKernels().at(type_).at(OpKernelKey(dev_ctx));
314
    opKernel->Compute(ExecutionContext(*this, scope, &dev_ctx));
Q
Qiao Longfei 已提交
315 316
  }

Y
Yu Yang 已提交
317 318 319 320
  static std::unordered_map<std::string /* op_type */, OpKernelMap>&
  AllOpKernels() {
    static std::unordered_map<std::string, OpKernelMap> g_all_op_kernels;
    return g_all_op_kernels;
Y
Yu Yang 已提交
321
  }
Y
Yan Chunwei 已提交
322

323 324 325 326 327 328
  bool SupportGPU() const override {
    OperatorWithKernel::OpKernelKey key;
    key.place_ = platform::GPUPlace();
    return OperatorWithKernel::AllOpKernels().at(type_).count(key) != 0;
  }

Y
Yu Yang 已提交
329
 protected:
330
  virtual void InferShape(const InferShapeContext& ctx) const = 0;
Q
Qiao Longfei 已提交
331 332 333 334
};

}  // namespace framework
}  // namespace paddle