operator.h 6.9 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14 15 16

#pragma once

L
liuruilong 已提交
17
#include <map>
Z
zhaojiaying01 已提交
18
#include <string>
H
hjchen2 已提交
19
#include <utility>
Z
zhaojiaying01 已提交
20
#include <vector>
L
liuruilong 已提交
21

L
liuruilong 已提交
22 23
#include "common/enforce.h"
#include "common/type_define.h"
L
liuruilong 已提交
24 25
#include "common/types.h"
#include "common/variant.h"
L
liuruilong 已提交
26
#include "framework/attribute.h"
L
liuruilong 已提交
27
#include "framework/op_info.h"
L
liuruilong 已提交
28
#include "framework/op_kernel_type.h"
L
liuruilong 已提交
29 30
#include "framework/op_registry.h"
#include "framework/program/block_desc.h"
L
liuruilong 已提交
31
#include "framework/program/program-optimize/node.h"
L
liuruilong 已提交
32 33 34
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
L
liuruilong 已提交
35
#ifdef PADDLE_MOBILE_CL
L
liuruilong 已提交
36
#include "framework/cl/cl_helper.h"
37
#include "framework/cl/cl_scope.h"
L
liuruilong 已提交
38
#endif
39

朔-望's avatar
朔-望 已提交
40
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
41
namespace framework {
朔-望's avatar
朔-望 已提交
42

W
wangliu 已提交
43
template <typename T>
44
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
W
wangliu 已提交
45 46 47 48 49 50 51 52 53 54
                      const Scope &scope) {
  auto var_vec = var_map.at(key);
  if (!var_vec.empty()) {
    auto var = scope.FindVar(var_vec[0]);
    return var->GetMutable<T>();
  } else {
    return nullptr;
  }
}

朔-望's avatar
朔-望 已提交
55
template <typename Dtype>
L
liuruilong 已提交
56
class OperatorBase {
朔-望's avatar
朔-望 已提交
57
 public:
58 59 60 61
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
朔-望's avatar
朔-望 已提交
62

E
eclipsess 已提交
63
  virtual void Init() = 0;
64 65 66 67 68 69 70
  virtual void InferShape() const = 0;
  virtual void Run();
  virtual void RunImpl() = 0;

  std::vector<std::string> GetOutKeys() const;
  std::vector<std::string> GetInputKeys() const;

71 72 73 74
  const VariableNameMap &Inputs() const { return inputs_; }
  const VariableNameMap &Outputs() const { return outputs_; }
  const std::string &Type() const { return type_; }
  const AttributeMap &Attrs() const { return attrs_; }
75

76 77 78
  void ClearVariables(const std::vector<std::string> &var_names) const {
    if (this->scope_) {
      this->scope_->EraseVars(var_names);
朔-望's avatar
朔-望 已提交
79
    }
80
  }
81 82 83
#ifdef PADDLE_MOBILE_FPGA
  void ChangeNameMap(string key, std::vector<string> value);
#endif
朔-望's avatar
朔-望 已提交
84
 protected:
85 86 87 88 89
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;
L
liuruilong 已提交
90

朔-望's avatar
朔-望 已提交
91
 private:
92
  void CheckAllInputOutputSet() const;
朔-望's avatar
朔-望 已提交
93
};
朔-望's avatar
朔-望 已提交
94

L
liuruilong 已提交
95
template <typename Dtype, typename ParamType, typename KernelType>
朔-望's avatar
朔-望 已提交
96
class OperatorWithKernel : public OperatorBase<Dtype> {
朔-望's avatar
朔-望 已提交
97
 public:
98 99 100
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
L
liuruilong 已提交
101
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
L
liuruilong 已提交
102 103 104 105 106
        param_(inputs, outputs, attrs, *scope) {
#ifdef PADDLE_MOBILE_CL
    kernel_.InitCLHelper(scope->GetCLScpoe());
#endif
  }
L
liuruilong 已提交
107

L
liuruilong 已提交
108
  virtual void RunImpl() { this->kernel_.Compute(this->param_); }
109

W
wangliu 已提交
110
  virtual void InferShape() const = 0;
L
liuruilong 已提交
111

E
eclipsess 已提交
112 113
  void Init() {
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
L
liuruilong 已提交
114 115 116
                          this->type_.c_str());
  }

