param_val.cpp 17.0 KB
Newer Older
1 2
#include "megbrain/common.h"

3 4 5 6
#if MGB_CUSTOM_OP

#include "megbrain/custom/param_val.h"

7 8 9 10 11 12 13 14 15 16
#pragma GCC diagnostic ignored "-Wsign-compare"

using namespace mgb;

namespace custom {

/**
 * Macro Callback for Case
 */

M
Megvii Engine Team 已提交
17 18 19 20 21 22 23
#define CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS(dyn_type, static_type)        \
    case (ParamDynType::dyn_type): {                                     \
        std::unique_ptr<void, void_deleter> new_ptr(                     \
                new static_type(TypedRef(static_type, rhs.m_ptr.get())), \
                impl_deleter<static_type>);                              \
        m_ptr.swap(new_ptr);                                             \
        break;                                                           \
24 25
    }

M
Megvii Engine Team 已提交
26 27 28 29
#define CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS(dyn_type, static_type)                   \
    case (ParamDynType::dyn_type): {                                                 \
        TypedRef(static_type, m_ptr.get()) = TypedRef(static_type, rhs.m_ptr.get()); \
        break;                                                                       \
30 31
    }

M
Megvii Engine Team 已提交
32 33 34 35
#define CUSTOM_ASSERT_OPERAND_VALID(operand, opr)                                \
    mgb_assert(                                                                  \
            operand.m_ptr != nullptr && operand.m_type != ParamDynType::Invalid, \
            "invalid %s of operator %s of ParamVal", #operand, #opr)
36

M
Megvii Engine Team 已提交
37 38 39 40
#define CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op)                       \
    mgb_assert(                                                      \
            lhs.m_type == rhs.m_type, "`%s` %s `%s` is not allowed", \
            type2name[lhs.m_type].c_str(), #op, type2name[rhs.m_type].c_str())
41

M
Megvii Engine Team 已提交
42 43 44 45
#define CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL(dyn_type, static_type, op) \
    case (ParamDynType::dyn_type): {                                        \
        const auto& rval = TypedRef(static_type, rhs.m_ptr.get());          \
        return lval op rval;                                                \
46 47
    }

M
Megvii Engine Team 已提交
48 49 50 51 52 53 54 55 56 57
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC(dyn_type, static_type, op) \
    case (ParamDynType::dyn_type): {                                      \
        const auto& lval = TypedRef(static_type, lhs.m_ptr.get());        \
        switch (rhs.m_type) {                                             \
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE_COPY(                         \
                    CUSTOM_CASE_TO_GET_BINARY_OP_RHS_AND_CAL, op)         \
            default:                                                      \
                CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op);                   \
        }                                                                 \
        break;                                                            \
58 59
    }

M
Megvii Engine Team 已提交
60 61 62 63 64 65
#define CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC(dyn_type, static_type, op) \
    case (ParamDynType::dyn_type): {                                         \
        CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op);                              \
        const auto& lval = TypedRef(static_type, lhs.m_ptr.get());           \
        const auto& rval = TypedRef(static_type, rhs.m_ptr.get());           \
        return lval op rval;                                                 \
66 67
    }

M
Megvii Engine Team 已提交
68 69 70 71 72 73 74 75 76 77 78 79
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(op, ret_type)              \
    ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) { \
        CUSTOM_ASSERT_OPERAND_VALID(lhs, op);                        \
        CUSTOM_ASSERT_OPERAND_VALID(rhs, op);                        \
                                                                     \
        switch (lhs.m_type) {                                        \
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE(                         \
                    CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op)      \
            default:                                                 \
                CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op);              \
        }                                                            \
        return {};                                                   \
80 81
    }

M
Megvii Engine Team 已提交
82 83 84 85 86 87 88 89 90 91 92 93 94
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(op, ret_type)                     \
    ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) {                   \
        CUSTOM_ASSERT_OPERAND_VALID(lhs, op);                                          \
        CUSTOM_ASSERT_OPERAND_VALID(rhs, op);                                          \
                                                                                       \
        switch (lhs.m_type) {                                                          \
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE(                                           \
                    CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op)                        \
            CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \
            default:                                                                   \
                CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op);                                \
        }                                                                              \
        return {};                                                                     \
95 96
    }

