kernel.h 6.7 KB
Newer Older
S
superjomn 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <map>
S
superjomn 已提交
18
#include <set>
S
superjomn 已提交
19
#include <string>
S
superjomn 已提交
20
#include <vector>
S
superjomn 已提交
21
#include "paddle/fluid/framework/op_desc.h"
S
superjomn 已提交
22 23
#include "paddle/fluid/lite/core/context.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
S
superjomn 已提交
24
#include "paddle/fluid/lite/core/type_system.h"
S
superjomn 已提交
25
#include "paddle/fluid/lite/core/types.h"
S
superjomn 已提交
26
#include "paddle/fluid/lite/operators/op_params.h"
S
superjomn 已提交
27 28 29 30 31
#include "paddle/fluid/lite/utils/all.h"

namespace paddle {
namespace lite {

S
update  
superjomn 已提交
32 33
// An base with virtual functions to unify all the kernel implementation on
// different targets.
S
superjomn 已提交
34
class KernelBase {
S
superjomn 已提交
35
 public:
S
superjomn 已提交
36
  virtual void Run() = 0;
S
superjomn 已提交
37

S
superjomn 已提交
38 39 40 41
  template <TargetType Target>
  void SetContext(std::unique_ptr<Context<Target>>&& ctx) {
    context_.set<std::unique_ptr<Context<Target>>>(std::move(ctx));
  }
S
superjomn 已提交
42

S
superjomn 已提交
43 44 45 46
  template <typename T>
  void SetParam(T param) {
    param_.set<T>(param);
  }
S
superjomn 已提交
47 48 49

  template <typename Param>
  Param& param() const {
S
superjomn 已提交
50
    return param_.get<Param>();
S
superjomn 已提交
51 52
  }

S
superjomn 已提交
53 54 55
  void set_op_type(const std::string& type) { op_type_ = type; }
  const std::string& op_type() const { return op_type_; }

S
superjomn 已提交
56 57
  void Torch() {}

S
update  
superjomn 已提交
58 59
  virtual TargetType target() const = 0;
  virtual PrecisionType precision() const = 0;
S
superjomn 已提交
60
  virtual DataLayoutType layout() const = 0;
S
update  
superjomn 已提交
61

S
superjomn 已提交
62 63
  virtual std::string name() const = 0;

S
superjomn 已提交
64
  virtual ~KernelBase() = default;
S
update  
superjomn 已提交
65 66

 protected:
S
superjomn 已提交
67 68
  core::any_context_t context_;
  mutable operators::param_t param_;
S
superjomn 已提交
69 70
  // The corresponding op type.
  std::string op_type_;
S
superjomn 已提交
71 72
};

S
superjomn 已提交
73 74 75 76 77 78 79
/*
 * ParamType is used to represent a data type of a parameter for the kernel. It
 * can represent any Variable data type.
 * The element_type_hash is the hash code of the element, it should be
 * registered in the `TypeSystem`.
 */
struct ParamType {
S
superjomn 已提交
80
  // For unsupported types.
S
superjomn 已提交
81 82
  size_t element_type_hash{};
  Place tensor_place{};
S
superjomn 已提交
83
  const Type* type_;
S
superjomn 已提交
84

S
superjomn 已提交
85 86 87
  explicit ParamType() = default;
  explicit ParamType(size_t element_type_hash)
      : element_type_hash(element_type_hash) {}
S
superjomn 已提交
88 89
  ParamType(size_t element_type_hash, const Place& place)
      : element_type_hash(element_type_hash), tensor_place(place) {}
S
superjomn 已提交
90
  ParamType(const Type* type) : type_(type) {}
S
superjomn 已提交
91 92 93
};

/*
S
superjomn 已提交
94 95
 * The data types of kernel parameters. It is used to track the type of kernel's
 * inputs and outputs.
S
superjomn 已提交
96
 */
S
superjomn 已提交
97 98 99
struct ParamTypeRecorder {
  std::map<std::string, ParamType> inputs;
  std::map<std::string, ParamType> outputs;
S
superjomn 已提交
100

S
superjomn 已提交
101 102
  void RegisterInputType(const std::string& arg_name, const ParamType& type) {
    Register(&inputs, arg_name, type);
S
superjomn 已提交
103 104
  }

S
superjomn 已提交
105 106
  void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
    Register(&outputs, arg_name, type);
S
superjomn 已提交
107 108 109
  }

