diff --git a/CMakeLists.txt b/CMakeLists.txt index d14c3faa9d10b19f060cd468ca55720b8e8894b2..761138627773f72525225de489adfaff57ff5100 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -90,6 +90,7 @@ if (USE_ABY3_TRUNC1) add_compile_definitions(USE_ABY3_TRUNC1) endif(USE_ABY3_TRUNC1) +add_subdirectory(core/privc) add_subdirectory(core/privc3) add_subdirectory(core/paddlefl_mpc/mpc_protocol) add_subdirectory(core/paddlefl_mpc/operators) @@ -133,4 +134,4 @@ install(TARGETS paddle_enc mpc_data_utils if (WITH_PSI) install(TARGETS psi LIBRARY DESTINATION ${PADDLE_ENCRYPTED_LIB_PATH}) -endif() +endif() \ No newline at end of file diff --git a/core/privc/CMakeLists.txt b/core/privc/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..e641eb2f6783218590e34233d6360afca61bf111 --- /dev/null +++ b/core/privc/CMakeLists.txt @@ -0,0 +1,17 @@ +set(PRIVC_SRCS + "crypto.cc" + "privc_context.cc" +) + +add_library(privc_o OBJECT ${PRIVC_SRCS}) +add_dependencies(privc_o crypto privc3) + +add_library(privc SHARED $) + +target_link_libraries(privc psi) + +#set(CMAKE_BUILD_TYPE Debug) + +cc_test(crypto_test SRCS crypto_test.cc DEPS privc) +cc_test(privc_fixedpoint_tensor_test SRCS fixedpoint_tensor_test.cc DEPS privc) + diff --git a/core/privc/crypto.cc b/core/privc/crypto.cc new file mode 100644 index 0000000000000000000000000000000000000000..477beaf205cdef6799a7ccff5edd5045a30905d8 --- /dev/null +++ b/core/privc/crypto.cc @@ -0,0 +1,258 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "crypto.h" + +#include +#include + +#include "glog/logging.h" + +namespace psi { + +u8 *hash(const void *d, u64 n, void *md) { + return SHA1(reinterpret_cast(d), n, reinterpret_cast(md)); +} + +int encrypt(const unsigned char *plaintext, int plaintext_len, + const unsigned char *key, const unsigned char *iv, + unsigned char *ciphertext) { + EVP_CIPHER_CTX *ctx = NULL; + int len = 0; + int aes_ciphertext_len = 0; + int ret = 0; + + memcpy(ciphertext, iv, GCM_IV_LEN); + + unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN; + unsigned char *tag = ciphertext + GCM_IV_LEN + plaintext_len; + + ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + LOG(ERROR) << "openssl error"; + return 0; + } + + ret = EVP_EncryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return 0; + } + + ret = EVP_EncryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return 0; + } + + ret = + EVP_EncryptUpdate(ctx, aes_ciphertext, &len, plaintext, plaintext_len); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return 0; + } + aes_ciphertext_len = len; + + ret = EVP_EncryptFinal_ex(ctx, aes_ciphertext + len, &len); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return 0; + } + aes_ciphertext_len += len; + + if (aes_ciphertext_len != plaintext_len) { + LOG(ERROR) << "encrypt error: ciphertext len mismatched"; + return 0; + } + + ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_GET_TAG, GCM_TAG_LEN, tag); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return 0; + } + + EVP_CIPHER_CTX_free(ctx); + + return aes_ciphertext_len + GCM_IV_LEN + GCM_TAG_LEN; +} + +int decrypt(const unsigned char *ciphertext, int ciphertext_len, + const unsigned char *key, unsigned char *plaintext) { + EVP_CIPHER_CTX *ctx = NULL; + + int len = 0; + int plaintext_len = 0; + int ret = 0; + + const unsigned char *iv = ciphertext; + const unsigned char *aes_ciphertext = ciphertext + GCM_IV_LEN; + + int aes_ciphertext_len = ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN; + + unsigned char tag[GCM_TAG_LEN]; + + memcpy(tag, ciphertext + ciphertext_len - GCM_TAG_LEN, GCM_TAG_LEN); + + ctx = EVP_CIPHER_CTX_new(); + if (ctx == NULL) { + LOG(ERROR) << "openssl error"; + return -1; + } + + ret = EVP_DecryptInit_ex(ctx, EVP_aes_128_gcm(), NULL, key, iv); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return -1; + } + + ret = EVP_DecryptUpdate(ctx, NULL, &len, iv, GCM_IV_LEN); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return -1; + } + + ret = EVP_DecryptUpdate(ctx, plaintext, &len, aes_ciphertext, + aes_ciphertext_len); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return -1; + } + plaintext_len = len; + + if (plaintext_len != ciphertext_len - GCM_IV_LEN - GCM_TAG_LEN) { + LOG(ERROR) << "openssl error"; + return -1; + } + + ret = EVP_CIPHER_CTX_ctrl(ctx, EVP_CTRL_GCM_SET_TAG, GCM_TAG_LEN, tag); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + return -1; + } + + ret = EVP_DecryptFinal_ex(ctx, plaintext + len, &len); + + EVP_CIPHER_CTX_free(ctx); + + if (ret > 0) { + plaintext_len += len; + return plaintext_len; + } else { + return -1; + } +} + +ECDH::ECDH() { + _error = false; + int ret = 0; + + _group = EC_GROUP_new_by_curve_name(CURVE_ID); + if (_group == NULL) { + LOG(ERROR) << "openssl error"; + _error = true; + return; + } + + _key = EC_KEY_new(); + if (_key == NULL) { + LOG(ERROR) << "openssl error"; + _error = true; + return; + } + + ret = EC_KEY_set_group(_key, _group); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + _error = true; + return; + } + + _remote_key = EC_POINT_new(_group); + if (_remote_key == NULL) { + LOG(ERROR) << "openssl error"; + _error = true; + return; + } +} + +ECDH::~ECDH() { + EC_POINT_free(_remote_key); + EC_KEY_free(_key); + EC_GROUP_free(_group); +} + +std::array ECDH::generate_key() { + int ret = 0; + std::array output; + + if (_error) { + LOG(ERROR) << "internal error"; + return output; + } + + ret = EC_KEY_generate_key(_key); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + _error = true; + return output; + } + + const EC_POINT *key_point = EC_KEY_get0_public_key(_key); + if (key_point == NULL) { + LOG(ERROR) << "openssl error"; + _error = true; + return output; + } + + + ret = EC_POINT_point2oct(_group, key_point, POINT_CONVERSION_COMPRESSED, + output.data(), POINT_BUFFER_LEN, NULL); + if (ret == 0) { + LOG(ERROR) << "openssl error"; + _error = true; + return output; + } + + return output; +} + +std::array +ECDH::get_shared_secret(const std::array &remote_key) { + int ret = 0; + std::array secret; + + ret = EC_POINT_oct2point(_group, _remote_key, remote_key.data(), + remote_key.size(), NULL); + if (ret != 1) { + LOG(ERROR) << "openssl error"; + _error = true; + return secret; + } + + + int secret_len = POINT_BUFFER_LEN - 1; + // compressed flag not included in secret, see + // http://www.secg.org/sec1-v2.pdf chapter 2.2.3 + + ret = ECDH_compute_key(secret.data(), secret_len, _remote_key, _key, NULL); + + if (ret <= 0) { + LOG(ERROR) << "openssl error"; + _error = true; + return secret; + } + + return secret; +} +} // namespace psi diff --git a/core/privc/crypto.h b/core/privc/crypto.h new file mode 100644 index 0000000000000000000000000000000000000000..08f6d209f66e5feb58315bccce0542e1d7d6f56a --- /dev/null +++ b/core/privc/crypto.h @@ -0,0 +1,154 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include + +#include +#include +#include + +#include "../psi/aes.h" + +namespace psi { + +typedef unsigned char u8; +typedef unsigned long long u64; +const block ZeroBlock = _mm_set_epi64x(0, 0); +const block OneBlock = _mm_set_epi64x(-1, -1); + +const int CURVE_ID = NID_secp160k1; + +const int POINT_BUFFER_LEN = 21; +// only apply for 160 bit curve +// specification about point buf len, see http://www.secg.org/sec1-v2.pdf +// chapter 2.2.3 + +const int HASH_DIGEST_LEN = SHA_DIGEST_LENGTH; + +const int GCM_IV_LEN = 12; + +const int GCM_TAG_LEN = 16; + +u8 *hash(const void *d, u64 n, void *md); + +static block double_block(block bl); + +static inline block hash_block(const block& x, const block& i = ZeroBlock) { + static AES pi(ZeroBlock); + block k = double_block(x) ^ i; + return pi.ecb_enc_block(k) ^ k; +} + +static inline std::pair hash_blocks(const std::pair& x, + const std::pair& i = {ZeroBlock, ZeroBlock}) { + static AES pi(ZeroBlock); + block k[2] = {double_block(x.first) ^ i.first, double_block(x.second) ^ i.second}; + block c[2]; + pi.ecb_enc_blocks(k, 2, c); + return {c[0] ^ k[0], c[1] ^ k[1]}; +} + +template +static inline block to_block(const T& val) { + block ret = ZeroBlock; + std::memcpy(&ret, &val, std::min(sizeof ret, sizeof val)); + return ret; +} + +// ciphertext = iv || aes_ciphertext || gcm_tag +// allocate buffer before call +int encrypt(const unsigned char *plaintext, int plaintext_len, + const unsigned char *key, const unsigned char *iv, + unsigned char *ciphertext); + +int decrypt(const unsigned char *ciphertext, int ciphertext_len, + const unsigned char *key, unsigned char *plaintext); + +class ECDH { +private: + EC_GROUP *_group; + EC_KEY *_key; + EC_POINT *_remote_key; + bool _error; + +public: + ECDH(); + ~ECDH(); + + inline bool error() {return _error;} + + ECDH(const ECDH &other) = delete; + ECDH operator=(const ECDH &other) = delete; + + std::array generate_key(); + + std::array + get_shared_secret(const std::array &remote_key); +}; + +/* + This file is part of JustGarble. + JustGarble is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + JustGarble is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + You should have received a copy of the GNU General Public License + along with JustGarble. If not, see . + */ + + +/*------------------------------------------------------------------------ + / OCB Version 3 Reference Code (Optimized C) Last modified 08-SEP-2012 + /------------------------------------------------------------------------- + / Copyright (c) 2012 Ted Krovetz. + / + / Permission to use, copy, modify, and/or distribute this software for any + / purpose with or without fee is hereby granted, provided that the above + / copyright notice and this permission notice appear in all copies. + / + / THE SOFTWARE IS PROVIDED "AS IS" AND THE AUTHOR DISCLAIMS ALL WARRANTIES + / WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF + / MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL THE AUTHOR BE LIABLE FOR + / ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES + / WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN + / ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT OF + / OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE. + / + / Phillip Rogaway holds patents relevant to OCB. See the following for + / his patent grant: http://www.cs.ucdavis.edu/~rogaway/ocb/grant.htm + / + / Special thanks to Keegan McAllister for suggesting several good improvements + / + / Comments are welcome: Ted Krovetz - Dedicated to Laurel K + /------------------------------------------------------------------------- */ + +static inline block double_block(block bl) { + const __m128i mask = _mm_set_epi32(135,1,1,1); + __m128i tmp = _mm_srai_epi32(bl, 31); + tmp = _mm_and_si128(tmp, mask); + tmp = _mm_shuffle_epi32(tmp, _MM_SHUFFLE(2,1,0,3)); + bl = _mm_slli_epi32(bl, 1); + return _mm_xor_si128(bl,tmp); +} + +} // namespace psi + diff --git a/core/privc/crypto_test.cc b/core/privc/crypto_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..d76f0fb740451fd2d6f4d3104566629be63cc4aa --- /dev/null +++ b/core/privc/crypto_test.cc @@ -0,0 +1,100 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "crypto.h" + +#include +#include + +#include "gtest/gtest.h" + +namespace psi { + +TEST(crypto, hash) { + std::string input = "abc"; + char output[HASH_DIGEST_LEN + 1]; + + output[HASH_DIGEST_LEN] = '\0'; + + const char *standard_vec = + "\xa9\x99\x3e\x36\x47\x06\x81\x6a\xba\x3e" + "\x25\x71\x78\x50\xc2\x6c\x9c\xd0\xd8\x9d"; + + hash(input.data(), input.size(), output); + + EXPECT_STREQ(standard_vec, output); + +} + +TEST(crypto, enc) { + std::string input = "abc"; + std::string iv = "0123456789ab"; + std::string key = "0123456789abcdef"; // aes_128_gcm, key_len = 128bit + + unsigned int cipher_len = GCM_IV_LEN + GCM_TAG_LEN + input.size(); + auto *output = new unsigned char [cipher_len]; + + int enc_ret = encrypt((unsigned char *)input.data(), input.size(), + (unsigned char *)key.data(), + (unsigned char *)iv.data(), output); + + ASSERT_EQ(cipher_len, (size_t)enc_ret); + + char *plaintext = new char [input.size() + 1]; + plaintext[input.size()] = '\0'; + int dec_ret = decrypt(output, enc_ret, (unsigned char *)key.data(), + (unsigned char *)plaintext); + + ASSERT_EQ(input.size(), (size_t)dec_ret); + + EXPECT_STREQ(input.c_str(), plaintext); + + delete output; + delete plaintext; +} + + +TEST(crypto, ecdh) { + ECDH alice; + ECDH bob; + + auto ga = alice.generate_key(); + auto gb = bob.generate_key(); + + auto ka = alice.get_shared_secret(gb); + auto kb = bob.get_shared_secret(ga); + + ASSERT_EQ(ka.size(), kb.size()); + EXPECT_TRUE(0 == std::memcmp(ka.data(), kb.data(), ka.size())); +} + +TEST(crypto, hash_block) { + + block in = ZeroBlock; + + for (size_t i = 0; i < 1e6; ++i) { + hash_block(in); + } +} + +TEST(crypto, hash_blocks) { + + block in = ZeroBlock; + + for (size_t i = 0; i < 1e6; ++i) { + hash_blocks({in, in}); + } +} + +}; diff --git a/core/privc/fixedpoint_tensor.h b/core/privc/fixedpoint_tensor.h new file mode 100644 index 0000000000000000000000000000000000000000..29dab9efc9a257f9e0dc01192404e62104334c43 --- /dev/null +++ b/core/privc/fixedpoint_tensor.h @@ -0,0 +1,205 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "privc_context.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" +#include "../privc3/paddle_tensor.h" +#include "./triplet_generator.h" + +namespace privc { + +template +using TensorAdapter = aby3::TensorAdapter; +using TensorAdapterFactory = aby3::TensorAdapterFactory; + +template +inline void fixed64_tensor_mult(const TensorAdapter* lhs, + const TensorAdapter* rhs, + TensorAdapter* ret) { + std::transform(lhs->data(), lhs->data() + lhs->numel(), + rhs->data(), ret->data(), + [] (const int64_t& lhs, const int64_t& rhs) -> int64_t { + return fixed64_mult(lhs, rhs); + }); +} + +template +class FixedPointTensor { + +public: + explicit FixedPointTensor(TensorAdapter* share_tensor); + + ~FixedPointTensor() {}; + + template + class Type2Type { + typedef T_ type; + }; + + //get mutable shape of tensor + TensorAdapter* mutable_share(); + + const TensorAdapter* share() const; + + size_t numel() const { + return _share->numel(); + } + + // reveal fixedpointtensor to one party + void reveal_to_one(size_t party, TensorAdapter* ret) const; + + // reveal fixedpointtensor to all parties + void reveal(TensorAdapter* ret) const; + + const std::vector shape() const; + + //convert TensorAdapter to shares + static void share(const TensorAdapter* input, + TensorAdapter* output_shares[2], + block seed = psi::g_zero_block); + + // element-wise add with FixedPointTensor + void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const; + + // element-wise add with TensorAdapter + + void add(const TensorAdapter* rhs, FixedPointTensor* ret) const; + + // element-wise sub with FixedPointTensor + void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const; + + // element-wise sub with TensorAdapter + void sub(const TensorAdapter* rhs, FixedPointTensor* ret) const; + + // negative + void negative(FixedPointTensor* ret) const; + + // element-wise mul with FixedPointTensor using truncate1 + void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const { + mul_impl(rhs, ret, Type2Type()); + } + + template + void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type) const { + PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet"); + } + + template + void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type) const; + + // element-wise mul with TensorAdapter + void mul(const TensorAdapter* rhs, FixedPointTensor* ret) const; + + // div by TensorAdapter + void div(const TensorAdapter* rhs, FixedPointTensor* ret) const; + + //sum all element + void sum(FixedPointTensor* ret) const; + + // mat_mul with FixedPointTensor + void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const; + + // mat_mul with TensorAdapter + void mat_mul(const TensorAdapter* rhs, FixedPointTensor* ret) const; + + // exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n + // where n = 2^ite + // void exp(FixedPointTensor* ret, size_t iter = 8) const; + + // element-wise relu + void relu(FixedPointTensor* ret) const; + + // element-wise relu with relu' + // void relu_with_derivative(FixedPointTensor* ret, BooleanTensor* derivative) const; + + // element-wise sigmoid using 3 piecewise polynomials + void sigmoid(FixedPointTensor* ret) const; + + // softmax axis = -1 + //void softmax(FixedPointTensor* ret) const; + + // element-wise sigmoid using 3 piecewise polynomials + void argmax(FixedPointTensor* ret) const; + + // element-wise compare + // < + template class CTensor, + size_t... N1> + void lt(const CTensor* rhs, CTensor* ret) const; + + // <= + template class CTensor, + size_t... N1> + void leq(const CTensor* rhs, CTensor* ret) const; + + // > + template class CTensor, + size_t... N1> + void gt(const CTensor* rhs, CTensor* ret) const; + + // >= + template class CTensor, + size_t... N1> + void geq(const CTensor* rhs, CTensor* ret) const; + + // == + template class CTensor, + size_t... N1> + void eq(const CTensor* rhs, CTensor* ret) const; + + // != + template class CTensor, + size_t... N1> + void neq(const CTensor* rhs, CTensor* ret) const; + + // element-wise max + // if not null, cmp stores true if rhs is bigger + template class CTensor, + size_t... N1> + void max(const CTensor* rhs, + FixedPointTensor* ret, + CTensor* cmp = nullptr) const; + +private: + static inline std::shared_ptr privc_ctx() { + return paddle::mpc::ContextHolder::mpc_ctx(); + } + + static inline std::shared_ptr tensor_factory() { + return paddle::mpc::ContextHolder::tensor_factory(); + } + + static inline std::shared_ptr> tripletor() { + return std::dynamic_pointer_cast(privc_ctx())->triplet_generator(); + } + + static size_t party() { + return privc_ctx()->party(); + } + + static size_t next_party() { + return privc_ctx()->next_party(); + } + + TensorAdapter* _share; + +}; + +} //namespace privc + +#include "fixedpoint_tensor_imp.h" diff --git a/core/privc/fixedpoint_tensor_imp.h b/core/privc/fixedpoint_tensor_imp.h new file mode 100644 index 0000000000000000000000000000000000000000..1bf282af2993bafe0eba18078423f5b64e71b304 --- /dev/null +++ b/core/privc/fixedpoint_tensor_imp.h @@ -0,0 +1,1313 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/enforce.h" +#include "../privc3/prng.h" + +namespace privc { + +template +FixedPointTensor::FixedPointTensor(TensorAdapter* share_tensor) { + _share = share_tensor; +} + +template +TensorAdapter* FixedPointTensor::mutable_share() { + return _share; +} + +template +const TensorAdapter* FixedPointTensor::share() const { + return _share; +} + +// reveal fixedpointtensor to one party +template +void FixedPointTensor::reveal_to_one(size_t party, + TensorAdapter* ret) const { + + if (party == this->party()) { + auto buffer = tensor_factory()->template create(ret->shape()); + privc_ctx()->network()->template recv(next_party(), *buffer); + + share()->add(buffer.get(), ret); + ret->scaling_factor() = N; + } else { + privc_ctx()->network()->template send(party, *share()); + } +} + +// reveal fixedpointtensor to all parties +template +void FixedPointTensor::reveal(TensorAdapter* ret) const { + for (size_t i = 0; i < 2; ++i) { + reveal_to_one(i, ret); + } +} + +template +const std::vector FixedPointTensor::shape() const { + return _share->shape(); +} + +//convert TensorAdapter to shares +template +void FixedPointTensor::share(const TensorAdapter* input, + TensorAdapter* output_shares[2], + block seed) { + + if (psi::equals(seed, psi::g_zero_block)) { + seed = psi::block_from_dev_urandom(); + } + //set seed of prng[2] + privc_ctx()->set_random_seed(seed, 2); + + privc_ctx()->template gen_random_private(*output_shares[0]); + + input->sub(output_shares[0], output_shares[1]); + for (int i = 0; i < 2; ++i) { + output_shares[i]->scaling_factor() = input->scaling_factor(); + } +} + +template +void FixedPointTensor::add(const FixedPointTensor* rhs, + FixedPointTensor* ret) const { + _share->add(rhs->_share, ret->_share); +} + +template +void FixedPointTensor::add(const TensorAdapter* rhs, + FixedPointTensor* ret) const { + PADDLE_ENFORCE_EQ(N, rhs->scaling_factor(), + "no match scaling factor"); + if (party() == 0) { + _share->add(rhs, ret->_share); + } else { + _share->copy(ret->_share); + } +} + +template +void FixedPointTensor::sub(const FixedPointTensor* rhs, + FixedPointTensor* ret) const { + _share->sub(rhs->_share, ret->_share); +} + +template +void FixedPointTensor::sub(const TensorAdapter* rhs, + FixedPointTensor* ret) const { + PADDLE_ENFORCE_EQ(N, rhs->scaling_factor(), + "no match scaling factor"); + if (party() == 0) { + _share->sub(rhs, ret->_share); + } else { + _share->copy(ret->_share); + } +} + +template +void FixedPointTensor::negative(FixedPointTensor* ret) const { + _share->negative(ret->_share); +} + +template +template +void FixedPointTensor::mul_impl(const FixedPointTensor* rhs, + FixedPointTensor* ret, + const Type2Type) const { + auto triplet_shape = shape(); + triplet_shape.insert(triplet_shape.begin(), 3); + auto triplet = tensor_factory()->template create(triplet_shape); + tripletor()->get_triplet(triplet.get()); + + std::vector>> temp; + for (int i = 0; i < 8; ++i) { + temp.emplace_back( + tensor_factory()->template create(ret->shape())); + } + FixedPointTensor a(temp[0].get()); + FixedPointTensor b(temp[1].get()); + FixedPointTensor c(temp[2].get()); + auto parse_triplet = [&triplet](int idx, FixedPointTensor& ret) { + triplet->slice(idx, idx + 1, ret.mutable_share()); + auto shape = ret.shape(); + shape.erase(shape.begin()); + ret.mutable_share()->reshape(shape); + }; + + parse_triplet(0, a); + parse_triplet(1, b); + parse_triplet(2, c); + + FixedPointTensor e(temp[3].get()); + FixedPointTensor f(temp[4].get()); + this->sub(&a, &e); + rhs->sub(&b, &f); + + auto& reveal_e = temp[5]; + auto& reveal_f = temp[6]; + + e.reveal(reveal_e.get()); + f.reveal(reveal_f.get()); + + FixedPointTensor ft_temp(temp[7].get()); + fixed64_tensor_mult(reveal_f.get(), a.share(), ret->mutable_share()); + fixed64_tensor_mult(reveal_e.get(), b.share(), ft_temp.mutable_share()); + + ret->add(&ft_temp, ret); + ret->add(&c, ret); + + if(party() == 1) { + auto& ef = temp[7]; + ef->scaling_factor() = N; + fixed64_tensor_mult(reveal_e.get(), reveal_f.get(), ef.get()); + ret->share()->add(ef.get(), ret->mutable_share()); + } +} + +/* +template +void FixedPointTensor::truncate(const FixedPointTensor* op, + FixedPointTensor* ret, + size_t scaling_factor) { + if (scaling_factor == 0) { + op->share(0)->copy(ret->mutable_share(0)); + op->share(1)->copy(ret->mutable_share(1)); + } + // implement ABY3's truncate1 algorithm + if (party() == 0) { + // party0 + op->_share[0]->rshift(scaling_factor, ret->_share[0]); + privc_ctx()->network()->template recv(1, *(ret->_share[1])); + + } else if (party() == 1) { + // party1 + auto r_12 = tensor_factory()->template create(op->shape()); + privc_ctx()->template gen_random(*r_12.get(), true); + + op->_share[0]->add(op->_share[1], ret->_share[0]); + // trunc from [SecureML, Thm.1] + ret->_share[0]->negative(ret->_share[0]); + ret->_share[0]->rshift(scaling_factor, ret->_share[0]); + ret->_share[0]->negative(ret->_share[0]); + ret->_share[0]->sub(r_12.get(), ret->_share[0]); + + privc_ctx()->network()->template send(0, *(ret->_share[0])); + r_12->copy(ret->_share[1]); + + } else { + // party2 + op->_share[1]->rshift(scaling_factor, ret->_share[1]); + + auto r_21 = tensor_factory()->template create(op->shape()); + privc_ctx()->template gen_random(*r_21.get(), false); + + r_21->copy(ret->_share[0]); + } + + return; +} + +// Protocol. `truncate3` +// P2 randomly generates r' \in (-2^62, 2^62), randomly generates r'_0, r_0, r_1 in Z_{2^64}, +// P2 compute r'_1 = r' - r'_0, r_2 = r'/2^N - r_0 - r_1, let x2 = r_2 +// P2 send r_0, r'_0 to P0, send r_1, r'_1 to P1 +// P1 and P0 execute "reveal x - r' to P1" +// P1 compute x1 = (x - r') / 2^N + r_1 +// P0 set x0 = r_0 +// P0, P1, P2 invoke reshare() with inputs x0, x1, x2 respectively. +template +void FixedPointTensor::truncate3(const FixedPointTensor* op, + FixedPointTensor* ret, + size_t scaling_factor) { + if (scaling_factor == 0) { + op->share(0)->copy(ret->mutable_share(0)); + op->share(1)->copy(ret->mutable_share(1)); + return; + } + std::vector>> temp; + if (party() == 2) { + for (int i = 0; i < 7; ++i) { + temp.emplace_back( + tensor_factory()->template create(op->shape())); + } + // r', contraint in (-2^62, 2^62) + // notice : when r' is contrainted in (-2^62, 2^62), + // the SD (statistical distance) of x - r' between this + // and r' in Z_{2^64} is equal to |X| / (2^63 + |X|) + // according to http://yuyu.hk/files/ho2.pdf + privc_ctx()->template gen_random_private(*temp[0]); + int64_t contraint_upper = ~((uint64_t) 1 << 62); + int64_t contraint_low = (uint64_t) 1 << 62; + std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), + [&contraint_upper, &contraint_low] (T& a) { + // contraint -2^62 < a < 2^62 + if (a >= 0) { + a &= contraint_upper; + } else { + a |= contraint_low; + } + }); + + //r'_0, r'_1 + privc_ctx()->template gen_random_private(*temp[1]); + temp[0]->sub(temp[1].get(), temp[2].get()); + // r, r_0, r_1 + temp[0]->rshift(scaling_factor, temp[3].get()); + privc_ctx()->template gen_random_private(*temp[4]); + privc_ctx()->template gen_random_private(*temp[5]); + // r_2 + temp[3]->sub(temp[4].get(), temp[6].get()); + temp[6]->sub(temp[5].get(), temp[6].get()); + + privc_ctx()->network()->template send(1, *temp[2]); + privc_ctx()->network()->template send(1, *temp[5]); + privc_ctx()->network()->template send(0, *temp[1]); + privc_ctx()->network()->template send(0, *temp[4]); + + temp[6]->copy(ret->mutable_share(0)); + + } else if (party() == 1) { + for (int i = 0; i < 4; ++i) { + temp.emplace_back( + tensor_factory()->template create(op->shape())); + } + // r'_1, r_1 + privc_ctx()->network()->template recv(2, *temp[0]); + privc_ctx()->network()->template recv(2, *temp[1]); + // recv x0 - r'_0 from party 0 + privc_ctx()->network()->template recv(0, *temp[2]); + //reveal x - r' to party 1 + op->share(0)->add(op->share(1), temp[3].get()); + temp[3]->add(temp[2].get(), temp[3].get()); + temp[3]->sub(temp[0].get(), temp[3].get()); + // truncate x-r' + temp[3]->rshift(scaling_factor, temp[3].get()); + + temp[3]->add(temp[1].get(), ret->mutable_share(0)); + } else { + for (int i = 0; i < 3; ++i) { + temp.emplace_back( + tensor_factory()->template create(op->shape())); + } + // r'_0, r_0 + privc_ctx()->network()->template recv(2, *temp[0]); + privc_ctx()->network()->template recv(2, *temp[1]); + //send x0 - r'_0 to party 1 + op->share(0)->sub(temp[0].get(), temp[2].get()); + privc_ctx()->network()->template send(1, *temp[2]); + temp[1]->copy(ret->mutable_share(0)); + } + + reshare(ret->share(0), ret->mutable_share(1)); + + // compensation for carry in + auto tensor_carry_in = tensor_factory()->template create(ret->shape()); + assign_to_tensor(tensor_carry_in.get(), (T)1); + tensor_carry_in->scaling_factor() = N; + ret->add(tensor_carry_in.get(), ret); +} + +template +template +void FixedPointTensor::mul_trunc(const FixedPointTensor* lhs, + const FixedPointTensor* rhs, + FixedPointTensor* ret, + MulFunc mul_func) { + + auto r_zero = tensor_factory()->template create(ret->shape()); + privc_ctx()->gen_zero_sharing_arithmetic(*r_zero.get()); + + // temp = _share[0]->mul(rhs->_share[0]) + + // _share[0]->mul(rhs->_share[1]) + + // _share[1]->mul(rhs->_share[0]) + + // r_zero + auto temp = tensor_factory()->template create(ret->shape()); + auto temp1 = tensor_factory()->template create(ret->shape()); + + // use mul_func to fit both element_wise mul and mat mul + (lhs->share(0)->*mul_func)(rhs->share(0), temp.get()); + (lhs->share(0)->*mul_func)(rhs->share(1), temp1.get()); + temp1->add(temp.get(), temp1.get()); + + (lhs->share(1)->*mul_func)(rhs->share(0), temp.get()); + temp1->add(r_zero.get(), temp1.get()); + temp->add(temp1.get(), temp.get()); + + auto temp2 = tensor_factory()->template create(ret->shape()); + auto temp3 = tensor_factory()->template create(ret->shape()); + + TensorAdapter* temp_array[2] = {temp2.get(), temp3.get()}; + + std::shared_ptr> ret_no_trunc = + std::make_shared>(temp_array); + + temp->copy(ret_no_trunc->_share[0]); + reshare(temp.get(), ret_no_trunc->_share[1]); + + truncate3(ret_no_trunc.get(), ret, N); +} + +template +void FixedPointTensor::mul(const TensorAdapter* rhs, + FixedPointTensor* ret) const { + // PADDLE_ENFORCE_EQ(N, rhs->scaling_factor(), + // "no match scaling factor"); + auto temp0 = tensor_factory()->template create(this->shape()); + auto temp1 = tensor_factory()->template create(this->shape()); + std::shared_ptr> temp = + std::make_shared>(temp0.get(), temp1.get()); + + _share[0]->mul(rhs, temp->_share[0]); + _share[1]->mul(rhs, temp->_share[1]); + truncate3(temp.get(), ret, rhs->scaling_factor()); +} + +template +void FixedPointTensor::sum(FixedPointTensor* ret) const { + PADDLE_ENFORCE_EQ(ret->numel(), 1, "output size should be 1."); + T sum1 = (T) 0; + T sum2 = (T) 0; + T* iter_0 = _share[0]->data(); + T* iter_1 = _share[1]->data(); + for (int i = 0; i < this->numel(); ++i) { + sum1 += *(iter_0 + i); + sum2 += *(iter_1 + i); + } + assign_to_tensor(ret->_share[0], sum1); + assign_to_tensor(ret->_share[1], sum2); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::dot_mul(const CTensor* rhs, + FixedPointTensor* ret) const { + PADDLE_ENFORCE_EQ(ret->numel(), 1, "output size should be 1."); + + auto temp0 = tensor_factory()->template create(this->shape()); + auto temp1 = tensor_factory()->template create(this->shape()); + std::shared_ptr> temp = + std::make_shared>(temp0.get(), temp1.get()); + this->mul(rhs, temp.get()); + temp->sum(ret); +} + +template +void FixedPointTensor::mat_mul(const FixedPointTensor* rhs, + FixedPointTensor* ret) const { + mul_trunc(this, rhs, ret, &TensorAdapter::mat_mul); +} + +template +void FixedPointTensor::mat_mul(const TensorAdapter* rhs, + FixedPointTensor* ret) const { + _share[0]->mat_mul(rhs, ret->_share[0]); + _share[1]->mat_mul(rhs, ret->_share[1]); + truncate3(ret, ret, rhs->scaling_factor()); +} + +template< typename T, size_t N> +void FixedPointTensor::div(const TensorAdapter* rhs, + FixedPointTensor* ret) const { + PADDLE_ENFORCE_EQ(N, rhs->scaling_factor(), + "no match scaling factor"); + + auto temp = tensor_factory()->template create(this->shape()); + + double scale = std::pow(2, rhs->scaling_factor()); + auto inverse = [scale](T d) -> T { + return 1.0 * scale / d * scale; }; + std::transform(rhs->data(), rhs->data() + rhs->numel(), + temp->data(), inverse); + temp->scaling_factor() = rhs->scaling_factor(); + + this->mul(temp.get(), ret); +} + +template +void FixedPointTensor::div(const FixedPointTensor* rhs, + FixedPointTensor* ret, + size_t iter, double x0) const { + auto temp0 = tensor_factory()->template create(ret->shape()); + auto temp1 = tensor_factory()->template create(ret->shape()); + std::shared_ptr> temp = + std::make_shared>(temp0.get(), temp1.get()); + reciprocal(rhs, temp.get(), iter, x0); + this->mul(temp.get(), ret); +} + +template +void FixedPointTensor::exp(FixedPointTensor* ret, + size_t iter) const { + // exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n + // where n = 2^ite + auto pow_iter = tensor_factory()->template create(this->shape()); + assign_to_tensor(pow_iter.get(), (T) (pow(2, N -iter))); + pow_iter->scaling_factor() = N; + + auto tensor_one = tensor_factory()->template create(this->shape()); + assign_to_tensor(tensor_one.get(), (T) 1 << N); + tensor_one->scaling_factor() = N; + + this->mul(pow_iter.get(), ret); + + ret->add(tensor_one.get(), ret); + + for (int i = 0; i < iter; ++i) { + ret->mul(ret, ret); + } +} + +template< typename T, size_t N> +void FixedPointTensor::relu(FixedPointTensor* ret) const { + //utilize polynomial_piecewise + // break_point = {0}, coeff[0] = {0, 0}, coeff[1] = {0, 1} + // break_point.shape = {1, this->shape}, coeff.shape = {2, 2, this->shape} + + auto shape_ = shape(); + //construct break_point + auto b_shape = shape_; + b_shape.insert(b_shape.begin(), 1); + + auto break_point = tensor_factory()->template create(b_shape); + + T* b_ptr = break_point->data(); + for (size_t i = 0; i < break_point->numel(); ++i) { + b_ptr[i] = 0; + } + break_point->scaling_factor() = N; + + //contruct coeff + std::vector c_shape = {2, 2}; + c_shape.insert(c_shape.end(), shape_.begin(), shape_.end()); + + auto coeff = tensor_factory()->template create(c_shape); + + T* c_ptr = coeff->data(); + + for (size_t i = 0; i < 3 * this->numel(); ++i) { + c_ptr[i] = 0; + } + for (size_t i = 3 * this->numel(); i < 4 * this->numel(); ++i) { + c_ptr[i] = (T) 1 << N; + } + coeff->scaling_factor() = N; + + this->polynomial_piecewise(coeff.get(), break_point.get(), ret); +} + +template< typename T, size_t N> +void FixedPointTensor::relu_with_derivative( + FixedPointTensor* ret, BooleanTensor* derivative) const { + + auto shape_ = shape(); + auto zero = tensor_factory()->template create(shape_); + + assign_to_tensor(zero.get(), (T)0); + zero->scaling_factor() = N; + + auto tmp0 = tensor_factory()->template create(shape_); + auto tmp1 = tensor_factory()->template create(shape_); + + BooleanTensor der(tmp0.get(), tmp1.get()); + + gt(zero.get(), &der); + + der.mul(this, ret); + + if (derivative) { + der.share(0)->copy(derivative->share(0)); + der.share(1)->copy(derivative->share(1)); + } +} + +template< typename T, size_t N> +void FixedPointTensor::sigmoid_chebyshev(FixedPointTensor* ret) const { + //utilize Chebyshev polynomial approximation + // more accurate in small range, such as [-4, 4] + auto shape = ret->shape(); + std::vector shape_ = shape; + shape_.insert(shape_.begin(), 10); + auto numel = ret->numel(); + auto coeff = tensor_factory()->template create(shape_); + std::vector w; + w.resize(10, 0.0f); + w[0] = 0.5; + w[1] = 0.2159198015; + w[3] = -0.0082176259; + w[5] = 0.0001825597; + w[7] = -0.0000018848; + w[9] = 0.0000000072; + for (int i = 0; i < 10; ++i) { + for (int j = 0; j < numel; ++j) { + *(coeff->data() + i * numel + j) = (T) (w[i] * pow(2, N)); + } + } + coeff->scaling_factor() = N; + polynomial(coeff.get(), ret); +} + +template< typename T, size_t N> +void FixedPointTensor::sigmoid(FixedPointTensor* ret) const { + //utilize polynomial_piecewise + // break_point = {-2.5, 2.5} + // coeff[0] = {10^-4, 0}, coeff[1] = {0.5, 0.17} + // coeff[2] = {1 - 10^-4, 0} + // break_point.shape = {2, this->shape}, coeff.shape = {3, 2, this->shape} + + //construct break_point + auto shape_ = shape(); + //construct break_point + auto b_shape = shape_; + b_shape.insert(b_shape.begin(), 2); + + auto break_point = tensor_factory()->template create(b_shape); + + T* b_ptr = break_point->data(); + for (size_t i = 0; i < break_point->numel(); ++i) { + b_ptr[i] = 0; + } + for (size_t i = 0; i < break_point->numel() / 2; ++i) { + b_ptr[i] = (T) (-2.5 * pow(2, N)); + } + for (size_t i = break_point->numel() / 2; i < break_point->numel(); ++i) { + b_ptr[i] = (T) (2.5 * pow(2, N)); + } + break_point->scaling_factor() = N; + + //contruct coeff + std::vector c_shape = {3, 2}; + c_shape.insert(c_shape.end(), shape_.begin(), shape_.end()); + + auto coeff = tensor_factory()->template create(c_shape); + + T* c_ptr = coeff->data(); + + size_t numel = this->numel(); + double scale = std::pow(2, N); + for (size_t i = 0; i < numel; ++i) { + c_ptr[i] = 0.0001 * scale; + c_ptr[i + numel] = 0; + c_ptr[i + 2 * numel] = 0.5 * scale; + c_ptr[i + 3 * numel] = 0.17 * scale; + c_ptr[i + 4 * numel] = (1 - 0.0001) * scale; + c_ptr[i + 5 * numel] = 0; + } + coeff->scaling_factor() = N; + + this->polynomial_piecewise(coeff.get(), break_point.get(), ret); +} + +template< typename T, size_t N> +void FixedPointTensor::sigmoid_enhanced(FixedPointTensor* ret) const { + //utilize polynomial_piecewise + // break_point = {-5, -2.5, 2.5, 5} + // coeff[0] = {10^-4, 0}, coeff[1] = {0.145, 0.02776} + // coeff[2] = {0.5, 0.17}, coeff[3] = {0.85498, 0.02776}, coeff[4] = {0.9999, 0} + // break_point.shape = {4, this->shape}, coeff.shape = {5, 2, this->shape} + + //construct break_point + auto shape_ = shape(); + //construct break_point + auto b_shape = shape_; + b_shape.insert(b_shape.begin(), 4); + + auto break_point = tensor_factory()->template create(b_shape); + + T* b_ptr = break_point->data(); + auto numel = ret->numel(); + double scale = std::pow(2, N); + for (size_t i = 0; i < numel; ++i) { + b_ptr[i] = (T) (-5 * scale); + b_ptr[i + numel] = (T) (-2.5 * scale); + b_ptr[i + 2 * numel] = (T) (2.5 * scale); + b_ptr[i + 3 * numel] = (T) (5 * scale); + } + break_point->scaling_factor() = N; + + //contruct coeff + std::vector c_shape = {5, 2}; + c_shape.insert(c_shape.end(), shape_.begin(), shape_.end()); + auto coeff = tensor_factory()->template create(c_shape); + T* c_ptr = coeff->data(); + for (size_t i = 0; i < numel; ++i) { + c_ptr[i] = 0.0001 * scale; + c_ptr[i + numel] = 0; + c_ptr[i + 2 * numel] = 0.145 * scale; + c_ptr[i + 3 * numel] = 0.02776 * scale; + c_ptr[i + 4 * numel] = 0.5 * scale; + c_ptr[i + 5 * numel] = 0.17 * scale; + c_ptr[i + 6 * numel] = 0.85498 * scale; + c_ptr[i + 7 * numel] = 0.02776 * scale; + c_ptr[i + 8 * numel] = 0.9999 * scale; + c_ptr[i + 9 * numel] = 0 * scale; + } + coeff->scaling_factor() = N; + + this->polynomial_piecewise(coeff.get(), break_point.get(), ret); +} + +template< typename T, size_t N> +void FixedPointTensor::softmax(FixedPointTensor* ret, + bool use_relu, bool use_long_div) const { + // softmax axis = -1 + const size_t col = *(shape().end() - 1); + const size_t row = numel() / col; + + std::vector>> temp; + // 11 for allocating temp tensor + for (size_t i = 0; i < 11; ++i) { + temp.emplace_back( + tensor_factory()->template create()); + } + + temp[0]->reshape({row, col}); + temp[1]->reshape({row, col}); + FixedPointTensor x(temp[0].get(), temp[1].get()); + + if (!use_relu) { + temp[2]->reshape({col, row}); + temp[3]->reshape({col, row}); + + temp[4]->reshape({1, row}); + temp[5]->reshape({1, row}); + } + FixedPointTensor x_t(temp[2].get(), temp[3].get()); + FixedPointTensor max_x_t(temp[4].get(), temp[5].get()); + + temp[6]->reshape({row, 1}); + temp[7]->reshape({row, 1}); + FixedPointTensor max_x(temp[6].get(), temp[7].get()); + + temp[8]->reshape({row, col}); + temp[9]->reshape({row, col}); + FixedPointTensor max_x_broadcast(temp[8].get(), temp[9].get()); + + temp[10]->reshape({row, col}); + auto exp_lower_bound = temp[10].get(); + + auto transpose = [](const TensorAdapter* in, TensorAdapter* out) { + // suppose input dims = 2 + const size_t col = in->shape()[1]; + const size_t row = in->shape()[0]; + const size_t numel = in->numel(); + + for (size_t k = 0; k < numel; ++k) { + size_t i = k / row; + size_t j = k % row; + out->data()[k] = in->data()[j * col + i]; + } + }; + + auto broadcast = [](const TensorAdapter* in, TensorAdapter* out) { + // suppose input dims = 2 + // in shape = [row, 1] + const size_t col = out->shape()[1]; + const size_t row = out->shape()[0]; + for (size_t k = 0; k < out->numel(); ++k) { + size_t i = k / col; + out->data()[k] = in->data()[i]; + } + }; + + share(0)->copy(x.mutable_share(0)); + share(1)->copy(x.mutable_share(1)); + + if (use_relu) { + + x.relu(&x); + + } else { // use exp + transpose(x.share(0), x_t.mutable_share(0)); + transpose(x.share(1), x_t.mutable_share(1)); + + // x = max(input - max(input), exp_lower_bound) + x_t.max_pooling(&max_x_t); + + transpose(max_x_t.share(0), max_x.mutable_share(0)); + transpose(max_x_t.share(1), max_x.mutable_share(1)); + + broadcast(max_x.share(0), max_x_broadcast.mutable_share(0)); + broadcast(max_x.share(1), max_x_broadcast.mutable_share(1)); + + x.sub(&max_x_broadcast, &x); + + // n = 64, see exp + assign_to_tensor(exp_lower_bound, (T)(-64 * (1 << N))); + exp_lower_bound->scaling_factor() = N; + + x.sub(exp_lower_bound, &x); + x.relu(&x); + x.add(exp_lower_bound, &x); + + x.exp(&x); + } + + // reuse max_x as sum + reduce(&x, &max_x); + + if (!use_long_div) { // invert sum by Newton's method + // divisor range = [1/col, 1.0] + // TODO: find better iter num & init val + reciprocal(&max_x, &max_x, 16, 0.5 / col); + } + + broadcast(max_x.share(0), max_x_broadcast.mutable_share(0)); + broadcast(max_x.share(1), max_x_broadcast.mutable_share(1)); + + if (use_long_div) { + x.long_div(&max_x_broadcast, &x, 1); + } else { + x.mul(&max_x_broadcast, &x); + } + + x.share(0)->copy(ret->mutable_share(0)); + x.share(1)->copy(ret->mutable_share(1)); +} + +template +void FixedPointTensor::long_div(const FixedPointTensor* rhs, + FixedPointTensor* ret, + size_t int_len) const { + std::vector>> temp; + for (int i = 0; i < 16; ++i) { + temp.emplace_back( + tensor_factory()->template create(ret->shape())); + } + + BooleanTensor sign_lhs(temp[0].get(), temp[1].get()); + BooleanTensor sign_rhs(temp[2].get(), temp[3].get()); + BooleanTensor sign_ret(temp[4].get(), temp[5].get()); + FixedPointTensor abs_lhs(temp[6].get(), temp[7].get()); + FixedPointTensor abs_rhs(temp[8].get(), temp[9].get()); + FixedPointTensor sub_rhs(temp[10].get(), temp[11].get()); + BooleanTensor cmp_res(temp[12].get(), temp[13].get()); + BooleanTensor cmp_res_all(temp[14].get(), temp[15].get()); + + assign_to_tensor(cmp_res_all.share(0), (T)0); + assign_to_tensor(cmp_res_all.share(1), (T)0); + + const size_t msb = sizeof(T) * 8 - 1; + sign_lhs.bit_extract(msb, this); + sign_rhs.bit_extract(msb, rhs); + sign_lhs.bitwise_xor(&sign_rhs, &sign_ret); + + auto lshift = [] (const FixedPointTensor* in, + size_t rhs, + FixedPointTensor* out) { + in->share(0)->lshift(rhs, out->mutable_share(0)); + in->share(1)->lshift(rhs, out->mutable_share(1)); + }; + + // abs = val - 2 * sign * val + auto abs = [lshift] (const FixedPointTensor* in, + const BooleanTensor* sign, + FixedPointTensor* out) { + lshift(in, 1, out); + sign->mul(out, out); + in->sub(out, out); + }; + + auto out0 = tensor_factory()->template create(ret->shape()); + + abs(this, &sign_lhs, &abs_lhs); + + abs(rhs, &sign_rhs, &abs_rhs); + + + for (ssize_t i = int_len - 1; i >= 0; --i) { + lshift(&abs_rhs, i, &sub_rhs); + + + abs_lhs.gt(&sub_rhs, &cmp_res); + + + cmp_res.mul(&sub_rhs, &sub_rhs); + cmp_res.lshift(N + i, &cmp_res); + abs_lhs.sub(&sub_rhs, &abs_lhs); + cmp_res.bitwise_xor(&cmp_res_all, &cmp_res_all); + + } + + for (size_t i = 1; i <= N; ++i) { + truncate3(&abs_rhs, &sub_rhs, i); + abs_lhs.gt(&sub_rhs, &cmp_res); + cmp_res.mul(&sub_rhs, &sub_rhs); + cmp_res.lshift(N - i, &cmp_res); + abs_lhs.sub(&sub_rhs, &abs_lhs); + cmp_res.bitwise_xor(&cmp_res_all, &cmp_res_all); + } + + // use abs_lhs as buffer + cmp_res_all.b2a(&abs_lhs); + + abs(&abs_lhs, &sign_ret, ret); +} + +// reduce last dim +template +void FixedPointTensor::reduce(FixedPointTensor* input, + FixedPointTensor* ret) { + //enfoce shape: input->shape[0 ... (n-2)] == ret shape + auto& shape = input->shape(); + size_t ite_size = shape[shape.size() - 1]; + + T* ret_begin_ptr_0 = ret->_share[0]->data(); + T* ret_begin_ptr_1 = ret->_share[1]->data(); + + T* input_begin_ptr_0 = input->_share[0]->data(); + T* input_begin_ptr_1 = input->_share[1]->data(); + + for (int j = 0; j < ret->numel(); ++j) { + *(ret_begin_ptr_0 + j) = *(input_begin_ptr_0 + j * ite_size); + *(ret_begin_ptr_1 + j) = *(input_begin_ptr_1 + j * ite_size); + for (int i = 1; i < ite_size; ++i) { + *(ret_begin_ptr_0 + j) += + *(input_begin_ptr_0 + j * ite_size + i); + *(ret_begin_ptr_1 + j) += + *(input_begin_ptr_1 + j * ite_size + i); + } + } +} + +template< typename T, size_t N> +void FixedPointTensor::polynomial(const TensorAdapter* coeff, + FixedPointTensor* ret) const { + + // e.g., x.shape = {2, 3}, coeff.shape = {n, 2, 3} (n: polynomial power) + + //TODO: improve performance: [ABY3] + std::vector>> temp; + for (int i = 0; i < 7; ++i) { + temp.emplace_back( + tensor_factory()->template create(this->shape())); + } + std::shared_ptr> x_pow_i = + std::make_shared>( + temp[0].get(), temp[1].get()); + std::shared_ptr> temp_fixed = + std::make_shared>( + temp[2].get(), temp[3].get()); + std::shared_ptr> result = + std::make_shared>( + temp[5].get(), temp[6].get()); + assign_to_tensor(result->_share[0], (T) 0); + assign_to_tensor(result->_share[1], (T) 0); + + //x_pow_i.get() = 1; + assign_to_tensor(x_pow_i.get()->_share[0], (T) 0); + assign_to_tensor(x_pow_i.get()->_share[1], (T) 0); + assign_to_tensor(temp[4].get(), (T) 1 << N); + temp[4]->scaling_factor() = N; + x_pow_i->add(temp[4].get(), x_pow_i.get()); + + for (int i = 0; i < coeff->shape()[0]; ++i) { + auto t = tensor_factory()->template create(); + coeff->slice(i, i + 1, t.get()); + auto t_shape = t->shape(); + // remove leading 1 + t_shape.erase(t_shape.begin()); + t->reshape(t_shape); + x_pow_i->mul(t.get(), temp_fixed.get()); + result->add(temp_fixed.get(), result.get()); + x_pow_i->mul(this, x_pow_i.get()); + } + result->share(0)->copy(ret->mutable_share(0)); + result->share(1)->copy(ret->mutable_share(1)); +} + +template< typename T, size_t N> +void FixedPointTensor::polynomial_piecewise( + const TensorAdapter* coeff, + const TensorAdapter* break_point, + FixedPointTensor* ret) const { + + // e.g., x.shape = {2, 3}, + // break_point.shape = {k, 2, 3} (k: num of break point) + // coeff.shape = {k + 1, n, 2, 3} (n: poly power) + + // copy ret + auto ret_cpy_s0 = tensor_factory()->create_int64_t(ret->share(0)->shape()); + ret->share(0)->copy(ret_cpy_s0.get()); + auto ret_cpy_s1 = tensor_factory()->create_int64_t(ret->share(1)->shape()); + ret->share(1)->copy(ret_cpy_s1.get()); + std::shared_ptr> ret_cpy{new FixedPointTensor(ret_cpy_s0.get(), ret_cpy_s1.get())}; + + std::vector>> msb; + + int len_break_point = break_point->shape()[0]; + int len_coeff = coeff->shape()[0]; + + //number of temp tensor used + int temp_total = 4 * len_break_point + 2 + + 2 * (len_break_point - 1) + 2 + 4 * len_coeff; + std::vector>> temp; + for (int i = 0; i < temp_total; ++i) { + temp.emplace_back(tensor_factory()-> + template create(this->shape())); + } + int temp_index = 0; + + // std::vector>> paddle_t_break; + std::vector>> temp1; + + for (int i = 0; i < break_point->shape()[0]; ++i) { + // msb[i] = msb(x - break_point[i]) + auto t_break = tensor_factory()->template create(); + break_point->slice(i, i + 1, t_break.get()); + + auto t_shape = t_break->shape(); + // remove leading 1 + t_shape.erase(t_shape.begin()); + t_break->reshape(t_shape); + + temp1.emplace_back( + std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + this->sub(t_break.get(), temp1[i].get()); + msb.emplace_back(std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + msb[i]->bit_extract(sizeof(T) * 8 - 1, temp1[i].get()); + } + + // b[0] = msb[0], b[i + 1] = ~ msb[i] & msb[i + 1] + std::vector>> b; + b.emplace_back(std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + b[0] = msb[0]; + + for (int i = 0; i < len_break_point - 1; ++i) { + b.emplace_back(std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + + msb[i]->bitwise_not(b[i + 1].get()); + b[i + 1]->bitwise_and(msb[i + 1].get(), b[i + 1].get()); + } + + b.emplace_back(std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + msb[len_break_point - 1]->bitwise_not(b[len_break_point].get()); + + // ret += b[i].mul(polynomial(coeff[i])) + std::vector>> temp_fixed; + std::vector>> temp_fixed1; + + assign_to_tensor(ret_cpy->_share[0], (T) 0); + assign_to_tensor(ret_cpy->_share[1], (T) 0); + + for (int i = 0; i < len_coeff; ++i) { + temp_fixed.emplace_back( + std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + temp_fixed1.emplace_back( + std::make_shared>( + temp[temp_index++].get(), + temp[temp_index++].get())); + auto t = tensor_factory()->template create(); + coeff->slice(i, i + 1, t.get()); + auto t_shape = t->shape(); + // remove leading 1 + t_shape.erase(t_shape.begin()); + t->reshape(t_shape);; + this->polynomial(t.get(), temp_fixed[i].get()); + b[i]->bit_extract(0, b[i].get()); + b[i]->mul(temp_fixed[i].get(), temp_fixed1[i].get()); + ret_cpy->add(temp_fixed1[i].get(), ret_cpy.get()); + } + ret_cpy->share(0)->copy(ret->mutable_share(0)); + ret_cpy->share(1)->copy(ret->mutable_share(1)); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::lt(const CTensor* rhs, + BooleanTensor* ret) const { + + std::vector>> temp; + for (int i = 0; i < 2; ++i) { + temp.emplace_back( + tensor_factory()->template create(this->shape())); + } + std::shared_ptr> sub_result = + std::make_shared>( + temp[0].get(), temp[1].get()); + this->sub(rhs, sub_result.get()); + ret->bit_extract(sizeof(T) * 8 - 1, sub_result.get()); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::leq(const CTensor* rhs, + BooleanTensor* ret) const { + + this->gt(rhs, ret); + auto tensor_one = tensor_factory()-> + template create(this->shape()); + + assign_to_tensor(tensor_one.get(), (T) 1); + ret->bitwise_xor(tensor_one.get(), ret); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::gt(const CTensor* rhs, + BooleanTensor* ret) const { + + std::vector>> temp; + for (int i = 0; i < 2; ++i) { + temp.emplace_back( + tensor_factory()->template create(this->shape())); + } + std::shared_ptr> sub_result = + std::make_shared>( + temp[0].get(), temp[1].get()); + this->sub(rhs, sub_result.get()); + sub_result->negative(sub_result.get()); + ret->template bit_extract(sizeof(T) * 8 - 1, sub_result.get()); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::geq(const CTensor* rhs, + BooleanTensor* ret) const { + + this->lt(rhs, ret); + auto tensor_one = tensor_factory()-> + template create(this->shape()); + + assign_to_tensor(tensor_one.get(), (T) 1); + ret->bitwise_xor(tensor_one.get(), ret); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::eq(const CTensor* rhs, + BooleanTensor* ret) const { + + this->neq(rhs, ret); + auto tensor_one = tensor_factory()->template create(this->shape()); + assign_to_tensor(tensor_one.get(), (T) 1); + ret->bitwise_xor(tensor_one.get(), ret); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::neq(const CTensor* rhs, + BooleanTensor* ret) const { + std::vector>> temp; + for (int i = 0; i < 4; i ++) { + temp.emplace_back(tensor_factory()-> + template create(this->shape())); + } + std::shared_ptr> lt = + std::make_shared>( + temp[0].get(), temp[1].get()); + std::shared_ptr> gt = + std::make_shared>( + temp[2].get(), temp[3].get()); + + this->lt(rhs, lt.get()); + this->gt(rhs, gt.get()); + lt->bitwise_or(gt.get(), ret); +} + +template +void FixedPointTensor::reciprocal(const FixedPointTensor* op, FixedPointTensor* ret, + size_t iter, double x0) { + auto temp0 = tensor_factory()->template create(ret->shape()); + auto temp1 = tensor_factory()->template create(ret->shape()); + auto temp2 = tensor_factory()->template create(ret->shape()); + auto temp3 = tensor_factory()->template create(ret->shape()); + std::shared_ptr> result = + std::make_shared>(temp0.get(), temp1.get()); + std::shared_ptr> x_copy = + std::make_shared>(temp2.get(), temp3.get()); + assign_to_tensor(result->mutable_share(0), (T) 0); + assign_to_tensor(result->mutable_share(1), (T) 0); + auto tensor_x0 = tensor_factory()->template create(op->shape()); + assign_to_tensor(tensor_x0.get(), (T)(x0 * pow(2, N))); + tensor_x0->scaling_factor() = N; + result->add(tensor_x0.get(), result.get()); + auto tensor_2 = tensor_factory()->template create(op->shape()); + tensor_2->scaling_factor() = N; + assign_to_tensor(tensor_2.get(), (T)(2 << N)); + for (int i = 0; i < iter; ++i) { + result->share(0)->copy(x_copy->mutable_share(0)); + result->share(1)->copy(x_copy->mutable_share(1)); + auto res_ptr = result.get(); + op->mul(res_ptr, res_ptr); + result->negative(res_ptr); + result->add(tensor_2.get(), res_ptr); + x_copy->mul(res_ptr, res_ptr); + } + result->share(0)->copy(ret->mutable_share(0)); + result->share(1)->copy(ret->mutable_share(1)); +} + +template +void FixedPointTensor::inverse_square_root(FixedPointTensor* ret, + size_t iter, + double x0) const { + inverse_square_root(this, ret, iter, x0); +} + +// Newton's method, var naming from Quake III Arena: Q_rsqrt +// float threehalfs = 1.5F; +// x2 = number * 0.5F; +// y = x0; // since 0x5f3759df does not fit fixed-point arithmetic +// y = y * ( threehalfs - ( x2 * y * y ) ); // iteration of Newton's method +template +void FixedPointTensor::inverse_square_root(const FixedPointTensor* op, + FixedPointTensor* ret, + size_t iter, + double x0) { + std::vector>> temp; + for (int i = 0; i < 7; ++i) { + temp.emplace_back( + tensor_factory()->template create(op->shape())); + } + std::shared_ptr> y = + std::make_shared>(temp[0].get(), temp[1].get()); + std::shared_ptr> x2 = + std::make_shared>(temp[2].get(), temp[3].get()); + // x2 = 0.5 * op + truncate3(op, x2.get(), 1); + + assign_to_tensor(y->mutable_share(0), (T)(x0 * pow(2, N))); + assign_to_tensor(y->mutable_share(1), (T)(x0 * pow(2, N))); + + // threehalfs + temp[4]->scaling_factor() = N; + assign_to_tensor(temp[4].get(), T(1.5 * pow(2, N))); + + std::shared_ptr> y_copy = + std::make_shared>(temp[5].get(), temp[6].get()); + + for (int i = 0; i < iter; ++i) { + y->share(0)->copy(y_copy->mutable_share(0)); + y->share(1)->copy(y_copy->mutable_share(1)); + y->mul(y.get(), y.get()); + y->mul(x2.get(), y.get()); + y->negative(y.get()); + y->add(temp[4].get(), y.get()); + y_copy->mul(y.get(), y.get()); + } + y->share(0)->copy(ret->mutable_share(0)); + y->share(1)->copy(ret->mutable_share(1)); +} + +template +template class CTensor, + size_t... N1> +void FixedPointTensor::max(const CTensor* rhs, + FixedPointTensor* ret, + BooleanTensor* cmp) const { + // max = lhs + (rhs - lhs) if rhs > lhs else lhs + std::vector>> temp; + bool output_cmp = cmp != nullptr; + // if cmp is not null, store cmp results in cmp + // else, store them in tmp tensors + for (int i = 0; i < 2 + 2 * (!output_cmp); ++i) { + temp.emplace_back( + tensor_factory()->template create(this->shape())); + } + FixedPointTensor delta(temp[0].get(), temp[1].get()); + sub(rhs, &delta); + BooleanTensor sign; + if (output_cmp) { + sign = *cmp; + } else { + sign = BooleanTensor(temp[2].get(), temp[3].get()); + } + sign.template bit_extract(sizeof(T) * 8 - 1, &delta); + delta.negative(&delta); + sign.mul(&delta, &delta); + add(&delta, ret); +} + +template +void FixedPointTensor::max_pooling(FixedPointTensor* ret, + BooleanTensor* pos) const { + size_t k = shape()[0]; + std::vector>> tmp; + for (int i = 0; i < 4; ++i) { + tmp.emplace_back( + tensor_factory()->template create()); + } + + FixedPointTensor now(tmp[0].get(), tmp[1].get()); + BooleanTensor cmp(tmp[2].get(), tmp[3].get()); + auto cmp_ptr = pos ? &cmp : nullptr; + + share(0)->slice(0, 1, tmp[0].get()); + share(1)->slice(0, 1, tmp[1].get()); + + tmp[0]->copy(ret->mutable_share(0)); + tmp[1]->copy(ret->mutable_share(1)); + + if (pos) { + pos->share(0)->slice(0, 1, tmp[2].get()); + pos->share(1)->slice(0, 1, tmp[3].get()); + + // set init 1, slice_0 is larger than null + if (party() == 0 || party() == 2) { + size_t idx = 2 + (party() == 2); + assign_to_tensor(tmp[idx].get(), T(1)); + assign_to_tensor(tmp[5 - idx].get(), T(0)); + } else { + assign_to_tensor(tmp[2].get(), T(0)); + assign_to_tensor(tmp[3].get(), T(0)); + } + + } + + for (size_t i = 1; i < k; ++i) { + share(0)->slice(i, i + 1, tmp[0].get()); + share(1)->slice(i, i + 1, tmp[1].get()); + + if (pos) { + pos->share(0)->slice(i, i + 1, tmp[2].get()); + pos->share(1)->slice(i, i + 1, tmp[3].get()); + } + + ret->max(&now, ret, cmp_ptr); + + } + + if (pos) { + pos->onehot_from_cmp(); + } + +} +*/ +} // namespace privc diff --git a/core/privc/fixedpoint_tensor_test.cc b/core/privc/fixedpoint_tensor_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..60848f9365d1ff5f704037575399cefe500f6fa7 --- /dev/null +++ b/core/privc/fixedpoint_tensor_test.cc @@ -0,0 +1,478 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" + +#include "./privc_context.h" +#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" +#include "fixedpoint_tensor.h" +#include "core/privc/triplet_generator.h" +#include "core/privc3/paddle_tensor.h" + +namespace privc { + +using g_ctx_holder = paddle::mpc::ContextHolder; +using Fix64N32 = FixedPointTensor; +using AbstractContext = paddle::mpc::AbstractContext; + +class FixedTensorTest : public ::testing::Test { +public: + + paddle::platform::CPUDeviceContext _cpu_ctx; + std::shared_ptr _exec_ctx; + std::shared_ptr _mpc_ctx[2]; + std::shared_ptr _store; + std::thread _t[2]; + std::shared_ptr _s_tensor_factory; + + virtual ~FixedTensorTest() noexcept {} + + void SetUp() { + + paddle::framework::OperatorBase* op = nullptr; + paddle::framework::Scope scope; + paddle::framework::RuntimeContext ctx({}, {}); + // only device_ctx is needed + _exec_ctx = std::make_shared( + *op, scope, _cpu_ctx, ctx); + + _store = std::make_shared(); + + for (size_t i = 0; i < 2; ++i) { + _t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, this, i); + } + for (auto& ti : _t) { + ti.join(); + } + + for (size_t i = 0; i < 2; ++i) { + _t[i] = std::thread(&FixedTensorTest::init_triplet, this, i); + } + for (auto& ti : _t) { + ti.join(); + } + _s_tensor_factory = std::make_shared(&_cpu_ctx); + } + std::shared_ptr gen_network(size_t idx) { + return std::make_shared(idx, + "127.0.0.1", + 2, + "test_prefix_privc", + _store); + } + void gen_mpc_ctx(size_t idx) { + auto net = gen_network(idx); + net->init(); + _mpc_ctx[idx] = std::make_shared(idx, net); + } + + void init_triplet(size_t idx) { + std::shared_ptr> tripletor + = std::make_shared>(_mpc_ctx[idx]); + tripletor->init(); + std::dynamic_pointer_cast(_mpc_ctx[idx])->set_triplet_generator(tripletor); + } + + std::shared_ptr> gen(std::vector shape) { + return _s_tensor_factory->template create(shape); + } +}; + +std::shared_ptr> gen(std::vector shape) { + return g_ctx_holder::tensor_factory()->template create(shape); +} + +TEST_F(FixedTensorTest, share) { + std::vector shape = { 1 }; + std::shared_ptr> sl = gen(shape); + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + + TensorAdapter* output_share[2] = {ret[0].get(), ret[1].get()}; + sl->data()[0] = (int64_t)1 << SCALING_N; + sl->scaling_factor() = SCALING_N; + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + Fix64N32::share(sl.get(), output_share); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){}); + } + ); + for (auto &t: _t) { + t.join(); + } + + auto p = gen(shape); + output_share[0]->add(output_share[1], p.get()); + EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, reveal) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + + auto p0 = gen(shape); + auto p1 = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.reveal(p0.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.reveal(p1.get()); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(p0->data()[0], p1->data()[0]); + EXPECT_EQ(3, p0->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, addplain) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> sr = gen(shape); + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + // rhs = 3 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + sr->data()[0] = (int64_t)3 << SCALING_N; + + sr->scaling_factor() = SCALING_N; + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.add(sr.get(), &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.add(sr.get(), &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(6, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, addfixed) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> sr[2] = { gen(shape), gen(shape) }; + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + // rhs = 3 = 1 + 2 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + sr[0]->data()[0] = (int64_t)1 << SCALING_N; + sr[1]->data()[0] = (int64_t)2 << SCALING_N; + + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fr0(sr[0].get()); + Fix64N32 fr1(sr[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.add(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.add(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(6, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, subplain) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> sr = gen(shape); + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + // rhs = 2 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + sr->data()[0] = (int64_t)2 << SCALING_N; + + sr->scaling_factor() = SCALING_N; + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.sub(sr.get(), &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.sub(sr.get(), &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, subfixed) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> sr[2] = { gen(shape), gen(shape) }; + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + // rhs = 2 = 1 + 1 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + sr[0]->data()[0] = (int64_t)1 << SCALING_N; + sr[1]->data()[0] = (int64_t)1 << SCALING_N; + + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fr0(sr[0].get()); + Fix64N32 fr1(sr[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.sub(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.sub(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, negative) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 3 = 1 + 2 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)2 << SCALING_N; + + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.negative(&fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.negative(&fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ(-3, p->data()[0] / std::pow(2, SCALING_N)); +} + +TEST_F(FixedTensorTest, triplet) { + std::vector shape = { 1 }; + + auto shape_triplet = shape; + shape_triplet.insert(shape_triplet.begin(), 3); + + std::shared_ptr> ret[2] = {gen(shape_triplet), gen(shape_triplet)}; + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + std::dynamic_pointer_cast(_mpc_ctx[0]) + ->triplet_generator()->get_triplet(ret[0].get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + std::dynamic_pointer_cast(_mpc_ctx[1]) + ->triplet_generator()->get_triplet(ret[1].get()); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + + auto num_triplet = ret[0]->numel() / 3; + for (int i = 0; i < ret[0]->numel() / 3; ++i) { + auto ret0_ptr = ret[0]->data(); + auto ret1_ptr = ret[1]->data(); + + uint64_t a_idx = i; + uint64_t b_idx = num_triplet + i; + uint64_t c_idx = 2 * num_triplet + i; + int64_t c = fixed64_mult(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx)) + + fixed64_mult(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx)) + + fixed64_mult(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx)) + + fixed64_mult(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx)); + + EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), 20); + } +} + +TEST_F(FixedTensorTest, mulfixed) { + std::vector shape = { 1 }; + std::shared_ptr> sl[2] = { gen(shape), gen(shape) }; + std::shared_ptr> sr[2] = { gen(shape), gen(shape) }; + std::shared_ptr> ret[2] = { gen(shape), gen(shape) }; + // lhs = 2 = 1 + 1 + // rhs = 2 = 1 + 1 + sl[0]->data()[0] = (int64_t)1 << SCALING_N; + sl[1]->data()[0] = (int64_t)1 << SCALING_N; + sr[0]->data()[0] = (int64_t)1 << SCALING_N; + sr[1]->data()[0] = (int64_t)1 << SCALING_N; + + auto p = gen(shape); + + Fix64N32 fl0(sl[0].get()); + Fix64N32 fl1(sl[1].get()); + Fix64N32 fr0(sr[0].get()); + Fix64N32 fr1(sr[1].get()); + Fix64N32 fout0(ret[0].get()); + Fix64N32 fout1(ret[1].get()); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + fl0.mul(&fr0, &fout0); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + fl1.mul(&fr1, &fout1); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_NEAR(4, p->data()[0] / std::pow(2, SCALING_N), 0.00001); +} + +} // namespace privc diff --git a/core/privc/privc_context.cc b/core/privc/privc_context.cc new file mode 100644 index 0000000000000000000000000000000000000000..29e6bedc760bf0d5532c2ac804d8c13c60e3d59b --- /dev/null +++ b/core/privc/privc_context.cc @@ -0,0 +1,30 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#include +#include + +#include "core/privc/triplet_generator.h" +#include "core/privc/privc_context.h" +namespace privc { + +void PrivCContext::set_triplet_generator(std::shared_ptr>& tripletor) { + _tripletor = tripletor; +} + +std::shared_ptr> PrivCContext::triplet_generator() { + PADDLE_ENFORCE_NE(_tripletor, nullptr, "must set triplet generator first."); + return _tripletor; +} + +} // namespace privc diff --git a/core/privc/privc_context.h b/core/privc/privc_context.h index 35b67a7e9c948099883e4343c3d5dfe26f8eaff2..8ef743d86fd536d0b9fc48602083e24d974ba4ca 100644 --- a/core/privc/privc_context.h +++ b/core/privc/privc_context.h @@ -13,24 +13,28 @@ // limitations under the License. #pragma once -#include #include #include #include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" -#include "prng_utils.h" +#include "core/privc3/prng_utils.h" -namespace aby3 { +namespace privc { using AbstractNetwork = paddle::mpc::AbstractNetwork; using AbstractContext = paddle::mpc::AbstractContext; +using block = psi::block; + +static const size_t SCALING_N = 32; +template +class TripletGenerator; class PrivCContext : public AbstractContext { public: PrivCContext(size_t party, std::shared_ptr network, - block seed = g_zero_block): - AbstractContext::AbstractContext(party, network) { + block seed = psi::g_zero_block): + AbstractContext::AbstractContext(party, network), _tripletor{nullptr} { set_num_party(2); if (psi::equals(seed, psi::g_zero_block)) { @@ -42,13 +46,28 @@ public: PrivCContext(const PrivCContext &other) = delete; PrivCContext &operator=(const PrivCContext &other) = delete; +/* + block get_private_block() { + std::array ret_block; + ret_block[0] = gen_random_private(); + ret_block[1] = gen_random_private(); + + return *(reinterpret_cast(ret_block.data())); + } +*/ + + void set_triplet_generator(std::shared_ptr>& tripletor); + + std::shared_ptr> triplet_generator(); protected: - PseudorandomNumberGenerator& get_prng(size_t idx) override { + psi::PseudorandomNumberGenerator& get_prng(size_t idx) override { return _prng; } + private: - PseudorandomNumberGenerator _prng; + std::shared_ptr> _tripletor; + psi::PseudorandomNumberGenerator _prng; }; -} // namespace aby3 +} // namespace privc diff --git a/core/privc/triplet_generator.h b/core/privc/triplet_generator.h new file mode 100644 index 0000000000000000000000000000000000000000..3cb1a1a2f9547722c0ffb494281ceb35de0ed18b --- /dev/null +++ b/core/privc/triplet_generator.h @@ -0,0 +1,181 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include + +#include "paddle/fluid/platform/enforce.h" + +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" +#include "core/privc3/prng_utils.h" +#include "core/privc/crypto.h" +#include "core/psi/naorpinkas_ot.h" +#include "core/psi/ot_extension.h" +#include "core/privc3/tensor_adapter.h" + +namespace privc { + +using AbstractNetwork = paddle::mpc::AbstractNetwork; +using AbstractContext = paddle::mpc::AbstractContext; +using block = psi::block; +using NaorPinkasOTsender = psi::NaorPinkasOTsender; +using NaorPinkasOTreceiver = psi::NaorPinkasOTreceiver; + +template +using OTExtSender = psi::OTExtSender; +template +using OTExtReceiver = psi::OTExtReceiver; + +template +using TensorAdapter = aby3::TensorAdapter; + +template +inline int64_t fixed64_mult(const int64_t a, const int64_t b) { + + __int128_t res = (__int128_t)a * (__int128_t)b; + + return static_cast(res >> N); +} + +template +inline uint64_t lshift(uint64_t lhs, size_t rhs) { + return fixed64_mult(lhs, (uint64_t)1 << rhs); +} + +inline std::string block_to_string(const block &b) { + return std::string(reinterpret_cast(&b), sizeof(block)); +} + +inline void gen_ot_masks(OTExtReceiver & ot_ext_recver, + uint64_t input, + std::vector& ot_masks, + std::vector& t0_buffer, + size_t word_width = 8 * sizeof(uint64_t)) { + for (uint64_t idx = 0; idx < word_width; idx += 1) { + auto ot_instance = ot_ext_recver.get_ot_instance(); + block choice = (input >> idx) & 1 ? psi::OneBlock : psi::ZeroBlock; + + t0_buffer.emplace_back(ot_instance[0]); + ot_masks.emplace_back(choice ^ ot_instance[0] ^ ot_instance[1]); + } +} + +inline void gen_ot_masks(OTExtReceiver & ot_ext_recver, + const int64_t* input, + size_t size, + std::vector& ot_masks, + std::vector& t0_buffer, + size_t word_width = 8 * sizeof(uint64_t)) { + for (size_t i = 0; i < size; ++i) { + gen_ot_masks(ot_ext_recver, input[i], ot_masks, t0_buffer, word_width); + } +} + +template +inline void gen_ot_masks(OTExtReceiver & ot_ext_recver, + const std::vector& input, + std::vector& ot_masks, + std::vector& t0_buffer, + size_t word_width = 8 * sizeof(uint64_t)) { + for (const auto& i: input) { + gen_ot_masks(ot_ext_recver, i, ot_masks, t0_buffer, word_width); + } +} + +template +class TripletGenerator { +public: + TripletGenerator(std::shared_ptr& circuit_context) : + _base_ot_choices(circuit_context->gen_random_private()), + _np_ot_sender(sizeof(block) * 8), + _np_ot_recver(sizeof(block) * 8, block_to_string(_base_ot_choices)) { + _privc_ctx = circuit_context; + }; + + void init(); + + virtual void get_triplet(TensorAdapter* ret); + + virtual void get_penta_triplet(TensorAdapter* ret) {} + + std::queue> _triplet_buffer; + static const size_t _s_triplet_step = 1 << 8; + static constexpr double _s_fixed_point_compensation = 0.3; + static const size_t OT_SIZE = sizeof(block) * 8; + +protected: + // T = int64 + template + class Type2Type { + typedef T_ type; + }; + + void fill_triplet_buffer() { fill_triplet_buffer_impl(Type2Type()); } + + template + void fill_triplet_buffer_impl(const Type2Type) { + PADDLE_THROW("type except `int64_t` for generating triplet is not implemented yet"); + } + + // specialize template method by overload + template + void fill_triplet_buffer_impl(const Type2Type); + +private: + + std::shared_ptr privc_ctx() { + return _privc_ctx; + } + AbstractNetwork* net() { + return _privc_ctx->network(); + } + + size_t party() { + return privc_ctx()->party(); + } + + size_t next_party() { + return privc_ctx()->next_party(); + } + // gen triplet for int64_t type + std::vector gen_product(const std::vector &input) { + return gen_product_impl(input, Type2Type()); + }; + + template + std::vector gen_product_impl(const std::vector &input, + Type2Type) { + PADDLE_THROW("type except `int64_t` for generating triplet is not implemented yet"); + } + + template + std::vector gen_product_impl(const std::vector &input, + Type2Type); + + const block _base_ot_choices; + + NaorPinkasOTsender _np_ot_sender; + NaorPinkasOTreceiver _np_ot_recver; + + OTExtSender _ot_ext_sender; + OTExtReceiver _ot_ext_recver; + std::shared_ptr _privc_ctx; +}; + +} // namespace privc + +#include "triplet_generator_impl.h" \ No newline at end of file diff --git a/core/privc/triplet_generator_impl.h b/core/privc/triplet_generator_impl.h new file mode 100644 index 0000000000000000000000000000000000000000..bfee397aebae290d63df7a81ffe8620a2b8186a7 --- /dev/null +++ b/core/privc/triplet_generator_impl.h @@ -0,0 +1,209 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "core/privc/triplet_generator.h" +#include "core/privc/crypto.h" +#include "core/privc/privc_context.h" + +namespace privc { + +template +void TripletGenerator::init() { + auto np_ot_send_pre = [&]() { + std::array, 2>, OT_SIZE> send_buffer; + + for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) { + send_buffer[idx] = _np_ot_sender.send_pre(idx); + } + net()->send(next_party(), send_buffer.data(), sizeof(send_buffer)); + }; + + auto np_ot_send_post = [&]() { + std::array, OT_SIZE> recv_buffer; + + net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer)); + + for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) { + _np_ot_sender.send_post(idx, recv_buffer[idx]); + } + }; + + auto np_ot_recv = [&]() { + std::array, 2>, OT_SIZE> recv_buffer; + + std::array, OT_SIZE> send_buffer; + + net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer)); + + for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) { + send_buffer[idx] = _np_ot_recver.recv(idx, recv_buffer[idx]); + } + + net()->send(next_party(), send_buffer.data(), sizeof(send_buffer)); + }; + + if (party() == 0) { + np_ot_recv(); + + np_ot_send_pre(); + np_ot_send_post(); + + } else { // party == Bob + np_ot_send_pre(); + np_ot_send_post(); + + np_ot_recv(); + } + _ot_ext_sender.init(_base_ot_choices, _np_ot_recver._msgs); + _ot_ext_recver.init(_np_ot_sender._msgs); +} + +template +void TripletGenerator::get_triplet(TensorAdapter* ret) { + size_t num_trip = ret->numel() / 3; + if (_triplet_buffer.size() < num_trip) { + fill_triplet_buffer(); + } + + for (int i = 0; i < num_trip; ++i) { + auto triplet = _triplet_buffer.front(); + auto ret_ptr = ret->data() + i; + *ret_ptr = triplet[0]; + *(ret_ptr + num_trip) = triplet[1]; + *(ret_ptr + 2 * num_trip) = triplet[2]; + _triplet_buffer.pop(); + } + +} + +template +template +void TripletGenerator::fill_triplet_buffer_impl(const Type2Type) { + std::vector a(_s_triplet_step); + std::vector b(_s_triplet_step); + + std::for_each(a.data(), a.data() + a.size(), + [this](uint64_t& val) { + val = privc_ctx()-> template gen_random_private(); }); + std::for_each(b.data(), b.data() + b.size(), + [this](uint64_t& val) { + val = privc_ctx()-> template gen_random_private(); }); + + std::vector ab0; + std::vector ab1; + + std::function(const std::vector&)> gen_p + = [this](const std::vector& v) { + return gen_product(v); + }; + + ab0 = gen_p(privc_ctx()->party() == 0 ? a : b); + ab1 = gen_p(privc_ctx()->party() == 0 ? b : a); + + for (uint64_t i = 0; i < a.size(); i += 1) { + std::array item = { + static_cast(a[i]), + static_cast(b[i]), + static_cast(fixed64_mult(a[i], b[i]) + ab0[i] + ab1[i])}; + _triplet_buffer.push(std::move(item)); + } +} + +template +template +std::vector TripletGenerator::gen_product_impl( + const std::vector &input, + Type2Type) { + size_t word_width = 8 * sizeof(uint64_t); + std::vector ret; + + if (party() == 0) { + std::vector s1_buffer; + + std::vector ot_mask; + ot_mask.resize(input.size() * word_width); + net()->recv(next_party(), ot_mask.data(), sizeof(block) * ot_mask.size()); + size_t ot_mask_idx = 0; + for (const auto &a: input) { + uint64_t ret_val = 0; + + for (uint64_t idx = 0; idx < word_width; idx += 1) { + + const block& round_ot_mask = ot_mask.at(ot_mask_idx); + //net()->recv(next_party(), &round_ot_mask, sizeof(block)); + + // bad naming from ot extention + block q = _ot_ext_sender.get_ot_instance(); + + q ^= (round_ot_mask & _base_ot_choices); + + auto s = psi::hash_blocks({q, q ^ _base_ot_choices}); + uint64_t s0 = *reinterpret_cast(&s.first); + uint64_t s1 = *reinterpret_cast(&s.second); + + s1 ^= lshift(a, idx) - s0; + + s1_buffer.emplace_back(s1); + + ret_val += s0; + ot_mask_idx++; + } + ret.emplace_back(ret_val); + } + net()->send(next_party(), s1_buffer.data(), sizeof(uint64_t) * s1_buffer.size()); + + } else { // as ot recver + + std::vector ot_masks; + std::vector t0_buffer; + + gen_ot_masks(_ot_ext_recver, input, ot_masks, t0_buffer); + net()->send(next_party(), ot_masks.data(), sizeof(block) * ot_masks.size()); + std::vector ot_msg; + ot_msg.resize(input.size() * word_width); + net()->recv(next_party(), ot_msg.data(), sizeof(uint64_t) * ot_msg.size()); + size_t ot_msg_idx = 0; + uint64_t b_idx = 0; + for (const auto &b: input) { + uint64_t ret_val = 0; + + int b_weight = 0; + + for (size_t idx = 0; idx < word_width; idx += 1) { + const uint64_t& round_ot_msg = ot_msg.at(ot_msg_idx); + + auto t0_hash = psi::hash_block(t0_buffer[b_idx * word_width + idx]); + + uint64_t key = *reinterpret_cast(&t0_hash); + + bool b_i = (b >> idx) & 1; + + b_weight += b_i; + + ret_val += b_i ? round_ot_msg ^ key : -key; + ot_msg_idx++; + } + // compensation for precision loss + ret.emplace_back(ret_val + static_cast(_s_fixed_point_compensation * b_weight)); + + b_idx += 1; + } + } + + return ret; +} + +} // namespace privc diff --git a/core/privc3/aby3_context.h b/core/privc3/aby3_context.h index 8094d9d67640fb8098f6ce97c3e5050e287d6e63..025e87acfd7964a5beedd57ffaacf5409befd462 100644 --- a/core/privc3/aby3_context.h +++ b/core/privc3/aby3_context.h @@ -13,7 +13,6 @@ // limitations under the License. #pragma once -#include #include #include