rc4_cryption_impl.cpp 7.4 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
/**
 * \file src/decryption/rc4/rc4_cryption_impl.cpp
 *
 * This file is part of MegEngine, a deep learning framework developed by
 * Megvii.
 *
 * \copyright Copyright (c) 2020-2021 Megvii Inc. All rights reserved.
 */

#include "rc4_cryption_impl.h"
#include "../../misc.h"

#include <cstring>

using namespace lite;

/*!
 * \brief Read the input stream once in order to initialize the decryption
 *        state.
 */
void RC4Impl::init_rc4_state() {
    rc4::RC4RandStream enc_stream(m_enc_key);
    rc4::FastHash64 dechash(m_hash_key);

    size_t offset = 0;

    std::vector<uint64_t> buffer(128);
    size_t remaining = m_model_length - sizeof(uint64_t);
    while (remaining > 0) {
        size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
        memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
               toread);
        offset += toread;
        remaining -= toread;

        for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
            uint64_t value = buffer[i];
            value ^= enc_stream.next64();
            dechash.feed(value);
        }
    }

    uint64_t hashvalue;
    memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
           sizeof(hashvalue));
    offset += sizeof(hashvalue);

    hashvalue ^= dechash.get() ^ enc_stream.next64();
    m_state.hash_stream.reset(hashvalue);
    m_state.enc_stream.reset(m_enc_key);
}

std::vector<uint8_t> RC4Impl::decrypt_model() {
    std::vector<uint8_t> result(m_model_length, 0);

    uint8_t* ptr = result.data();
    for (size_t i = 0; i < m_model_length; ++i) {
        ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
        ptr[i] ^= m_state.hash_stream.next8() ^ m_state.enc_stream.next8();
    }
    return result;
}

/*! \brief Encrypt the data in m_buffer.
 *
 * The basic idea is to calculate a 64-bit hash from the buffer and append
 * it to the end of the buffer. The basic requirement is that the change of
 * every byte including the hash value will destroy the whole model in every
 * byte.
 *
 * Encryption:
 *
 * 1. First calculate a 64-bit hash, called plain hash value, from the
 * buffer.
 * 2. Initialize a RC4 stream with the plain hash value.
 * 3. Obfuscate the model body with the RC4 stream defined in step 2.
 * 4. Calculate the hash value of the obfuscated model, called hash value
 *    after hashing.
 * 5. Encrypt the model body with a RC4 stream made from the encryption key.
 * 6. Bit-xor the hash value after hashing with the plain hash value, called
 *    mixed hash.
 * 7. Encrypt the mixed hash with the RC4 stream defined in step 5, called
 * the protected hash.
 * 8. Append the protected hash to the buffer.
 *
 * Decryption:
 * 1. Decrypt the model body with a RC4 stream made from the encryption key,
 *    which is the reverse of step 5 and 7 of encryption and get the mixed
 *    hash.
 * 2. Calculate the hash value of the decrypted model, which equals to the
 *    hash value after hashing in step 4 of encryption.
 * 3. Bit-xor the hash value after hashing and the mixed hash to get the
 * plain hash value, which is the reverse of step 6 of encryption.
 * 4. Un-obfuscate the model body with the plain hash value, which is the
 *    reverse of step 3 of encryption.
 *
 * Think:
 * 1. If any byte in the model body is broken, the hash value after hashing
 *    will be broken in step 2, and hence the plain hash value in step 3
 * will be also broken, and finally, the model body will be broken in
 * step 4.
 * 2. If the protected hash is broken, the plain hash value in step 3 will
 * be broken, and finally the model body will be broken.
 */
