提交 ae687bb5 编写于 作者: Y yangqingyou

add activation function of fixedpoint

上级 7150cc79
set(PRIVC_SRCS
"privc_context.cc"
"integer.cc"
"bit.cc"
"ot.cc"
)
add_library(privc_o OBJECT ${PRIVC_SRCS})
......@@ -13,4 +16,6 @@ target_link_libraries(privc psi)
cc_test(crypto_test SRCS crypto_test.cc DEPS privc)
cc_test(privc_fixedpoint_tensor_test SRCS fixedpoint_tensor_test.cc DEPS privc)
cc_test(triplet_generator_test SRCS triplet_generator_test.cc DEPS privc)
cc_test(privc_gc_test SRCS privc_gc_test.cc DEPS privc)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "bit.h"
namespace privc {
std::vector<bool> reconstruct(std::vector<Bit> bits,
size_t party_in) {
std::vector<u8> remote;
std::vector<u8> local;
std::vector<bool> ret;
ret.resize(bits.size());
auto party = paddle::mpc::ContextHolder::mpc_ctx()->party();
auto next_party = paddle::mpc::ContextHolder::mpc_ctx()->next_party();
auto net = paddle::mpc::ContextHolder::mpc_ctx()->network();
for (auto& i : bits) {
local.emplace_back(block_lsb(i.share()));
}
// make remote ^ local = 0 if party_in == next_party()
remote = local;
if (party_in == std::numeric_limits<size_t>::max()) {
// reveal to all
if (party == 0) {
net->recv(next_party, remote.data(), remote.size() * sizeof(u8));
net->send(next_party, local.data(), local.size() * sizeof(u8));
} else {
net->send(next_party, local.data(), local.size() * sizeof(u8));
net->recv(next_party, remote.data(), remote.size() * sizeof(u8));
}
} else {
//reveal to one
if (party == party_in) {
net->recv(next_party, remote.data(), remote.size() * sizeof(u8));
} else {
net->send(next_party, local.data(), local.size() * sizeof(u8));
}
}
std::transform(local.begin(), local.end(),
remote.begin(), ret.begin(),
[] (u8& lhs, u8& rhs){ return lhs ^ rhs; });
return ret;
}
block garbled_and(block a, block b) {
auto& garbled_and_ctr = ot()->garbled_and_ctr();
if (party() == 0) {
u8 pa = block_lsb(a);
u8 pb = block_lsb(b);
u64 j0 = garbled_and_ctr += 2;
u64 j1 = j0 + 1;
auto j0_ = psi::to_block(j0);
auto j1_ = psi::to_block(j1);
auto& garbled_delta = ot()->garbled_delta();
auto t = psi::hash_blocks({a, a ^ garbled_delta}, {j0_, j0_});
block tg = t.first;
block wg = tg;
tg ^= t.second;
t = psi::hash_blocks({b, b ^ garbled_delta}, {j1_, j1_});
block te = t.first;
block we = te;
te ^= t.second;
te ^= a;
if (pb) {
tg ^= garbled_delta;
we ^= te ^ a;
}
if (pa) {
wg ^= tg;
}
/* TODO: add gc delay
if (_gc_delay) {
send_to_buffer(tg);
send_to_buffer(te);
} else {
send_val(tg);
send_val(te);
}
*/
net()->send(next_party(), tg);
net()->send(next_party(), te);
return we ^ wg;
} else {
u8 sa = block_lsb(a);
u8 sb = block_lsb(b);
u64 j0 = garbled_and_ctr += 2;
u64 j1 = j0 + 1;
auto j0_ = psi::to_block(j0);
auto j1_ = psi::to_block(j1);
block tg = net()->template recv<block>(next_party());
block te = net()->template recv<block>(next_party());
auto t = psi::hash_blocks({a, b}, {j0_, j1_});
block wg = t.first;
block we = t.second;
if (sa) {
wg ^= tg;
}
if (sb) {
we ^= te ^ a;
}
return wg ^ we;
}
}
block garbled_share(bool val) {
if (party() == 0) {
block ot_mask = net()->template recv<block>(next_party());
block q = ot()->ot_sender().get_ot_instance();
q ^= ot_mask & ot()->base_ot_choice();
auto ret_ = psi::hash_blocks({q, q ^ ot()->base_ot_choice()});
auto& garbled_delta = ot()->garbled_delta();
block to_send =
ret_.first ^ ret_.second ^ garbled_delta;
net()->send(next_party(), to_send);
return ret_.first;
} else {
auto ot_ins = ot()->ot_receiver().get_ot_instance();
block choice = val ? psi::OneBlock : psi::ZeroBlock;
block ot_mask = ot_ins[0] ^ ot_ins[1] ^ choice;
net()->send(next_party(), ot_mask);
block ot_recv = net()->template recv<block>(next_party());
block ret = psi::hash_block(ot_ins[0]);
if (val) {
ret ^= ot_recv;
}
return ret;
}
}
std::vector<block> garbled_share_internal(const int64_t* val, size_t size) {
std::vector<block> ret(sizeof(int64_t) * 8 * size); // 8 bit for 1 byte
std::vector<block> send_buffer;
std::vector<block> recv_buffer;
recv_buffer.resize(sizeof(int64_t) * 8 * size);
if (party() == 0) {
net()->recv(next_party(), recv_buffer.data(), recv_buffer.size() * sizeof(recv_buffer));
for (size_t idx = 0; idx < 8 * sizeof(int64_t) * size; ++idx) {
const block& ot_mask = recv_buffer.at(idx);
block q = ot()->ot_sender().get_ot_instance();
q ^= ot_mask & ot()->base_ot_choice();
auto ret_ = psi::hash_blocks({q, q ^ ot()->base_ot_choice()});
ret[idx] = ret_.first;
auto& garbled_delta = ot()->garbled_delta();
block to_send =
ret_.second ^ ret[idx] ^ garbled_delta;
//send_to_buffer(to_send);
send_buffer.emplace_back(to_send);
}
//flush_buffer();
net()->send(next_party(), send_buffer.data(), send_buffer.size() * sizeof(block));
return ret;
} else {
for (size_t idx = 0; idx < 8 * sizeof(int64_t) * size; ++idx) {
auto ot_ins = ot()->ot_receiver().get_ot_instance();
ret[idx] = psi::hash_block(ot_ins[0]);
size_t i = idx / (sizeof(int64_t) * 8);
size_t j = idx % (sizeof(int64_t) * 8);
block choice = (val[i] >> j) & 1 ? psi::OneBlock : psi::ZeroBlock;
block ot_mask = ot_ins[0] ^ ot_ins[1] ^ choice;
//send_to_buffer(ot_mask);
send_buffer.emplace_back(ot_mask);
}
//flush_buffer();
net()->send(next_party(), send_buffer.data(), send_buffer.size() * sizeof(block));
net()->recv(next_party(), recv_buffer.data(), recv_buffer.size() * sizeof(recv_buffer));
for (size_t idx = 0; idx < 8 * sizeof(int64_t) * size; ++idx) {
const block& ot_recv = recv_buffer.at(idx);
size_t i = idx / (sizeof(int64_t) * 8);
size_t j = idx % (sizeof(int64_t) * 8);
ret[idx] ^= (val[i] >> j) & 1 ? ot_recv : psi::ZeroBlock;
}
return ret;
}
}
std::vector<block> garbled_share(int64_t val) {
return garbled_share_internal(&val, 1);
}
std::vector<block> garbled_share(const std::vector<int64_t>& val) {
return garbled_share_internal(val.data(), val.size());
}
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include <limits>
#include "core/privc/privc_context.h"
#include "core/privc/crypto.h"
#include "core/privc/triplet_generator.h"
#include "core/privc/common.h"
#include "core/privc/ot.h"
namespace privc {
block garbled_and(block a, block b);
block garbled_share(bool val);
std::vector<block> garbled_share(int64_t val);
std::vector<block> garbled_share(const std::vector<int64_t>& val);
class Bit {
public:
block _share;
public:
Bit() : _share(psi::ZeroBlock) {}
Bit(bool val, size_t party_in) {
if (party_in == 0) {
if (party() == 0) {
_share = privc_ctx()->gen_random_private<block>();
block to_send = _share;
if (val) {
to_send ^= ot()->garbled_delta();
}
net()->send(next_party(), to_send);
} else {
_share = net()->recv<block>(next_party());
}
} else {
_share = garbled_share(val);
}
}
~Bit() {}
Bit operator^(const Bit &rhs) const {
Bit ret;
ret._share = _share ^ rhs._share;
return ret;
}
block& share() {
return _share;
}
const block& share() const {
return _share;
}
Bit operator&(const Bit &rhs) const {
Bit ret;
ret._share = garbled_and(_share, rhs._share);
return ret;
}
Bit operator|(const Bit &rhs) const { return *this ^ rhs ^ (*this & rhs); }
Bit operator~() const {
Bit ret;
ret._share = _share;
if (party() == 0) {
ret._share ^= ot()->garbled_delta();
}
return ret;
}
Bit operator&&(const Bit &rhs) const {
return *this & rhs;
}
Bit operator||(const Bit &rhs) const {
return *this | rhs;
}
Bit operator!() const {
return ~*this;
}
u8 lsb() const {
u8 ret = block_lsb(_share);
return ret & (u8)1;
}
bool reconstruct() const {
u8 remote;
u8 local = block_lsb(_share);
if (party() == 0) {
remote = net()->recv<u8>(next_party());
net()->send(next_party(), local);
} else {
net()->send(next_party(), local);
remote = net()->recv<u8>(next_party());
}
return remote ^ local;
}
};
std::vector<bool> reconstruct(std::vector<Bit> bits,
size_t party_in);
using Bool = Bit;
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc/privc_context.h"
#include "core/privc3/tensor_adapter_factory.h"
namespace privc {
static inline std::shared_ptr<AbstractContext> privc_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
static inline size_t party() {
return privc_ctx()->party();
}
static inline std::shared_ptr<OT> ot() {
return std::dynamic_pointer_cast<PrivCContext>(privc_ctx())->ot();
}
static inline size_t next_party() {
return privc_ctx()->next_party();
}
static inline AbstractNetwork* net() {
return privc_ctx()->network();
}
static inline std::shared_ptr<aby3::TensorAdapterFactory> tensor_factory() {
return paddle::mpc::ContextHolder::tensor_factory();
}
static inline std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> tripletor() {
return std::dynamic_pointer_cast<PrivCContext>(privc_ctx())->triplet_generator();
}
} // namespace privc
\ No newline at end of file
......@@ -30,6 +30,7 @@ typedef unsigned char u8;
typedef unsigned long long u64;
const block ZeroBlock = _mm_set_epi64x(0, 0);
const block OneBlock = _mm_set_epi64x(-1, -1);
const int POINT_BUFFER_LEN = 21;
static block double_block(block bl);
......@@ -48,6 +49,13 @@ static inline std::pair<block, block> hash_blocks(const std::pair<block, block>&
return {c[0] ^ k[0], c[1] ^ k[1]};
}
template <typename T>
static inline block to_block(const T& val) {
block ret = ZeroBlock;
std::memcpy(&ret, &val, std::min(sizeof ret, sizeof val));
return ret;
}
static inline block double_block(block bl) {
const __m128i mask = _mm_set_epi32(135,1,1,1);
__m128i tmp = _mm_srai_epi32(bl, 31);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <cmath>
#include <stdexcept>
#include <vector>
#include "bit.h"
#include "core/privc/integer.h"
namespace privc {
const unsigned int taylor_n = 6;
template<size_t N>
int64_t double_to_fix64(double in) {
return (int64_t) (in * std::pow(2, N));
}
template<size_t N>
double fix64_to_double(int64_t in) {
return in / std::pow(2, N);
}
inline int64_t factorial(unsigned int i) {
int64_t ret = 1;
for (; i > 0; i -= 1) {
ret *= i;
}
return ret;
}
template<size_t N>
class FixedPoint : public privc::Integer {
public:
FixedPoint(Integer &&in) : Integer(in) {}
FixedPoint(const Integer &in) : Integer(in) {}
FixedPoint(double in) {
_length = sizeof(int64_t) * 8;
_bits.resize(_length);
int64_t in_ = double_to_fix64<N>(in);
for (int i = 0; i < _length; i += 1) {
if (party() == 0 && in_ >> i & 1) {
_bits[i]._share = ot()->garbled_delta();
}
}
}
FixedPoint(double input, size_t party)
: Integer(double_to_fix64<N>(input), party) {
}
double reconstruct() const {
return fix64_to_double<N>(Integer::reconstruct());
}
FixedPoint decimal() const {
FixedPoint res = abs();
for (int i = N; i < res.size(); i += 1) {
res[i]._share = psi::ZeroBlock;
}
cond_neg(_bits[size() - 1], res.bits(), res.cbits(), size());
return res;
}
FixedPoint operator*(const FixedPoint &rhs) const {
if (size() != rhs.size()) {
throw std::logic_error("op len not match");
}
FixedPoint res(*this);
const unsigned int full_size = size() + N;
std::vector<Bit> l_vec;
std::vector<Bit> r_vec;
std::vector<Bit> res_vec(full_size, Bit());
for (int i = 0; i < size(); i += 1) {
l_vec.emplace_back(_bits[i]);
r_vec.emplace_back(rhs[i]);
}
for (int i = 0; (unsigned)i < N; i += 1) {
l_vec.emplace_back(_bits[size() - 1]);
r_vec.emplace_back(rhs[size() - 1]);
}
mul_full(res_vec.data(), l_vec.data(), r_vec.data(), full_size);
for (int i = 0; i < size(); i += 1) {
res[i] = std::move(res_vec[i + N]);
}
return res;
}
FixedPoint operator/(const FixedPoint &rhs) const {
if (size() != rhs.size()) {
throw std::logic_error("op len not match");
}
FixedPoint res(*this);
FixedPoint i1 = abs();
FixedPoint i2 = rhs.abs();
Bit sign = _bits[size() - 1] ^ rhs[size() - 1];
const unsigned int full_size = size() + N;
std::vector<Bit> l_vec(N, Bit());
std::vector<Bit> r_vec;
std::vector<Bit> res_vec(full_size);
for (int i = 0; i < size(); i += 1) {
l_vec.emplace_back(std::move(i1[i]));
r_vec.emplace_back(std::move(i2[i]));
}
for (int i = 0; (unsigned)i < N; i += 1) {
r_vec.emplace_back(Bit());
}
div_full(res_vec.data(), nullptr, l_vec.data(), r_vec.data(),
full_size);
Bit q_sign = res_vec[size() - 1];
std::vector<Bit> nan(size(), q_sign);
nan[size() - 1] = ~q_sign;
privc::if_then_else(res_vec.data(), nan.data(), res_vec.data(), size(), q_sign);
cond_neg(sign, res_vec.data(), res_vec.data(), full_size);
res_vec[0] = res_vec[0] ^ (q_sign & sign);
for (int i = 0; i < size(); i += 1) {
res[i] = std::move(res_vec[i]);
}
return res;
}
FixedPoint exp_int() const {
// e^22 > 2^31 - 1, 22 = 0x16
// e^-22 = 2.79 * 10 ^ -10 sufficiently precise
return exp(_bits[size() - 1], abs().bits() + N, 5);
}
FixedPoint exp_gc() const {
auto exp_int_ = exp_int();
auto x = decimal() * FixedPoint(0.5);
auto x_n = FixedPoint(1.0);
std::vector<FixedPoint> var;
for (int i = 0; (unsigned)i <= taylor_n; i += 1) {
var.emplace_back(x_n);
x_n = x_n * x;
}
auto exp_dec = var[0];
for (unsigned int i = 1; i <= taylor_n; i += 1) {
exp_dec = exp_dec + var[i] * FixedPoint(1.0 / factorial(i));
}
return exp_int_ * exp_dec * exp_dec;
}
FixedPoint relu() const {
FixedPoint zero(0.0);
return if_then_else(zero.geq(*this), zero, *this);
}
int64_t relu_bc() const {
FixedPoint zero(0.0);
return if_then_else_bc(zero.geq(*this), zero, *this);
}
FixedPoint logistic() const {
FixedPoint one(1.0);
FixedPoint half(0.5);
FixedPoint t_option = FixedPoint(operator+(half)).relu();
return if_then_else(one.geq(t_option), t_option, one);
}
static std::vector<FixedPoint> softmax(std::vector<FixedPoint> &&in) {
if (in.size() == 0) {
throw std::logic_error("zero input vector size");
}
FixedPoint sum(0.0);
for (auto &x: in) {
x = x.relu();
sum = sum + x;
}
auto sum_zero = sum.is_zero();
FixedPoint avg(1.0 / in.size());
for (auto &x: in) {
x = if_then_else(sum_zero, avg, x / sum);
}
return in;
}
private:
static FixedPoint exp(const Bit &neg, const Bit *in, int size) {
FixedPoint res(1.0);
FixedPoint base = Integer::if_then_else(neg, FixedPoint(1.0 / M_E),
FixedPoint(M_E));
FixedPoint one = res;
for (int i = size - 1; i >= 0; i -= 1) {
FixedPoint round = Integer::if_then_else(in[i], base, one);
res = res * round;
if (i) {
res = res * res;
}
}
return res;
}
};
template<size_t N>
using Fix64gc = FixedPoint<N>;
} // namespace privc
......@@ -16,10 +16,11 @@
#include <vector>
#include "privc_context.h"
#include "core/privc/privc_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "../privc3/paddle_tensor.h"
#include "./triplet_generator.h"
#include "core/privc3/paddle_tensor.h"
#include "core/privc/triplet_generator.h"
#include "core/privc/common.h"
namespace privc {
......@@ -113,88 +114,22 @@ public:
// mat_mul with TensorAdapter
void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
// exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n
// where n = 2^ite
// void exp(FixedPointTensor* ret, size_t iter = 8) const;
// element-wise relu
void relu(FixedPointTensor* ret) const;
// element-wise relu with relu'
// void relu_with_derivative(FixedPointTensor* ret, BooleanTensor<T>* derivative) const;
// element-wise sigmoid using 3 piecewise polynomials
void sigmoid(FixedPointTensor* ret) const;
// softmax axis = -1
//void softmax(FixedPointTensor* ret) const;
// element-wise sigmoid using 3 piecewise polynomials
void argmax(FixedPointTensor* ret) const;
// element-wise compare
// <
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void lt(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// <=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void leq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// >
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void gt(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// >=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void geq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// ==
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void eq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// !=
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void neq(const CTensor<T, N1...>* rhs, CTensor<T, N1...>* ret) const;
// element-wise max
// if not null, cmp stores true if rhs is bigger
template<template<typename U, size_t...> class CTensor,
size_t... N1>
void max(const CTensor<T, N1...>* rhs,
FixedPointTensor* ret,
CTensor<T, N1...>* cmp = nullptr) const;
private:
static inline std::shared_ptr<AbstractContext> privc_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
void relu(FixedPointTensor* ret) const {
relu_impl<T>(ret, Type2Type<T>());
}
static inline std::shared_ptr<TensorAdapterFactory> tensor_factory() {
return paddle::mpc::ContextHolder::tensor_factory();
// element-wise sigmoid
void sigmoid(FixedPointTensor* ret) const {
sigmoid_impl<T>(ret, Type2Type<T>());
}
static inline std::shared_ptr<TripletGenerator<T, N>> tripletor() {
return std::dynamic_pointer_cast<PrivCContext>(privc_ctx())->triplet_generator();
}
static size_t party() {
return privc_ctx()->party();
}
static size_t next_party() {
return privc_ctx()->next_party();
}
static inline AbstractNetwork* net() {
return privc_ctx()->network();
// matrix argmax
void argmax(FixedPointTensor<T, N>* ret) const {
argmax_impl<T>(ret, Type2Type<T>());
}
private:
// mul_impl with FixedPointTensor
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
......@@ -214,11 +149,35 @@ private:
// mat_mul_impl with FixedPointTensor
template<typename T_>
void mat_mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
PADDLE_THROW("type except `int64_t` for fixedtensor mat mul is not implemented yet");
}
template<typename T_>
void mat_mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
// relu_impl with FixedPointTensor
template<typename T_>
void relu_impl(FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor relu is not implemented yet");
}
template<typename T_>
void relu_impl(FixedPointTensor* ret, Type2Type<int64_t>) const;
// sigmoid_impl with FixedPointTensor
template<typename T_>
void sigmoid_impl(FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor sigmoid is not implemented yet");
}
template<typename T_>
void sigmoid_impl(FixedPointTensor* ret, Type2Type<int64_t>) const;
// argmax_impl with FixedPointTensor
template<typename T_>
void argmax_impl(FixedPointTensor<T, N>* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor sigmoid is not implemented yet");
}
template<typename T_>
void argmax_impl(FixedPointTensor<T, N>* ret, Type2Type<int64_t>) const;
TensorAdapter<T>* _share;
};
......
......@@ -18,8 +18,10 @@
#include <algorithm>
#include "paddle/fluid/platform/enforce.h"
#include "../privc3/prng.h"
#include "../privc3/paddle_tensor.h"
#include "core/privc3/prng.h"
#include "core/privc3/paddle_tensor.h"
#include "core/privc3/paddle_tensor_util.h"
#include "core/privc/fixed_point.h"
namespace privc {
......@@ -43,7 +45,7 @@ template<typename T, size_t N>
void FixedPointTensor<T, N>::reveal_to_one(size_t party,
TensorAdapter<T>* ret) const {
if (party == this->party()) {
if (party == privc::party()) {
auto buffer = tensor_factory()->template create<T>(ret->shape());
privc_ctx()->network()->template recv(next_party(), *buffer);
......@@ -405,4 +407,70 @@ void FixedPointTensor<T, N>::mat_mul_impl(const FixedPointTensor<T, N>* rhs,
result->copy(ret->mutable_share());
}
template<typename T, size_t N>
template<typename T_>
void FixedPointTensor<T, N>::relu_impl(FixedPointTensor<T, N>* ret,
const Type2Type<int64_t>) const {
std::vector<T> op_v;
aby3::TensorToVector<T>(share(), &op_v);
// ac to gc
auto x_v = Integer::vector(op_v, 0);
auto y_v = Integer::vector(op_v, 1);
std::transform(x_v.begin(), x_v.end(),
y_v.begin(), ret->mutable_share()->data(),
[](const Integer& x, const Integer& y) -> int64_t {
FixedPoint<N> gc = (FixedPoint<N>) (x + y);
auto ret_bc = gc.relu_bc();
return to_ac_num(ret_bc);
});
}
template<typename T, size_t N>
template<typename T_>
void FixedPointTensor<T, N>::sigmoid_impl(FixedPointTensor<T, N>* ret,
const Type2Type<int64_t>) const {
std::vector<T> op_v;
aby3::TensorToVector<T>(share(), &op_v);
// ac to gc
auto x_v = Integer::vector(op_v, 0);
auto y_v = Integer::vector(op_v, 1);
std::transform(x_v.begin(), x_v.end(),
y_v.begin(), ret->mutable_share()->data(),
[](const Integer& x, const Integer& y) -> int64_t {
FixedPoint<N> gc = (FixedPoint<N>) (x + y);
auto ret_gc = gc.logistic();
return to_ac_num(ret_gc.lsb());
});
}
template<typename T, size_t N>
template<typename T_>
void FixedPointTensor<T, N>::argmax_impl(FixedPointTensor<T, N>* ret,
const Type2Type<int64_t>) const {
PADDLE_ENFORCE_EQ(ret->shape()[1], shape()[1], "shape mot match.");
for ( int i = 0; i < shape()[0]; ++i) {
std::vector<T> vec;
aby3::TensorToVector<T>(share(), &vec, i);
// ac to gc
auto x_v = Integer::vector(vec, 0);
auto y_v = Integer::vector(vec, 1);
std::vector<Integer> gc_v;
gc_v.resize(x_v.size());
std::transform(x_v.begin(), x_v.end(),
y_v.begin(), gc_v.begin(),
std::plus<Integer>());
std::vector<int64_t> one_hot_index = Integer::argmax_one_hot(gc_v);
// gc to ac
auto ac_one_hot = to_ac_num(one_hot_index);
T* ret_ptr = ret->mutable_share()->data() + i * shape()[1];
std::transform(ac_one_hot.begin(), ac_one_hot.end(), ret_ptr,
[] (const int64_t& op) {
// int to fixedpoint
return op << N;
});
}
}
} // namespace privc
......@@ -35,59 +35,61 @@ using AbstractContext = paddle::mpc::AbstractContext;
class FixedTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<AbstractContext> _mpc_ctx[2];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
std::thread _t[2];
std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
static paddle::platform::CPUDeviceContext _cpu_ctx;
static std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
static std::shared_ptr<AbstractContext> _mpc_ctx[2];
static std::shared_ptr<gloo::rendezvous::HashStore> _store;
static std::thread _t[2];
static std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
virtual ~FixedTensorTest() noexcept {}
void SetUp() {
static void SetUpTestCase() {
paddle::framework::OperatorBase* op = nullptr;
paddle::framework::Scope scope;
paddle::framework::RuntimeContext ctx({}, {});
// only device_ctx is needed
_exec_ctx = std::make_shared<paddle::framework::ExecutionContext>(
*op, scope, _cpu_ctx, ctx);
_store = std::make_shared<gloo::rendezvous::HashStore>();
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, this, i);
_t[i] = std::thread(&FixedTensorTest::gen_mpc_ctx, i);
}
for (auto& ti : _t) {
ti.join();
}
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&FixedTensorTest::init_triplet, this, i);
_t[i] = std::thread(&FixedTensorTest::init_ot_and_triplet, i);
}
for (auto& ti : _t) {
ti.join();
}
_s_tensor_factory = std::make_shared<aby3::PaddleTensorFactory>(&_cpu_ctx);
}
std::shared_ptr<paddle::mpc::MeshNetwork> gen_network(size_t idx) {
static inline std::shared_ptr<paddle::mpc::MeshNetwork> gen_network(size_t idx) {
return std::make_shared<paddle::mpc::MeshNetwork>(idx,
"127.0.0.1",
2,
"test_prefix_privc",
_store);
}
void gen_mpc_ctx(size_t idx) {
static inline void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<PrivCContext>(idx, net);
}
void init_triplet(size_t idx) {
static inline void init_ot_and_triplet(size_t idx) {
std::shared_ptr<OT> ot = std::make_shared<OT>(_mpc_ctx[idx]);
ot->init();
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_ot(ot);
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> tripletor
= std::make_shared<TripletGenerator<int64_t, SCALING_N>>(_mpc_ctx[idx]);
tripletor->init();
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_triplet_generator(tripletor);
}
......@@ -383,108 +385,6 @@ TEST_F(FixedTensorTest, negative) {
EXPECT_EQ(-3, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 3);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 3;
for (int i = 0; i < ret[0]->numel() / 3; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t b_idx = num_triplet + i;
uint64_t c_idx = 2 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
TEST_F(FixedTensorTest, penta_triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 5);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_penta_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_penta_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 5;
for (int i = 0; i < ret[0]->numel() / 5; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t alpha_idx = num_triplet + i;
uint64_t b_idx = 2 * num_triplet + i;
uint64_t c_idx = 3 * num_triplet + i;
uint64_t alpha_c_idx = 4 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
int64_t alpha_c = fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret1_ptr + b_idx));
// sometimes the difference big than 200
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
EXPECT_NEAR(alpha_c , (*(ret0_ptr + alpha_c_idx) + *(ret1_ptr + alpha_c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
TEST_F(FixedTensorTest, mulfixed) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
......@@ -765,4 +665,190 @@ TEST_F(FixedTensorTest, mat_mulfixed) {
EXPECT_NEAR(-19, p->data()[3] / std::pow(2, SCALING_N), 0.00001);
}
TEST_F(FixedTensorTest, relu) {
std::vector<size_t> shape = { 2 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = [-2, 2]
sl[0]->data()[0] = (int64_t)-1 << SCALING_N;
sl[0]->data()[1] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)-1 << SCALING_N;
sl[1]->data()[1] = (int64_t)1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.relu(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.relu(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(0, p->data()[0] / std::pow(2, SCALING_N));
EXPECT_EQ(2, p->data()[1] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, sigmoid) {
std::vector<size_t> shape = { 3 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = [-1, 0, 1]
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[0]->data()[1] = (int64_t)1 << SCALING_N;
sl[0]->data()[2] = (int64_t)2 << SCALING_N;
sl[1]->data()[0] = (int64_t)-2 << SCALING_N;
sl[1]->data()[1] = (int64_t)-1 << SCALING_N;
sl[1]->data()[2] = (int64_t)-1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.sigmoid(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.sigmoid(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(0, p->data()[0] / std::pow(2, SCALING_N));
EXPECT_EQ(0.5, p->data()[1] / std::pow(2, SCALING_N));
EXPECT_EQ(1, p->data()[2] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, argmax) {
std::vector<size_t> shape = { 2, 2 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = [[-2, 2], [1, 0]]
sl[0]->data()[0] = (int64_t)-1 << SCALING_N;
sl[0]->data()[1] = (int64_t)1 << SCALING_N;
sl[0]->data()[2] = (int64_t)-1 << SCALING_N;
sl[0]->data()[3] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)-1 << SCALING_N;
sl[1]->data()[1] = (int64_t)1 << SCALING_N;
sl[1]->data()[2] = (int64_t)2 << SCALING_N;
sl[1]->data()[3] = (int64_t)-1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.argmax(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.argmax(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(0, p->data()[0] / std::pow(2, SCALING_N));
EXPECT_EQ(1, p->data()[1] / std::pow(2, SCALING_N));
EXPECT_EQ(1, p->data()[2] / std::pow(2, SCALING_N));
EXPECT_EQ(0, p->data()[3] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, argmax_size_one) {
std::vector<size_t> shape = { 2, 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = [[-2], [2]]
sl[0]->data()[0] = (int64_t)-1 << SCALING_N;
sl[0]->data()[1] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)-1 << SCALING_N;
sl[1]->data()[1] = (int64_t)1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.argmax(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.argmax(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(1, p->data()[0] / std::pow(2, SCALING_N));
EXPECT_EQ(1, p->data()[1] / std::pow(2, SCALING_N));
}
paddle::platform::CPUDeviceContext privc::FixedTensorTest::_cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> privc::FixedTensorTest::_exec_ctx;
std::shared_ptr<AbstractContext> privc::FixedTensorTest::_mpc_ctx[2];
std::shared_ptr<gloo::rendezvous::HashStore> privc::FixedTensorTest::_store;
std::thread privc::FixedTensorTest::_t[2];
std::shared_ptr<TensorAdapterFactory> privc::FixedTensorTest::_s_tensor_factory;
} // namespace privc
此差异已折叠。
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <stdexcept>
#include <vector>
#include "bit.h"
namespace privc {
int64_t to_ac_num(int64_t val);
std::vector<int64_t> to_ac_num(const std::vector<int64_t>& input);
int64_t bc_mux(bool choice, int64_t val_t, int64_t val_f);
std::vector<int64_t> bc_mux(const std::vector<uint8_t>& choice,
const std::vector<int64_t>& val_t,
const std::vector<int64_t>& val_f);
class Integer {
protected:
int _length;
std::vector<Bit> _bits;
public:
Integer(Integer &&in) : _length(in._length) {
std::swap(_bits, in._bits);
}
Integer(const Integer &in)
: _length(in._length), _bits(in._bits) {}
Integer &operator=(Integer &&rhs) {
_length = rhs._length;
std::swap(_bits, rhs._bits);
return *this;
}
Integer &operator=(const Integer &rhs) {
_length = rhs._length;
_bits = rhs._bits;
return *this;
}
~Integer() {}
Integer(int64_t input, size_t party);
static std::vector<Integer> vector(const std::vector<int64_t>& input,
size_t party);
Integer() : _length(0), _bits() {}
// Comparable
Bit geq(const Integer &rhs) const;
Bit equal(const Integer &rhs) const;
inline int size() const { return _length; }
Bit* bits() { return _bits.data(); }
const Bit* cbits() const { return _bits.data(); }
std::vector<Bit>& share() { return _bits; }
const std::vector<Bit>& share() const { return _bits; }
Integer operator+(const Integer &rhs) const;
Integer operator-(const Integer &rhs) const;
Integer operator-() const;
Integer operator*(const Integer &rhs) const;
Integer operator/(const Integer &rhs) const;
Integer operator^(const Integer &rhs) const;
Integer abs() const;
Bit& operator[](int index);
const Bit& operator[](int index) const;
int64_t reconstruct() const {
int64_t ret = lsb();
if (party() == 0) {
net()->send(next_party(), ret);
ret ^= net()->recv<int64_t>(next_party());
} else {
auto remote = net()->recv<int64_t>(next_party());
net()->send(next_party(), ret);
ret ^= remote;
}
return ret;
}
bool reconstruct(u64 idx) const {
if (idx >= (unsigned)size()) {
throw std::logic_error("vector range exceed");
}
auto bit = _bits[idx].reconstruct();
return bit ? 1 : 0;
}
Bit is_zero() const;
int64_t lsb() const {
int64_t ret = 0;
for (int idx = 0; idx < size(); idx += 1) {
ret |= (int64_t)block_lsb(_bits[idx]._share) << idx;
}
return ret;
}
static Integer if_then_else(Bit cond, const Integer &t_int,
const Integer &f_int);
static int64_t if_then_else_bc(Bit cond, const Integer &t_int,
const Integer &f_int);
// input one dimension, return plaintext
static int64_t argmax(const std::vector<Integer>& op,
size_t party = std::numeric_limits<size_t>::max());
// with return ciphertext of one-hot
static std::vector<int64_t> argmax_one_hot(
const std::vector<Integer>& op);
};
void if_then_else(Bit *dest, const Bit *tsrc, const Bit *fsrc, int size,
Bit cond);
void cond_neg(Bit cond, Bit *dest, const Bit *src, int size);
void mul_full(Bit *dest, const Bit *op1, const Bit *op2, int size);
void div_full(Bit *vquot, Bit *vrem, const Bit *op1, const Bit *op2, int size);
typedef Integer Int64gc;
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "core/privc/crypto.h"
#include "core/privc/privc_context.h"
#include "core/privc/ot.h"
namespace privc {
void ObliviousTransfer::init() {
auto np_ot_send_pre = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> send_buffer;
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_sender.send_pre(idx);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
auto np_ot_send_post = [&]() {
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> recv_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
_np_ot_sender.send_post(idx, recv_buffer[idx]);
}
};
auto np_ot_recv = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> recv_buffer;
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> send_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_recver.recv(idx, recv_buffer[idx]);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
_garbled_delta = privc_ctx()->template gen_random_private<block>();
reinterpret_cast<u8 *>(&_garbled_delta)[0] |= (u8)1;
_garbled_and_ctr = 0;
if (party() == 0) {
np_ot_recv();
np_ot_send_pre();
np_ot_send_post();
} else { // party == Bob
np_ot_send_pre();
np_ot_send_post();
np_ot_recv();
}
_ot_ext_sender.init(_base_ot_choices, _np_ot_recver._msgs);
_ot_ext_recver.init(_np_ot_sender._msgs);
}
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <queue>
#include <array>
#include "paddle/fluid/platform/enforce.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/prng_utils.h"
#include "core/privc/crypto.h"
#include "core/psi/naorpinkas_ot.h"
#include "core/psi/ot_extension.h"
namespace privc {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
using block = psi::block;
using NaorPinkasOTsender = psi::NaorPinkasOTsender;
using NaorPinkasOTreceiver = psi::NaorPinkasOTreceiver;
using u64 = psi::u64;
using u8 = psi::u8;
template<typename T>
using OTExtSender = psi::OTExtSender<T>;
template<typename T>
using OTExtReceiver = psi::OTExtReceiver<T>;
inline std::string block_to_string(const block &b) {
return std::string(reinterpret_cast<const char *>(&b), sizeof(block));
}
inline u8 block_lsb(const block &val) {
const u8 *view = reinterpret_cast<const u8 *>(&val);
return view[0] & (u8)1;
};
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
uint64_t input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (uint64_t idx = 0; idx < word_width; idx += 1) {
auto ot_instance = ot_ext_recver.get_ot_instance();
block choice = (input >> idx) & 1 ? psi::OneBlock : psi::ZeroBlock;
t0_buffer.emplace_back(ot_instance[0]);
ot_masks.emplace_back(choice ^ ot_instance[0] ^ ot_instance[1]);
}
}
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const int64_t* input,
size_t size,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (size_t i = 0; i < size; ++i) {
gen_ot_masks(ot_ext_recver, input[i], ot_masks, t0_buffer, word_width);
}
}
template <typename T>
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const std::vector<T>& input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (const auto& i: input) {
gen_ot_masks(ot_ext_recver, i, ot_masks, t0_buffer, word_width);
}
}
class ObliviousTransfer {
public:
ObliviousTransfer(std::shared_ptr<AbstractContext>& circuit_context) :
_base_ot_choices(circuit_context->gen_random_private<block>()),
_np_ot_sender(sizeof(block) * 8),
_np_ot_recver(sizeof(block) * 8, block_to_string(_base_ot_choices)) {
_privc_ctx = circuit_context;
};
OTExtReceiver<block>& ot_receiver() { return _ot_ext_recver; }
OTExtSender<block>& ot_sender() { return _ot_ext_sender; }
const block& base_ot_choice() const { return _base_ot_choices; }
const block& garbled_delta() const { return _garbled_delta; }
u64& garbled_and_ctr() { return _garbled_and_ctr; }
void init();
static const size_t OT_SIZE = sizeof(block) * 8;
private:
std::shared_ptr<AbstractContext> privc_ctx() {
return _privc_ctx;
}
AbstractNetwork* net() {
return _privc_ctx->network();
}
size_t party() {
return privc_ctx()->party();
}
size_t next_party() {
return privc_ctx()->next_party();
}
const block _base_ot_choices;
block _garbled_delta;
u64 _garbled_and_ctr;
NaorPinkasOTsender _np_ot_sender;
NaorPinkasOTreceiver _np_ot_recver;
OTExtSender<block> _ot_ext_sender;
OTExtReceiver<block> _ot_ext_recver;
std::shared_ptr<AbstractContext> _privc_ctx;
};
using OT = ObliviousTransfer;
} // namespace privc
......@@ -16,6 +16,7 @@
#include "core/privc/triplet_generator.h"
#include "core/privc/privc_context.h"
#include "core/privc/ot.h"
namespace privc {
void PrivCContext::set_triplet_generator(std::shared_ptr<TripletGenerator<int64_t, SCALING_N>>& tripletor) {
......@@ -27,4 +28,12 @@ std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> PrivCContext::triplet_gene
return _tripletor;
}
void PrivCContext::set_ot(std::shared_ptr<OT>& ot) {
_ot = ot;
}
std::shared_ptr<OT>& PrivCContext::ot() {
return _ot;
}
} // namespace privc
......@@ -27,14 +27,20 @@ using AbstractContext = paddle::mpc::AbstractContext;
using block = psi::block;
static const size_t SCALING_N = 32;
// forward declare
template <typename T, size_t N>
class TripletGenerator;
class ObliviousTransfer;
class PrivCContext : public AbstractContext {
public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
block seed = psi::g_zero_block):
AbstractContext::AbstractContext(party, network), _tripletor{nullptr} {
AbstractContext::AbstractContext(party, network),
_tripletor{nullptr},
_ot{nullptr} {
set_num_party(2);
if (psi::equals(seed, psi::g_zero_block)) {
......@@ -51,6 +57,10 @@ public:
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> triplet_generator();
void set_ot(std::shared_ptr<ObliviousTransfer>& ot);
std::shared_ptr<ObliviousTransfer>& ot();
protected:
psi::PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng;
......@@ -59,6 +69,7 @@ protected:
private:
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> _tripletor;
psi::PseudorandomNumberGenerator _prng;
std::shared_ptr<ObliviousTransfer> _ot;
};
} // namespace privc
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <cmath>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "./privc_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h"
#include "core/privc/triplet_generator.h"
#include "core/privc3/paddle_tensor.h"
namespace privc {
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N32 = FixedPointTensor<int64_t, SCALING_N>;
using AbstractContext = paddle::mpc::AbstractContext;
class GCTest : public ::testing::Test {
public:
static paddle::platform::CPUDeviceContext _cpu_ctx;
static std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
static std::shared_ptr<AbstractContext> _mpc_ctx[2];
static std::shared_ptr<gloo::rendezvous::HashStore> _store;
static std::thread _t[2];
static std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
virtual ~GCTest() noexcept {}
static void SetUpTestCase() {
paddle::framework::OperatorBase* op = nullptr;
paddle::framework::Scope scope;
paddle::framework::RuntimeContext ctx({}, {});
_exec_ctx = std::make_shared<paddle::framework::ExecutionContext>(
*op, scope, _cpu_ctx, ctx);
_store = std::make_shared<gloo::rendezvous::HashStore>();
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&GCTest::gen_mpc_ctx, i);
}
for (auto& ti : _t) {
ti.join();
}
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&GCTest::init_ot_and_triplet, i);
}
for (auto& ti : _t) {
ti.join();
}
_s_tensor_factory = std::make_shared<aby3::PaddleTensorFactory>(&_cpu_ctx);
}
static inline std::shared_ptr<paddle::mpc::MeshNetwork> gen_network(size_t idx) {
return std::make_shared<paddle::mpc::MeshNetwork>(idx,
"127.0.0.1",
2,
"test_prefix_privc",
_store);
}
static inline void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<PrivCContext>(idx, net);
}
static inline void init_ot_and_triplet(size_t idx) {
std::shared_ptr<OT> ot = std::make_shared<OT>(_mpc_ctx[idx]);
ot->init();
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_ot(ot);
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> tripletor
= std::make_shared<TripletGenerator<int64_t, SCALING_N>>(_mpc_ctx[idx]);
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_triplet_generator(tripletor);
}
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return _s_tensor_factory->template create<int64_t>(shape);
}
};
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return g_ctx_holder::tensor_factory()->template create<int64_t>(shape);
}
void test_closure_gc(int party_flag, std::vector<double>& output) {
Fix64gc<SCALING_N> af(7.0, 0);
Fix64gc<SCALING_N> bf(3.0, 1);
Fix64gc<SCALING_N> cf = af;
Fix64gc<SCALING_N> df(std::move(cf));
df = bf;
output.emplace_back(af.geq(bf).reconstruct());
output.emplace_back(af.equal(bf).reconstruct());
Fix64gc<SCALING_N> res = af + bf;
output.emplace_back(res.reconstruct());
res = af - bf;
output.emplace_back(res.reconstruct());
res = -af;
output.emplace_back(res.reconstruct());
res = af * bf;
output.emplace_back(res.reconstruct());
res = af / bf;
output.emplace_back(res.reconstruct());
res = af ^ bf;
output.emplace_back(res.reconstruct());
res = af.abs();
output.emplace_back(res.reconstruct());
output.emplace_back(af[34].reconstruct());
}
TEST_F(GCTest, gc_closure) {
std::vector<double> output[2];
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
test_closure_gc(0, output[0]);
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
test_closure_gc(1, output[1]);
});
}
);
for (auto &t: _t) {
t.join();
}
int i = 0;
auto check_output_equal = [&] () -> bool {
if (output[0].size() != output[1].size()) {
return false;
}
for (unsigned int i = 0; i < output[0].size(); i += 1) {
if (output[0][i] != output[1][i]) {
return false;
}
}
return true;
};
ASSERT_TRUE(check_output_equal());
ASSERT_FLOAT_EQ(1, output[0][i++]);
ASSERT_FLOAT_EQ(0, output[0][i++]);
ASSERT_FLOAT_EQ(10, output[0][i++]);
ASSERT_FLOAT_EQ(4, output[0][i++]);
ASSERT_FLOAT_EQ(-7, output[0][i++]);
ASSERT_FLOAT_EQ(21, output[0][i++]);
ASSERT_FLOAT_EQ(7.0 / 3, output[0][i++]);
ASSERT_FLOAT_EQ(4, output[0][i++]);
ASSERT_FLOAT_EQ(7, output[0][i++]);
ASSERT_FLOAT_EQ(1, output[0][i++]);
}
paddle::platform::CPUDeviceContext privc::GCTest::_cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> privc::GCTest::_exec_ctx;
std::shared_ptr<AbstractContext> privc::GCTest::_mpc_ctx[2];
std::shared_ptr<gloo::rendezvous::HashStore> privc::GCTest::_store;
std::thread privc::GCTest::_t[2];
std::shared_ptr<TensorAdapterFactory> privc::GCTest::_s_tensor_factory;
} // namespace privc
......@@ -26,14 +26,13 @@
#include "core/psi/naorpinkas_ot.h"
#include "core/psi/ot_extension.h"
#include "core/privc3/tensor_adapter.h"
#include "core/privc/ot.h"
namespace privc {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
using block = psi::block;
using NaorPinkasOTsender = psi::NaorPinkasOTsender;
using NaorPinkasOTreceiver = psi::NaorPinkasOTreceiver;
template<typename T>
using OTExtSender = psi::OTExtSender<T>;
......@@ -56,57 +55,11 @@ inline uint64_t lshift(uint64_t lhs, size_t rhs) {
return fixed64_mult<N>(lhs, (uint64_t)1 << rhs);
}
inline std::string block_to_string(const block &b) {
return std::string(reinterpret_cast<const char *>(&b), sizeof(block));
}
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
uint64_t input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (uint64_t idx = 0; idx < word_width; idx += 1) {
auto ot_instance = ot_ext_recver.get_ot_instance();
block choice = (input >> idx) & 1 ? psi::OneBlock : psi::ZeroBlock;
t0_buffer.emplace_back(ot_instance[0]);
ot_masks.emplace_back(choice ^ ot_instance[0] ^ ot_instance[1]);
}
}
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const int64_t* input,
size_t size,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (size_t i = 0; i < size; ++i) {
gen_ot_masks(ot_ext_recver, input[i], ot_masks, t0_buffer, word_width);
}
}
template <typename T>
inline void gen_ot_masks(OTExtReceiver<block> & ot_ext_recver,
const std::vector<T>& input,
std::vector<block>& ot_masks,
std::vector<block>& t0_buffer,
size_t word_width = 8 * sizeof(uint64_t)) {
for (const auto& i: input) {
gen_ot_masks(ot_ext_recver, i, ot_masks, t0_buffer, word_width);
}
}
template<typename T, size_t N>
class TripletGenerator {
public:
TripletGenerator(std::shared_ptr<AbstractContext>& circuit_context) :
_base_ot_choices(circuit_context->gen_random_private<block>()),
_np_ot_sender(sizeof(block) * 8),
_np_ot_recver(sizeof(block) * 8, block_to_string(_base_ot_choices)) {
_privc_ctx = circuit_context;
};
void init();
_privc_ctx(circuit_context) {}
virtual void get_triplet(TensorAdapter<T>* ret);
......@@ -118,7 +71,6 @@ public:
static const size_t _s_triplet_step = 1 << 8;
static constexpr double _s_fixed_point_compensation = 0.3;
static const size_t OT_SIZE = sizeof(block) * 8;
protected:
// dummy type for specilize template method
......@@ -165,6 +117,9 @@ private:
size_t next_party() {
return privc_ctx()->next_party();
}
std::shared_ptr<OT> ot() {
return std::dynamic_pointer_cast<PrivCContext>(privc_ctx())->ot();
}
// gen triplet for int64_t type
std::vector<uint64_t> gen_product(const std::vector<uint64_t> &input);
std::vector<std::pair<uint64_t, uint64_t>> gen_product(size_t ot_sender,
......@@ -172,13 +127,6 @@ private:
const std::vector<uint64_t> &input1
= std::vector<uint64_t>());
const block _base_ot_choices;
NaorPinkasOTsender _np_ot_sender;
NaorPinkasOTreceiver _np_ot_recver;
OTExtSender<block> _ot_ext_sender;
OTExtReceiver<block> _ot_ext_recver;
std::shared_ptr<AbstractContext> _privc_ctx;
};
......
......@@ -18,59 +18,6 @@
namespace privc {
template<typename T, size_t N>
void TripletGenerator<T, N>::init() {
auto np_ot_send_pre = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> send_buffer;
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_sender.send_pre(idx);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
auto np_ot_send_post = [&]() {
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> recv_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
_np_ot_sender.send_post(idx, recv_buffer[idx]);
}
};
auto np_ot_recv = [&]() {
std::array<std::array<std::array<unsigned char,
psi::POINT_BUFFER_LEN>, 2>, OT_SIZE> recv_buffer;
std::array<std::array<unsigned char, psi::POINT_BUFFER_LEN>, OT_SIZE> send_buffer;
net()->recv(next_party(), recv_buffer.data(), sizeof(recv_buffer));
for (uint64_t idx = 0; idx < OT_SIZE; idx += 1) {
send_buffer[idx] = _np_ot_recver.recv(idx, recv_buffer[idx]);
}
net()->send(next_party(), send_buffer.data(), sizeof(send_buffer));
};
if (party() == 0) {
np_ot_recv();
np_ot_send_pre();
np_ot_send_post();
} else { // party == Bob
np_ot_send_pre();
np_ot_send_post();
np_ot_recv();
}
_ot_ext_sender.init(_base_ot_choices, _np_ot_recver._msgs);
_ot_ext_recver.init(_np_ot_sender._msgs);
}
template<typename T, size_t N>
void TripletGenerator<T, N>::get_triplet(TensorAdapter<T>* ret) {
size_t num_trip = ret->numel() / 3;
......@@ -208,11 +155,11 @@ std::vector<uint64_t> TripletGenerator<T, N>::gen_product(
const block& round_ot_mask = ot_mask.at(ot_mask_idx);
// bad naming from ot extention
block q = _ot_ext_sender.get_ot_instance();
block q = ot()->ot_sender().get_ot_instance();
q ^= (round_ot_mask & _base_ot_choices);
q ^= (round_ot_mask & ot()->base_ot_choice());
auto s = psi::hash_blocks({q, q ^ _base_ot_choices});
auto s = psi::hash_blocks({q, q ^ ot()->base_ot_choice()});
uint64_t s0 = *reinterpret_cast<uint64_t *>(&s.first);
uint64_t s1 = *reinterpret_cast<uint64_t *>(&s.second);
......@@ -231,8 +178,8 @@ std::vector<uint64_t> TripletGenerator<T, N>::gen_product(
std::vector<block> ot_masks;
std::vector<block> t0_buffer;
gen_ot_masks(_ot_ext_recver, input, ot_masks, t0_buffer);
auto& ot_ext_recver = ot()->ot_receiver();
gen_ot_masks(ot_ext_recver, input, ot_masks, t0_buffer);
net()->send(next_party(), ot_masks.data(), sizeof(block) * ot_masks.size());
std::vector<uint64_t> ot_msg;
ot_msg.resize(input.size() * word_width);
......@@ -298,11 +245,11 @@ TripletGenerator<T, N>::gen_product(size_t ot_sender,
const block& round_ot_mask = ot_mask.at(ot_mask_idx);
// bad naming from ot extention
block q = _ot_ext_sender.get_ot_instance();
block q = ot()->ot_sender().get_ot_instance();
q ^= (round_ot_mask & _base_ot_choices);
q ^= (round_ot_mask & ot()->base_ot_choice());
auto s = psi::hash_blocks({q, q ^ _base_ot_choices});
auto s = psi::hash_blocks({q, q ^ ot()->base_ot_choice()});
uint64_t* s0 = reinterpret_cast<uint64_t *>(&s.first);
uint64_t* s1 = reinterpret_cast<uint64_t *>(&s.second);
......@@ -323,8 +270,8 @@ TripletGenerator<T, N>::gen_product(size_t ot_sender,
std::vector<block> ot_masks;
std::vector<block> t0_buffer;
gen_ot_masks(_ot_ext_recver, input0, ot_masks, t0_buffer);
auto& ot_ext_recver = ot()->ot_receiver();
gen_ot_masks(ot_ext_recver, input0, ot_masks, t0_buffer);
net()->send(next_party(), ot_masks.data(), sizeof(block) * ot_masks.size());
std::vector<std::pair<uint64_t, uint64_t>> ot_msg;
ot_msg.resize(input0.size() * word_width);
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <string>
#include <cmath>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "./privc_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h"
#include "core/privc/triplet_generator.h"
#include "core/privc3/paddle_tensor.h"
namespace privc {
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N32 = FixedPointTensor<int64_t, SCALING_N>;
using AbstractContext = paddle::mpc::AbstractContext;
class TripletGeneratorTest : public ::testing::Test {
public:
static paddle::platform::CPUDeviceContext _cpu_ctx;
static std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
static std::shared_ptr<AbstractContext> _mpc_ctx[2];
static std::shared_ptr<gloo::rendezvous::HashStore> _store;
static std::thread _t[2];
static std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
virtual ~TripletGeneratorTest() noexcept {}
static void SetUpTestCase() {
paddle::framework::OperatorBase* op = nullptr;
paddle::framework::Scope scope;
paddle::framework::RuntimeContext ctx({}, {});
_exec_ctx = std::make_shared<paddle::framework::ExecutionContext>(
*op, scope, _cpu_ctx, ctx);
_store = std::make_shared<gloo::rendezvous::HashStore>();
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&TripletGeneratorTest::gen_mpc_ctx, i);
}
for (auto& ti : _t) {
ti.join();
}
for (size_t i = 0; i < 2; ++i) {
_t[i] = std::thread(&TripletGeneratorTest::init_ot_and_triplet, i);
}
for (auto& ti : _t) {
ti.join();
}
_s_tensor_factory = std::make_shared<aby3::PaddleTensorFactory>(&_cpu_ctx);
}
static inline std::shared_ptr<paddle::mpc::MeshNetwork> gen_network(size_t idx) {
return std::make_shared<paddle::mpc::MeshNetwork>(idx,
"127.0.0.1",
2,
"test_prefix_privc",
_store);
}
static inline void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<PrivCContext>(idx, net);
}
static inline void init_ot_and_triplet(size_t idx) {
std::shared_ptr<OT> ot = std::make_shared<OT>(_mpc_ctx[idx]);
ot->init();
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_ot(ot);
std::shared_ptr<TripletGenerator<int64_t, SCALING_N>> tripletor
= std::make_shared<TripletGenerator<int64_t, SCALING_N>>(_mpc_ctx[idx]);
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[idx])->set_triplet_generator(tripletor);
}
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return _s_tensor_factory->template create<int64_t>(shape);
}
};
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
return g_ctx_holder::tensor_factory()->template create<int64_t>(shape);
}
TEST_F(TripletGeneratorTest, triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 3);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 3;
for (int i = 0; i < ret[0]->numel() / 3; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t b_idx = num_triplet + i;
uint64_t c_idx = 2 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
TEST_F(TripletGeneratorTest, penta_triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 5);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_penta_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_penta_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 5;
for (int i = 0; i < ret[0]->numel() / 5; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t alpha_idx = num_triplet + i;
uint64_t b_idx = 2 * num_triplet + i;
uint64_t c_idx = 3 * num_triplet + i;
uint64_t alpha_c_idx = 4 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
int64_t alpha_c = fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret1_ptr + b_idx));
// sometimes the difference big than 200
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
EXPECT_NEAR(alpha_c , (*(ret0_ptr + alpha_c_idx) + *(ret1_ptr + alpha_c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
paddle::platform::CPUDeviceContext privc::TripletGeneratorTest::_cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> privc::TripletGeneratorTest::_exec_ctx;
std::shared_ptr<AbstractContext> privc::TripletGeneratorTest::_mpc_ctx[2];
std::shared_ptr<gloo::rendezvous::HashStore> privc::TripletGeneratorTest::_store;
std::thread privc::TripletGeneratorTest::_t[2];
std::shared_ptr<TensorAdapterFactory> privc::TripletGeneratorTest::_s_tensor_factory;
} // namespace privc
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle_tensor.h"
namespace aby3 {
// slice_i = -1 indicate do not slice tensor
// otherwise vectorize tensor.Slice(slice_i, slice_i + 1)
template<typename T>
void TensorToVector(const TensorAdapter<T>* src, std::vector<T>* dst, int slice_i = -1) {
auto& t = dynamic_cast<const PaddleTensor<T>*>(src)->tensor();
if (slice_i == -1) {
paddle::framework::TensorToVector(t, dst);
} else {
auto t_slice = t.Slice(slice_i, slice_i + 1);
paddle::framework::TensorToVector(t_slice, dst);
}
}
} // namespace aby3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册