operator.h 6.9 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14 15 16

#pragma once

L
liuruilong 已提交
17
#include <map>
Z
zhaojiaying01 已提交
18
#include <string>
H
hjchen2 已提交
19
#include <utility>
Z
zhaojiaying01 已提交
20
#include <vector>
L
liuruilong 已提交
21

L
liuruilong 已提交
22 23
#include "common/enforce.h"
#include "common/type_define.h"
L
liuruilong 已提交
24 25
#include "common/types.h"
#include "common/variant.h"
L
liuruilong 已提交
26
#include "framework/attribute.h"
L
liuruilong 已提交
27
#include "framework/op_info.h"
L
liuruilong 已提交
28
#include "framework/op_kernel_type.h"
L
liuruilong 已提交
29 30
#include "framework/op_registry.h"
#include "framework/program/block_desc.h"
L
liuruilong 已提交
31
#include "framework/program/program-optimize/node.h"
L
liuruilong 已提交
32 33 34
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
L
liuruilong 已提交
35
#ifdef PADDLE_MOBILE_CL
L
liuruilong 已提交
36
#include "framework/cl/cl_helper.h"
37
#include "framework/cl/cl_scope.h"
L
liuruilong 已提交
38
#endif
39

朔-望's avatar
朔-望 已提交
40
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
41
namespace framework {
朔-望's avatar
朔-望 已提交
42

W
wangliu 已提交
43
template <typename T>
44
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
W
wangliu 已提交
45 46 47 48 49 50 51 52 53 54
                      const Scope &scope) {
  auto var_vec = var_map.at(key);
  if (!var_vec.empty()) {
    auto var = scope.FindVar(var_vec[0]);
    return var->GetMutable<T>();
  } else {
    return nullptr;
  }
}

朔-望's avatar
朔-望 已提交
55
template <typename Dtype>
L
liuruilong 已提交
56
class OperatorBase {
朔-望's avatar
朔-望 已提交
57
 public:
58 59 60 61
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
朔-望's avatar
朔-望 已提交
62

E
eclipsess 已提交
63
  virtual void Init() = 0;
64 65 66 67 68 69 70
  virtual void InferShape() const = 0;
  virtual void Run();
  virtual void RunImpl() = 0;

  std::vector<std::string> GetOutKeys() const;
  std::vector<std::string> GetInputKeys() const;

71 72 73 74
  const VariableNameMap &Inputs() const { return inputs_; }
  const VariableNameMap &Outputs() const { return outputs_; }
  const std::string &Type() const { return type_; }
  const AttributeMap &Attrs() const { return attrs_; }
75

76 77 78
  void ClearVariables(const std::vector<std::string> &var_names) const {
    if (this->scope_) {
      this->scope_->EraseVars(var_names);
朔-望's avatar
朔-望 已提交
79
    }
80
  }
81
#ifdef PADDLE_MOBILE_FPGA
82
  void InsertTensors();
83
#endif
朔-望's avatar
朔-望 已提交
84
 protected:
85 86 87 88 89
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;
L
liuruilong 已提交
90

朔-望's avatar
朔-望 已提交
91
 private:
92
  void CheckAllInputOutputSet() const;
朔-望's avatar
朔-望 已提交
93
};
朔-望's avatar
朔-望 已提交
94

L
liuruilong 已提交
95
template <typename Dtype, typename ParamType, typename KernelType>
朔-望's avatar
朔-望 已提交
96
class OperatorWithKernel : public OperatorBase<Dtype> {
朔-望's avatar
朔-望 已提交
97
 public:
98 99 100
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
L
liuruilong 已提交
101
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
L
liuruilong 已提交
102 103 104 105 106
        param_(inputs, outputs, attrs, *scope) {
#ifdef PADDLE_MOBILE_CL
    kernel_.InitCLHelper(scope->GetCLScpoe());
#endif
  }
L
liuruilong 已提交
107
  virtual void RunImpl() { this->kernel_.Compute(this->param_); }
108

W
wangliu 已提交
109
  virtual void InferShape() const = 0;
L
liuruilong 已提交
110

E
eclipsess 已提交
111 112
  void Init() {
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
L
liuruilong 已提交
113 114 115
                          this->type_.c_str());
  }