L
liuruilong 已提交
117 118 119
 protected:
  KernelType kernel_;
  ParamType param_;
朔-望's avatar
朔-望 已提交
120
};
朔-望's avatar
朔-望 已提交
121

朔-望's avatar
朔-望 已提交
122
template <typename Dtype, typename P>
L
liuruilong 已提交
123
class OpKernelBase {
朔-望's avatar
朔-望 已提交
124
 public:
L
liuruilong 已提交
125 126
  OpKernelBase() = default;

L
liuruilong 已提交
127 128 129 130 131
#ifdef PADDLE_MOBILE_CL
  virtual void InitCLHelper(CLScope *clScope) {
    cl_helper_ = CLHelper(clScope);
  }
#endif
L
liuruilong 已提交
132

L
liuruilong 已提交
133
#ifdef PADDLE_McOBILE_MALI_GPU
E
eclipsess 已提交
134 135 136 137 138 139
  OpKernelBase() { acl_op_ = nullptr; }
  void *GetAclOp() const { return acl_op_; }
  void SetAclOp(void *op, void *ob) const {
    reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
  }
#endif
L
liuruilong 已提交
140
  virtual void Compute(const P &para) = 0;
141
  virtual bool Init(P *para) { return true; }
142
  virtual ~OpKernelBase() = default;
E
eclipsess 已提交
143

L
liuruilong 已提交
144
 protected:
L
liuruilong 已提交
145 146 147
#ifdef PADDLE_MOBILE_CL
  CLHelper cl_helper_;
#endif
L
liuruilong 已提交
148

E
eclipsess 已提交
149 150 151 152
 private:
#ifdef PADDLE_MOBILE_MALI_GPU
  void *acl_op_;
#endif
朔-望's avatar
朔-望 已提交
153
};
朔-望's avatar
朔-望 已提交
154

L
liuruilong 已提交
155
class FusionOpMatcher {
L
liuruilong 已提交
156 157 158 159 160
 public:
  FusionOpMatcher() {}

  virtual std::string Type() = 0;

L
liuruilong 已提交
161 162 163
  virtual void FolderNodes(
      Node *node,
      std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
L
liuruilong 已提交
164
    node->Folder(node_.Depth(), Type(), {}, removed_nodes);
L
liuruilong 已提交
165 166
  }

L
liuruilong 已提交
167
  virtual Node &BeginNode() { return node_; }
L
liuruilong 已提交
168

L
liuruilong 已提交
169
  std::string BeginType() { return node_.Type(); }
L
liuruilong 已提交
170

171 172
  virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }

L
liuruilong 已提交
173 174 175 176 177 178
 protected:
  Node node_;
  std::string type_;
  std::shared_ptr<OpDesc> new_opdesc_;
};

179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel)                          \
  template <typename DeviceType, typename T>                                 \
  class OpName##Op : public framework::OperatorWithKernel<                   \
                         DeviceType, OpParam<DeviceType>,                    \
                         operators::OpKernel<DeviceType, T>> {               \
   public:                                                                   \
    OpName##Op(const std::string &type, const VariableNameMap &inputs,       \
               const VariableNameMap &outputs,                               \
               const framework::AttributeMap &attrs,                         \
               std::shared_ptr<framework::Scope> scope)                      \
        : framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>,     \
                                        operators::OpKernel<DeviceType, T>>( \
              type, inputs, outputs, attrs, scope) {}                        \
                                                                             \
    void InferShape() const override;                                        \
  };

#define DECLARE_KERNEL(OpName, OpParam)                                   \
  template <typename DeviceType, typename T>                              \
  class OpName##Kernel                                                    \
      : public framework::OpKernelBase<DeviceType, OpParam<DeviceType>> { \
   public:                                                                \
    bool Init(OpParam<DeviceType> *param);                                \
    void Compute(const OpParam<DeviceType> &param);                       \
  };

#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)                                 \
  cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
      const ::paddle_mobile::VariableNameMap &outputs,                         \
      const ::paddle_mobile::framework::AttributeMap &attrs,                   \
      std::shared_ptr<::paddle_mobile::framework::Scope> scope)                \
      : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}

朔-望's avatar
朔-望 已提交
212 213
}  // namespace framework
}  // namespace paddle_mobile