From 7150cc79651c9dbffc368aa95d826909a0186ca5 Mon Sep 17 00:00:00 2001 From: yangqingyou Date: Fri, 4 Sep 2020 03:53:38 +0000 Subject: [PATCH] remove redundancy of crypto module --- CMakeLists.txt | 2 +- core/privc/CMakeLists.txt | 1 - core/privc/crypto.cc | 258 --------------------------------- core/privc/crypto.h | 93 ------------ core/privc/crypto_test.cc | 58 -------- core/privc/triplet_generator.h | 3 +- 6 files changed, 3 insertions(+), 412 deletions(-) delete mode 100644 core/privc/crypto.cc diff --git a/CMakeLists.txt b/CMakeLists.txt index 7611386..0e9f3a7 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -134,4 +134,4 @@ install(TARGETS paddle_enc mpc_data_utils if (WITH_PSI) install(TARGETS psi LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}) -endif() \ No newline at end of file +endif() diff --git a/core/privc/CMakeLists.txt b/core/privc/CMakeLists.txt index e641eb2..b1d10f1 100644 --- a/core/privc/CMakeLists.txt +++ b/core/privc/CMakeLists.txt @@ -1,5 +1,4 @@ set(PRIVC_SRCS - "crypto.cc" "privc_context.cc" ) diff --git a/core/privc/crypto.cc b/core/privc/crypto.cc deleted file mode 100644 index 477beaf..0000000 --- a/core/privc/crypto.cc +++ /dev/null @@ -1,258 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#include "crypto.h" - -#include -#include - -#include "glog/logging.h" - -namespace psi { - -u8 *hash(const void *d, u64 n, void *md) { - return SHA1(reinterpret_cast(d), n, reinterpret_cast(md)); -} - -int encrypt(const unsigned char *plaintext, int plaintext_len, - const unsigned char *key, const unsigned char *iv, - unsigned char *ciphertext) { - EVP_CIPHER_CTX *ctx = NULL; - int len = 0; - int aes_ciphertext_len = 0; - int ret = 0; - - memcpy(ciphertext, iv, GCM_IV_LEN); - - unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN; - unsigned char *tag = ciphertext + GCM_IV_LEN + plaintext_len; - - ctx = EVP_CIPHER_CTX_new(); - if (ctx == NULL) { - LOG(ERROR) << "openssl error"; - return 0; - } - - ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return 0; - } - - ret = EVP_EncryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return 0; - } - - ret = - EVP_EncryptUpdate(ctx, aes_ciphertext, &len, plaintext, plaintext_len); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return 0; - } - aes_ciphertext_len = len; - - ret = EVP_EncryptFinal_ex(ctx, aes_ciphertext + len, &len); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return 0; - } - aes_ciphertext_len += len; - - if (aes_ciphertext_len != plaintext_len) { - LOG(ERROR) << "encrypt error: ciphertext len mismatched"; - return 0; - } - - ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_LEN, tag); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return 0; - } - - EVP_CIPHER_CTX_free(ctx); - - return aes_ciphertext_len + GCM_IV_LEN + GCM_TAG_LEN; -} - -int decrypt(const unsigned char *ciphertext, int ciphertext_len, - const unsigned char *key, unsigned char *plaintext) { - EVP_CIPHER_CTX *ctx = NULL; - - int len = 0; - int plaintext_len = 0; - int ret = 0; - - const unsigned char *iv = ciphertext; - const unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN; - - int aes_ciphertext_len = ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN; - - unsigned char tag[GCM_TAG_LEN]; - - memcpy(tag, ciphertext + ciphertext_len - GCM_TAG_LEN, GCM_TAG_LEN); - - ctx = EVP_CIPHER_CTX_new(); - if (ctx == NULL) { - LOG(ERROR) << "openssl error"; - return -1; - } - - ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return -1; - } - - ret = EVP_DecryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return -1; - } - - ret = EVP_DecryptUpdate(ctx, plaintext, &len, aes_ciphertext, - aes_ciphertext_len); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return -1; - } - plaintext_len = len; - - if (plaintext_len != ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN) { - LOG(ERROR) << "openssl error"; - return -1; - } - - ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LEN, tag); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - return -1; - } - - ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len); - - EVP_CIPHER_CTX_free(ctx); - - if (ret > 0) { - plaintext_len += len; - return plaintext_len; - } else { - return -1; - } -} - -ECDH::ECDH() { - _error = false; - int ret = 0; - - _group = EC_GROUP_new_by_curve_name(CURVE_ID); - if (_group == NULL) { - LOG(ERROR) << "openssl error"; - _error = true; - return; - } - - _key = EC_KEY_new(); - if (_key == NULL) { - LOG(ERROR) << "openssl error"; - _error = true; - return; - } - - ret = EC_KEY_set_group(_key, _group); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - _error = true; - return; - } - - _remote_key = EC_POINT_new(_group); - if (_remote_key == NULL) { - LOG(ERROR) << "openssl error"; - _error = true; - return; - } -} - -ECDH::~ECDH() { - EC_POINT_free(_remote_key); - EC_KEY_free(_key); - EC_GROUP_free(_group); -} - -std::array ECDH::generate_key() { - int ret = 0; - std::array output; - - if (_error) { - LOG(ERROR) << "internal error"; - return output; - } - - ret = EC_KEY_generate_key(_key); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - _error = true; - return output; - } - - const EC_POINT *key_point = EC_KEY_get0_public_key(_key); - if (key_point == NULL) { - LOG(ERROR) << "openssl error"; - _error = true; - return output; - } - - - ret = EC_POINT_point2oct(_group, key_point, POINT_CONVERSION_COMPRESSED, - output.data(), POINT_BUFFER_LEN, NULL); - if (ret == 0) { - LOG(ERROR) << "openssl error"; - _error = true; - return output; - } - - return output; -} - -std::array -ECDH::get_shared_secret(const std::array &remote_key) { - int ret = 0; - std::array secret; - - ret = EC_POINT_oct2point(_group, _remote_key, remote_key.data(), - remote_key.size(), NULL); - if (ret != 1) { - LOG(ERROR) << "openssl error"; - _error = true; - return secret; - } - - - int secret_len = POINT_BUFFER_LEN - 1; - // compressed flag not included in secret, see - // http://www.secg.org/sec1-v2.pdf chapter 2.2.3 - - ret = ECDH_compute_key(secret.data(), secret_len, _remote_key, _key, NULL); - - if (ret <= 0) { - LOG(ERROR) << "openssl error"; - _error = true; - return secret; - } - - return secret; -} -} // namespace psi diff --git a/core/privc/crypto.h b/core/privc/crypto.h index 08f6d20..8b21042 100644 --- a/core/privc/crypto.h +++ b/core/privc/crypto.h @@ -31,21 +31,6 @@ typedef unsigned long long u64; const block ZeroBlock = _mm_set_epi64x(0, 0); const block OneBlock = _mm_set_epi64x(-1, -1); -const int CURVE_ID = NID_secp160k1; - -const int POINT_BUFFER_LEN = 21; -// only apply for 160 bit curve -// specification about point buf len, see http://www.secg.org/sec1-v2.pdf -// chapter 2.2.3 - -const int HASH_DIGEST_LEN = SHA_DIGEST_LENGTH; - -const int GCM_IV_LEN = 12; - -const int GCM_TAG_LEN = 16; - -u8 *hash(const void *d, u64 n, void *md); - static block double_block(block bl); static inline block hash_block(const block& x, const block& i = ZeroBlock) { @@ -63,84 +48,6 @@ static inline std::pair hash_blocks(const std::pair& return {c[0] ^ k[0], c[1] ^ k[1]}; } -template -static inline block to_block(const T& val) { - block ret = ZeroBlock; - std::memcpy(&ret, &val, std::min(sizeof ret, sizeof val)); - return ret; -} - -// ciphertext = iv || aes_ciphertext || gcm_tag -// allocate buffer before call -int encrypt(const unsigned char *plaintext, int plaintext_len, - const unsigned char *key, const unsigned char *iv, - unsigned char *ciphertext); - -int decrypt(const unsigned char *ciphertext, int ciphertext_len, - const unsigned char *key, unsigned char *plaintext); - -class ECDH { -private: - EC_GROUP *_group; - EC_KEY *_key; - EC_POINT *_remote_key; - bool _error; - -public: - ECDH(); - ~ECDH(); - - inline bool error() {return _error;} - - ECDH(const ECDH &other) = delete; - ECDH operator=(const ECDH &other) = delete; - - std::array generate_key(); - - std::array - get_shared_secret(const std::array &remote_key); -}; - -/* - This file is part of JustGarble. - JustGarble is free software: you can redistribute it and/or modify - it under the terms of the GNU General Public License as published by - the Free Software Foundation, either version 3 of the License, or - (at your option) any later version. - JustGarble is distributed in the hope that it will be useful, - but WITHOUT ANY WARRANTY; without even the implied warranty of - MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the - GNU General Public License for more details. - You should have received a copy of the GNU General Public License - along with JustGarble. If not, see . - */ - - -/*------------------------------------------------------------------------ - / OCB Version 3 Reference Code (Optimized C) Last modified 08-SEP-2012 - /------------------------------------------------------------------------- - / Copyright (c) 2012 Ted Krovetz. - / - / Permission to use, copy, modify, and/or distribute this software for any - / purpose with or without fee is hereby granted, provided that the above - / copyright notice and this permission notice appear in all copies. - / - / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES - / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF - / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR - / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES - / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN - / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF - / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. - / - / Phillip Rogaway holds patents relevant to OCB. See the following for - / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm - / - / Special thanks to Keegan McAllister for suggesting several good improvements - / - / Comments are welcome: Ted Krovetz - Dedicated to Laurel K - /------------------------------------------------------------------------- */ - static inline block double_block(block bl) { const __m128i mask = _mm_set_epi32(135,1,1,1); __m128i tmp = _mm_srai_epi32(bl, 31); diff --git a/core/privc/crypto_test.cc b/core/privc/crypto_test.cc index d76f0fb..706360b 100644 --- a/core/privc/crypto_test.cc +++ b/core/privc/crypto_test.cc @@ -21,64 +21,6 @@ namespace psi { -TEST(crypto, hash) { - std::string input = "abc"; - char output[HASH_DIGEST_LEN + 1]; - - output[HASH_DIGEST_LEN] = '\0'; - - const char *standard_vec = - "\xa9\x99\x3e\x36\x47\x06\x81\x6a\xba\x3e" - "\x25\x71\x78\x50\xc2\x6c\x9c\xd0\xd8\x9d"; - - hash(input.data(), input.size(), output); - - EXPECT_STREQ(standard_vec, output); - -} - -TEST(crypto, enc) { - std::string input = "abc"; - std::string iv = "0123456789ab"; - std::string key = "0123456789abcdef"; // aes_128_gcm, key_len = 128bit - - unsigned int cipher_len = GCM_IV_LEN + GCM_TAG_LEN + input.size(); - auto *output = new unsigned char [cipher_len]; - - int enc_ret = encrypt((unsigned char *)input.data(), input.size(), - (unsigned char *)key.data(), - (unsigned char *)iv.data(), output); - - ASSERT_EQ(cipher_len, (size_t)enc_ret); - - char *plaintext = new char [input.size() + 1]; - plaintext[input.size()] = '\0'; - int dec_ret = decrypt(output, enc_ret, (unsigned char *)key.data(), - (unsigned char *)plaintext); - - ASSERT_EQ(input.size(), (size_t)dec_ret); - - EXPECT_STREQ(input.c_str(), plaintext); - - delete output; - delete plaintext; -} - - -TEST(crypto, ecdh) { - ECDH alice; - ECDH bob; - - auto ga = alice.generate_key(); - auto gb = bob.generate_key(); - - auto ka = alice.get_shared_secret(gb); - auto kb = bob.get_shared_secret(ga); - - ASSERT_EQ(ka.size(), kb.size()); - EXPECT_TRUE(0 == std::memcmp(ka.data(), kb.data(), ka.size())); -} - TEST(crypto, hash_block) { block in = ZeroBlock; diff --git a/core/privc/triplet_generator.h b/core/privc/triplet_generator.h index 20ea6dc..ac7013c 100644 --- a/core/privc/triplet_generator.h +++ b/core/privc/triplet_generator.h @@ -110,6 +110,7 @@ public: virtual void get_triplet(TensorAdapter* ret); + // TODO: use SecureML sec4.2 triplet generator trick to improve mat_mul virtual void get_penta_triplet(TensorAdapter* ret); std::queue> _triplet_buffer; @@ -183,4 +184,4 @@ private: } // namespace privc -#include "triplet_generator_impl.h" \ No newline at end of file +#include "triplet_generator_impl.h" -- GitLab