提交 53efcf70 编写于 作者: Y yangqingyou

add penta triplet and mat mul

上级 a8576fd5
......@@ -94,16 +94,10 @@ public:
mul_impl<T>(rhs, ret, Type2Type<T>());
}
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
}
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
// element-wise mul with TensorAdapter
void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const {
mul_impl<T>(rhs, ret, Type2Type<T>());
}
// div by TensorAdapter
void div(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
......@@ -112,7 +106,9 @@ public:
void sum(FixedPointTensor* ret) const;
// mat_mul with FixedPointTensor
void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const {
mat_mul_impl<T>(rhs, ret, Type2Type<T>());
}
// mat_mul with TensorAdapter
void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
......@@ -195,6 +191,33 @@ private:
static size_t next_party() {
return privc_ctx()->next_party();
}
static inline AbstractNetwork* net() {
return privc_ctx()->network();
}
// mul_impl with FixedPointTensor
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
}
template<typename T_>
void mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
// mul_impl with TensorAdapter
template<typename T_>
void mul_impl(const TensorAdapter<T>* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
}
template<typename T_>
void mul_impl(const TensorAdapter<T>* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
// mat_mul_impl with FixedPointTensor
template<typename T_>
void mat_mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<T_>) const {
PADDLE_THROW("type except `int64_t` for fixedtensor mul is not implemented yet");
}
template<typename T_>
void mat_mul_impl(const FixedPointTensor* rhs, FixedPointTensor* ret, Type2Type<int64_t>) const;
TensorAdapter<T>* _share;
......
......@@ -19,6 +19,7 @@
#include "paddle/fluid/platform/enforce.h"
#include "../privc3/prng.h"
#include "../privc3/paddle_tensor.h"
namespace privc {
......@@ -178,4 +179,230 @@ void FixedPointTensor<T, N>::mul_impl(const FixedPointTensor<T, N>* rhs,
}
}
template<typename T, size_t N>
template<typename T_>
void FixedPointTensor<T, N>::mul_impl(const TensorAdapter<T>* rhs,
FixedPointTensor<T, N>* ret,
const Type2Type<int64_t>) const {
fixed64_tensor_mult<N>(share(), rhs, ret->mutable_share());
}
template< typename T, size_t N>
void FixedPointTensor<T, N>::div(const TensorAdapter<T>* rhs,
FixedPointTensor<T, N>* ret) const {
auto temp = tensor_factory()->template create<T>(this->shape());
double scale = std::pow(2, N);
auto inverse = [scale](T d) -> T {
return 1.0 * scale / d * scale; };
std::transform(rhs->data(), rhs->data() + rhs->numel(),
temp->data(), inverse);
this->mul(temp.get(), ret);
}
template<typename T, size_t N>
void FixedPointTensor<T, N>::sum(FixedPointTensor* ret) const {
PADDLE_ENFORCE_EQ(ret->numel(), 1, "output size should be 1.");
T sum = (T) 0;
for (int i = 0; i < numel(); ++i) {
sum += *(share()->data() + i);
}
*(ret->mutable_share()->data()) = sum;
}
template<typename T, size_t N>
template<typename T_>
void FixedPointTensor<T, N>::mat_mul_impl(const FixedPointTensor<T, N>* rhs,
FixedPointTensor<T, N>* ret,
const Type2Type<int64_t>) const {
// A * B, assume A.shape = [a, b], B.shape = [b, c]
size_t row = ret->shape()[0];
size_t col = ret->shape()[1];
PADDLE_ENFORCE_EQ(row, shape()[0], "invalid result shape for mat mul");
PADDLE_ENFORCE_EQ(col, rhs->shape()[1], "invalid result shape for mat mul");
PADDLE_ENFORCE_EQ(shape()[1], rhs->shape()[0], "invalid input shape for mat mul");
//transpose rhs
auto shape_trans = rhs->shape();
std::swap(shape_trans[0], shape_trans[1]);
auto trans_rhs = tensor_factory()->template create<T>(shape_trans);
const aby3::PaddleTensor<T>* p_rhs = dynamic_cast<const aby3::PaddleTensor<T>*>(rhs->share());
const_cast<aby3::PaddleTensor<T>*>(p_rhs)->template Transpose<2>(std::vector<int>({1, 0}), trans_rhs.get());
//get penta triplet, shape = [5, a, c, b]
std::vector<size_t> penta_triplet_shape{5, shape()[0], shape_trans[0], shape_trans[1]};
auto penta_triplet = tensor_factory()->template create<T>(penta_triplet_shape);
tripletor()->get_penta_triplet(penta_triplet.get());
// get triplet[idx0][idx1][idx2], shape = [b]
auto access_triplet = [&penta_triplet, &penta_triplet_shape](size_t idx0,
size_t idx1,
size_t idx2,
TensorAdapter<T>* ret) {
size_t numel = penta_triplet->numel();
auto& shape = penta_triplet_shape;
int64_t* tripl_ptr = penta_triplet->data();
size_t cal_idx_begin = idx0 * numel / shape[0]
+ idx1 * numel / (shape[0] * shape[1])
+ idx2 * numel / (shape[0] * shape[1] * shape[2]);
std::copy(tripl_ptr + cal_idx_begin,
tripl_ptr + cal_idx_begin + shape[3],
ret->data());
};
auto slice_and_reshape = [](const TensorAdapter<T>* input, int idx, TensorAdapter<T>* ret) {
input->slice(idx, idx + 1, ret);
auto shape = ret->shape();
shape.erase(shape.begin());
ret->reshape(shape);
};
std::vector<int64_t> buffer_e;
std::vector<int64_t> buffer_f;
buffer_e.resize(col * row * shape()[1]);
buffer_f.resize(col * row * shape()[1]);
int64_t* buffer_e_ptr = buffer_e.data();
int64_t* buffer_f_ptr = buffer_f.data();
// cal share <e>, <f>
for (int i = 0; i < row; ++i) {
auto lhs_v = tensor_factory()->template create<T>({shape()[1]});
slice_and_reshape(share(), i, lhs_v.get());
for (int j = 0; j < col; ++j) {
std::vector<size_t> shape_v{ shape()[1] };
std::vector<std::shared_ptr<TensorAdapter<T>>> temp_triplet_i_j;
for (int k = 0; k < 5; ++k) {
temp_triplet_i_j.emplace_back(
tensor_factory()->template create<T>(shape_v));
}
auto& a_i_j = temp_triplet_i_j[0];
auto& alpha_i_j = temp_triplet_i_j[1];
auto& b_i_j = temp_triplet_i_j[2];
auto& c_i_j = temp_triplet_i_j[3];
auto& alpha_c_i_j = temp_triplet_i_j[4];
access_triplet(0, i, j / 2, a_i_j.get());
access_triplet(1, i, j / 2, alpha_i_j.get());
access_triplet(2, i, j / 2, b_i_j.get());
access_triplet(3, i, j / 2, c_i_j.get());
access_triplet(4, i, j / 2, alpha_c_i_j.get());
auto e_v = tensor_factory()->template create<T>(shape_v);
auto f_v = tensor_factory()->template create<T>(shape_v);
auto rhs_v = tensor_factory()->template create<T>(shape_v);
slice_and_reshape(trans_rhs.get(), j, rhs_v.get());
if (j % 2 == 0) {
lhs_v->sub(a_i_j.get(), e_v.get());
} else {
lhs_v->sub(alpha_i_j.get(), e_v.get());
}
rhs_v->sub(b_i_j.get(), f_v.get());
std::copy(e_v->data(), e_v->data() + shape_v[0], buffer_e_ptr);
std::copy(f_v->data(), f_v->data() + shape_v[0], buffer_f_ptr);
buffer_e_ptr += shape_v[0];
buffer_f_ptr += shape_v[0];
}
}
// reveal all e and f
std::vector<int64_t> remote_buffer_e;
std::vector<int64_t> remote_buffer_f;
remote_buffer_e.resize(col * row * shape()[1]);
remote_buffer_f.resize(col * row * shape()[1]);
if (party() == 0) {
net()->send(next_party(), buffer_e.data(), buffer_e.size() * sizeof(int64_t));
net()->send(next_party(), buffer_f.data(), buffer_f.size() * sizeof(int64_t));
net()->recv(next_party(), remote_buffer_e.data(), remote_buffer_e.size() * sizeof(int64_t));
net()->recv(next_party(), remote_buffer_f.data(), remote_buffer_f.size() * sizeof(int64_t));
} else {
net()->recv(next_party(), remote_buffer_e.data(), remote_buffer_e.size() * sizeof(int64_t));
net()->recv(next_party(), remote_buffer_f.data(), remote_buffer_f.size() * sizeof(int64_t));
net()->send(next_party(), buffer_e.data(), buffer_e.size() * sizeof(int64_t));
net()->send(next_party(), buffer_f.data(), buffer_f.size() * sizeof(int64_t));
}
std::vector<int64_t> e;
std::vector<int64_t> f;
e.resize(col * row * shape()[1]);
f.resize(col * row * shape()[1]);
std::transform(buffer_e.begin(), buffer_e.end(),
remote_buffer_e.begin(), e.begin(),
std::plus<int64_t>());
std::transform(buffer_f.begin(), buffer_f.end(),
remote_buffer_f.begin(), f.begin(),
std::plus<int64_t>());
int64_t* e_ptr = e.data();
int64_t* f_ptr = f.data();
auto result = tensor_factory()->template create<T>(ret->shape());
int64_t* res_ptr = result->data();
// cal z = f<a> + e<b> + <c> or z = ef + f<a> + e<b> + <c>
for (int i = 0; i < row; ++i) {
for (int j = 0; j < col; ++j) {
std::vector<size_t> shape_v{ shape()[1] };
std::vector<std::shared_ptr<TensorAdapter<T>>> temp_triplet_i_j;
for (int k = 0; k < 5; ++k) {
temp_triplet_i_j.emplace_back(
tensor_factory()->template create<T>(shape_v));
}
auto& a_i_j = temp_triplet_i_j[0];
auto& alpha_i_j = temp_triplet_i_j[1];
auto& b_i_j = temp_triplet_i_j[2];
auto& c_i_j = temp_triplet_i_j[3];
auto& alpha_c_i_j = temp_triplet_i_j[4];
access_triplet(0, i, j / 2, a_i_j.get());
access_triplet(1, i, j / 2, alpha_i_j.get());
access_triplet(2, i, j / 2, b_i_j.get());
access_triplet(3, i, j / 2, c_i_j.get());
access_triplet(4, i, j / 2, alpha_c_i_j.get());
auto e_v = tensor_factory()->template create<T>(shape_v);
auto f_v = tensor_factory()->template create<T>(shape_v);
std::copy(e_ptr, e_ptr + shape_v[0], e_v->data());
std::copy(f_ptr, f_ptr + shape_v[0], f_v->data());
e_ptr += shape_v[0];
f_ptr += shape_v[0];
auto z_v = tensor_factory()->template create<T>(shape_v);
fixed64_tensor_mult<N>(e_v.get(), b_i_j.get(), z_v.get());
if (party() == 0) {
auto ef = tensor_factory()->template create<T>(shape_v);
fixed64_tensor_mult<N>(e_v.get(), f_v.get(), ef.get());
z_v->add(ef.get(), z_v.get());
}
auto fa = tensor_factory()->template create<T>(shape_v);
if (j % 2 == 0) {
fixed64_tensor_mult<N>(f_v.get(), a_i_j.get(), fa.get());
z_v->add(c_i_j.get(), z_v.get());
} else {
fixed64_tensor_mult<N>(f_v.get(), alpha_i_j.get(), fa.get());
z_v->add(alpha_c_i_j.get(), z_v.get());
}
z_v->add(fa.get(), z_v.get());
auto sum_v = [&z_v] () -> int64_t {
int64_t sum = 0;
for (int i = 0; i < z_v->numel(); ++i) {
sum += *(z_v->data() + i);
}
return sum;
};
*res_ptr = sum_v();
++res_ptr;
}
}
result->copy(ret->mutable_share());
}
} // namespace privc
......@@ -426,7 +426,62 @@ TEST_F(FixedTensorTest, triplet) {
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), 20);
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
TEST_F(FixedTensorTest, penta_triplet) {
std::vector<size_t> shape = { 1 };
auto shape_triplet = shape;
shape_triplet.insert(shape_triplet.begin(), 5);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = {gen(shape_triplet), gen(shape_triplet)};
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[0])
->triplet_generator()->get_penta_triplet(ret[0].get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
std::dynamic_pointer_cast<PrivCContext>(_mpc_ctx[1])
->triplet_generator()->get_penta_triplet(ret[1].get());
});
}
);
for (auto &t: _t) {
t.join();
}
auto num_triplet = ret[0]->numel() / 5;
for (int i = 0; i < ret[0]->numel() / 5; ++i) {
auto ret0_ptr = ret[0]->data();
auto ret1_ptr = ret[1]->data();
uint64_t a_idx = i;
uint64_t alpha_idx = num_triplet + i;
uint64_t b_idx = 2 * num_triplet + i;
uint64_t c_idx = 3 * num_triplet + i;
uint64_t alpha_c_idx = 4 * num_triplet + i;
int64_t c = fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + a_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + a_idx), *(ret1_ptr + b_idx));
int64_t alpha_c = fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret0_ptr + alpha_idx), *(ret1_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret0_ptr + b_idx))
+ fixed64_mult<SCALING_N>(*(ret1_ptr + alpha_idx), *(ret1_ptr + b_idx));
// sometimes the difference big than 200
EXPECT_NEAR(c , (*(ret0_ptr + c_idx) + *(ret1_ptr + c_idx)), std::pow(2, SCALING_N * 0.00001));
EXPECT_NEAR(alpha_c , (*(ret0_ptr + alpha_c_idx) + *(ret1_ptr + alpha_c_idx)), std::pow(2, SCALING_N * 0.00001));
}
}
......@@ -475,4 +530,239 @@ TEST_F(FixedTensorTest, mulfixed) {
EXPECT_NEAR(4, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
}
TEST_F(FixedTensorTest, mulfixed_upper_bound) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 2^16
// rhs = 2^16
sl[0]->data()[0] = (int64_t)1 << (SCALING_N + 15);
sl[1]->data()[0] = (int64_t)1 << (SCALING_N + 15);
sr[0]->data()[0] = (int64_t)1 << (SCALING_N + 15);
sr[1]->data()[0] = (int64_t)1 << (SCALING_N + 15);
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fr0(sr[0].get());
Fix64N32 fr1(sr[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.mul(&fr0, &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.mul(&fr1, &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_NEAR(0, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
}
TEST_F(FixedTensorTest, mulplain) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr = gen(shape);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 2 = 1 + 1
// rhs = 2 = 1 + 1
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)1 << SCALING_N;
sr->data()[0] = (int64_t)2 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.mul(sr.get(), &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.mul(sr.get(), &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_NEAR(4, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
}
TEST_F(FixedTensorTest, sum) {
std::vector<size_t> shape = { 2 };
std::vector<size_t> shape_ret = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape_ret), gen(shape_ret) };
// lhs = (3, 3)
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[0]->data()[1] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)2 << SCALING_N;
sl[1]->data()[1] = (int64_t)2 << SCALING_N;
auto p = gen(shape_ret);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.sum(&fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.sum(&fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_EQ(6, p->data()[0] / std::pow(2, SCALING_N));
}
TEST_F(FixedTensorTest, divplain) {
std::vector<size_t> shape = { 1 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr = gen(shape);
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = 2 = 1 + 1
// rhs = 4
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[0] = (int64_t)1 << SCALING_N;
sr->data()[0] = (int64_t)4 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.div(sr.get(), &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.div(sr.get(), &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_NEAR(0.5, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
}
TEST_F(FixedTensorTest, mat_mulfixed) {
std::vector<size_t> shape = { 2, 2 };
std::shared_ptr<TensorAdapter<int64_t>> sl[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> sr[2] = { gen(shape), gen(shape) };
std::shared_ptr<TensorAdapter<int64_t>> ret[2] = { gen(shape), gen(shape) };
// lhs = [2, 3, 4, 5]
sl[0]->data()[0] = (int64_t)1 << SCALING_N;
sl[0]->data()[1] = (int64_t)2 << SCALING_N;
sl[0]->data()[2] = (int64_t)3 << SCALING_N;
sl[0]->data()[3] = (int64_t)4 << SCALING_N;
sl[1]->data()[0] = (int64_t)1 << SCALING_N;
sl[1]->data()[1] = (int64_t)1 << SCALING_N;
sl[1]->data()[2] = (int64_t)1 << SCALING_N;
sl[1]->data()[3] = (int64_t)1 << SCALING_N;
// rhs = [0, -1, -2, -3]
sr[0]->data()[0] = (int64_t)-1 << SCALING_N;
sr[0]->data()[1] = (int64_t)-2 << SCALING_N;
sr[0]->data()[2] = (int64_t)-3 << SCALING_N;
sr[0]->data()[3] = (int64_t)-4 << SCALING_N;
sr[1]->data()[0] = (int64_t)1 << SCALING_N;
sr[1]->data()[1] = (int64_t)1 << SCALING_N;
sr[1]->data()[2] = (int64_t)1 << SCALING_N;
sr[1]->data()[3] = (int64_t)1 << SCALING_N;
auto p = gen(shape);
Fix64N32 fl0(sl[0].get());
Fix64N32 fl1(sl[1].get());
Fix64N32 fr0(sr[0].get());
Fix64N32 fr1(sr[1].get());
Fix64N32 fout0(ret[0].get());
Fix64N32 fout1(ret[1].get());
_t[0] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[0], [&](){
fl0.mat_mul(&fr0, &fout0);
fout0.reveal_to_one(0, p.get());
});
}
);
_t[1] = std::thread(
[&] () {
g_ctx_holder::template run_with_context(
_exec_ctx.get(), _mpc_ctx[1], [&](){
fl1.mat_mul(&fr1, &fout1);
fout1.reveal_to_one(0, nullptr);
});
}
);
for (auto &t: _t) {
t.join();
}
EXPECT_NEAR(-6, p->data()[0] / std::pow(2, SCALING_N), 0.00001);
EXPECT_NEAR(-11, p->data()[1] / std::pow(2, SCALING_N), 0.00001);
EXPECT_NEAR(-10, p->data()[2] / std::pow(2, SCALING_N), 0.00001);
EXPECT_NEAR(-19, p->data()[3] / std::pow(2, SCALING_N), 0.00001);
}
} // namespace privc
......@@ -46,15 +46,6 @@ public:
PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete;
/*
block get_private_block() {
std::array<int64_t, 2> ret_block;
ret_block[0] = gen_random_private<int64_t>();
ret_block[1] = gen_random_private<int64_t>();
return *(reinterpret_cast<block*>(ret_block.data()));
}
*/
void set_triplet_generator(std::shared_ptr<TripletGenerator<int64_t, SCALING_N>>& tripletor);
......
......@@ -110,15 +110,17 @@ public:
virtual void get_triplet(TensorAdapter<T>* ret);
virtual void get_penta_triplet(TensorAdapter<T>* ret) {}
virtual void get_penta_triplet(TensorAdapter<T>* ret);
std::queue<std::array<T, 3>> _triplet_buffer;
std::queue<std::array<T, 5>> _penta_triplet_buffer;
static const size_t _s_triplet_step = 1 << 8;
static constexpr double _s_fixed_point_compensation = 0.3;
static const size_t OT_SIZE = sizeof(block) * 8;
protected:
// T = int64
// dummy type for specilize template method
template<typename T_>
class Type2Type {
typedef T_ type;
......@@ -135,6 +137,17 @@ protected:
template<typename T__>
void fill_triplet_buffer_impl(const Type2Type<int64_t>);
void fill_penta_triplet_buffer() { fill_penta_triplet_buffer_impl<T>(Type2Type<T>()); }
template<typename T__>
void fill_penta_triplet_buffer_impl(const Type2Type<T__>) {
PADDLE_THROW("type except `int64_t` for generating triplet is not implemented yet");
}
// specialize template method by overload
template<typename T__>
void fill_penta_triplet_buffer_impl(const Type2Type<int64_t>);
private:
std::shared_ptr<AbstractContext> privc_ctx() {
......@@ -153,6 +166,10 @@ private:
}
// gen triplet for int64_t type
std::vector<uint64_t> gen_product(const std::vector<uint64_t> &input);
std::vector<std::pair<uint64_t, uint64_t>> gen_product(size_t ot_sender,
const std::vector<uint64_t> &input0,
const std::vector<uint64_t> &input1
= std::vector<uint64_t>());
const block _base_ot_choices;
......
......@@ -86,7 +86,25 @@ void TripletGenerator<T, N>::get_triplet(TensorAdapter<T>* ret) {
*(ret_ptr + 2 * num_trip) = triplet[2];
_triplet_buffer.pop();
}
}
template<typename T, size_t N>
void TripletGenerator<T, N>::get_penta_triplet(TensorAdapter<T>* ret) {
size_t num_trip = ret->numel() / 5;
if (_triplet_buffer.size() < num_trip) {
fill_penta_triplet_buffer();
}
for (int i = 0; i < num_trip; ++i) {
auto triplet = _penta_triplet_buffer.front();
auto ret_ptr = ret->data() + i;
*ret_ptr = triplet[0];
*(ret_ptr + num_trip) = triplet[1];
*(ret_ptr + 2 * num_trip) = triplet[2];
*(ret_ptr + 3 * num_trip) = triplet[3];
*(ret_ptr + 4 * num_trip) = triplet[4];
_triplet_buffer.pop();
}
}
template<typename T, size_t N>
......@@ -122,6 +140,53 @@ void TripletGenerator<T, N>::fill_triplet_buffer_impl(const Type2Type<int64_t>)
}
}
template<typename T, size_t N>
template<typename T_>
void TripletGenerator<T, N>::fill_penta_triplet_buffer_impl(const Type2Type<int64_t>) {
std::vector<uint64_t> a(_s_triplet_step);
std::vector<uint64_t> b(_s_triplet_step);
std::vector<uint64_t> alpha(_s_triplet_step);
std::for_each(a.data(), a.data() + a.size(),
[this](uint64_t& val) {
val = privc_ctx()-> template gen_random_private<uint64_t>(); });
std::for_each(b.data(), b.data() + b.size(),
[this](uint64_t& val) {
val = privc_ctx()-> template gen_random_private<uint64_t>(); });
std::for_each(alpha.data(), alpha.data() + alpha.size(),
[this](uint64_t& val) {
val = privc_ctx()-> template gen_random_private<uint64_t>(); });
std::vector<std::pair<uint64_t, uint64_t>> ab0;
std::vector<std::pair<uint64_t, uint64_t>> ab1;
std::function<std::vector<std::pair<uint64_t, uint64_t>>(size_t, const std::vector<uint64_t>&, const std::vector<uint64_t>&)> gen_p_2arg
= [this](size_t p, const std::vector<uint64_t>& v0, const std::vector<uint64_t>& v1) {
return gen_product(p, v0, v1); };
std::function<std::vector<std::pair<uint64_t, uint64_t>>(size_t, const std::vector<uint64_t>&)> gen_p_1arg
= [this](size_t p, const std::vector<uint64_t>& v) {
return gen_product(p, v); };
if (party() == 0) {
ab0 = gen_p_2arg(0, a, alpha);
ab1 = gen_p_1arg(1, b);
} else {
ab0 = gen_p_1arg(0, b);
ab1 = gen_p_2arg(1, a, alpha);
}
for (uint64_t i = 0; i < a.size(); i += 1) {
std::array<int64_t, 5> item = {
static_cast<int64_t>(a[i]),
static_cast<int64_t>(alpha[i]),
static_cast<int64_t>(b[i]),
static_cast<int64_t>(fixed64_mult<N>(a[i], b[i]) + ab0[i].first + ab1[i].first),
static_cast<int64_t>(fixed64_mult<N>(alpha[i], b[i]) + ab0[i].second + ab1[i].second)};
_penta_triplet_buffer.push(std::move(item));
}
}
template<typename T, size_t N>
std::vector<uint64_t> TripletGenerator<T, N>::gen_product(
const std::vector<uint64_t> &input) {
......@@ -141,7 +206,6 @@ std::vector<uint64_t> TripletGenerator<T, N>::gen_product(
for (uint64_t idx = 0; idx < word_width; idx += 1) {
const block& round_ot_mask = ot_mask.at(ot_mask_idx);
//net()->recv(next_party(), &round_ot_mask, sizeof(block));
// bad naming from ot extention
block q = _ot_ext_sender.get_ot_instance();
......@@ -204,4 +268,98 @@ std::vector<uint64_t> TripletGenerator<T, N>::gen_product(
return ret;
}
template<typename T, size_t N>
std::vector<std::pair<uint64_t, uint64_t>>
TripletGenerator<T, N>::gen_product(size_t ot_sender,
const std::vector<uint64_t> &input0,
const std::vector<uint64_t> &input1) {
size_t word_width = 8 * sizeof(uint64_t);
std::vector<std::pair<uint64_t, uint64_t>> ret;
if (party() == ot_sender) {
std::vector<std::pair<uint64_t, uint64_t>> s1_buffer;
std::vector<block> ot_mask;
auto size = std::min(input0.size(), input1.size());
ot_mask.resize(size * word_width);
net()->recv(next_party(), ot_mask.data(), sizeof(block) * ot_mask.size());
size_t ot_mask_idx = 0;
for (auto a_iter = input0.cbegin(), alpha_iter = input1.cbegin();
a_iter < input0.cend() && alpha_iter < input1.cend();
++a_iter, ++alpha_iter) {
uint64_t ret_val[2] = {0};
for (uint64_t idx = 0; idx < word_width; idx += 1) {
const block& round_ot_mask = ot_mask.at(ot_mask_idx);
// bad naming from ot extention
block q = _ot_ext_sender.get_ot_instance();
q ^= (round_ot_mask & _base_ot_choices);
auto s = psi::hash_blocks({q, q ^ _base_ot_choices});
uint64_t* s0 = reinterpret_cast<uint64_t *>(&s.first);
uint64_t* s1 = reinterpret_cast<uint64_t *>(&s.second);
s1[0] ^= lshift<N>(*a_iter, idx) - s0[0];
s1[1] ^= lshift<N>(*alpha_iter, idx) - s0[1];
s1_buffer.emplace_back(std::make_pair(s1[0], s1[1]));
ret_val[0] += s0[0];
ret_val[1] += s0[1];
ot_mask_idx++;
}
ret.emplace_back(std::make_pair(ret_val[0], ret_val[1]));
}
net()->send(next_party(), s1_buffer.data(), sizeof(std::pair<uint64_t, uint64_t>) * s1_buffer.size());
} else { // as ot recver
std::vector<block> ot_masks;
std::vector<block> t0_buffer;
gen_ot_masks(_ot_ext_recver, input0, ot_masks, t0_buffer);
net()->send(next_party(), ot_masks.data(), sizeof(block) * ot_masks.size());
std::vector<std::pair<uint64_t, uint64_t>> ot_msg;
ot_msg.resize(input0.size() * word_width);
net()->recv(next_party(), ot_msg.data(), sizeof(std::pair<uint64_t, uint64_t>) * ot_msg.size());
size_t ot_msg_idx = 0;
uint64_t b_idx = 0;
for (const auto &b: input0) {
uint64_t ret_val[2] = {0};
int b_weight = 0;
for (size_t idx = 0; idx < word_width; idx += 1) {
const std::pair<uint64_t, uint64_t>& round_ot_msg = ot_msg.at(ot_msg_idx);
auto t0_hash = psi::hash_block(t0_buffer[b_idx * word_width + idx]);
uint64_t* key = reinterpret_cast<uint64_t *>(&t0_hash);
bool b_i = (b >> idx) & 1;
b_weight += b_i;
ret_val[0] += b_i ? round_ot_msg.first ^ key[0] : -key[0];
ret_val[1] += b_i ? round_ot_msg.second ^ key[1] : -key[1];
ot_msg_idx++;
}
// compensation for precision loss
uint64_t loss = _s_fixed_point_compensation * b_weight;
ret.emplace_back(std::make_pair(ret_val[0] + loss, ret_val[1] + loss));
b_idx += 1;
}
}
return ret;
}
} // namespace privc
......@@ -21,6 +21,8 @@
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
#include "core/paddlefl_mpc/operators/math/math_function.h"
#include "tensor_adapter.h"
#include "tensor_adapter_factory.h"
......@@ -47,6 +49,14 @@ public:
const T *data() const override { return _tensor.data<T>(); }
const paddle::framework::Tensor* paddle_tensor() const {
return &_tensor;
}
paddle::framework::Tensor* paddle_tensor() {
return &_tensor;
}
std::vector<size_t> shape() const override {
return paddle::framework::vectorize<size_t>(_tensor.dims());
}
......@@ -109,6 +119,15 @@ public:
const std::vector<size_t> &shape,
size_t scaling_factor);
template<int Rank>
void Transpose(const std::vector<int> axis, TensorAdapter<T>* ret) {
paddle::operators::math::Transpose<paddle::platform::CPUDeviceContext, T, Rank> trans;
trans(*(dynamic_cast<const paddle::platform::CPUDeviceContext*>(_device_ctx)),
_tensor,
dynamic_cast<PaddleTensor<T>*>(ret)->paddle_tensor(),
axis);
}
private:
paddle::platform::Place place() const { return _device_ctx->GetPlace(); }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册