aes_decrypt.h 1.8 KB
Newer Older
1 2
/**
 * \file src/decryption/aes_decrypt.h
3
 * MegEngine is Licensed under the Apache License, Version 2.0 (the "License")
4
 *
5
 * Copyright (c) 2014-2021 Megvii Inc. All rights reserved.
6
 *
7 8 9
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT ARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10 11 12 13 14 15 16 17 18
 */

#include "./mbedtls/aes.h"
#include "decrypt_base.h"

namespace lite {

class AESDcryption {
public:
M
Megvii Engine Team 已提交
19 20
    static std::vector<uint8_t> decrypt_model(
            const void* model_mem, size_t size, const std::vector<uint8_t>& key) {
21 22 23 24 25 26 27 28 29 30 31 32 33 34 35
        mbedtls_aes_context ctx;
        mbedtls_aes_init(&ctx);
        mbedtls_aes_setkey_dec(&ctx, key.data(), 256);

        auto data = static_cast<const uint8_t*>(model_mem);
        //! first 16 bytes is IV
        uint8_t iv[16];
        //! last 8 bytes is file size(length)
        auto length_ptr = data + size - 8;
        size_t length = 0;
        for (int i = 0; i < 8; i++) {
            length |= length_ptr[i] << (8 * (7 - i));
        }
        std::copy(data, data + 16, iv);
        auto output = std::vector<uint8_t>(size - 24);
M
Megvii Engine Team 已提交
36 37
        mbedtls_aes_crypt_cbc(
                &ctx, MBEDTLS_AES_DECRYPT, size - 24, iv, data + 16, output.data());
38 39 40 41 42 43
        mbedtls_aes_free(&ctx);
        output.erase(output.begin() + length, output.end());
        return output;
    }

    static std::vector<uint8_t> get_decrypt_key() {
44 45 46 47
        std::vector<uint8_t> key(32);
        key = {0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0A,
               0x0B, 0x0C, 0x0D, 0x0E, 0x0F, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15,
               0x16, 0x17, 0x18, 0x19, 0x1A, 0x1B, 0x1C, 0x1D, 0x1E, 0x1F};
48 49 50 51 52 53
        return key;
    }
};
}  // namespace lite

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}