op_repository.h 1.5 KB
Newer Older
W
wangguibao 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70
#ifndef BAIDU_PADDLE_SERVING_PREDICTOR_OP_REPOSITORY_H
#define BAIDU_PADDLE_SERVING_PREDICTOR_OP_REPOSITORY_H

#include "common/inner_common.h"

namespace baidu {
namespace paddle_serving {
namespace predictor {

#define REGISTER_OP(op)                                                 \
    ::baidu::paddle_serving::predictor::OpRepository::instance().regist_op<op>(#op)

class Op;

class Factory {
public:
    virtual Op* get_op() = 0;
    virtual void return_op(Op* op) = 0;
};

template<typename OP_TYPE>
class OpFactory : public Factory {
public:
    Op* get_op() {
        return base::get_object<OP_TYPE>();
    }

    void return_op(Op* op) {
        base::return_object<OP_TYPE>(dynamic_cast<OP_TYPE*>(op));
    }

    static OpFactory<OP_TYPE>& instance() {
        static OpFactory<OP_TYPE> ins; 
        return ins;
    }
};

class OpRepository {
public:
    typedef boost::unordered_map<std::string, Factory*> ManagerMap;

    OpRepository() {}
    ~OpRepository() {}

    static OpRepository& instance() {
        static OpRepository repo;
        return repo;
    }

    template<typename OP_TYPE>
    void regist_op(std::string op_type) {
        _repository[op_type] = &OpFactory<OP_TYPE>::instance();
        LOG(TRACE) << "Succ regist op: " << op_type << "!";
    }

    Op* get_op(std::string op_type);

    void return_op(Op* op);

    void return_op(const std::string& op_type, Op* op);

private:
    ManagerMap _repository;
};

} // predictor
} // paddle_serving
} // baidu

#endif