M
Megvii Engine Team 已提交
97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
#define CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(op, ret_type)            \
    ret_type operator op(const ParamVal& lhs, const ParamVal& rhs) {                   \
        CUSTOM_ASSERT_OPERAND_VALID(lhs, op);                                          \
        CUSTOM_ASSERT_OPERAND_VALID(rhs, op);                                          \
                                                                                       \
        switch (lhs.m_type) {                                                          \
            CUSTOM_FOR_EACH_BASIC_PARAMTYPE(                                           \
                    CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_BASIC, op)                        \
            CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op) \
            CUSTOM_FOR_EACH_LIST_PARAMTYPE(                                            \
                    CUSTOM_CASE_TO_CAL_BINARY_OP_FOR_NONBASIC, op)                     \
            default:                                                                   \
                CUSTOM_INVALID_EXPR_EXCP(lhs, rhs, op);                                \
        }                                                                              \
        return {};                                                                     \
112 113
    }

M
Megvii Engine Team 已提交
114 115 116 117 118
#define CUSTOM_CASE_TO_PRINT_NONLIST(dyn_type, static_type) \
    case (ParamDynType::dyn_type): {                        \
        auto rval = TypedRef(static_type, m_ptr.get());     \
        ss << rval;                                         \
        break;                                              \
119 120
    }

M
Megvii Engine Team 已提交
121 122 123 124 125
#define CUSTOM_CASE_TO_PRINT_LIST(dyn_type, static_type) \
    case (ParamDynType::dyn_type): {                     \
        auto rval = TypedRef(static_type, m_ptr.get());  \
        ss << vec2str(rval);                             \
        break;                                           \
126 127
    }

M
Megvii Engine Team 已提交
128 129 130 131
#define CUSTOM_CASE_TO_RET_SIZE(dyn_type, static_type)    \
    case (ParamDynType::dyn_type): {                      \
        return TypedRef(static_type, m_ptr.get()).size(); \
        break;                                            \
132 133 134 135 136 137 138 139 140 141
    }

#define CUSTOM_CASE_TO_DUMP_BASIC(dyn_type, static_type)                            \
    case (ParamDynType::dyn_type): {                                                \
        res.resize(sizeof(ParamDynType) + sizeof(static_type));                     \
        memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));                     \
        memcpy(&res[sizeof(ParamDynType)], value.m_ptr.get(), sizeof(static_type)); \
        break;                                                                      \
    }

M
Megvii Engine Team 已提交
142 143 144 145 146 147 148 149 150 151
#define CUSTOM_CASE_TO_DUMP_LIST(dyn_type, static_type)                               \
    case (ParamDynType::dyn_type): {                                                  \
        auto& ref = TypedRef(static_type, value.m_ptr.get());                         \
        size_t len = ref.size();                                                      \
        size_t elem_size = len != 0 ? sizeof(ref[0]) : 0;                             \
        res.resize(sizeof(ParamDynType) + sizeof(len) + len * elem_size);             \
        memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));                       \
        memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len));                        \
        memcpy(&res[sizeof(ParamDynType) + sizeof(len)], ref.data(), len* elem_size); \
        break;                                                                        \
152 153
    }

M
Megvii Engine Team 已提交
154 155 156 157 158 159 160
#define CUSTOM_CASE_TO_LOAD_BASIC(dyn_type, static_type) \
    case (ParamDynType::dyn_type): {                     \
        static_type val;                                 \
        memcpy(&val, &bytes[offset], sizeof(val));       \
        offset += sizeof(val);                           \
        return val;                                      \
        break;                                           \
161 162
    }

M
Megvii Engine Team 已提交
163 164 165 166 167 168 169 170 171 172 173 174
#define CUSTOM_CASE_TO_LOAD_LIST(dyn_type, static_type)    \
    case (ParamDynType::dyn_type): {                       \
        size_t len = 0;                                    \
        memcpy(&len, &bytes[offset], sizeof(len));         \
        offset += sizeof(len);                             \
        static_type vals;                                  \
        vals.resize(len);                                  \
        size_t elem_size = len != 0 ? sizeof(vals[0]) : 0; \
        memcpy(&vals[0], &bytes[offset], len* elem_size);  \
        offset += len * elem_size;                         \
        return vals;                                       \
        break;                                             \
175 176
    }

M
Megvii Engine Team 已提交
177
ParamVal::ParamVal() : m_ptr(nullptr, [](void*) -> void {}) {
178 179 180
    m_type = ParamDynType::Invalid;
}

M
Megvii Engine Team 已提交
181
ParamVal::ParamVal(const char* str) : ParamVal(std::string(str)) {}
182

