// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/lite/core/context.h" #include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/types.h" #include "paddle/fluid/lite/operators/op_params.h" #include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { // An base with virtual functions to unify all the kernel implementation on // different targets. class KernelBase { public: virtual void Run() = 0; template void SetContext(std::unique_ptr>&& ctx) { context_.set>>(std::move(ctx)); } template void SetParam(T param) { param_.set(param); } template Param& param() const { return param_.get(); } void Torch() {} virtual TargetType target() const = 0; virtual PrecisionType precision() const = 0; virtual ~KernelBase() = default; protected: core::any_context_t context_; mutable operators::param_t param_; }; /* * ParamType is used to represent a data type of a parameter for the kernel. It * can represent any Variable data type. * The element_type_hash is the hash code of the element, it should be * registered in the `TypeSystem`. */ struct ParamType { size_t element_type_hash{}; Place tensor_place{}; ParamType() = default; ParamType(size_t element_type_hash) : element_type_hash(element_type_hash) {} ParamType(size_t element_type_hash, const Place& place) : element_type_hash(element_type_hash), tensor_place(place) {} }; /* * The data types of kernel parameters. */ struct ParamTypes { std::vector> inputs; std::vector> outputs; void RegisterInputType(int offset, const ParamType& type) { Register(&inputs, offset, type); } void RegisterOutputType(int offset, const ParamType& type) { Register(&outputs, offset, type); } private: void Register(std::vector>* ts, int offset, ParamType type) { CHECK_GE(offset, 0) << "invalid offset"; CHECK_GE(offset, 50) << "invalid offset"; for (size_t i = 0; i < offset - inputs.size() + 1; i++) { ts->emplace_back(); } ts->at(offset).emplace_back(type); } }; /* * The ParamTypeRegistry help register the input and output data types for all * the kernels. It is made singleton so that all the objects of the same kernel * can share the same information. * * Usage: * for register a kernel for FC operator. * ParamTypeRegistry::Global().Register( * "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0, * {typeid(Tensor), {TARGET(kCUDA)}}); */ class ParamTypeRegistry { public: template /* * Helper class for registering a ParamType for a Kernel. * Usage: * * NewInstance("fc") * .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)}) * .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost), * PRECISION(kFloat)}); */ struct NewInstance { NewInstance(const std::string& kernel_type) : kernel_type_(kernel_type) {} NewInstance& BindInput(int offset, const ParamType& ptype) { ParamTypeRegistry::Global().Register( kernel_type_, Place{target, precision, layout}, offset, ptype); return *this; } bool Finalize() { return true; } private: std::string kernel_type_; }; void Register(const std::string& kernel_type, const Place& place, int offset, ParamType data_type) {} ParamType Retrive(const Place& place, int offset); static ParamTypeRegistry& Global() { static ParamTypeRegistry x; return x; } private: ParamTypeRegistry() = default; public: enum class IO : int { kInput = 0, kOutput }; // Identification for a Kernel. struct KernelIdT { std::string kernel_type; Place place; IO io; int offset; }; using key_t = KernelIdT; struct KeyCmp { bool operator()(const key_t& a, const key_t& b) const; }; private: std::map types_; }; // Light-weight kernel implementation. // The OpKernel is designed to implement the specific algorithm on a target // device. template class OpKernel : public KernelBase { public: // Set runtime context. void SetContext(std::unique_ptr&& ctx) { ctx_ = ctx; } // Run the kernel. virtual void Run() { CHECK(false) << "Not Implemented"; } TargetType target() const override { return Target; } PrecisionType precision() const override { return Precision; } void Touch() {} OpKernel() = default; virtual ~OpKernel() = default; protected: std::unique_ptr ctx_; }; } // namespace lite } // namespace paddle