kernel.h 7.8 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
S
superjomn 已提交
18
#include <set>
S
superjomn 已提交
19
#include <string>
S
superjomn 已提交
20
#include <vector>
S
superjomn 已提交
21
#include "paddle/fluid/framework/op_desc.h"
S
superjomn 已提交
22 23
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
S
superjomn 已提交
24
#include "paddle/fluid/lite/core/type_system.h"
S
superjomn 已提交
25
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
26
#include "paddle/fluid/lite/operators/op_params.h"
S
superjomn 已提交
27 28 29 30 31
#include "paddle/fluid/lite/utils/all.h"

namespace paddle {
namespace lite {

S
update  
superjomn 已提交
32 33
// An base with virtual functions to unify all the kernel implementation on
// different targets.
S
superjomn 已提交
34
class KernelBase {
S
superjomn 已提交
35
 public:
S
superjomn 已提交
36
  virtual void Run() = 0;
S
superjomn 已提交
37

S
superjomn 已提交
38 39 40 41
  template <TargetType Target>
  void SetContext(std::unique_ptr<Context<Target>>&& ctx) {
    context_.set<std::unique_ptr<Context<Target>>>(std::move(ctx));
  }
S
superjomn 已提交
42

S
superjomn 已提交
43 44 45 46
  template <typename T>
  void SetParam(T param) {
    param_.set<T>(param);
  }
S
superjomn 已提交
47 48 49

  template <typename Param>
  Param& param() const {
S
superjomn 已提交
50
    return param_.get<Param>();
S
superjomn 已提交
51 52
  }

S
superjomn 已提交
53 54 55
  void set_op_type(const std::string& type) { op_type_ = type; }
  const std::string& op_type() const { return op_type_; }

S
superjomn 已提交
56 57
  void Torch() {}

58
  virtual Place place() const = 0;
S
update  
superjomn 已提交
59 60
  virtual TargetType target() const = 0;
  virtual PrecisionType precision() const = 0;
S
superjomn 已提交
61
  virtual DataLayoutType layout() const = 0;
S
update  
superjomn 已提交
62

S
superjomn 已提交
63 64
  virtual std::string name() const = 0;

S
superjomn 已提交
65
  virtual ~KernelBase() = default;
S
update  
superjomn 已提交
66 67

 protected:
S
superjomn 已提交
68 69
  core::any_context_t context_;
  mutable operators::param_t param_;
S
superjomn 已提交
70 71
  // The corresponding op type.
  std::string op_type_;
S
superjomn 已提交
72 73
};

S
superjomn 已提交
74 75 76 77 78 79 80
/*
 * ParamType is used to represent a data type of a parameter for the kernel. It
 * can represent any Variable data type.
 * The element_type_hash is the hash code of the element, it should be
 * registered in the `TypeSystem`.
 */
struct ParamType {
S
superjomn 已提交
81
  // For unsupported types.
S
superjomn 已提交
82 83
  size_t element_type_hash{};
  Place tensor_place{};
S
superjomn 已提交
84
  const Type* type_;
S
superjomn 已提交
85

S
superjomn 已提交
86 87 88
  explicit ParamType() = default;
  explicit ParamType(size_t element_type_hash)
      : element_type_hash(element_type_hash) {}
S
superjomn 已提交
89 90
  ParamType(size_t element_type_hash, const Place& place)
      : element_type_hash(element_type_hash), tensor_place(place) {}
91 92 93
  ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }

  std::string DebugString() const { return tensor_place.DebugString(); }
S
superjomn 已提交
94 95 96
};

/*
S
superjomn 已提交
97 98
 * The data types of kernel parameters. It is used to track the type of kernel's
 * inputs and outputs.
S
superjomn 已提交
99
 */
S
superjomn 已提交
100 101 102
struct ParamTypeRecorder {
  std::map<std::string, ParamType> inputs;
  std::map<std::string, ParamType> outputs;
S
superjomn 已提交
103

S
superjomn 已提交
104 105
  void RegisterInputType(const std::string& arg_name, const ParamType& type) {
    Register(&inputs, arg_name, type);
S
superjomn 已提交
106 107
  }

S
superjomn 已提交
108 109
  void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
    Register(&outputs, arg_name, type);
S
superjomn 已提交
110 111 112
  }

 private:
S
superjomn 已提交
113 114 115
  void Register(std::map<std::string, ParamType>* ts,
                const std::string& arg_name, ParamType type) {
    (*ts)[arg_name] = type;
S
superjomn 已提交
116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131
  }
};

/*
 * The ParamTypeRegistry help register the input and output data types for all
 * the kernels. It is made singleton so that all the objects of the same kernel
 * can share the same information.
 *
 * Usage:
 * for register a kernel for FC operator.
 * ParamTypeRegistry::Global().Register(
 *        "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
 *        {typeid(Tensor), {TARGET(kCUDA)}});
 */
