manager.cpp 2.4 KB
Newer Older
1 2 3 4
#include "megbrain_build_config.h"

#if MGB_CUSTOM_OP

5
#include "gtest/gtest.h"
M
Megvii Engine Team 已提交
6 7
#include "megbrain/custom/custom.h"
#include "megbrain/custom/manager.h"
8 9 10 11 12 13

#define MANAGER_TEST_LOG 0

namespace custom {

TEST(TestOpManager, TestOpManager) {
M
Megvii Engine Team 已提交
14
    CustomOpManager* com = CustomOpManager::inst();
15 16 17
    std::vector<std::string> builtin_op_names = com->op_name_list();
    size_t builtin_op_num = builtin_op_names.size();

18 19 20 21 22 23
    com->insert("Op1", CUSTOM_OP_VERSION);
    com->insert("Op2", CUSTOM_OP_VERSION);

    std::vector<std::string> op_names = com->op_name_list();
    std::vector<RunTimeId> op_ids = com->op_id_list();

24 25
    ASSERT_TRUE(op_names.size() == builtin_op_num + 2);
    ASSERT_TRUE(op_ids.size() == builtin_op_num + 2);
26 27

#if MANAGER_TEST_LOG
M
Megvii Engine Team 已提交
28
    for (std::string& name : op_names) {
29 30 31 32
        std::cout << name << std::endl;
    }
#endif

M
Megvii Engine Team 已提交
33
    for (std::string& name : op_names) {
34 35 36 37 38 39 40
        std::shared_ptr<const CustomOp> op = com->find(name);
        ASSERT_TRUE(op != nullptr);
        ASSERT_TRUE(op->op_type() == name);
        RunTimeId id = com->to_id(name);
        ASSERT_TRUE(com->find(id) == op);
    }

M
Megvii Engine Team 已提交
41
    for (RunTimeId& id : op_ids) {
42 43 44 45 46 47 48 49 50
        std::shared_ptr<const CustomOp> op = com->find(id);
        ASSERT_TRUE(op != nullptr);
        ASSERT_TRUE(op->runtime_id() == id);
        std::string name = com->to_name(id);
        ASSERT_TRUE(com->find(name) == op);
    }

    ASSERT_FALSE(com->erase("Op0"));
#if MANAGER_TEST_LOG
M
Megvii Engine Team 已提交
51
    for (auto& name : com->op_name_list()) {
52 53 54 55
        std::cout << name << std::endl;
    }
#endif
    ASSERT_TRUE(com->erase("Op1"));
56 57 58
    ASSERT_TRUE(com->op_id_list().size() == builtin_op_num + 1);
    ASSERT_TRUE(com->op_name_list().size() == builtin_op_num + 1);
    ASSERT_TRUE(com->erase("Op2"));
59 60 61 62
}

TEST(TestOpManager, TestOpReg) {
    CUSTOM_OP_REG(Op1)
M
Megvii Engine Team 已提交
63 64 65 66 67 68
            .add_inputs(2)
            .add_outputs(3)
            .add_input("lhs")
            .add_param("param1", 1)
            .add_param("param2", 3.45);

69
    CUSTOM_OP_REG(Op2)
M
Megvii Engine Team 已提交
70 71 72 73 74 75
            .add_input("lhs")
            .add_input("rhs")
            .add_output("out")
            .add_param("param1", "test")
            .add_param("param2", true)
            .add_param("", "no name");
76 77 78

    (void)_Op1;
    (void)_Op2;
M
Megvii Engine Team 已提交
79

80
#if MANAGER_TEST_LOG
M
Megvii Engine Team 已提交
81
    for (const auto& name : CustomOpManager::inst()->op_name_list()) {
82 83 84 85 86
        std::cout << CustomOpManager::inst()->find(name)->str() << std::endl;
    }
#endif
}

M
Megvii Engine Team 已提交
87
}  // namespace custom
88 89

#endif