提交 7d553c48 编写于 作者: Y yangqingyou

add fixedpoint tensor and triplet generator

上级 acfda7c1
...@@ -90,6 +90,7 @@ if (USE_ABY3_TRUNC1) ...@@ -90,6 +90,7 @@ if (USE_ABY3_TRUNC1)
add_compile_definitions(USE_ABY3_TRUNC1) add_compile_definitions(USE_ABY3_TRUNC1)
endif(USE_ABY3_TRUNC1) endif(USE_ABY3_TRUNC1)
add_subdirectory(core/privc)
add_subdirectory(core/privc3) add_subdirectory(core/privc3)
add_subdirectory(core/paddlefl_mpc/mpc_protocol) add_subdirectory(core/paddlefl_mpc/mpc_protocol)
add_subdirectory(core/paddlefl_mpc/operators) add_subdirectory(core/paddlefl_mpc/operators)
...@@ -133,4 +134,4 @@ install(TARGETS paddle_enc mpc_data_utils ...@@ -133,4 +134,4 @@ install(TARGETS paddle_enc mpc_data_utils
if (WITH_PSI) if (WITH_PSI)
install(TARGETS psi LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}) install(TARGETS psi LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH})
endif() endif()
\ No newline at end of file
set(PRIVC_SRCS
"crypto.cc"
"privc_context.cc"
)
add_library(privc_o OBJECT ${PRIVC_SRCS})
add_dependencies(privc_o crypto privc3)
add_library(privc SHARED $<TARGET_OBJECTS:privc_o>)
target_link_libraries(privc psi)
#set(CMAKE_BUILD_TYPE Debug)
cc_test(crypto_test SRCS crypto_test.cc DEPS privc)
cc_test(privc_fixedpoint_tensor_test SRCS fixedpoint_tensor_test.cc DEPS privc)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "crypto.h"
#include <openssl/ecdh.h>
#include <string.h>
#include "glog/logging.h"
namespace psi {
u8 *hash(const void *d, u64 n, void *md) {
return SHA1(reinterpret_cast<const u8 *>(d), n, reinterpret_cast<u8 *>(md));
}
int encrypt(const unsigned char *plaintext, int plaintext_len,
const unsigned char *key, const unsigned char *iv,
unsigned char *ciphertext) {
EVP_CIPHER_CTX *ctx = NULL;
int len = 0;
int aes_ciphertext_len = 0;
int ret = 0;
memcpy(ciphertext, iv, GCM_IV_LEN);
unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN;
unsigned char *tag = ciphertext + GCM_IV_LEN + plaintext_len;
ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
LOG(ERROR) << "openssl error";
return 0;
}
ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
ret = EVP_EncryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
ret =
EVP_EncryptUpdate(ctx, aes_ciphertext, &len, plaintext, plaintext_len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
aes_ciphertext_len = len;
ret = EVP_EncryptFinal_ex(ctx, aes_ciphertext + len, &len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
aes_ciphertext_len += len;
if (aes_ciphertext_len != plaintext_len) {
LOG(ERROR) << "encrypt error: ciphertext len mismatched";
return 0;
}
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_LEN, tag);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return 0;
}
EVP_CIPHER_CTX_free(ctx);
return aes_ciphertext_len + GCM_IV_LEN + GCM_TAG_LEN;
}
int decrypt(const unsigned char *ciphertext, int ciphertext_len,
const unsigned char *key, unsigned char *plaintext) {
EVP_CIPHER_CTX *ctx = NULL;
int len = 0;
int plaintext_len = 0;
int ret = 0;
const unsigned char *iv = ciphertext;
const unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN;
int aes_ciphertext_len = ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN;
unsigned char tag[GCM_TAG_LEN];
memcpy(tag, ciphertext + ciphertext_len - GCM_TAG_LEN, GCM_TAG_LEN);
ctx = EVP_CIPHER_CTX_new();
if (ctx == NULL) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptUpdate(ctx, plaintext, &len, aes_ciphertext,
aes_ciphertext_len);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
plaintext_len = len;
if (plaintext_len != ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LEN, tag);
if (ret != 1) {
LOG(ERROR) << "openssl error";
return -1;
}
ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len);
EVP_CIPHER_CTX_free(ctx);
if (ret > 0) {
plaintext_len += len;
return plaintext_len;
} else {
return -1;
}
}
ECDH::ECDH() {
_error = false;
int ret = 0;
_group = EC_GROUP_new_by_curve_name(CURVE_ID);
if (_group == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
_key = EC_KEY_new();
if (_key == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
ret = EC_KEY_set_group(_key, _group);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
_remote_key = EC_POINT_new(_group);
if (_remote_key == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return;
}
}
ECDH::~ECDH() {
EC_POINT_free(_remote_key);
EC_KEY_free(_key);
EC_GROUP_free(_group);
}
std::array<u8, POINT_BUFFER_LEN> ECDH::generate_key() {
int ret = 0;
std::array<u8, POINT_BUFFER_LEN> output;
if (_error) {
LOG(ERROR) << "internal error";
return output;
}
ret = EC_KEY_generate_key(_key);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
const EC_POINT *key_point = EC_KEY_get0_public_key(_key);
if (key_point == NULL) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
ret = EC_POINT_point2oct(_group, key_point, POINT_CONVERSION_COMPRESSED,
output.data(), POINT_BUFFER_LEN, NULL);
if (ret == 0) {
LOG(ERROR) << "openssl error";
_error = true;
return output;
}
return output;
}
std::array<u8, POINT_BUFFER_LEN - 1>
ECDH::get_shared_secret(const std::array<u8, POINT_BUFFER_LEN> &remote_key) {
int ret = 0;
std::array<u8, POINT_BUFFER_LEN - 1> secret;
ret = EC_POINT_oct2point(_group, _remote_key, remote_key.data(),
remote_key.size(), NULL);
if (ret != 1) {
LOG(ERROR) << "openssl error";
_error = true;
return secret;
}
int secret_len = POINT_BUFFER_LEN - 1;
// compressed flag not included in secret, see
// http://www.secg.org/sec1-v2.pdf chapter 2.2.3
ret = ECDH_compute_key(secret.data(), secret_len, _remote_key, _key, NULL);
if (ret <= 0) {
LOG(ERROR) << "openssl error";
_error = true;
return secret;
}
return secret;
}
} // namespace psi
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include <utility>
#include <openssl/ec.h>
#include <openssl/evp.h>
#include <openssl/sha.h>
#include "../psi/aes.h"
namespace psi {
typedef unsigned char u8;
typedef unsigned long long u64;
const block ZeroBlock = _mm_set_epi64x(0, 0);
const block OneBlock = _mm_set_epi64x(-1, -1);
const int CURVE_ID = NID_secp160k1;
const int POINT_BUFFER_LEN = 21;
// only apply for 160 bit curve
// specification about point buf len, see http://www.secg.org/sec1-v2.pdf
// chapter 2.2.3
const int HASH_DIGEST_LEN = SHA_DIGEST_LENGTH;
const int GCM_IV_LEN = 12;
const int GCM_TAG_LEN = 16;
u8 *hash(const void *d, u64 n, void *md);
static block double_block(block bl);
static inline block hash_block(const block& x, const block& i = ZeroBlock) {
static AES pi(ZeroBlock);
block k = double_block(x) ^ i;
return pi.ecb_enc_block(k) ^ k;
}
static inline std::pair<block, block> hash_blocks(const std::pair<block, block>& x,
const std::pair<block, block>& i = {ZeroBlock, ZeroBlock}) {
static AES pi(ZeroBlock);
block k[2] = {double_block(x.first) ^ i.first, double_block(x.second) ^ i.second};
block c[2];
pi.ecb_enc_blocks(k, 2, c);
return {c[0] ^ k[0], c[1] ^ k[1]};
}
template <typename T>
static inline block to_block(const T& val) {
block ret = ZeroBlock;
std::memcpy(&ret, &val, std::min(sizeof ret, sizeof val));
return ret;
}
// ciphertext = iv || aes_ciphertext || gcm_tag
// allocate buffer before call
int encrypt(const unsigned char *plaintext, int plaintext_len,
const unsigned char *key, const unsigned char *iv,
unsigned char *ciphertext);
int decrypt(const unsigned char *ciphertext, int ciphertext_len,
const unsigned char *key, unsigned char *plaintext);
class ECDH {
private:
EC_GROUP *_group;
EC_KEY *_key;
EC_POINT *_remote_key;
bool _error;
public:
ECDH();
~ECDH();
inline bool error() {return _error;}
ECDH(const ECDH &other) = delete;
ECDH operator=(const ECDH &other) = delete;
std::array<u8, POINT_BUFFER_LEN> generate_key();
std::array<u8, POINT_BUFFER_LEN - 1>
get_shared_secret(const std::array<u8, POINT_BUFFER_LEN> &remote_key);
};
/*
This file is part of JustGarble.
JustGarble is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
JustGarble is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with JustGarble. If not, see <http://www.gnu.org/licenses/>.
*/
/*------------------------------------------------------------------------
/ OCB Version 3 Reference Code (Optimized C) Last modified 08-SEP-2012
/-------------------------------------------------------------------------
/ Copyright (c) 2012 Ted Krovetz.
/
/ Permission to use, copy, modify, and/or distribute this software for any
/ purpose with or without fee is hereby granted, provided that the above
/ copyright notice and this permission notice appear in all copies.
/
/ THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES
/ WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
/ MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR
/ ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
/ WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
/ ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF
/ OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.
/
/ Phillip Rogaway holds patents relevant to OCB. See the following for
/ his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm
/
/ Special thanks to Keegan McAllister for suggesting several good improvements
/
/ Comments are welcome: Ted Krovetz <ted@krovetz.net> - Dedicated to Laurel K
/------------------------------------------------------------------------- */
static inline block double_block(block bl) {
const __m128i mask = _mm_set_epi32(135,1,1,1);
__m128i tmp = _mm_srai_epi32(bl, 31);
tmp = _mm_and_si128(tmp, mask);
tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2,1,0,3));
bl = _mm_slli_epi32(bl, 1);
return _mm_xor_si128(bl,tmp);
}
} // namespace psi
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "crypto.h"
#include <cstring>
#include <string>
#include "gtest/gtest.h"
namespace psi {
TEST(crypto, hash) {
std::string input = "abc";
char output[HASH_DIGEST_LEN + 1];
output[HASH_DIGEST_LEN] = '\0';
const char *standard_vec =
"\xa9\x99\x3e\x36\x47\x06\x81\x6a\xba\x3e"
"\x25\x71\x78\x50\xc2\x6c\x9c\xd0\xd8\x9d";
hash(input.data(), input.size(), output);
EXPECT_STREQ(standard_vec, output);
}
TEST(crypto, enc) {
std::string input = "abc";
std::string iv = "0123456789ab";
std::string key = "0123456789abcdef"; // aes_128_gcm, key_len = 128bit
unsigned int cipher_len = GCM_IV_LEN + GCM_TAG_LEN + input.size();
auto *output = new unsigned char [cipher_len];
int enc_ret = encrypt((unsigned char *)input.data(), input.size(),
(unsigned char *)key.data(),
(unsigned char *)iv.data(), output);
ASSERT_EQ(cipher_len, (size_t)enc_ret);
char *plaintext = new char [input.size() + 1];
plaintext[input.size()] = '\0';
int dec_ret = decrypt(output, enc_ret, (unsigned char *)key.data(),
(unsigned char *)plaintext);
ASSERT_EQ(input.size(), (size_t)dec_ret);
EXPECT_STREQ(input.c_str(), plaintext);
delete output;
delete plaintext;
}
TEST(crypto, ecdh) {
ECDH alice;
ECDH bob;
auto ga = alice.generate_key();
auto gb = bob.generate_key();
auto ka = alice.get_shared_secret(gb);
auto kb = bob.get_shared_secret(ga);
ASSERT_EQ(ka.size(), kb.size());
EXPECT_TRUE(0 == std::memcmp(ka.data(), kb.data(), ka.size()));
}
TEST(crypto, hash_block) {
block in = ZeroBlock;
for (size_t i = 0; i < 1e6; ++i) {
hash_block(in);
}
}
TEST(crypto, hash_blocks) {
block in = ZeroBlock;
for (size_t i = 0; i < 1e6; ++i) {
hash_blocks({in, in});
}
}
};
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "privc_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "../privc3/paddle_tensor.h"
#include "./triplet_generator.h"
namespace privc {
template<typename T>
using TensorAdapter = aby3::TensorAdapter<T>;
using TensorAdapterFactory = aby3::TensorAdapterFactory;
template<size_t N>
inline void fixed64_tensor_mult(const TensorAdapter<int64_t>* lhs,
const TensorAdapter<int64_t>* rhs,
TensorAdapter<int64_t>* ret) {
std::transform(lhs->data(), lhs->data() + lhs->numel(),
rhs->data(), ret->data(),
[] (const int64_t& lhs, const int64_t& rhs) -> int64_t {
return fixed64_mult<N>(lhs, rhs);
});
}
template<typename T, size_t N>
class FixedPointTensor {
public:
explicit FixedPointTensor(TensorAdapter<T>* share_tensor);
~FixedPointTensor() {};
template<typename T_>
class Type2Type {
typedef T_ type;
};
//get mutable shape of tensor
TensorAdapter<T>* mutable_share();
const TensorAdapter<T>* share() const;
size_t numel() const {
return _share->numel();
}
// reveal fixedpointtensor to one party
void reveal_to_one(size_t party, TensorAdapter<T>* ret) const;
// reveal fixedpointtensor to all parties
void reveal(TensorAdapter<T>* ret) const;
const std::vector<size_t> shape() const;
//convert TensorAdapter to shares
static void share(const TensorAdapter<T>* input,
TensorAdapter<T>* output_shares[2],
block seed = psi::g_zero_block);
// element-wise add with FixedPointTensor
void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
// element-wise add with TensorAdapter
void add(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
// element-wise sub with FixedPointTensor
void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
// element-wise sub with TensorAdapter
void sub(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
// negative
void negative(FixedPointTensor* ret) const;
// element-wise mul with FixedPointTensor using truncate1
void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const {
mul_impl<T>(rhs, ret, Type2Type<T>());
}
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
}
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
// element-wise mul with TensorAdapter
void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
// div by TensorAdapter
void div(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
//sum all element
void sum(FixedPointTensor* ret) const;
// mat_mul with FixedPointTensor
void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
// mat_mul with TensorAdapter
void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
// exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n
// where n = 2^ite
// void exp(FixedPointTensor* ret, size_t iter = 8) const;
// element-wise relu
void relu(FixedPointTensor* ret) const;
// element-wise relu with relu'
// void relu_with_derivative(FixedPointTensor* ret, BooleanTensor<T>* derivative) const;
// element-wise sigmoid using 3 piecewise polynomials
void sigmoid(FixedPointTensor* ret) const;
// softmax axis = -1
//void softmax(FixedPointTensor* ret) const;
// element-wise sigmoid using 3 piecewise polynomials
void argmax(FixedPointTensor* ret) const;
// element-wise compare
// <
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void lt(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// <=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void leq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// >
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void gt(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// >=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void geq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// ==
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void eq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// !=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void neq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// element-wise max
// if not null, cmp stores true if rhs is bigger
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void max(const CTensor<T, N1...>* rhs,
FixedPointTensor* ret,
CTensor<T, N1...>* cmp = nullptr) const;
private:
static inline std::shared_ptr<AbstractContext> privc_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
static inline std::shared_ptr<TensorAdapterFactory> tensor_factory() {
return paddle::mpc::ContextHolder::tensor_factory();
}
static inline std::shared_ptr<TripletGenerator<T, N>> tripletor() {
return std::dynamic_pointer_cast<PrivCContext>(privc_ctx())->triplet_generator();
}
static size_t party() {
return privc_ctx()->party();
}
static size_t next_party() {
return privc_ctx()->next_party();
}
TensorAdapter<T>* _share;
};
} //namespace privc
#include "fixedpoint_tensor_imp.h"
此差异已折叠。
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <cmath>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "./privc_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h"
#include "core/privc/triplet_generator.h"
#include "core/privc3/paddle_tensor.h"
namespace privc {
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N32 = FixedPointTensor<int64_t, SCALING_N>;
using AbstractContext = paddle::mpc::AbstractContext;
class FixedTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<AbstractContext> _mpc_ctx[2];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
std::thread _t[2];
std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
virtual ~FixedTensorTest() noexcept {}
void SetUp() {
paddle::framework::OperatorBase* op = nullptr;
paddle::framework::Scope scope;
paddle::framework::RuntimeContext ctx({}, {});
// only device_ctx is needed
_exec_ctx = std::make_shared<paddle::framework::ExecutionContext>(
*op, scope, _cpu_ctx, ctx);
_store = std::make_shared<gloo::rendezvous::HashStore>();
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, this, i);
}
for (auto& ti : _t) {
ti.join();
}
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&FixedTensorTest::init_triplet, this, i);
}
for (auto& ti : _t) {
ti.join();
}
_s_tensor_factory = std::make_shared<aby3::PaddleTensorFactory>(&_cpu_ctx);
}
std::shared_ptr<paddle::mpc::MeshNetwork> gen_network(size_t idx) {
return std::make_shared<paddle::mpc::MeshNetwork>(idx,
"127.0.0.1",
2,
"test_prefix_privc",
_store);
}
void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<PrivCContext>(idx, net);
}
void init_triplet(size_t idx) {
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> tripletor
= std::make_shared<TripletGenerator<int64_t, SCALING_N>>(_mpc_ctx[idx]);
tripletor->init();
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_triplet_generator(tripletor);
}
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return _s_tensor_factory->template create<int64_t>(shape);
}
};
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return g_ctx_holder::tensor_factory()->template create<int64_t>(shape);
}
TEST_F(FixedTensorTest, share) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl = gen(shape);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
TensorAdapter<int64_t>* output_share[2] = {ret[0].get(), ret[1].get()};
sl->data()[0] = (int64_t)1 << SCALING_N;
sl->scaling_factor() = SCALING_N;
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
Fix64N32::share(sl.get(), output_share);
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){});
}
);
for (auto &t: _t) {
t.join();
}
auto p = gen(shape);
output_share[0]->add(output_share[1], p.get());
EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, reveal) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
auto p0 = gen(shape);
auto p1 = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.reveal(p0.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.reveal(p1.get());
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(p0->data()[0], p1->data()[0]);
EXPECT_EQ(3, p0->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, addplain) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr = gen(shape);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
// rhs = 3
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
sr->data()[0] = (int64_t)3 << SCALING_N;
sr->scaling_factor() = SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.add(sr.get(), &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.add(sr.get(), &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(6, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, addfixed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
// rhs = 3 = 1 + 2
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
sr[0]->data()[0] = (int64_t)1 << SCALING_N;
sr[1]->data()[0] = (int64_t)2 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fr0(sr[0].get());
Fix64N32 fr1(sr[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.add(&fr0, &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.add(&fr1, &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(6, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, subplain) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr = gen(shape);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
// rhs = 2
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
sr->data()[0] = (int64_t)2 << SCALING_N;
sr->scaling_factor() = SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.sub(sr.get(), &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.sub(sr.get(), &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, subfixed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
// rhs = 2 = 1 + 1
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
sr[0]->data()[0] = (int64_t)1 << SCALING_N;
sr[1]->data()[0] = (int64_t)1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fr0(sr[0].get());
Fix64N32 fr1(sr[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.sub(&fr0, &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.sub(&fr1, &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, negative) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 3 = 1 + 2
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.negative(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.negative(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(-3, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 3);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 3;
for (int i = 0; i < ret[0]->numel() / 3; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t b_idx = num_triplet + i;
uint64_t c_idx = 2 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), 20);
}
}
TEST_F(FixedTensorTest, mulfixed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 2 = 1 + 1
// rhs = 2 = 1 + 1
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)1 << SCALING_N;
sr[0]->data()[0] = (int64_t)1 << SCALING_N;
sr[1]->data()[0] = (int64_t)1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fr0(sr[0].get());
Fix64N32 fr1(sr[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.mul(&fr0, &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.mul(&fr1, &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_NEAR(4, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
}
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include <memory>
#include "core/privc/triplet_generator.h"
#include "core/privc/privc_context.h"
namespace privc {
void PrivCContext::set_triplet_generator(std::shared_ptr<TripletGenerator<int64_t, SCALING_N>>& tripletor) {
_tripletor = tripletor;
}
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> PrivCContext::triplet_generator() {
PADDLE_ENFORCE_NE(_tripletor, nullptr, "must set triplet generator first.");
return _tripletor;
}
} // namespace privc
...@@ -13,24 +13,28 @@ ...@@ -13,24 +13,28 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h" #include "core/privc3/prng_utils.h"
namespace aby3 { namespace privc {
using AbstractNetwork = paddle::mpc::AbstractNetwork; using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext; using AbstractContext = paddle::mpc::AbstractContext;
using block = psi::block;
static const size_t SCALING_N = 32;
template <typename T, size_t N>
class TripletGenerator;
class PrivCContext : public AbstractContext { class PrivCContext : public AbstractContext {
public: public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network, PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
block seed = g_zero_block): block seed = psi::g_zero_block):
AbstractContext::AbstractContext(party, network) { AbstractContext::AbstractContext(party, network), _tripletor{nullptr} {
set_num_party(2); set_num_party(2);
if (psi::equals(seed, psi::g_zero_block)) { if (psi::equals(seed, psi::g_zero_block)) {
...@@ -42,13 +46,28 @@ public: ...@@ -42,13 +46,28 @@ public:
PrivCContext(const PrivCContext &other) = delete; PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete; PrivCContext &operator=(const PrivCContext &other) = delete;
/*
block get_private_block() {
std::array<int64_t, 2> ret_block;
ret_block[0] = gen_random_private<int64_t>();
ret_block[1] = gen_random_private<int64_t>();
return *(reinterpret_cast<block*>(ret_block.data()));
}
*/
void set_triplet_generator(std::shared_ptr<TripletGenerator<int64_t, SCALING_N>>& tripletor);
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> triplet_generator();
protected: protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override { psi::PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng; return _prng;
} }
private: private:
PseudorandomNumberGenerator _prng; std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> _tripletor;
psi::PseudorandomNumberGenerator _prng;
}; };
} // namespace aby3 } // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <array>
#include "paddle/fluid/platform/enforce.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/prng_utils.h"
#include "core/privc/crypto.h"
#include "core/psi/naorpinkas_ot.h"
#include "core/psi/ot_extension.h"
#include "core/privc3/tensor_adapter.h"
namespace privc {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
using block = psi::block;
using NaorPinkasOTsender = psi::NaorPinkasOTsender;
using NaorPinkasOTreceiver = psi::NaorPinkasOTreceiver;
template<typename T>
using OTExtSender = psi::OTExtSender<T>;
template<typename T>
using OTExtReceiver = psi::OTExtReceiver<T>;
template <typename T>
using TensorAdapter = aby3::TensorAdapter<T>;
template<size_t N>
inline int64_t fixed64_mult(const int64_t a, const int64_t b) {
__int128_t res = (__int128_t)a * (__int128_t)b;
return static_cast<int64_t>(res >> N);
}
template<size_t N>
inline uint64_t lshift(uint64_t lhs, size_t rhs) {
return fixed64_mult<N>(lhs, (uint64_t)1 << rhs);
}
inline std::string block_to_string(const block &b) {
return std::string(reinterpret_cast<const char *>(&b), sizeof(block));
}
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
uint64_t input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (uint64_t idx = 0; idx < word_width; idx += 1) {
auto ot_instance = ot_ext_recver.get_ot_instance();
block choice = (input >> idx) & 1 ? psi::OneBlock : psi::ZeroBlock;
t0_buffer.emplace_back(ot_instance[0]);
ot_masks.emplace_back(choice ^ ot_instance[0] ^ ot_instance[1]);
}
}
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const int64_t* input,
size_t size,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (size_t i = 0; i < size; ++i) {
gen_ot_masks(ot_ext_recver, input[i], ot_masks, t0_buffer, word_width);
}
}
template <typename T>
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const std::vector<T>& input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (const auto& i: input) {
gen_ot_masks(ot_ext_recver, i, ot_masks, t0_buffer, word_width);
}
}
template<typename T, size_t N>
class TripletGenerator {
public:
TripletGenerator(std::shared_ptr<AbstractContext>& circuit_context) :
_base_ot_choices(circuit_context->gen_random_private<block>()),
_np_ot_sender(sizeof(block) * 8),
_np_ot_recver(sizeof(block) * 8, block_to_string(_base_ot_choices)) {
_privc_ctx = circuit_context;
};
void init();
virtual void get_triplet(TensorAdapter<T>* ret);
virtual void get_penta_triplet(TensorAdapter<T>* ret) {}
std::queue<std::array<T, 3>> _triplet_buffer;
static const size_t _s_triplet_step = 1 << 8;
static constexpr double _s_fixed_point_compensation = 0.3;
static const size_t OT_SIZE = sizeof(block) * 8;
protected:
// T = int64
template<typename T_>
class Type2Type {
typedef T_ type;
};
void fill_triplet_buffer() { fill_triplet_buffer_impl<T>(Type2Type<T>()); }
template<typename T__>
void fill_triplet_buffer_impl(const Type2Type<T__>) {
PADDLE_THROW("type except `int64_t` for generating triplet is not implemented yet");
}
// specialize template method by overload
template<typename T__>
void fill_triplet_buffer_impl(const Type2Type<int64_t>);
private:
std::shared_ptr<AbstractContext> privc_ctx() {
return _privc_ctx;
}
AbstractNetwork* net() {
return _privc_ctx->network();
}
size_t party() {
return privc_ctx()->party();
}
size_t next_party() {
return privc_ctx()->next_party();
}
// gen triplet for int64_t type
std::vector<uint64_t> gen_product(const std::vector<uint64_t> &input) {
return gen_product_impl<T>(input, Type2Type<T>());
};
template<typename T__>
std::vector<uint64_t> gen_product_impl(const std::vector<uint64_t> &input,
Type2Type<T__>) {
PADDLE_THROW("type except `int64_t` for generating triplet is not implemented yet");
}
template<typename T__>
std::vector<uint64_t> gen_product_impl(const std::vector<uint64_t> &input,
Type2Type<int64_t>);
const block _base_ot_choices;
NaorPinkasOTsender _np_ot_sender;
NaorPinkasOTreceiver _np_ot_recver;
OTExtSender<block> _ot_ext_sender;
OTExtReceiver<block> _ot_ext_recver;
std::shared_ptr<AbstractContext> _privc_ctx;
};
} // namespace privc
#include "triplet_generator_impl.h"
\ No newline at end of file
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/privc/triplet_generator.h"
#include "core/privc/crypto.h"
#include "core/privc/privc_context.h"
namespace privc {
template<typename T, size_t N>
void TripletGenerator<T, N>::init() {
auto np_ot_send_pre = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> send_buffer;
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_sender.send_pre(idx);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
auto np_ot_send_post = [&]() {
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> recv_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
_np_ot_sender.send_post(idx, recv_buffer[idx]);
}
};
auto np_ot_recv = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> recv_buffer;
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> send_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_recver.recv(idx, recv_buffer[idx]);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
if (party() == 0) {
np_ot_recv();
np_ot_send_pre();
np_ot_send_post();
} else { // party == Bob
np_ot_send_pre();
np_ot_send_post();
np_ot_recv();
}
_ot_ext_sender.init(_base_ot_choices, _np_ot_recver._msgs);
_ot_ext_recver.init(_np_ot_sender._msgs);
}
template<typename T, size_t N>
void TripletGenerator<T, N>::get_triplet(TensorAdapter<T>* ret) {
size_t num_trip = ret->numel() / 3;
if (_triplet_buffer.size() < num_trip) {
fill_triplet_buffer();
}
for (int i = 0; i < num_trip; ++i) {
auto triplet = _triplet_buffer.front();
auto ret_ptr = ret->data() + i;
*ret_ptr = triplet[0];
*(ret_ptr + num_trip) = triplet[1];
*(ret_ptr + 2 * num_trip) = triplet[2];
_triplet_buffer.pop();
}
}
template<typename T, size_t N>
template<typename T_>
void TripletGenerator<T, N>::fill_triplet_buffer_impl(const Type2Type<int64_t>) {
std::vector<uint64_t> a(_s_triplet_step);
std::vector<uint64_t> b(_s_triplet_step);
std::for_each(a.data(), a.data() + a.size(),
[this](uint64_t& val) {
val = privc_ctx()-> template gen_random_private<uint64_t>(); });
std::for_each(b.data(), b.data() + b.size(),
[this](uint64_t& val) {
val = privc_ctx()-> template gen_random_private<uint64_t>(); });
std::vector<uint64_t> ab0;
std::vector<uint64_t> ab1;
std::function<std::vector<uint64_t>(const std::vector<uint64_t>&)> gen_p
= [this](const std::vector<uint64_t>& v) {
return gen_product(v);
};
ab0 = gen_p(privc_ctx()->party() == 0 ? a : b);
ab1 = gen_p(privc_ctx()->party() == 0 ? b : a);
for (uint64_t i = 0; i < a.size(); i += 1) {
std::array<int64_t, 3> item = {
static_cast<int64_t>(a[i]),
static_cast<int64_t>(b[i]),
static_cast<int64_t>(fixed64_mult<N>(a[i], b[i]) + ab0[i] + ab1[i])};
_triplet_buffer.push(std::move(item));
}
}
template<typename T, size_t N>
template<typename T_>
std::vector<uint64_t> TripletGenerator<T, N>::gen_product_impl(
const std::vector<uint64_t> &input,
Type2Type<int64_t>) {
size_t word_width = 8 * sizeof(uint64_t);
std::vector<uint64_t> ret;
if (party() == 0) {
std::vector<uint64_t> s1_buffer;
std::vector<block> ot_mask;
ot_mask.resize(input.size() * word_width);
net()->recv(next_party(), ot_mask.data(), sizeof(block) * ot_mask.size());
size_t ot_mask_idx = 0;
for (const auto &a: input) {
uint64_t ret_val = 0;
for (uint64_t idx = 0; idx < word_width; idx += 1) {
const block& round_ot_mask = ot_mask.at(ot_mask_idx);
//net()->recv(next_party(), &round_ot_mask, sizeof(block));
// bad naming from ot extention
block q = _ot_ext_sender.get_ot_instance();
q ^= (round_ot_mask & _base_ot_choices);
auto s = psi::hash_blocks({q, q ^ _base_ot_choices});
uint64_t s0 = *reinterpret_cast<uint64_t *>(&s.first);
uint64_t s1 = *reinterpret_cast<uint64_t *>(&s.second);
s1 ^= lshift<N>(a, idx) - s0;
s1_buffer.emplace_back(s1);
ret_val += s0;
ot_mask_idx++;
}
ret.emplace_back(ret_val);
}
net()->send(next_party(), s1_buffer.data(), sizeof(uint64_t) * s1_buffer.size());
} else { // as ot recver
std::vector<block> ot_masks;
std::vector<block> t0_buffer;
gen_ot_masks(_ot_ext_recver, input, ot_masks, t0_buffer);
net()->send(next_party(), ot_masks.data(), sizeof(block) * ot_masks.size());
std::vector<uint64_t> ot_msg;
ot_msg.resize(input.size() * word_width);
net()->recv(next_party(), ot_msg.data(), sizeof(uint64_t) * ot_msg.size());
size_t ot_msg_idx = 0;
uint64_t b_idx = 0;
for (const auto &b: input) {
uint64_t ret_val = 0;
int b_weight = 0;
for (size_t idx = 0; idx < word_width; idx += 1) {
const uint64_t& round_ot_msg = ot_msg.at(ot_msg_idx);
auto t0_hash = psi::hash_block(t0_buffer[b_idx * word_width + idx]);
uint64_t key = *reinterpret_cast<uint64_t *>(&t0_hash);
bool b_i = (b >> idx) & 1;
b_weight += b_i;
ret_val += b_i ? round_ot_msg ^ key : -key;
ot_msg_idx++;
}
// compensation for precision loss
ret.emplace_back(ret_val + static_cast<uint64_t>(_s_fixed_point_compensation * b_weight));
b_idx += 1;
}
}
return ret;
}
} // namespace privc
...@@ -13,7 +13,6 @@ ...@@ -13,7 +13,6 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册