aes.cc 3.9 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "aes.h"

J
jhjiangcs 已提交
17
#ifdef USE_AES_NI
J
jingqinghe 已提交
18
#include <wmmintrin.h>
J
jhjiangcs 已提交
19
#endif
J
jingqinghe 已提交
20 21 22

namespace psi {

J
jhjiangcs 已提交
23
#ifdef USE_AES_NI
J
jingqinghe 已提交
24
static block aes128_key_expansion(block key, block key_rcon) {
J
jhjiangcs 已提交
25 26 27 28 29
    key_rcon = _mm_shuffle_epi32(key_rcon, _MM_SHUFFLE(3, 3, 3, 3));
    key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
    key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
    key = _mm_xor_si128(key, _mm_slli_si128(key, 4));
    return _mm_xor_si128(key, key_rcon);
J
jingqinghe 已提交
30 31
}

J
jhjiangcs 已提交
32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54
void AES::set_key(const block& user_key) {
    _round_key[0] = user_key;
    _round_key[1] = aes128_key_expansion(
        _round_key[0], _mm_aeskeygenassist_si128(_round_key[0], 0x01));
    _round_key[2] = aes128_key_expansion(
        _round_key[1], _mm_aeskeygenassist_si128(_round_key[1], 0x02));
    _round_key[3] = aes128_key_expansion(
        _round_key[2], _mm_aeskeygenassist_si128(_round_key[2], 0x04));
    _round_key[4] = aes128_key_expansion(
        _round_key[3], _mm_aeskeygenassist_si128(_round_key[3], 0x08));
    _round_key[5] = aes128_key_expansion(
        _round_key[4], _mm_aeskeygenassist_si128(_round_key[4], 0x10));
    _round_key[6] = aes128_key_expansion(
        _round_key[5], _mm_aeskeygenassist_si128(_round_key[5], 0x20));
    _round_key[7] = aes128_key_expansion(
        _round_key[6], _mm_aeskeygenassist_si128(_round_key[6], 0x40));
    _round_key[8] = aes128_key_expansion(
        _round_key[7], _mm_aeskeygenassist_si128(_round_key[7], 0x80));
    _round_key[9] = aes128_key_expansion(
        _round_key[8], _mm_aeskeygenassist_si128(_round_key[8], 0x1B));
    _round_key[10] = aes128_key_expansion(
        _round_key[9], _mm_aeskeygenassist_si128(_round_key[9], 0x36));
}
J
jingqinghe 已提交
55

J
jhjiangcs 已提交
56 57 58 59 60 61 62 63 64 65 66 67
void AES::ecb_enc_block(const block& plaintext, block& cyphertext) const {
    cyphertext = _mm_xor_si128(plaintext, _round_key[0]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[1]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[2]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[3]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[4]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[5]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[6]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[7]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[8]);
    cyphertext = _mm_aesenc_si128(cyphertext, _round_key[9]);
    cyphertext = _mm_aesenclast_si128(cyphertext, _round_key[10]);
J
jingqinghe 已提交
68 69
}

J
jhjiangcs 已提交
70 71 72 73 74 75
#else
// openssl aes
void AES::set_key(const block& user_key) {
    // sizeof block = 128 bit
    AES_set_encrypt_key(reinterpret_cast<const unsigned char*>(&user_key),
                        128, &_aes_key);
J
jingqinghe 已提交
76 77
}

J
jhjiangcs 已提交
78 79 80 81
void AES::ecb_enc_block(const block& plaintext, block& cyphertext) const {
    AES_encrypt(reinterpret_cast<const unsigned char*>(&plaintext),
                reinterpret_cast<unsigned char*>(&cyphertext),
                &_aes_key);
J
jingqinghe 已提交
82
}
J
jhjiangcs 已提交
83
#endif
J
jingqinghe 已提交
84

J
jhjiangcs 已提交
85 86
void AES::ecb_enc_blocks(const block* plaintexts, size_t block_num,
                         block* cyphertext) const {
J
jingqinghe 已提交
87

J
jhjiangcs 已提交
88 89 90 91 92 93
#pragma omp parallel num_threads(4)
#pragma omp for
    for (size_t i = 0; i < block_num; ++i) {
        ecb_enc_block(plaintexts[i], cyphertext[i]);
    }
}
J
jingqinghe 已提交
94

J
jhjiangcs 已提交
95
AES::AES(const block& user_key) { set_key(user_key); }
J
jingqinghe 已提交
96

J
jhjiangcs 已提交
97 98 99 100
block AES::ecb_enc_block(const block& plaintext) const {
    block ret;
    ecb_enc_block(plaintext, ret);
    return ret;
J
jingqinghe 已提交
101
}
J
jhjiangcs 已提交
102

J
jingqinghe 已提交
103
} // namespace psi