LODTensor related Variable type-deduce with inherience
Created by: Superjomn
LODTensor 的逻辑需要在 Variable 里添加一个继承关系的推导,最终达到此目的:
- 前提, LODTensor 继承自 Tensor
- 如果
variable
已经被设定为LODTensor
, 则variable.Get<Tensor>()
和variable.GetMutate<Tensor>()
都返回原有LODTensor
中的内容
如此,必须记录相关类型继承的关系。
下面是一个简单的方案:
- 建立一个全局的map 通过
PrefixHash
记录不同类型 -
base class
的PrefixHash
会是所有derived class
的前缀 - Variable 在判定两种类型是否有继承关系时,获取两者的
PrefixHash
并判定是否是前缀关系便可
#include <cstring>
#include <iostream>
#include <map>
#include <typeindex>
#include <typeinfo>
struct PrefixHash {
std::string hash;
size_t num_children;
bool IsDescendentOf(const PrefixHash &other) {
if (hash.size() < other.hash.size() &&
std::memcmp(hash.data(), other.hash.data(), hash.size())) {
return true;
}
return false;
}
void SetDescendentOf(PrefixHash &other) {
hash = other.hash;
hash.push_back((unsigned char)other.num_children++);
}
};
// base of all types
struct BaseType {};
struct TypeDescendentDeducer {
TypeDescendentDeducer() {
// insert base type
PrefixHash hash{"0", 0};
type_map[typeid(BaseType)] = hash;
}
static TypeDescendentDeducer &Global() {
static TypeDescendentDeducer x;
return x;
}
template <typename Father, typename Child> void Register() {
const auto &father_type = std::type_index(typeid(Father));
const auto &child_type = std::type_index(typeid(Child));
auto father_it = type_map.find(father_type);
// insert father record
if (father_it == type_map.end()) {
auto &base_type = type_map[typeid(BaseType)];
PrefixHash hash;
hash.SetDescendentOf(base_type);
type_map[std::type_index(father_type)] = hash;
}
PrefixHash child_hash;
child_hash.SetDescendentOf(type_map[father_type]);
type_map[child_type] = child_hash;
}
template <typename T> bool IsDescendentOf(const std::type_info &child) {
auto child_it = type_map.find(child);
auto father_it = type_map.find(typeid(T));
if (child_it == type_map.end() || father_it == type_map.end())
return false;
return child_it->second.IsDescendentOf(father_it->second);
}
std::map<std::type_index, PrefixHash> type_map;
};
#define REGISTER_TYPE_DESCENDENCE(__father, __child) \
struct __type_descendence_##__father##__child##__ { \
__type_descendence_##__father##__child##__() { \
TypeDescendentDeducer::Global().Register<__father, __child>(); \
} \
}; \
__type_descendence_##__father##__child##__ \
__type_descendence_##__father##__child##___;
具体使用方法:
// register LODTensor as a derived class of Tensor
class Tensor {};
class LODTensor : public Tensor {};
REGISTER_TYPE_DESCENDENCE(Tensor, LODTensor)
class Variable {
public:
// ...
template <typename T> bool IsType() const {
if (std::type_index(typeid(T)) == type_) {
return true;
}
return TypeDescendentDeducer::Global().IsDescendentOf<T>(type_);
}
private:
std::type_index type_;
};