diff --git a/paddle/fluid/lite/core/kernel.h b/paddle/fluid/lite/core/kernel.h index 6bba59263ace64884d4737ca490855d8ae964f1a..d3308d016f26cd1fc44926fa126a20cdfd63c1eb 100644 --- a/paddle/fluid/lite/core/kernel.h +++ b/paddle/fluid/lite/core/kernel.h @@ -45,9 +45,9 @@ class KernelBase { param_.set(param); } - template - Param& param() const { - return param_.get(); + template + P& Param() const { + return param_.get

(); } void set_op_type(const std::string& type) { op_type_ = type; } @@ -71,161 +71,6 @@ class KernelBase { std::string op_type_; }; -/* - * ParamType is used to represent a data type of a parameter for the kernel. It - * can represent any Variable data type. - * The element_type_hash is the hash code of the element, it should be - * registered in the `TypeSystem`. - */ -struct ParamType { - // For unsupported types. - size_t element_type_hash{}; - Place tensor_place{}; - const Type* type_; - - explicit ParamType() = default; - explicit ParamType(size_t element_type_hash) - : element_type_hash(element_type_hash) {} - ParamType(size_t element_type_hash, const Place& place) - : element_type_hash(element_type_hash), tensor_place(place) {} - ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); } - - std::string DebugString() const { return tensor_place.DebugString(); } -}; - -/* - * The data types of kernel parameters. It is used to track the type of kernel's - * inputs and outputs. - */ -struct ParamTypeRecorder { - std::map inputs; - std::map outputs; - - void RegisterInputType(const std::string& arg_name, const ParamType& type) { - Register(&inputs, arg_name, type); - } - - void RegisterOutputType(const std::string& arg_name, const ParamType& type) { - Register(&outputs, arg_name, type); - } - - private: - void Register(std::map* ts, - const std::string& arg_name, ParamType type) { - (*ts)[arg_name] = type; - } -}; - -/* - * The ParamTypeRegistry help register the input and output data types for all - * the kernels. It is made singleton so that all the objects of the same kernel - * can share the same information. - * - * Usage: - * for register a kernel for FC operator. - * ParamTypeRegistry::Global().Register( - * "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0, - * {typeid(Tensor), {TARGET(kCUDA)}}); - */ -class ParamTypeRegistry { - public: - enum class IO : int { kInput = 0, kOutput }; - - template - /* - * Helper class for registering a ParamType for a Kernel. - * Usage: - * - * NewInstance("fc") - * .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)}) - * .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost), - * PRECISION(kFloat)}); - */ - struct NewInstance { - explicit NewInstance(const std::string& kernel_type) - : kernel_type_(kernel_type) {} - - NewInstance& BindInput(const std::string& arg_name, - const ParamType& ptype) { - ParamTypeRegistry::Global().Register( - kernel_type_, Place{target, precision, layout}, arg_name, ptype); - return *this; - } - NewInstance& BindOutput(const std::string& arg_name, - const ParamType& ptype) { - ParamTypeRegistry::Global().Register( - kernel_type_, Place{target, precision, layout}, arg_name, ptype); - return *this; - } - - bool Finalize() { return true; } - - private: - std::string kernel_type_; - }; - - template - void Register(const std::string& kernel_type, const Place& place, - const std::string& arg_name, ParamType data_type) { - KernelIdTy key{kernel_type, place, io, arg_name}; - types_[key] = data_type; - CHECK(types_.count(key)); - } - - template - const ParamType* Retrieve(const Place& place, const std::string& op_type, - const std::string& arg_name) { - KernelIdTy key{op_type, place, io, arg_name}; - auto it = types_.find(key); - if (it == types_.end()) return nullptr; - return &it->second; - } - - static ParamTypeRegistry& Global() { - static ParamTypeRegistry x; - return x; - } - - friend std::ostream& operator<<(std::ostream& os, - const ParamTypeRegistry& other) { - for (auto& item : other.types_) { - os << item.first << " " << item.second.DebugString() << "\n"; - } - return os; - } - - private: - ParamTypeRegistry() = default; - - public: - // Identification for a Kernel. - struct KernelIdTy { - std::string kernel_type; - Place place; - IO io; - std::string arg_name; - - size_t hash() const { - std::hash h; - size_t hash = h(kernel_type); - hash = hash_combine(hash, place.hash()); - hash = hash_combine(hash, std::hash()(static_cast(io))); - hash = hash_combine(hash, std::hash()(arg_name)); - return hash; - } - friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other); - }; - - using key_t = KernelIdTy; - struct KeyCmp { - bool operator()(const key_t& a, const key_t& b) const; - }; - - private: - std::map types_; -}; - // Light-weight kernel implementation. // The OpKernel is designed to implement the specific algorithm on a target // device. diff --git a/paddle/fluid/lite/core/kernel_test.cc b/paddle/fluid/lite/core/kernel_test.cc index 7be5f8714c4af5066e27619b8ac8e8e742539c64..bf23990c8f369b9ac658ad31c26c083b49239c0e 100644 --- a/paddle/fluid/lite/core/kernel_test.cc +++ b/paddle/fluid/lite/core/kernel_test.cc @@ -25,8 +25,8 @@ class SomeKernel : public OpKernel { public: void Run() override { LOG(INFO) << "SomeKernel executed"; - LOG(INFO) << param().in_num_col_dims; - test_code = param().in_num_col_dims; + LOG(INFO) << Param().in_num_col_dims; + test_code = Param().in_num_col_dims; } TargetType target() const override { return TARGET(kHost); } diff --git a/paddle/fluid/lite/core/type_system.h b/paddle/fluid/lite/core/type_system.h index d9107afff8d21714137d12110e302412be7cf42b..05240e2943aa94293515887146d6bddc13a0370e 100644 --- a/paddle/fluid/lite/core/type_system.h +++ b/paddle/fluid/lite/core/type_system.h @@ -21,12 +21,14 @@ // for analysis and runtime. #include +#include #include #include #include #include #include #include "paddle/fluid/lite/core/tensor.h" +#include "paddle/fluid/lite/utils/all.h" namespace paddle { namespace lite { @@ -202,5 +204,160 @@ class TypeSystem { std::unordered_set names_; }; +/* + * ParamType is used to represent a data type of a parameter for the kernel. It + * can represent any Variable data type. + * The element_type_hash is the hash code of the element, it should be + * registered in the `TypeSystem`. + */ +struct ParamType { + // For unsupported types. + size_t element_type_hash{}; + Place tensor_place{}; + const Type* type_; + + explicit ParamType() = default; + explicit ParamType(size_t element_type_hash) + : element_type_hash(element_type_hash) {} + ParamType(size_t element_type_hash, const Place& place) + : element_type_hash(element_type_hash), tensor_place(place) {} + ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); } + + std::string DebugString() const { return tensor_place.DebugString(); } +}; + +/* + * The data types of kernel parameters. It is used to track the type of kernel's + * inputs and outputs. + */ +struct ParamTypeRecorder { + std::map inputs; + std::map outputs; + + void RegisterInputType(const std::string& arg_name, const ParamType& type) { + Register(&inputs, arg_name, type); + } + + void RegisterOutputType(const std::string& arg_name, const ParamType& type) { + Register(&outputs, arg_name, type); + } + + private: + void Register(std::map* ts, + const std::string& arg_name, ParamType type) { + (*ts)[arg_name] = type; + } +}; + +/* + * The ParamTypeRegistry help register the input and output data types for all + * the kernels. It is made singleton so that all the objects of the same kernel + * can share the same information. + * + * Usage: + * for register a kernel for FC operator. + * ParamTypeRegistry::Global().Register( + * "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0, + * {typeid(Tensor), {TARGET(kCUDA)}}); + */ +class ParamTypeRegistry { + public: + enum class IO : int { kInput = 0, kOutput }; + + template + /* + * Helper class for registering a ParamType for a Kernel. + * Usage: + * + * NewInstance("fc") + * .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)}) + * .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost), + * PRECISION(kFloat)}); + */ + struct NewInstance { + explicit NewInstance(const std::string& kernel_type) + : kernel_type_(kernel_type) {} + + NewInstance& BindInput(const std::string& arg_name, + const ParamType& ptype) { + ParamTypeRegistry::Global().Register( + kernel_type_, Place{target, precision, layout}, arg_name, ptype); + return *this; + } + NewInstance& BindOutput(const std::string& arg_name, + const ParamType& ptype) { + ParamTypeRegistry::Global().Register( + kernel_type_, Place{target, precision, layout}, arg_name, ptype); + return *this; + } + + bool Finalize() { return true; } + + private: + std::string kernel_type_; + }; + + template + void Register(const std::string& kernel_type, const Place& place, + const std::string& arg_name, ParamType data_type) { + KernelIdTy key{kernel_type, place, io, arg_name}; + types_[key] = data_type; + CHECK(types_.count(key)); + } + + template + const ParamType* Retrieve(const Place& place, const std::string& op_type, + const std::string& arg_name) { + KernelIdTy key{op_type, place, io, arg_name}; + auto it = types_.find(key); + if (it == types_.end()) return nullptr; + return &it->second; + } + + static ParamTypeRegistry& Global() { + static ParamTypeRegistry x; + return x; + } + + friend std::ostream& operator<<(std::ostream& os, + const ParamTypeRegistry& other) { + for (auto& item : other.types_) { + os << item.first << " " << item.second.DebugString() << "\n"; + } + return os; + } + + private: + ParamTypeRegistry() = default; + + public: + // Identification for a Kernel. + struct KernelIdTy { + std::string kernel_type; + Place place; + IO io; + std::string arg_name; + + size_t hash() const { + std::hash h; + size_t hash = h(kernel_type); + hash = hash_combine(hash, place.hash()); + hash = hash_combine(hash, std::hash()(static_cast(io))); + hash = hash_combine(hash, std::hash()(arg_name)); + return hash; + } + friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other); + }; + + using key_t = KernelIdTy; + struct KeyCmp { + bool operator()(const key_t& a, const key_t& b) const; + }; + + private: + std::map types_; +}; + } // namespace lite } // namespace paddle