提交 e55a5cd9 编写于 作者: S Superjomn

move ParamTypes to type_system.h

上级 43036686
...@@ -45,9 +45,9 @@ class KernelBase { ...@@ -45,9 +45,9 @@ class KernelBase {
param_.set<T>(param); param_.set<T>(param);
} }
template <typename Param> template <typename P>
Param& param() const { P& Param() const {
return param_.get<Param>(); return param_.get<P>();
} }
void set_op_type(const std::string& type) { op_type_ = type; } void set_op_type(const std::string& type) { op_type_ = type; }
...@@ -71,161 +71,6 @@ class KernelBase { ...@@ -71,161 +71,6 @@ class KernelBase {
std::string op_type_; std::string op_type_;
}; };
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* The element_type_hash is the hash code of the element, it should be
* registered in the `TypeSystem`.
*/
struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{};
const Type* type_;
explicit ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
};
/*
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct ParamTypeRecorder {
std::map<std::string, ParamType> inputs;
std::map<std::string, ParamType> outputs;
void RegisterInputType(const std::string& arg_name, const ParamType& type) {
Register(&inputs, arg_name, type);
}
void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
Register(&outputs, arg_name, type);
}
private:
void Register(std::map<std::string, ParamType>* ts,
const std::string& arg_name, ParamType type) {
(*ts)[arg_name] = type;
}
};
/*
* The ParamTypeRegistry help register the input and output data types for all
* the kernels. It is made singleton so that all the objects of the same kernel
* can share the same information.
*
* Usage:
* for register a kernel for FC operator.
* ParamTypeRegistry::Global().Register(
* "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
* {typeid(Tensor), {TARGET(kCUDA)}});
*/
class ParamTypeRegistry {
public:
enum class IO : int { kInput = 0, kOutput };
template <TargetType target, PrecisionType precision,
DataLayoutType layout = DataLayoutType::kNCHW>
/*
* Helper class for registering a ParamType for a Kernel.
* Usage:
*
* NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
* .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
* .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
* PRECISION(kFloat)});
*/
struct NewInstance {
explicit NewInstance(const std::string& kernel_type)
: kernel_type_(kernel_type) {}
NewInstance& BindInput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
NewInstance& BindOutput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
bool Finalize() { return true; }
private:
std::string kernel_type_;
};
template <IO io>
void Register(const std::string& kernel_type, const Place& place,
const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type;
CHECK(types_.count(key));
}
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
static ParamTypeRegistry& Global() {
static ParamTypeRegistry x;
return x;
}
friend std::ostream& operator<<(std::ostream& os,
const ParamTypeRegistry& other) {
for (auto& item : other.types_) {
os << item.first << " " << item.second.DebugString() << "\n";
}
return os;
}
private:
ParamTypeRegistry() = default;
public:
// Identification for a Kernel.
struct KernelIdTy {
std::string kernel_type;
Place place;
IO io;
std::string arg_name;
size_t hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
};
using key_t = KernelIdTy;
struct KeyCmp {
bool operator()(const key_t& a, const key_t& b) const;
};
private:
std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
};
// Light-weight kernel implementation. // Light-weight kernel implementation.
// The OpKernel is designed to implement the specific algorithm on a target // The OpKernel is designed to implement the specific algorithm on a target
// device. // device.
......
...@@ -25,8 +25,8 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> { ...@@ -25,8 +25,8 @@ class SomeKernel : public OpKernel<TARGET(kHost), PRECISION(kFloat)> {
public: public:
void Run() override { void Run() override {
LOG(INFO) << "SomeKernel executed"; LOG(INFO) << "SomeKernel executed";
LOG(INFO) << param<operators::FcParam>().in_num_col_dims; LOG(INFO) << Param<operators::FcParam>().in_num_col_dims;
test_code = param<operators::FcParam>().in_num_col_dims; test_code = Param<operators::FcParam>().in_num_col_dims;
} }
TargetType target() const override { return TARGET(kHost); } TargetType target() const override { return TARGET(kHost); }
......
...@@ -21,12 +21,14 @@ ...@@ -21,12 +21,14 @@
// for analysis and runtime. // for analysis and runtime.
#include <glog/logging.h> #include <glog/logging.h>
#include <map>
#include <string> #include <string>
#include <typeinfo> #include <typeinfo>
#include <unordered_map> #include <unordered_map>
#include <unordered_set> #include <unordered_set>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/tensor.h" #include "paddle/fluid/lite/core/tensor.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
...@@ -202,5 +204,160 @@ class TypeSystem { ...@@ -202,5 +204,160 @@ class TypeSystem {
std::unordered_set<std::string> names_; std::unordered_set<std::string> names_;
}; };
/*
* ParamType is used to represent a data type of a parameter for the kernel. It
* can represent any Variable data type.
* The element_type_hash is the hash code of the element, it should be
* registered in the `TypeSystem`.
*/
struct ParamType {
// For unsupported types.
size_t element_type_hash{};
Place tensor_place{};
const Type* type_;
explicit ParamType() = default;
explicit ParamType(size_t element_type_hash)
: element_type_hash(element_type_hash) {}
ParamType(size_t element_type_hash, const Place& place)
: element_type_hash(element_type_hash), tensor_place(place) {}
ParamType(const Type* type) : type_(type) { tensor_place = type_->place(); }
std::string DebugString() const { return tensor_place.DebugString(); }
};
/*
* The data types of kernel parameters. It is used to track the type of kernel's
* inputs and outputs.
*/
struct ParamTypeRecorder {
std::map<std::string, ParamType> inputs;
std::map<std::string, ParamType> outputs;
void RegisterInputType(const std::string& arg_name, const ParamType& type) {
Register(&inputs, arg_name, type);
}
void RegisterOutputType(const std::string& arg_name, const ParamType& type) {
Register(&outputs, arg_name, type);
}
private:
void Register(std::map<std::string, ParamType>* ts,
const std::string& arg_name, ParamType type) {
(*ts)[arg_name] = type;
}
};
/*
* The ParamTypeRegistry help register the input and output data types for all
* the kernels. It is made singleton so that all the objects of the same kernel
* can share the same information.
*
* Usage:
* for register a kernel for FC operator.
* ParamTypeRegistry::Global().Register(
* "fc", {TARGET(kCUDA), PRECISION(kFloat)}, 0,
* {typeid(Tensor), {TARGET(kCUDA)}});
*/
class ParamTypeRegistry {
public:
enum class IO : int { kInput = 0, kOutput };
template <TargetType target, PrecisionType precision,
DataLayoutType layout = DataLayoutType::kNCHW>
/*
* Helper class for registering a ParamType for a Kernel.
* Usage:
*
* NewInstance<TARGET(kHost), PRECISION(kFloat)>("fc")
* .BindInput(0, {typeid(Tensor).hash_code(), {TARGET(kHost)})
* .BindInput(1, {typeid(Tensor).hash_code(), {TARGET(kHost),
* PRECISION(kFloat)});
*/
struct NewInstance {
explicit NewInstance(const std::string& kernel_type)
: kernel_type_(kernel_type) {}
NewInstance& BindInput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kInput>(
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
NewInstance& BindOutput(const std::string& arg_name,
const ParamType& ptype) {
ParamTypeRegistry::Global().Register<IO::kOutput>(
kernel_type_, Place{target, precision, layout}, arg_name, ptype);
return *this;
}
bool Finalize() { return true; }
private:
std::string kernel_type_;
};
template <IO io>
void Register(const std::string& kernel_type, const Place& place,
const std::string& arg_name, ParamType data_type) {
KernelIdTy key{kernel_type, place, io, arg_name};
types_[key] = data_type;
CHECK(types_.count(key));
}
template <IO io>
const ParamType* Retrieve(const Place& place, const std::string& op_type,
const std::string& arg_name) {
KernelIdTy key{op_type, place, io, arg_name};
auto it = types_.find(key);
if (it == types_.end()) return nullptr;
return &it->second;
}
static ParamTypeRegistry& Global() {
static ParamTypeRegistry x;
return x;
}
friend std::ostream& operator<<(std::ostream& os,
const ParamTypeRegistry& other) {
for (auto& item : other.types_) {
os << item.first << " " << item.second.DebugString() << "\n";
}
return os;
}
private:
ParamTypeRegistry() = default;
public:
// Identification for a Kernel.
struct KernelIdTy {
std::string kernel_type;
Place place;
IO io;
std::string arg_name;
size_t hash() const {
std::hash<std::string> h;
size_t hash = h(kernel_type);
hash = hash_combine(hash, place.hash());
hash = hash_combine(hash, std::hash<int>()(static_cast<int>(io)));
hash = hash_combine(hash, std::hash<std::string>()(arg_name));
return hash;
}
friend std::ostream& operator<<(std::ostream& os, const KernelIdTy& other);
};
using key_t = KernelIdTy;
struct KeyCmp {
bool operator()(const key_t& a, const key_t& b) const;
};
private:
std::map<key_t, ParamType, ParamTypeRegistry::KeyCmp> types_;
};
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册