M
Megvii Engine Team 已提交
183 184
ParamVal::ParamVal(const std::initializer_list<const char*>& strs)
        : ParamVal(std::vector<const char*>(strs)) {}
185

M
Megvii Engine Team 已提交
186 187 188
ParamVal::ParamVal(const std::vector<const char*>& strs)
        : m_ptr(new std::vector<std::string>(),
                impl_deleter<std::vector<std::string>>) {
189
    m_type = ParamDynType::StringList;
M
Megvii Engine Team 已提交
190
    for (const auto& str : strs) {
191 192 193 194
        TypedRef(std::vector<std::string>, m_ptr.get()).emplace_back(str);
    }
}

M
Megvii Engine Team 已提交
195
ParamVal::ParamVal(const ParamVal& rhs) : m_ptr(nullptr, [](void*) -> void {}) {
196
    mgb_assert(
M
Megvii Engine Team 已提交
197 198
            rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr,
            "invalid rhs of copy constructor of ParamVal");
199
    m_type = rhs.m_type;
M
Megvii Engine Team 已提交
200
    switch (m_type) {
201 202 203 204 205 206 207
        CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS)
        default: {
            mgb_assert(false, "invalid rhs of copy constructor of ParamVal");
        }
    }
}

M
Megvii Engine Team 已提交
208
ParamVal& ParamVal::operator=(const char* str) {
209 210 211 212
    this->operator=(std::string(str));
    return *this;
}

M
Megvii Engine Team 已提交
213
ParamVal& ParamVal::operator=(const std::initializer_list<const char*>& strs) {
214 215 216 217
    this->operator=(std::vector<const char*>(strs));
    return *this;
}

M
Megvii Engine Team 已提交
218
ParamVal& ParamVal::operator=(const std::vector<const char*>& strs) {
219
    std::vector<std::string> tmp_strs;
M
Megvii Engine Team 已提交
220
    for (const auto& str : strs) {
221 222 223 224 225 226
        tmp_strs.emplace_back(str);
    }
    this->operator=(tmp_strs);
    return *this;
}

M
Megvii Engine Team 已提交
227
ParamVal& ParamVal::operator=(const ParamVal& rhs) {
228 229 230
    if (&rhs == this)
        return *this;
    mgb_assert(
M
Megvii Engine Team 已提交
231 232 233
            rhs.m_type != ParamDynType::Invalid && rhs.m_ptr != nullptr,
            "invalid rhs of assignment operator of ParamVal");

234
    if (rhs.m_type == m_type) {
M
Megvii Engine Team 已提交
235
        switch (m_type) {
236 237 238 239
            CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ASSIGN_ACCORD_TO_RHS);
            default:
                mgb_assert(false, "invalid rhs of assignment operator of ParamVal");
        }
M
Megvii Engine Team 已提交
240
    } else {
241
        m_type = rhs.m_type;
M
Megvii Engine Team 已提交
242
        switch (m_type) {
243 244 245 246 247 248 249 250
            CUSTOM_FOR_EACH_VALID_PARAMTYPE(CUSTOM_CASE_TO_ALLOC_ACCORD_TO_RHS);
            default:
                mgb_assert(false, "invalid rhs of assignment operator of ParamVal");
        }
    }
    return *this;
}

M
Megvii Engine Team 已提交
251
const void* ParamVal::raw_ptr(void) const {
252 253 254
    return m_ptr.get();
}

M
Megvii Engine Team 已提交
255
void* ParamVal::raw_ptr(void) {
256 257 258 259 260 261 262 263 264
    return m_ptr.get();
}

ParamDynType ParamVal::type(void) const {
    return m_type;
}

std::string ParamVal::str() const {
    std::stringstream ss;
M
Megvii Engine Team 已提交
265 266
    ss << "type: " << type2name[m_type] << "\n"
       << "value: ";
267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285
    switch (m_type) {
        CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
        CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_PRINT_NONLIST)
        CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_PRINT_LIST)
        default:
            mgb_assert(false, "invalid data of assignment operator of ParamVal");
    }
    return ss.str();
}

size_t ParamVal::size(void) const {
    switch (m_type) {
        CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE)
        CUSTOM_FOR_EACH_LIST_PARAMTYPE(CUSTOM_CASE_TO_RET_SIZE)
        default:
            mgb_assert(false, "there is no size() for basic data types");
    }
}

