Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #3396

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 8月 10, 2017 by saxon_zh@saxon_zhGuest

LODTensor related Variable type-deduce with inherience

Created by: Superjomn

LODTensor 的逻辑需要在 Variable 里添加一个继承关系的推导,最终达到此目的:

  • 前提, LODTensor 继承自 Tensor
  • 如果 variable 已经被设定为 LODTensor , 则 variable.Get<Tensor>() 和 variable.GetMutate<Tensor>() 都返回原有 LODTensor 中的内容

如此,必须记录相关类型继承的关系。

下面是一个简单的方案:

  • 建立一个全局的map 通过 PrefixHash 记录不同类型
  • base class 的 PrefixHash 会是所有 derived class 的前缀
  • Variable 在判定两种类型是否有继承关系时,获取两者的 PrefixHash 并判定是否是前缀关系便可
#include <cstring>
#include <iostream>
#include <map>
#include <typeindex>
#include <typeinfo>

struct PrefixHash {
  std::string hash;
  size_t num_children;

  bool IsDescendentOf(const PrefixHash &other) {
    if (hash.size() < other.hash.size() &&
        std::memcmp(hash.data(), other.hash.data(), hash.size())) {
      return true;
    }
    return false;
  }

  void SetDescendentOf(PrefixHash &other) {
    hash = other.hash;
    hash.push_back((unsigned char)other.num_children++);
  }
};

// base of all types
struct BaseType {};

struct TypeDescendentDeducer {

  TypeDescendentDeducer() {
    // insert base type
    PrefixHash hash{"0", 0};
    type_map[typeid(BaseType)] = hash;
  }

  static TypeDescendentDeducer &Global() {
    static TypeDescendentDeducer x;
    return x;
  }

  template <typename Father, typename Child> void Register() {
    const auto &father_type = std::type_index(typeid(Father));
    const auto &child_type = std::type_index(typeid(Child));

    auto father_it = type_map.find(father_type);
    // insert father record
    if (father_it == type_map.end()) {
      auto &base_type = type_map[typeid(BaseType)];
      PrefixHash hash;
      hash.SetDescendentOf(base_type);
      type_map[std::type_index(father_type)] = hash;
    }

    PrefixHash child_hash;
    child_hash.SetDescendentOf(type_map[father_type]);
    type_map[child_type] = child_hash;
  }

  template <typename T> bool IsDescendentOf(const std::type_info &child) {
    auto child_it = type_map.find(child);
    auto father_it = type_map.find(typeid(T));
    if (child_it == type_map.end() || father_it == type_map.end())
      return false;
    return child_it->second.IsDescendentOf(father_it->second);
  }

  std::map<std::type_index, PrefixHash> type_map;
};

#define REGISTER_TYPE_DESCENDENCE(__father, __child)                           \
  struct __type_descendence_##__father##__child##__ {                          \
    __type_descendence_##__father##__child##__() {                             \
      TypeDescendentDeducer::Global().Register<__father, __child>();           \
    }                                                                          \
  };                                                                           \
  __type_descendence_##__father##__child##__                                   \
      __type_descendence_##__father##__child##___;

具体使用方法:

// register LODTensor as a derived class of Tensor
class Tensor {};
class LODTensor : public Tensor {};

REGISTER_TYPE_DESCENDENCE(Tensor, LODTensor)

class Variable {

public:
  // ...
  template <typename T> bool IsType() const {
    if (std::type_index(typeid(T)) == type_) {
      return true;
    }
    return TypeDescendentDeducer::Global().IsDescendentOf<T>(type_);
  }

private:
  std::type_index type_;
};
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#3396
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7