operator.h 6.3 KB
Newer Older
Z
zhaojiaying01 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
朔-望's avatar
朔-望 已提交
14 15 16

#pragma once

L
liuruilong 已提交
17
#include <map>
Z
zhaojiaying01 已提交
18
#include <string>
H
hjchen2 已提交
19
#include <utility>
Z
zhaojiaying01 已提交
20
#include <vector>
L
liuruilong 已提交
21

L
liuruilong 已提交
22 23
#include "common/enforce.h"
#include "common/type_define.h"
L
liuruilong 已提交
24 25
#include "common/types.h"
#include "common/variant.h"
L
liuruilong 已提交
26
#include "framework/attribute.h"
L
liuruilong 已提交
27
#include "framework/op_info.h"
L
liuruilong 已提交
28
#include "framework/op_kernel_type.h"
L
liuruilong 已提交
29 30
#include "framework/op_registry.h"
#include "framework/program/block_desc.h"
L
liuruilong 已提交
31
#include "framework/program/program-optimize/node.h"
L
liuruilong 已提交
32 33 34
#include "framework/scope.h"
#include "framework/tensor.h"
#include "framework/variable.h"
L
liuruilong 已提交
35
#ifdef PADDLE_MOBILE_CL
L
liuruilong 已提交
36
#include "framework/cl/cl_helper.h"
37
#include "framework/cl/cl_scope.h"
L
liuruilong 已提交
38
#endif
39

朔-望's avatar
朔-望 已提交
40
namespace paddle_mobile {
朔-望's avatar
朔-望 已提交
41
namespace framework {
朔-望's avatar
朔-望 已提交
42

W
wangliu 已提交
43
template <typename T>
44
static T *GetVarValue(const std::string &key, const VariableNameMap &var_map,
W
wangliu 已提交
45 46 47 48 49 50 51 52 53 54
                      const Scope &scope) {
  auto var_vec = var_map.at(key);
  if (!var_vec.empty()) {
    auto var = scope.FindVar(var_vec[0]);
    return var->GetMutable<T>();
  } else {
    return nullptr;
  }
}

朔-望's avatar
朔-望 已提交
55
template <typename Dtype>
L
liuruilong 已提交
56
class OperatorBase {
朔-望's avatar
朔-望 已提交
57
 public:
58 59 60 61
  OperatorBase(const std::string &type, const VariableNameMap &inputs,
               const VariableNameMap &outputs, const AttributeMap &attrs,
               std::shared_ptr<Scope> scope);
  virtual ~OperatorBase() {}
朔-望's avatar
朔-望 已提交
62

E
eclipsess 已提交
63
  virtual void Init() = 0;
64 65 66 67 68 69 70
  virtual void InferShape() const = 0;
  virtual void Run();
  virtual void RunImpl() = 0;

