operator.h 7.8 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14 15 16

#pragma once

L
liuruilong 已提交
17
#include <map>
J
jameswu2014 已提交
18
#include <memory>
Z
zhaojiaying01 已提交
19
#include <string>
H
hjchen2 已提交
20
#include <utility>
Z
zhaojiaying01 已提交
21
#include <vector>
L
liuruilong 已提交
22

L
liuruilong 已提交
23 24
#include "common/enforce.h"
#include "common/type_define.h"
L
liuruilong 已提交
25 26
#include "common/types.h"
#include "common/variant.h"
L
liuruilong 已提交
27
#include "framework/attribute.h"
L
liuruilong 已提交
28
#include "framework/op_info.h"
L
liuruilong 已提交
29
#include "framework/op_kernel_type.h"
L
liuruilong 已提交
30 31
#include "framework/op_registry.h"
#include "framework/program/block_desc.h"
L
liuruilong 已提交
32
#include "framework/program/program-optimize/node.h"
L
liuruilong 已提交
33 34 35
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
L
liuruilong 已提交
36
#ifdef PADDLE_MOBILE_CL
L
liuruilong 已提交
37
#include "framework/cl/cl_helper.h"
38
#include "framework/cl/cl_scope.h"
L
liuruilong 已提交
39
#endif
40

朔-望's avatar
朔-望 已提交
41
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
42
namespace framework {
朔-望's avatar
朔-望 已提交
43

W
wangliu 已提交
44
template <typename T>
45
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
W
wangliu 已提交
46 47 48 49 50 51 52 53 54 55
                      const Scope &scope) {
  auto var_vec = var_map.at(key);
  if (!var_vec.empty()) {
    auto var = scope.FindVar(var_vec[0]);
    return var->GetMutable<T>();
  } else {
    return nullptr;
  }
}

朔-望's avatar
朔-望 已提交
56
template <typename Dtype>
L
liuruilong 已提交
57
class OperatorBase {
朔-望's avatar
朔-望 已提交
58
 public:
59 60 61 62
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
朔-望's avatar
朔-望 已提交
63

E
eclipsess 已提交
64
  virtual void Init() = 0;
65 66 67 68 69 70 71
  virtual void InferShape() const = 0;
  virtual void Run();
  virtual void RunImpl() = 0;

  std::vector<std::string> GetOutKeys() const;
  std::vector<std::string> GetInputKeys() const;

72 73 74 75
  const VariableNameMap &Inputs() const { return inputs_; }
  const VariableNameMap &Outputs() const { return outputs_; }
  const std::string &Type() const { return type_; }
  const AttributeMap &Attrs() const { return attrs_; }
76

77 78 79
  void ClearVariables(const std::vector<std::string> &var_names) const {
    if (this->scope_) {
      this->scope_->EraseVars(var_names);
朔-望's avatar
朔-望 已提交
80
    }
81
  }
82
#ifdef PADDLE_MOBILE_FPGA
83
  void InsertTensors();
J
jameswu2014 已提交
84
  void ChangeNameMap(string key, std::vector<string> value);
85
#endif
J
jameswu2014 已提交
86

朔-望's avatar
朔-望 已提交
87
 protected:
88 89 90 91 92
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;
L
liuruilong 已提交
93

朔-望's avatar
朔-望 已提交
94
 private:
95
  void CheckAllInputOutputSet() const;
朔-望's avatar
朔-望 已提交
96
};
朔-望's avatar
朔-望 已提交
97

L
liuruilong 已提交
98
template <typename Dtype, typename ParamType, typename KernelType>
朔-望's avatar
朔-望 已提交
99
class OperatorWithKernel : public OperatorBase<Dtype> {
朔-望's avatar
朔-望 已提交
100
 public:
J
jameswu2014 已提交
101
#ifndef PADDLE_MOBILE_FPGA1
102 103 104
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
L
liuruilong 已提交
105
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
L
liuruilong 已提交
106 107 108 109 110
        param_(inputs, outputs, attrs, *scope) {
#ifdef PADDLE_MOBILE_CL
    kernel_.InitCLHelper(scope->GetCLScpoe());
#endif
  }
J
jameswu2014 已提交
111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129
#else
  OperatorWithKernel(const std::string &type, const VariableNameMap inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope) {
    static int feed_num = 0;
    static int fetch_num = 0;
    if (type == "feed") {
      auto new_name = string("feed") + std::to_string(feed_num++);
      auto var = scope->Var(new_name);
      (const_cast<VariableNameMap &>(inputs)).at("X") = {string(new_name)};
    } else if (type == "fetch") {
      auto new_name = string("fetch") + std::to_string(fetch_num++);
      auto var = scope->Var(new_name);
      (const_cast<VariableNameMap &>(outputs)).at("Out") = {string(new_name)};
    }
    param_ = ParamType(inputs, outputs, attrs, *scope);
  }
#endif
L
liuruilong 已提交
130
  virtual void RunImpl() { this->kernel_.Compute(this->param_); }
131

W
wangliu 已提交
132
  virtual void InferShape() const = 0;
L
liuruilong 已提交
133

E
eclipsess 已提交
134 135
  void Init() {
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
L
liuruilong 已提交
136 137 138
                          this->type_.c_str());
  }

