dict_plugin.h 3.4 KB
Newer Older
X
xiexionghang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128
#pragma once

#include <list>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include <glog/logging.h>
#include "paddle/fluid/feed/src/common/bhopscotch_map.h"

namespace paddle {
namespace framework {
class DictPlugin {
public:
    DictPlugin() {}
    virtual ~DictPlugin() {}
    virtual int Load(const std::string& path, const std::string& converter) = 0;
};

template <class K, class V>
class KvEntity {
public:
    KvEntity() {}
    ~KvEntity() {}
    uint32_t Size() {
        return _key_list.size();
    }
    void Append(const K& k, const V& v) {
        if (_dict_data.find(k) != _dict_data.end()) {
            return;
        }
        _key_list.push_back(k);
        _dict_data.emplace(k, v);
    }
    std::vector<K> _key_list;
    tsl::bhopscotch_pg_map<K, V> _dict_data; 
};

template <class K, class V>
class KvDictPlugin : public DictPlugin {
public:
    KvDictPlugin() {
        versioned_entity_.resize(2);
    }
    virtual ~KvDictPlugin() {} 
    
    // GetValue with version, Return: value
    virtual int GetValueWithVersion(uint32_t version, const K& key, V& v) {
        CHECK(version < versioned_entity_.size());
        auto& entity = versioned_entity_[version];
        auto itr = entity._dict_data.find(key);
        if (itr == entity._dict_data.end()) {
            return -1; // miss
        }
        v = itr->second;
        return 0;
    }

    // GetValue without version, Return: value version
    virtual int GetValue(const K& key, V& v, uint32_t& version) {
        version = version_;
        auto& entity = versioned_entity_[version];
        auto itr = entity._dict_data.find(key);
        if (itr == entity._dict_data.end()) {
            return -1; // miss
        }
        v = itr->second;
        return 0;
    }
    
    virtual int GetVersion() {
        return version_;
    }
protected:
    uint32_t version_ = 0;
    // double-buffer support version:0 1
    std::vector<KvEntity<K, V>> versioned_entity_;
};

class FeasignCacheDict : public KvDictPlugin<uint64_t, uint32_t> {
public:
    FeasignCacheDict(){}
    virtual ~FeasignCacheDict(){}
    virtual int Load(const std::string& path, const std::string& converter);
};

class DictPluginManager {
public:
    DictPluginManager() {}
    virtual ~DictPluginManager(){}

    static DictPluginManager& Instance() {
        static DictPluginManager manager;
        return manager;
    }
    inline int CreateDict(const std::string& dict_name) {
        #define PADDLE_DICT_PLUGIN_REGIST(dict)                  \
        if (dict_name == #dict) {                                \
            dicts_map_[dict_name].reset(new dict());             \
            return 0;                                            \
        }
        
        PADDLE_DICT_PLUGIN_REGIST(FeasignCacheDict)
        #undef PADDLE_DICT_PLUGIN_REGIST
        return -1;
    }
    inline DictPlugin* GetDict(const std::string& dict_name) {
        if (dicts_map_.count(dict_name)) {
            return dicts_map_[dict_name].get();
        }
        return nullptr;
    }
    inline int LoadDict(const std::string& dict_name,
        const std::string& path, const std::string converter) {
        auto dict = GetDict(dict_name);
        if (!dict) {
            return -1;
        }
        return dict->Load(path, converter);
    }
private:
    std::unordered_map<std::string, std::shared_ptr<DictPlugin>> dicts_map_;
};

}  // namespace framework
}  // namespace paddle