std::vector<uint8_t> RC4Impl::encrypt_model() {
    size_t total_length = (m_model_length + (sizeof(size_t) - 1)) /
                          sizeof(size_t) * sizeof(size_t);
    std::vector<uint8_t> pad_model(total_length, 0);
    memcpy(pad_model.data(), m_model_mem, m_model_length);

    // Calculate the hash of the model.
    rc4::FastHash64 plainhash(m_hash_key);
    uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
    size_t len = pad_model.size() / sizeof(uint64_t);

    for (size_t i = 0; i < len; ++i)
        plainhash.feed(ptr[i]);
    uint64_t plainhash_value = plainhash.get();

    // Encrypt the model.
    rc4::RC4RandStream hash_enc(plainhash_value);
    rc4::RC4RandStream outmost_enc(m_enc_key);
    rc4::FastHash64 afterhashenc_hash(m_hash_key);

    for (size_t i = 0; i < len; ++i) {
        uint64_t value = ptr[i] ^ hash_enc.next64();
        afterhashenc_hash.feed(value);
        ptr[i] = value ^ outmost_enc.next64();
    }

    uint64_t protected_hash =
            plainhash_value ^ afterhashenc_hash.get() ^ outmost_enc.next64();

    size_t end = pad_model.size();
    pad_model.resize(pad_model.size() + sizeof(uint64_t));
    ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
    *ptr = protected_hash;
    return pad_model;
}

/*!
 * \brief Read the input stream once in order to initialize the decryption
 *        state.
 */
void SimpleFastRC4Impl::init_sfrc4_state() {
    rc4::RC4RandStream enc_stream(m_enc_key);
    rc4::FastHash64 dechash(m_hash_key);

    size_t offset = 0;
    std::vector<uint64_t> buffer(128);
    size_t remaining = m_model_length - sizeof(uint64_t);
    while (remaining > 0) {
        size_t toread = std::min(remaining, buffer.size() * sizeof(uint64_t));
        memcpy(buffer.data(), static_cast<const uint8_t*>(m_model_mem) + offset,
               toread);
        offset += toread;
        remaining -= toread;

        for (size_t i = 0; i < toread / sizeof(uint64_t); ++i) {
            uint64_t value = buffer[i];
            dechash.feed(value);
        }
    }

    uint64_t hashvalue;
    memcpy(&hashvalue, static_cast<const uint8_t*>(m_model_mem) + offset,
           sizeof(hashvalue));

    offset += sizeof(hashvalue);

    /*! \brief test the hash_val. */
    if (hashvalue != dechash.get())
        LITE_THROW(
                "The checksum of the file cannot be verified. The file may "
                "be encrypted in the wrong algorithm or different keys.");

    m_state.hash_stream.reset(m_hash_key);
    m_state.enc_stream.reset(m_enc_key);
}

std::vector<uint8_t> SimpleFastRC4Impl::decrypt_model() {
    std::vector<uint8_t> result(m_model_length, 0);
    uint8_t* ptr = result.data();
    for (size_t i = 0; i < m_model_length; ++i) {
        ptr[i] = static_cast<const uint8_t*>(m_model_mem)[i];
        ptr[i] ^= m_state.enc_stream.next8();
    }
    return result;
}

std::vector<uint8_t> SimpleFastRC4Impl::encrypt_model() {
    size_t total_length = (m_model_length + (sizeof(size_t) - 1)) /
                          sizeof(size_t) * sizeof(size_t);
    std::vector<uint8_t> pad_model(total_length, 0);
    memcpy(pad_model.data(), m_model_mem, m_model_length);

    // Calculate the hash of the model.
    rc4::FastHash64 enchash(m_hash_key);
    uint64_t* ptr = reinterpret_cast<uint64_t*>(pad_model.data());
    size_t len = pad_model.size() / sizeof(uint64_t);

    // Encrypt the model.
    rc4::RC4RandStream out_enc(m_enc_key);
    for (size_t i = 0; i < len; ++i) {
        ptr[i] = ptr[i] ^ out_enc.next64();
        enchash.feed(ptr[i]);
    }

    uint64_t hash_value = enchash.get();

    size_t end = pad_model.size();
    pad_model.resize(pad_model.size() + sizeof(uint64_t));
    ptr = reinterpret_cast<uint64_t*>(&pad_model[end]);
    *ptr = hash_value;

    return pad_model;
}

// vim: syntax=cpp.doxygen foldmethod=marker foldmarker=f{{{,f}}}