L
liuruilong 已提交
139 140 141
 protected:
  KernelType kernel_;
  ParamType param_;
朔-望's avatar
朔-望 已提交
142
};
朔-望's avatar
朔-望 已提交
143

朔-望's avatar
朔-望 已提交
144
template <typename Dtype, typename P>
L
liuruilong 已提交
145
class OpKernelBase {
朔-望's avatar
朔-望 已提交
146
 public:
L
liuruilong 已提交
147 148
  OpKernelBase() = default;

L
liuruilong 已提交
149 150 151 152 153
#ifdef PADDLE_MOBILE_CL
  virtual void InitCLHelper(CLScope *clScope) {
    cl_helper_ = CLHelper(clScope);
  }
#endif
L
liuruilong 已提交
154

L
liuruilong 已提交
155
#ifdef PADDLE_McOBILE_MALI_GPU
E
eclipsess 已提交
156 157 158 159 160 161
  OpKernelBase() { acl_op_ = nullptr; }
  void *GetAclOp() const { return acl_op_; }
  void SetAclOp(void *op, void *ob) const {
    reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
  }
#endif
L
liuruilong 已提交
162
  virtual void Compute(const P &para) = 0;
163
  virtual bool Init(P *para) { return true; }
164
  virtual ~OpKernelBase() = default;
E
eclipsess 已提交
165

L
liuruilong 已提交
166
 protected:
L
liuruilong 已提交
167 168 169
#ifdef PADDLE_MOBILE_CL
  CLHelper cl_helper_;
#endif
L
liuruilong 已提交
170

E
eclipsess 已提交
171 172 173 174
 private:
#ifdef PADDLE_MOBILE_MALI_GPU
  void *acl_op_;
#endif
朔-望's avatar
朔-望 已提交
175
};
朔-望's avatar
朔-望 已提交
176

L
liuruilong 已提交
177
class FusionOpMatcher {
L
liuruilong 已提交
178 179 180 181 182
 public:
  FusionOpMatcher() {}

  virtual std::string Type() = 0;

L
liuruilong 已提交
183 184 185
  virtual void FolderNodes(
      Node *node,
      std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
L
liuruilong 已提交
186
    node->Folder(node_.Depth(), Type(), {}, removed_nodes);
L
liuruilong 已提交
187 188
  }

L
liuruilong 已提交
189
  virtual Node &BeginNode() { return node_; }
L
liuruilong 已提交
190

L
liuruilong 已提交
191
  std::string BeginType() { return node_.Type(); }
L
liuruilong 已提交
192

193 194
  virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }

L
liuruilong 已提交
195 196 197 198 199 200
 protected:
  Node node_;
  std::string type_;
  std::shared_ptr<OpDesc> new_opdesc_;
};

201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel)                          \
  template <typename DeviceType, typename T>                                 \
  class OpName##Op : public framework::OperatorWithKernel<                   \
                         DeviceType, OpParam<DeviceType>,                    \
                         operators::OpKernel<DeviceType, T>> {               \
   public:                                                                   \
    OpName##Op(const std::string &type, const VariableNameMap &inputs,       \
               const VariableNameMap &outputs,                               \
               const framework::AttributeMap &attrs,                         \
               std::shared_ptr<framework::Scope> scope)                      \
        : framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>,     \
                                        operators::OpKernel<DeviceType, T>>( \
              type, inputs, outputs, attrs, scope) {}                        \
                                                                             \
    void InferShape() const override;                                        \
  };

#define DECLARE_KERNEL(OpName, OpParam)                                   \
  template <typename DeviceType, typename T>                              \
  class OpName##Kernel                                                    \
      : public framework::OpKernelBase<DeviceType, OpParam<DeviceType>> { \
   public:                                                                \
    bool Init(OpParam<DeviceType> *param);                                \
    void Compute(const OpParam<DeviceType> &param);                       \
  };

#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)                                 \
  cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
      const ::paddle_mobile::VariableNameMap &outputs,                         \
      const ::paddle_mobile::framework::AttributeMap &attrs,                   \
      std::shared_ptr<::paddle_mobile::framework::Scope> scope)                \
      : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}

朔-望's avatar
朔-望 已提交
234 235
}  // namespace framework
}  // namespace paddle_mobile