M
Megvii Engine Team 已提交
286
std::string ParamVal::to_bytes(const ParamVal& value) {
287 288 289
    std::string res;
    // because the specialization of std::vector<bool>
    if (value.type() == ParamDynType::BoolList) {
M
Megvii Engine Team 已提交
290
        std::vector<bool>& ref = TypedRef(std::vector<bool>, value.m_ptr.get());
291 292
        size_t len = ref.size();
        size_t elem_size = sizeof(bool);
M
Megvii Engine Team 已提交
293
        res.resize(sizeof(ParamDynType) + sizeof(len) + len * elem_size);
294 295
        memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));
        memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len));
M
Megvii Engine Team 已提交
296 297
        size_t startpos = sizeof(ParamDynType) + sizeof(len);
        for (size_t idx = 0; idx < len; idx++) {
298
            bool b = ref[idx];
M
Megvii Engine Team 已提交
299
            memcpy(&res[startpos + idx * sizeof(b)], &b, sizeof(b));
300 301
        }
        return res;
M
Megvii Engine Team 已提交
302 303 304
    } else if (value.type() == ParamDynType::StringList) {
        std::vector<std::string>& ref =
                TypedRef(std::vector<std::string>, value.m_ptr.get());
305 306 307 308
        size_t len = ref.size();
        res.resize(sizeof(ParamDynType) + sizeof(len));
        memcpy(&res[0], &(value.m_type), sizeof(ParamDynType));
        memcpy(&res[sizeof(ParamDynType)], &len, sizeof(len));
M
Megvii Engine Team 已提交
309
        for (size_t idx = 0; idx < ref.size(); ++idx) {
310 311 312 313 314 315 316 317
            size_t str_len = ref[idx].size();
            std::string bytes(sizeof(str_len) + str_len, ' ');
            memcpy(&bytes[0], &str_len, sizeof(str_len));
            memcpy(&bytes[sizeof(str_len)], ref[idx].data(), str_len);
            res += bytes;
        }
        return res;
    }
M
Megvii Engine Team 已提交
318
    switch (value.type()) {
319 320 321 322 323 324 325 326 327
        CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_DUMP_BASIC)
        CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST)
        CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_DUMP_LIST)
        default:
            mgb_assert(false, "invalid param type");
    }
    return res;
}

M
Megvii Engine Team 已提交
328
ParamVal ParamVal::from_bytes(const std::string& bytes, size_t& offset) {
329 330 331 332 333 334 335 336
    ParamDynType data_type = ParamDynType::Invalid;
    memcpy(&data_type, &bytes[offset], sizeof(ParamDynType));
    offset += sizeof(ParamDynType);
    if (data_type == ParamDynType::BoolList) {
        std::vector<bool> ret;
        size_t len = 0;
        memcpy(&len, &bytes[offset], sizeof(len));
        offset += sizeof(len);
M
Megvii Engine Team 已提交
337
        for (size_t idx = 0; idx < len; ++idx) {
338 339 340 341 342 343
            bool b = true;
            memcpy(&b, &bytes[offset], sizeof(bool));
            offset += sizeof(bool);
            ret.push_back(b);
        }
        return ret;
M
Megvii Engine Team 已提交
344
    } else if (data_type == ParamDynType::StringList) {
345 346 347 348
        std::vector<std::string> ret;
        size_t len = 0;
        memcpy(&len, &bytes[offset], sizeof(len));
        offset += sizeof(len);
M
Megvii Engine Team 已提交
349
        for (size_t idx = 0; idx < len; ++idx) {
350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381
            size_t str_len = 0;
            memcpy(&str_len, &bytes[offset], sizeof(str_len));
            offset += sizeof(str_len);
            std::string str(str_len, ' ');
            memcpy(&str[0], &bytes[offset], str_len);
            offset += str_len;
            ret.push_back(str);
        }
        return ret;
    }

    switch (data_type) {
        CUSTOM_FOR_EACH_BASIC_PARAMTYPE(CUSTOM_CASE_TO_LOAD_BASIC)
        CUSTOM_FOR_STRING_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST)
        CUSTOM_FOR_EACH_BASIC_LIST_PARAMTYPE(CUSTOM_CASE_TO_LOAD_LIST);
        default:
            mgb_assert(false, "invalid param type");
    }
    return {};
}

CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING(+, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(-, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(*, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC(/, ParamVal)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(==, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(!=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<=, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(>, bool)
CUSTOM_DEFINE_BINARY_OP_FOR_BASIC_AND_STRING_AND_LIST(<, bool)

M
Megvii Engine Team 已提交
382
}  // namespace custom
383 384

#endif