 private:
S
superjomn 已提交
110 111 112
  void Register(std::map<std::string, ParamType>* ts,
                const std::string& arg_name, ParamType type) {
    (*ts)[arg_name] = type;
S
superjomn 已提交
113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
  }
};

/*
 * The ParamTypeRegistry help register the input and output data types for all
 * the kernels. It is made singleton so that all the objects of the same kernel
 * can share the same information.
 *
 * Usage:
 * for register a kernel for FC operator.
 * ParamTypeRegistry::Global().Register(
 *        "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
 *        {typeid(Tensor), {TARGET(kCUDA)}});
 */
class ParamTypeRegistry {
 public:
S
superjomn 已提交
129 130
  enum class IO : int { kInput = 0, kOutput };

S
superjomn 已提交
131 132 133 134 135 136 137 138 139 140 141 142
  template <TargetType target, PrecisionType precision,
            DataLayoutType layout = DataLayoutType::kNCHW>
  /*
   * Helper class for registering a ParamType for a Kernel.
   * Usage:
   *
   * NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
   *   .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
   *   .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
   *                                               PRECISION(kFloat)});
   */
  struct NewInstance {
S
superjomn 已提交
143 144
    explicit NewInstance(const std::string& kernel_type)
        : kernel_type_(kernel_type) {}
S
superjomn 已提交
145

S
superjomn 已提交
146 147
    NewInstance& BindInput(const std::string& arg_name,
                           const ParamType& ptype) {
S
superjomn 已提交
148
      ParamTypeRegistry::Global().Register<IO::kInput>(
S
superjomn 已提交
149
          kernel_type_, Place{target, precision, layout}, arg_name, ptype);
S
superjomn 已提交
150 151
      return *this;
    }
S
superjomn 已提交
152 153
    NewInstance& BindOutput(const std::string& arg_name,
                            const ParamType& ptype) {
S
superjomn 已提交
154
      ParamTypeRegistry::Global().Register<IO::kOutput>(
S
superjomn 已提交
155
          kernel_type_, Place{target, precision, layout}, arg_name, ptype);
S
superjomn 已提交
156 157 158 159 160 161 162 163 164
      return *this;
    }

    bool Finalize() { return true; }

   private:
    std::string kernel_type_;
  };

S
superjomn 已提交
165
  template <IO io>
S
superjomn 已提交
166 167 168
  void Register(const std::string& kernel_type, const Place& place,
                const std::string& arg_name, ParamType data_type) {
    KernelIdTy key{kernel_type, place, io, arg_name};
S
superjomn 已提交
169 170
    types_[key] = data_type;
  }
S
superjomn 已提交
171 172 173 174 175 176 177 178 179 180 181 182 183

  ParamType Retrive(const Place& place, int offset);

  static ParamTypeRegistry& Global() {
    static ParamTypeRegistry x;
    return x;
  }

 private:
  ParamTypeRegistry() = default;

 public:
  // Identification for a Kernel.
S
superjomn 已提交
184
  struct KernelIdTy {
S
superjomn 已提交
185 186 187
    std::string kernel_type;
    Place place;
    IO io;
S
superjomn 已提交
188
    std::string arg_name;
S
superjomn 已提交
189 190
  };

S
superjomn 已提交
191
  using key_t = KernelIdTy;
S
superjomn 已提交
192 193 194 195 196 197 198 199
  struct KeyCmp {
    bool operator()(const key_t& a, const key_t& b) const;
  };

 private:
  std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
};

S
superjomn 已提交
200 201 202
// Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target
// device.
S
superjomn 已提交
203 204
template <TargetType Target, PrecisionType Precision,
          DataLayoutType DataLayout = DataLayoutType::kNCHW>
S
superjomn 已提交
205 206
class OpKernel : public KernelBase {
 public:
S
superjomn 已提交
207 208
  // Set runtime context.
  void SetContext(std::unique_ptr<KernelContext>&& ctx) { ctx_ = ctx; }
S
superjomn 已提交
209

S
superjomn 已提交
210 211
  // Run the kernel.
  virtual void Run() { CHECK(false) << "Not Implemented"; }
S
superjomn 已提交
212

S
superjomn 已提交
213 214
  TargetType target() const override { return Target; }
  PrecisionType precision() const override { return Precision; }
S
superjomn 已提交
215
  DataLayoutType layout() const override { return DataLayout; }
S
superjomn 已提交
216 217 218 219
  std::string name() const override {
    return op_type() + ":" + TargetToStr(Target) + "/" +
           PrecisionToStr(Precision) + "/" + DataLayoutToStr(DataLayout);
  }
S
superjomn 已提交
220

S
superjomn 已提交
221 222
  void Touch() {}

S
superjomn 已提交
223 224
  OpKernel() = default;
  virtual ~OpKernel() = default;
S
superjomn 已提交
225 226 227

 protected:
  std::unique_ptr<KernelContext> ctx_;
S
superjomn 已提交
228 229 230 231
};

}  // namespace lite
}  // namespace paddle