class ParamTypeRegistry {
 public:
S
superjomn 已提交
132 133
  enum class IO : int { kInput = 0, kOutput };

S
superjomn 已提交
134 135 136 137 138 139 140 141 142 143 144 145
  template <TargetType target, PrecisionType precision,
            DataLayoutType layout = DataLayoutType::kNCHW>
  /*
   * Helper class for registering a ParamType for a Kernel.
   * Usage:
   *
   * NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
   *   .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
   *   .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
   *                                               PRECISION(kFloat)});
   */
  struct NewInstance {
S
superjomn 已提交
146 147
    explicit NewInstance(const std::string& kernel_type)
        : kernel_type_(kernel_type) {}
S
superjomn 已提交
148

S
superjomn 已提交
149 150
    NewInstance& BindInput(const std::string& arg_name,
                           const ParamType& ptype) {
S
superjomn 已提交
151
      ParamTypeRegistry::Global().Register<IO::kInput>(
S
superjomn 已提交
152
          kernel_type_, Place{target, precision, layout}, arg_name, ptype);
S
superjomn 已提交
153 154
      return *this;
    }
S
superjomn 已提交
155 156
    NewInstance& BindOutput(const std::string& arg_name,
                            const ParamType& ptype) {
S
superjomn 已提交
157
      ParamTypeRegistry::Global().Register<IO::kOutput>(
S
superjomn 已提交
158
          kernel_type_, Place{target, precision, layout}, arg_name, ptype);
S
superjomn 已提交
159 160 161 162 163 164 165 166 167
      return *this;
    }

    bool Finalize() { return true; }

   private:
    std::string kernel_type_;
  };

S
superjomn 已提交
168
  template <IO io>
S
superjomn 已提交
169 170 171
  void Register(const std::string& kernel_type, const Place& place,
                const std::string& arg_name, ParamType data_type) {
    KernelIdTy key{kernel_type, place, io, arg_name};
S
superjomn 已提交
172
    types_[key] = data_type;
173
    CHECK(types_.count(key));
S
superjomn 已提交
174
  }
S
superjomn 已提交
175

176 177 178 179 180 181 182 183
  template <IO io>
  const ParamType* Retrieve(const Place& place, const std::string& op_type,
                            const std::string& arg_name) {
    KernelIdTy key{op_type, place, io, arg_name};
    auto it = types_.find(key);
    if (it == types_.end()) return nullptr;
    return &it->second;
  }
S
superjomn 已提交
184 185 186 187 188 189

  static ParamTypeRegistry& Global() {
    static ParamTypeRegistry x;
    return x;
  }

190 191 192 193 194 195 196 197
  friend std::ostream& operator<<(std::ostream& os,
                                  const ParamTypeRegistry& other) {
    for (auto& item : other.types_) {
      os << item.first << " " << item.second.DebugString() << "\n";
    }
    return os;
  }

S
superjomn 已提交
198 199 200 201 202
 private:
  ParamTypeRegistry() = default;

 public:
  // Identification for a Kernel.
S
superjomn 已提交
203
  struct KernelIdTy {
S
superjomn 已提交
204 205 206
    std::string kernel_type;
    Place place;
    IO io;
S
superjomn 已提交
207
    std::string arg_name;
208 209 210 211 212 213 214 215 216 217

    size_t hash() const {
      std::hash<std::string> h;
      size_t hash = h(kernel_type);
      hash = hash_combine(hash, place.hash());
      hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
      hash = hash_combine(hash, std::hash<std::string>()(arg_name));
      return hash;
    }
    friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
S
superjomn 已提交
218 219
  };

S
superjomn 已提交
220
  using key_t = KernelIdTy;
S
superjomn 已提交
221 222 223 224 225 226 227 228
  struct KeyCmp {
    bool operator()(const key_t& a, const key_t& b) const;
  };

 private:
  std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
};

S
superjomn 已提交
229 230 231
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
S
superjomn 已提交
232 233
template <TargetType Target, PrecisionType Precision,
          DataLayoutType DataLayout = DataLayoutType::kNCHW>
S
superjomn 已提交
234 235
class OpKernel : public KernelBase {
 public:
S
superjomn 已提交
236 237
  // Set runtime context.
  void SetContext(std::unique_ptr<KernelContext>&& ctx) { ctx_ = ctx; }
S
superjomn 已提交
238

S
superjomn 已提交
239 240
  // Run the kernel.
  virtual void Run() { CHECK(false) << "Not Implemented"; }
S
superjomn 已提交
241

S
superjomn 已提交
242 243
  TargetType target() const override { return Target; }
  PrecisionType precision() const override { return Precision; }
S
superjomn 已提交
244
  DataLayoutType layout() const override { return DataLayout; }
245
  Place place() const override { return Place{Target, Precision, DataLayout}; }
S
superjomn 已提交
246 247 248 249
  std::string name() const override {
    return op_type() + ":" + TargetToStr(Target) + "/" +
           PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
  }
S
superjomn 已提交
250

S
superjomn 已提交
251 252
  void Touch() {}

S
superjomn 已提交
253 254
  OpKernel() = default;
  virtual ~OpKernel() = default;
S
superjomn 已提交
255 256 257

 protected:
  std::unique_ptr<KernelContext> ctx_;
S
superjomn 已提交
258 259 260 261
};

}  // namespace lite
}  // namespace paddle