factory.h 5.2 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/***************************************************************************
 * 
 * Copyright (c) 2018 Baidu.com, Inc. All Rights Reserved
 * 
 **************************************************************************/
 
/**
 * @file include/factory.h
 * @author wanlijin01(wanlijin01@baidu.com)
 * @date 2018/07/10 22:09:57
 * @brief 
 *  
 **/

#ifndef BAIDU_PADDLE_SERVING_PREDICTOR_FACTORY_H
#define BAIDU_PADDLE_SERVING_PREDICTOR_FACTORY_H

#include "common/inner_common.h"
W
serving  
wangguibao 已提交
19
#include "glog/raw_logging.h"
W
wangguibao 已提交
20 21 22 23 24 25 26 27 28 29 30 31
namespace baidu {
namespace paddle_serving {
namespace predictor {

//////////////// DECLARE INTERFACE ////////////////
#define DECLARE_FACTORY_OBJECT(D, B)            \
    static int regist(const std::string& tag) { \
        FactoryDerive<D, B>* factory =                \
                new (std::nothrow) FactoryDerive<D, B>();\
        if (factory == NULL                     \
                || FactoryPool<B>::instance().register_factory(\
                    tag, factory) != 0) {       \
W
serving  
wangguibao 已提交
32
            RAW_LOG_FATAL("Failed regist factory: %s in macro!", #D); \
W
wangguibao 已提交
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
            return -1;                          \
        }                                       \
        return 0;                               \
    }

#define PDS_STR_CAT(a, b) PDS_STR_CAT_I(a, b)
#define PDS_STR_CAT_I(a, b) a ## b

#define DEFINE_FACTORY_OBJECT(D)                \
__attribute__((constructor)) static void PDS_STR_CAT(GlobalRegistObject, __LINE__)(void)    \
{                                               \
    D::regist(#D);                              \
}

//////////////// REGISTER INTERFACE ////////////////

#define REGIST_FACTORY_OBJECT_IMPL(D, B)        \
__attribute__((constructor)) static void PDS_STR_CAT(GlobalRegistObject, __LINE__)(void)    \
{                                               \
    ::baidu::paddle_serving::predictor::FactoryDerive<D, B>* factory =\
            new (::std::nothrow) ::baidu::paddle_serving::predictor::FactoryDerive<D, B>();\
    if (factory == NULL                         \
            || ::baidu::paddle_serving::predictor::FactoryPool<B>::instance().register_factory(\
                #D, factory) != 0) {            \
W
serving  
wangguibao 已提交
57
        RAW_LOG_FATAL("Failed regist factory: %s->%s in macro!", #D, #B);  \
W
wangguibao 已提交
58 59 60 61 62 63 64 65 66 67 68 69 70
        return ;                                \
    }                                           \
    return ;                                    \
}

#define REGIST_FACTORY_OBJECT_IMPL_WITH_NAME(D, B, N) \
__attribute__((constructor)) static void PDS_STR_CAT(GlobalRegistObject, __LINE__)(void)    \
{                                               \
    ::baidu::paddle_serving::predictor::FactoryDerive<D, B>* factory =\
            new (::std::nothrow) ::baidu::paddle_serving::predictor::FactoryDerive<D, B>();\
    if (factory == NULL                         \
            || ::baidu::paddle_serving::predictor::FactoryPool<B>::instance().register_factory(\
                N, factory) != 0) {             \
W
serving  
wangguibao 已提交
71
        RAW_LOG_FATAL("Failed regist factory: %s->%s, tag: %s in macro!", #D, #B, N);  \
W
wangguibao 已提交
72 73
        return ;                                \
    }                                           \
W
serving  
wangguibao 已提交
74
    RAW_LOG_WARNING("Succ regist factory: %s->%s, tag: %s in macro!", #D, #B, N);      \
W
wangguibao 已提交
75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109
    return ;                                    \
}

template<typename B>
class FactoryBase {
public:
    virtual B* gen() = 0;
    virtual void del(B* obj) = 0;
};

template<typename D, typename B>
class FactoryDerive : public FactoryBase<B> {
public:
    B* gen() {
        return new(std::nothrow) D();
    }

    void del(B* obj) {
        delete dynamic_cast<D*>(obj);
    }
};

template<typename B>
class FactoryPool {
public:
    static FactoryPool<B>& instance() {
        static FactoryPool<B> singleton;
        return singleton;
    }

    int register_factory(const std::string& tag,
            FactoryBase<B>* factory) {
        typename std::map<std::string, FactoryBase<B>*>::iterator it 
            = _pool.find(tag);
        if (it != _pool.end()) {
W
serving  
wangguibao 已提交
110
            RAW_LOG_FATAL("Insert duplicate with tag: %s", tag.c_str());
W
wangguibao 已提交
111 112 113 114 115 116 117
            return -1;
        }

        std::pair<
            typename std::map<std::string, FactoryBase<B>*>::iterator, 
            bool> r = _pool.insert(std::make_pair(tag, factory));
        if (!r.second) {
W
serving  
wangguibao 已提交
118
            RAW_LOG_FATAL("Failed insert new factory with: %s", tag.c_str());
W
wangguibao 已提交
119 120 121
            return -1;
        }

W
serving  
wangguibao 已提交
122
        RAW_LOG_INFO("Succ insert one factory, tag: %s, base type %s", tag.c_str(), typeid(B).name());
W
wangguibao 已提交
123 124 125 126 127 128 129 130

        return 0;
    }

    B* generate_object(const std::string& tag) {
        typename std::map<std::string, FactoryBase<B>*>::iterator it 
            = _pool.find(tag);
        if (it == _pool.end() || it->second == NULL) {
W
serving  
wangguibao 已提交
131
            RAW_LOG_FATAL("Not found factory pool, tag: %s, pool size %u", tag.c_str(), _pool.size());
W
wangguibao 已提交
132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154
            return NULL;
        }

        return it->second->gen();
    }

    template<typename D>
    void return_object(B* object) {
        FactoryDerive<D, B> factory;
        factory.del(object);
    }

private:
    std::map<std::string, FactoryBase<B>*> _pool;
};

} // predictor
} // paddle_serving
} // baidu

#endif  //BAIDU_PADDLE_SERVING_PREDICTOR_FACTORY_H

/* vim: set expandtab ts=4 sw=4 sts=4 tw=100: */