operator.h 7.8 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14 15 16

#pragma once

L
liuruilong 已提交
17
#include <map>
Z
zhaojiaying01 已提交
18
#include <string>
H
hjchen2 已提交
19
#include <utility>
Z
zhaojiaying01 已提交
20
#include <vector>
L
liuruilong 已提交
21

L
liuruilong 已提交
22 23
#include "common/enforce.h"
#include "common/type_define.h"
L
liuruilong 已提交
24 25
#include "common/types.h"
#include "common/variant.h"
L
liuruilong 已提交
26
#include "framework/attribute.h"
L
liuruilong 已提交
27
#include "framework/op_info.h"
L
liuruilong 已提交
28
#include "framework/op_kernel_type.h"
L
liuruilong 已提交
29 30
#include "framework/op_registry.h"
#include "framework/program/block_desc.h"
L
liuruilong 已提交
31
#include "framework/program/program-optimize/node.h"
L
liuruilong 已提交
32 33 34
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
L
liuruilong 已提交
35
#ifdef PADDLE_MOBILE_CL
L
liuruilong 已提交
36
#include "framework/cl/cl_helper.h"
37
#include "framework/cl/cl_scope.h"
L
liuruilong 已提交
38
#endif
39

朔-望's avatar
朔-望 已提交
40
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
41
namespace framework {
朔-望's avatar
朔-望 已提交
42

W
wangliu 已提交
43
template <typename T>
44
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
W
wangliu 已提交
45 46 47 48 49 50 51 52 53 54
                      const Scope &scope) {
  auto var_vec = var_map.at(key);
  if (!var_vec.empty()) {
    auto var = scope.FindVar(var_vec[0]);
    return var->GetMutable<T>();
  } else {
    return nullptr;
  }
}

朔-望's avatar
朔-望 已提交
55
template <typename Dtype>
L
liuruilong 已提交
56
class OperatorBase {
朔-望's avatar
朔-望 已提交
57
 public:
58 59 60 61
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
朔-望's avatar
朔-望 已提交
62

E
eclipsess 已提交
63
  virtual void Init() = 0;
64 65 66 67 68 69 70
  virtual void InferShape() const = 0;
  virtual void Run();
  virtual void RunImpl() = 0;

  std::vector<std::string> GetOutKeys() const;
  std::vector<std::string> GetInputKeys() const;

71 72 73 74
  const VariableNameMap &Inputs() const { return inputs_; }
  const VariableNameMap &Outputs() const { return outputs_; }
  const std::string &Type() const { return type_; }
  const AttributeMap &Attrs() const { return attrs_; }
75

76 77 78
  void ClearVariables(const std::vector<std::string> &var_names) const {
    if (this->scope_) {
      this->scope_->EraseVars(var_names);
朔-望's avatar
朔-望 已提交
79
    }
80
  }
81
#ifdef PADDLE_MOBILE_FPGA
82
  void InsertTensors();
83 84
  void ChangeNameMap(string key, std::vector<string> value);
#endif
朔-望's avatar
朔-望 已提交
85
 protected:
86 87 88 89 90
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;
L
liuruilong 已提交
91

朔-望's avatar
朔-望 已提交
92
 private:
93
  void CheckAllInputOutputSet() const;
朔-望's avatar
朔-望 已提交
94
};
朔-望's avatar
朔-望 已提交
95

L
liuruilong 已提交
96
template <typename Dtype, typename ParamType, typename KernelType>
朔-望's avatar
朔-望 已提交
97
class OperatorWithKernel : public OperatorBase<Dtype> {
朔-望's avatar
朔-望 已提交
98
 public:
99
#ifndef PADDLE_MOBILE_FPGA1
100 101 102
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
L
liuruilong 已提交
103
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
L
liuruilong 已提交
104 105 106 107 108
        param_(inputs, outputs, attrs, *scope) {
#ifdef PADDLE_MOBILE_CL
    kernel_.InitCLHelper(scope->GetCLScpoe());
#endif
  }
109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127
#else
  OperatorWithKernel(const std::string &type, const VariableNameMap inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {
    static int feed_num = 0;
    static int fetch_num = 0;
    if (type == "feed") {
      auto new_name = string("feed") + std::to_string(feed_num++);
      auto var = scope->Var(new_name);
      (const_cast<VariableNameMap &>(inputs)).at("X") = {string(new_name)};
    } else if (type == "fetch") {
      auto new_name = string("fetch") + std::to_string(fetch_num++);
      auto var = scope->Var(new_name);
      (const_cast<VariableNameMap &>(outputs)).at("Out") = {string(new_name)};
    }
    param_ = ParamType(inputs, outputs, attrs, *scope);
  }
#endif
L
liuruilong 已提交
128
  virtual void RunImpl() { this->kernel_.Compute(this->param_); }
129

W
wangliu 已提交
130
  virtual void InferShape() const = 0;
L
liuruilong 已提交
131

E
eclipsess 已提交
132 133
  void Init() {
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
L
liuruilong 已提交
134 135 136
                          this->type_.c_str());
  }

L
liuruilong 已提交
137 138 139
 protected:
  KernelType kernel_;
  ParamType param_;
朔-望's avatar
朔-望 已提交
140
};
朔-望's avatar
朔-望 已提交
141