L
liuruilong 已提交
116 117 118
 protected:
  KernelType kernel_;
  ParamType param_;
朔-望's avatar
朔-望 已提交
119
};
朔-望's avatar
朔-望 已提交
120

朔-望's avatar
朔-望 已提交
121
template <typename Dtype, typename P>
L
liuruilong 已提交
122
class OpKernelBase {
朔-望's avatar
朔-望 已提交
123
 public:
L
liuruilong 已提交
124 125
  OpKernelBase() = default;

L
liuruilong 已提交
126 127 128 129 130
#ifdef PADDLE_MOBILE_CL
  virtual void InitCLHelper(CLScope *clScope) {
    cl_helper_ = CLHelper(clScope);
  }
#endif
L
liuruilong 已提交
131

L
liuruilong 已提交
132
#ifdef PADDLE_McOBILE_MALI_GPU
E
eclipsess 已提交
133 134 135 136 137 138
  OpKernelBase() { acl_op_ = nullptr; }
  void *GetAclOp() const { return acl_op_; }
  void SetAclOp(void *op, void *ob) const {
    reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
  }
#endif
L
liuruilong 已提交
139
  virtual void Compute(const P &para) = 0;
140
  virtual bool Init(P *para) { return true; }
141
  virtual ~OpKernelBase() = default;
E
eclipsess 已提交
142

L
liuruilong 已提交
143
 protected:
L
liuruilong 已提交
144 145 146
#ifdef PADDLE_MOBILE_CL
  CLHelper cl_helper_;
#endif
L
liuruilong 已提交
147

E
eclipsess 已提交
148 149 150 151
 private:
#ifdef PADDLE_MOBILE_MALI_GPU
  void *acl_op_;
#endif
朔-望's avatar
朔-望 已提交
152
};
朔-望's avatar
朔-望 已提交
153

L
liuruilong 已提交
154
class FusionOpMatcher {
L
liuruilong 已提交
155 156 157 158 159
 public:
  FusionOpMatcher() {}

  virtual std::string Type() = 0;

L
liuruilong 已提交
160 161 162
  virtual void FolderNodes(
      Node *node,
      std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
L
liuruilong 已提交
163
    node->Folder(node_.Depth(), Type(), {}, removed_nodes);
L
liuruilong 已提交
164 165
  }

L
liuruilong 已提交
166
  virtual Node &BeginNode() { return node_; }
L
liuruilong 已提交
167

L
liuruilong 已提交
168
  std::string BeginType() { return node_.Type(); }
L
liuruilong 已提交
169

170 171
  virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }

L
liuruilong 已提交
172 173 174 175 176 177
 protected:
  Node node_;
  std::string type_;
  std::shared_ptr<OpDesc> new_opdesc_;
};

178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel)                          \
  template <typename DeviceType, typename T>                                 \
  class OpName##Op : public framework::OperatorWithKernel<                   \
                         DeviceType, OpParam<DeviceType>,                    \
                         operators::OpKernel<DeviceType, T>> {               \
   public:                                                                   \
    OpName##Op(const std::string &type, const VariableNameMap &inputs,       \
               const VariableNameMap &outputs,                               \
               const framework::AttributeMap &attrs,                         \
               std::shared_ptr<framework::Scope> scope)                      \
        : framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>,     \
                                        operators::OpKernel<DeviceType, T>>( \
              type, inputs, outputs, attrs, scope) {}                        \
                                                                             \
    void InferShape() const override;                                        \
  };

#define DECLARE_KERNEL(OpName, OpParam)                                   \
  template <typename DeviceType, typename T>                              \
  class OpName##Kernel                                                    \
      : public framework::OpKernelBase<DeviceType, OpParam<DeviceType>> { \
   public:                                                                \
    bool Init(OpParam<DeviceType> *param);                                \
    void Compute(const OpParam<DeviceType> &param);                       \
  };

#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)                                 \
  cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
      const ::paddle_mobile::VariableNameMap &outputs,                         \
      const ::paddle_mobile::framework::AttributeMap &attrs,                   \
      std::shared_ptr<::paddle_mobile::framework::Scope> scope)                \
      : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}

朔-望's avatar
朔-望 已提交
211 212
}  // namespace framework
}  // namespace paddle_mobile