  std::vector<std::string> GetOutKeys() const;
  std::vector<std::string> GetInputKeys() const;

71 72 73 74
  const VariableNameMap &Inputs() const { return inputs_; }
  const VariableNameMap &Outputs() const { return outputs_; }
  const std::string &Type() const { return type_; }
  const AttributeMap &Attrs() const { return attrs_; }
75

76 77 78
  void ClearVariables(const std::vector<std::string> &var_names) const {
    if (this->scope_) {
      this->scope_->EraseVars(var_names);
朔-望's avatar
朔-望 已提交
79
    }
80
  }
L
liuruilong 已提交
81

朔-望's avatar
朔-望 已提交
82
 protected:
83 84 85 86 87
  std::shared_ptr<Scope> scope_;
  std::string type_;
  VariableNameMap inputs_;
  VariableNameMap outputs_;
  AttributeMap attrs_;
L
liuruilong 已提交
88

朔-望's avatar
朔-望 已提交
89
 private:
90
  void CheckAllInputOutputSet() const;
朔-望's avatar
朔-望 已提交
91
};
朔-望's avatar
朔-望 已提交
92

L
liuruilong 已提交
93
template <typename Dtype, typename ParamType, typename KernelType>
朔-望's avatar
朔-望 已提交
94
class OperatorWithKernel : public OperatorBase<Dtype> {
朔-望's avatar
朔-望 已提交
95
 public:
96 97 98
  OperatorWithKernel(const std::string &type, const VariableNameMap &inputs,
                     const VariableNameMap &outputs, const AttributeMap &attrs,
                     std::shared_ptr<Scope> scope)
L
liuruilong 已提交
99
      : OperatorBase<Dtype>(type, inputs, outputs, attrs, scope),
L
liuruilong 已提交
100 101 102 103 104
        param_(inputs, outputs, attrs, *scope) {
#ifdef PADDLE_MOBILE_CL
    kernel_.InitCLHelper(scope->GetCLScpoe());
#endif
  }
L
liuruilong 已提交
105

L
liuruilong 已提交
106
  virtual void RunImpl() { this->kernel_.Compute(this->param_); }
107

W
wangliu 已提交
108
  virtual void InferShape() const = 0;
L
liuruilong 已提交
109

E
eclipsess 已提交
110 111
  void Init() {
    PADDLE_MOBILE_ENFORCE(kernel_.Init(&param_), "  %s kernel init failed",
L
liuruilong 已提交
112 113 114
                          this->type_.c_str());
  }

L
liuruilong 已提交
115 116 117
 protected:
  KernelType kernel_;
  ParamType param_;
朔-望's avatar
朔-望 已提交
118
};
朔-望's avatar
朔-望 已提交
119

朔-望's avatar
朔-望 已提交
120
template <typename Dtype, typename P>
L
liuruilong 已提交
121
class OpKernelBase {
朔-望's avatar
朔-望 已提交
122
 public:
L
liuruilong 已提交
123 124
  OpKernelBase() = default;

L
liuruilong 已提交
125 126 127 128 129
#ifdef PADDLE_MOBILE_CL
  virtual void InitCLHelper(CLScope *clScope) {
    cl_helper_ = CLHelper(clScope);
  }
#endif
L
liuruilong 已提交
130

L
liuruilong 已提交
131
#ifdef PADDLE_McOBILE_MALI_GPU
E
eclipsess 已提交
132 133 134 135 136 137
  OpKernelBase() { acl_op_ = nullptr; }
  void *GetAclOp() const { return acl_op_; }
  void SetAclOp(void *op, void *ob) const {
    reinterpret_cast<OpKernelBase<Dtype, P> *>(ob)->acl_op_ = op;
  }
#endif
L
liuruilong 已提交
138
  virtual void Compute(const P &para) = 0;
139
  virtual bool Init(P *para) { return true; }
140
  virtual ~OpKernelBase() = default;
E
eclipsess 已提交
141

L
liuruilong 已提交
142
 protected:
L
liuruilong 已提交
143 144 145
#ifdef PADDLE_MOBILE_CL
  CLHelper cl_helper_;
#endif
L
liuruilong 已提交
146

E
eclipsess 已提交
147 148 149 150
 private:
#ifdef PADDLE_MOBILE_MALI_GPU
  void *acl_op_;
#endif
朔-望's avatar
朔-望 已提交
151
};
朔-望's avatar
朔-望 已提交
152

153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169
#define DECLARE_OPERATOR(OpName, OpParam, OpKernel)                          \
  template <typename DeviceType, typename T>                                 \
  class OpName##Op : public framework::OperatorWithKernel<                   \
                         DeviceType, OpParam<DeviceType>,                    \
                         operators::OpKernel<DeviceType, T>> {               \
   public:                                                                   \
    OpName##Op(const std::string &type, const VariableNameMap &inputs,       \
               const VariableNameMap &outputs,                               \
               const framework::AttributeMap &attrs,                         \
               std::shared_ptr<framework::Scope> scope)                      \
        : framework::OperatorWithKernel<DeviceType, OpParam<DeviceType>,     \
                                        operators::OpKernel<DeviceType, T>>( \
              type, inputs, outputs, attrs, scope) {}                        \
                                                                             \
    void InferShape() const override;                                        \
  };

Z
zhaojiaying01 已提交
170 171 172 173 174 175 176
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls)                                 \
  cls(const std::string &type, const ::paddle_mobile::VariableNameMap &inputs, \
      const ::paddle_mobile::VariableNameMap &outputs,                         \
      const ::paddle_mobile::framework::AttributeMap &attrs,                   \
      std::shared_ptr<::paddle_mobile::framework::Scope> scope)                \
      : parent_cls<Dtype, T>(type, inputs, outputs, attrs, scope) {}

L
liuruilong 已提交
177
class FusionOpMatcher {
L
liuruilong 已提交
178 179 180 181 182
 public:
  FusionOpMatcher() {}

  virtual std::string Type() = 0;

L
liuruilong 已提交
183 184 185
  virtual void FolderNodes(
      Node *node,
      std::vector<std::shared_ptr<framework::Node>> *removed_nodes) {
L
liuruilong 已提交
186
    node->Folder(node_.Depth(), Type(), {}, removed_nodes);
L
liuruilong 已提交
187 188
  }

L
liuruilong 已提交
189
  virtual Node &BeginNode() { return node_; }
L
liuruilong 已提交
190

L
liuruilong 已提交
191
  std::string BeginType() { return node_.Type(); }
L
liuruilong 已提交
192

193 194
  virtual std::vector<std::pair<int, std::string>> NeedCheck() { return {}; }

L
liuruilong 已提交
195 196 197 198 199 200
 protected:
  Node node_;
  std::string type_;
  std::shared_ptr<OpDesc> new_opdesc_;
};

朔-望's avatar
朔-望 已提交
201 202
}  // namespace framework
}  // namespace paddle_mobile