朔-望's avatar
朔-望 已提交
142
template <typename Dtype, typename P>
L
liuruilong 已提交
143
class OpKernelBase {
朔-望's avatar
朔-望 已提交
144
 public:
L
liuruilong 已提交
145 146
  OpKernelBase() = default;

L
liuruilong 已提交
147 148 149 150 151
#ifdef PADDLE_MOBILE_CL
  virtual void InitCLHelper(CLScope *clScope) {
    cl_helper_ = CLHelper(clScope);
  }
#endif
L
liuruilong 已提交
152

L
liuruilong 已提交
153
#ifdef PADDLE_McOBILE_MALI_GPU
E
eclipsess 已提交
154 155 156 157 158 159
  OpKernelBase() { acl_op_ = nullptr; }
  void *GetAclOp() const { return acl_op_; }
  void SetAclOp(void *op, void *ob) const {
    reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
  }
#endif
L
liuruilong 已提交
160
  virtual void Compute(const P &para) = 0;
161
  virtual bool Init(P *para) { return true; }
162
  virtual ~OpKernelBase() = default;
E
eclipsess 已提交
163

L
liuruilong 已提交
164
 protected:
L
liuruilong 已提交
165 166 167
#ifdef PADDLE_MOBILE_CL
  CLHelper cl_helper_;
#endif
L
liuruilong 已提交
168

E
eclipsess 已提交
169 170 171 172
 private:
#ifdef PADDLE_MOBILE_MALI_GPU
  void *acl_op_;
#endif
朔-望's avatar
朔-望 已提交
173
};
朔-望's avatar
朔-望 已提交
174

L
liuruilong 已提交
175
class FusionOpMatcher {
L
liuruilong 已提交
176 177 178 179 180
 public:
  FusionOpMatcher() {}

  virtual std::string Type() = 0;

L
liuruilong 已提交
181 182 183
  virtual void FolderNodes(
      Node *node,
      std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
L
liuruilong 已提交
184
    node->Folder(node_.Depth(), Type(), {}, removed_nodes);
L
liuruilong 已提交
185 186
  }

L
liuruilong 已提交
187
  virtual Node &BeginNode() { return node_; }
L
liuruilong 已提交
188

L
liuruilong 已提交
189
  std::string BeginType() { return node_.Type(); }
L
liuruilong 已提交
190

191 192
  virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }

L
liuruilong 已提交
193 194 195 196 197 198
 protected:
  Node node_;
  std::string type_;
  std::shared_ptr<OpDesc> new_opdesc_;
};

199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel)                          \
  template <typename DeviceType, typename T>                                 \
  class OpName##Op : public framework::OperatorWithKernel<                   \
                         DeviceType, OpParam<DeviceType>,                    \
                         operators::OpKernel<DeviceType, T>> {               \
   public:                                                                   \
    OpName##Op(const std::string &type, const VariableNameMap &inputs,       \
               const VariableNameMap &outputs,                               \
               const framework::AttributeMap &attrs,                         \
               std::shared_ptr<framework::Scope> scope)                      \
        : framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>,     \
                                        operators::OpKernel<DeviceType, T>>( \
              type, inputs, outputs, attrs, scope) {}                        \
                                                                             \
    void InferShape() const override;                                        \
  };

#define DECLARE_KERNEL(OpName, OpParam)                                   \
  template <typename DeviceType, typename T>                              \
  class OpName##Kernel                                                    \
      : public framework::OpKernelBase<DeviceType, OpParam<DeviceType>> { \
   public:                                                                \
    bool Init(OpParam<DeviceType> *param);                                \
    void Compute(const OpParam<DeviceType> &param);                       \
  };

#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)                                 \
  cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
      const ::paddle_mobile::VariableNameMap &outputs,                         \
      const ::paddle_mobile::framework::AttributeMap &attrs,                   \
      std::shared_ptr<::paddle_mobile::framework::Scope> scope)                \
      : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}

朔-望's avatar
朔-望 已提交
232 233
}  // namespace framework
}  // namespace paddle_mobile