提交 7150cc79 编写于 作者: Y yangqingyou

remove redundancy of crypto module

上级 53efcf70
set(PRIVC_SRCS set(PRIVC_SRCS
"crypto.cc"
"privc_context.cc" "privc_context.cc"
) )
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "crypto.h"
#include <openssl/ecdh.h>
#include <string.h>
#include "glog/logging.h"
namespace psi {
u8 *hash(const void *d, u64 n, void *md) {
return SHA1(reinterpret_cast<const u8 *>(d), n, reinterpret_cast<u8 *>(md));
}
int encrypt(const unsigned char *plaintext, int plaintext_len,
const unsigned char *key, const unsigned char *iv,
unsigned char *ciphertext) {
EVP_CIPHER_CTX *ctx = NULL;
int len = 0;
int aes_ciphertext_len = 0;
int ret = 0;
memcpy(ciphertext, iv, GCM_IV_LEN);
unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN;
unsigned char *tag = ciphertext + GCM_IV_LEN + plaintext_len;
ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
LOG(ERROR) << "openssl error";
return 0;
}
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
ret = EVP_EncryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
ret =
EVP_EncryptUpdate(ctx, aes_ciphertext, &len, plaintext, plaintext_len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
aes_ciphertext_len = len;
ret = EVP_EncryptFinal_ex(ctx, aes_ciphertext + len, &len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
aes_ciphertext_len += len;
if (aes_ciphertext_len != plaintext_len) {
LOG(ERROR) << "encrypt error: ciphertext len mismatched";
return 0;
}
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_LEN, tag);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
EVP_CIPHER_CTX_free(ctx);
return aes_ciphertext_len + GCM_IV_LEN + GCM_TAG_LEN;
}
int decrypt(const unsigned char *ciphertext, int ciphertext_len,
const unsigned char *key, unsigned char *plaintext) {
EVP_CIPHER_CTX *ctx = NULL;
int len = 0;
int plaintext_len = 0;
int ret = 0;
const unsigned char *iv = ciphertext;
const unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN;
int aes_ciphertext_len = ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN;
unsigned char tag[GCM_TAG_LEN];
memcpy(tag, ciphertext + ciphertext_len - GCM_TAG_LEN, GCM_TAG_LEN);
ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptUpdate(ctx, plaintext, &len, aes_ciphertext,
aes_ciphertext_len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
plaintext_len = len;
if (plaintext_len != ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LEN, tag);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len);
EVP_CIPHER_CTX_free(ctx);
if (ret > 0) {
plaintext_len += len;
return plaintext_len;
} else {
return -1;
}
}
ECDH::ECDH() {
_error = false;
int ret = 0;
_group = EC_GROUP_new_by_curve_name(CURVE_ID);
if (_group == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
_key = EC_KEY_new();
if (_key == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
ret = EC_KEY_set_group(_key, _group);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
_remote_key = EC_POINT_new(_group);
if (_remote_key == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
}
ECDH::~ECDH() {
EC_POINT_free(_remote_key);
EC_KEY_free(_key);
EC_GROUP_free(_group);
}
std::array<u8, POINT_BUFFER_LEN> ECDH::generate_key() {
int ret = 0;
std::array<u8, POINT_BUFFER_LEN> output;
if (_error) {
LOG(ERROR) << "internal error";
return output;
}
ret = EC_KEY_generate_key(_key);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
const EC_POINT *key_point = EC_KEY_get0_public_key(_key);
if (key_point == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
ret = EC_POINT_point2oct(_group, key_point, POINT_CONVERSION_COMPRESSED,
output.data(), POINT_BUFFER_LEN, NULL);
if (ret == 0) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
return output;
}
std::array<u8, POINT_BUFFER_LEN - 1>
ECDH::get_shared_secret(const std::array<u8, POINT_BUFFER_LEN> &remote_key) {
int ret = 0;
std::array<u8, POINT_BUFFER_LEN - 1> secret;
ret = EC_POINT_oct2point(_group, _remote_key, remote_key.data(),
remote_key.size(), NULL);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return secret;
}
int secret_len = POINT_BUFFER_LEN - 1;
// compressed flag not included in secret, see
// http://www.secg.org/sec1-v2.pdf chapter 2.2.3
ret = ECDH_compute_key(secret.data(), secret_len, _remote_key, _key, NULL);
if (ret <= 0) {
LOG(ERROR) << "openssl error";
_error = true;
return secret;
}
return secret;
}
} // namespace psi
...@@ -31,21 +31,6 @@ typedef unsigned long long u64; ...@@ -31,21 +31,6 @@ typedef unsigned long long u64;
const block ZeroBlock = _mm_set_epi64x(0, 0); const block ZeroBlock = _mm_set_epi64x(0, 0);
const block OneBlock = _mm_set_epi64x(-1, -1); const block OneBlock = _mm_set_epi64x(-1, -1);
const int CURVE_ID = NID_secp160k1;
const int POINT_BUFFER_LEN = 21;
// only apply for 160 bit curve
// specification about point buf len, see http://www.secg.org/sec1-v2.pdf
// chapter 2.2.3
const int HASH_DIGEST_LEN = SHA_DIGEST_LENGTH;
const int GCM_IV_LEN = 12;
const int GCM_TAG_LEN = 16;
u8 *hash(const void *d, u64 n, void *md);
static block double_block(block bl); static block double_block(block bl);
static inline block hash_block(const block& x, const block& i = ZeroBlock) { static inline block hash_block(const block& x, const block& i = ZeroBlock) {
...@@ -63,84 +48,6 @@ static inline std::pair<block, block> hash_blocks(const std::pair<block, block>& ...@@ -63,84 +48,6 @@ static inline std::pair<block, block> hash_blocks(const std::pair<block, block>&
return {c[0] ^ k[0], c[1] ^ k[1]}; return {c[0] ^ k[0], c[1] ^ k[1]};
} }
template <typename T>
static inline block to_block(const T& val) {
block ret = ZeroBlock;
std::memcpy(&ret, &val, std::min(sizeof ret, sizeof val));
return ret;
}
// ciphertext = iv || aes_ciphertext || gcm_tag
// allocate buffer before call
int encrypt(const unsigned char *plaintext, int plaintext_len,
const unsigned char *key, const unsigned char *iv,
unsigned char *ciphertext);
int decrypt(const unsigned char *ciphertext, int ciphertext_len,
const unsigned char *key, unsigned char *plaintext);
class ECDH {
private:
EC_GROUP *_group;
EC_KEY *_key;
EC_POINT *_remote_key;
bool _error;
public:
ECDH();
~ECDH();
inline bool error() {return _error;}
ECDH(const ECDH &other) = delete;
ECDH operator=(const ECDH &other) = delete;
std::array<u8, POINT_BUFFER_LEN> generate_key();
std::array<u8, POINT_BUFFER_LEN - 1>
get_shared_secret(const std::array<u8, POINT_BUFFER_LEN> &remote_key);
};
/*
This file is part of JustGarble.
JustGarble is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
JustGarble is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with JustGarble. If not, see <http://www.gnu.org/licenses/>.
*/
/*------------------------------------------------------------------------
/ OCB Version 3 Reference Code (Optimized C) Last modified 08-SEP-2012
/-------------------------------------------------------------------------
/ Copyright (c) 2012 Ted Krovetz.
/
/ Permission to use, copy, modify, and/or distribute this software for any
/ purpose with or without fee is hereby granted, provided that the above
/ copyright notice and this permission notice appear in all copies.
/
/ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
/ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
/ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
/ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
/ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
/ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
/ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
/
/ Phillip Rogaway holds patents relevant to OCB. See the following for
/ his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
/
/ Special thanks to Keegan McAllister for suggesting several good improvements
/
/ Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
/------------------------------------------------------------------------- */
static inline block double_block(block bl) { static inline block double_block(block bl) {
const __m128i mask = _mm_set_epi32(135,1,1,1); const __m128i mask = _mm_set_epi32(135,1,1,1);
__m128i tmp = _mm_srai_epi32(bl, 31); __m128i tmp = _mm_srai_epi32(bl, 31);
......
...@@ -21,64 +21,6 @@ ...@@ -21,64 +21,6 @@
namespace psi { namespace psi {
TEST(crypto, hash) {
std::string input = "abc";
char output[HASH_DIGEST_LEN + 1];
output[HASH_DIGEST_LEN] = '\0';
const char *standard_vec =
"\xa9\x99\x3e\x36\x47\x06\x81\x6a\xba\x3e"
"\x25\x71\x78\x50\xc2\x6c\x9c\xd0\xd8\x9d";
hash(input.data(), input.size(), output);
EXPECT_STREQ(standard_vec, output);
}
TEST(crypto, enc) {
std::string input = "abc";
std::string iv = "0123456789ab";
std::string key = "0123456789abcdef"; // aes_128_gcm, key_len = 128bit
unsigned int cipher_len = GCM_IV_LEN + GCM_TAG_LEN + input.size();
auto *output = new unsigned char [cipher_len];
int enc_ret = encrypt((unsigned char *)input.data(), input.size(),
(unsigned char *)key.data(),
(unsigned char *)iv.data(), output);
ASSERT_EQ(cipher_len, (size_t)enc_ret);
char *plaintext = new char [input.size() + 1];
plaintext[input.size()] = '\0';
int dec_ret = decrypt(output, enc_ret, (unsigned char *)key.data(),
(unsigned char *)plaintext);
ASSERT_EQ(input.size(), (size_t)dec_ret);
EXPECT_STREQ(input.c_str(), plaintext);
delete output;
delete plaintext;
}
TEST(crypto, ecdh) {
ECDH alice;
ECDH bob;
auto ga = alice.generate_key();
auto gb = bob.generate_key();
auto ka = alice.get_shared_secret(gb);
auto kb = bob.get_shared_secret(ga);
ASSERT_EQ(ka.size(), kb.size());
EXPECT_TRUE(0 == std::memcmp(ka.data(), kb.data(), ka.size()));
}
TEST(crypto, hash_block) { TEST(crypto, hash_block) {
block in = ZeroBlock; block in = ZeroBlock;
......
...@@ -110,6 +110,7 @@ public: ...@@ -110,6 +110,7 @@ public:
virtual void get_triplet(TensorAdapter<T>* ret); virtual void get_triplet(TensorAdapter<T>* ret);
// TODO: use SecureML sec4.2 triplet generator trick to improve mat_mul
virtual void get_penta_triplet(TensorAdapter<T>* ret); virtual void get_penta_triplet(TensorAdapter<T>* ret);
std::queue<std::array<T, 3>> _triplet_buffer; std::queue<std::array<T, 3>> _triplet_buffer;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册