diff --git a/core/privc3/fixedpoint_tensor.h b/core/privc3/fixedpoint_tensor.h index 35e21b6e3e48550fe97eeb64bc6f0d9ed9f6a747..3fb2883d76a479af30fee1da67cf65ed980c60da 100644 --- a/core/privc3/fixedpoint_tensor.h +++ b/core/privc3/fixedpoint_tensor.h @@ -191,6 +191,9 @@ public: void max_pooling(FixedPointTensor* ret, BooleanTensor* pos = nullptr) const; + static void truncate(const FixedPointTensor* op, FixedPointTensor* ret, + size_t scaling_factor); + private: static inline std::shared_ptr aby3_ctx() { @@ -201,9 +204,6 @@ private: return paddle::mpc::ContextHolder::tensor_factory(); } - static void truncate(const FixedPointTensor* op, FixedPointTensor* ret, - size_t scaling_factor); - template static void mul_trunc(const FixedPointTensor* lhs, const FixedPointTensor* rhs, diff --git a/core/privc3/fixedpoint_tensor_imp.h b/core/privc3/fixedpoint_tensor_imp.h index adf660a7d55d2282b65b35fc41b30d0101513115..404694e787c1326b61dd6b8131099ddec8b6bf50 100644 --- a/core/privc3/fixedpoint_tensor_imp.h +++ b/core/privc3/fixedpoint_tensor_imp.h @@ -14,6 +14,7 @@ #pragma once +#include #include #include @@ -21,7 +22,6 @@ #include "prng.h" namespace aby3 { - template FixedPointTensor::FixedPointTensor(TensorAdapter* share_tensor[2]) { // TODO: check tensors' shapes @@ -245,15 +245,12 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, temp.emplace_back( tensor_factory()->template create(op->shape())); } - // r', contraint in (constraint_low, contraint_upper) + // r' aby3_ctx()->template gen_random_private(*temp[0]); - T contraint_upper = (T) 1 << (sizeof(T) * 8 - 2); - T contraint_low = - contraint_upper; std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), - [&contraint_upper, &contraint_low] (T& a) { - while ((a > contraint_upper || a < contraint_low)) { - a = aby3_ctx()->template gen_random_private(); - } + [] (T& a) { + a = (T) (a * std::pow(2, sizeof(T) * 8 - 2) + / std::numeric_limits::max()); }); //r'_0, r'_1 diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index 8b09cbd735275e7544bb51dc02866386d478f611..11feb42870533b0fe04139e984dd9ced12e9df9a 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -3437,4 +3437,124 @@ TEST_F(FixedTensorTest, inv_sqrt_test) { } +#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1 +TEST_F(FixedTensorTest, truncate1_msb_failed) { + std::vector shape = { 1 }; + std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; + std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), + gen(shape), gen(shape), gen(shape)}; + // lhs = 6 = 1 + 2 + 3, share before truncate + // zero share 0 = (1 << 62) + (1 << 62) - (1 << 63) + sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63); + sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62); + sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62); + + auto pr = gen(shape); + + // rhs = 15 + pr->data()[0] = 6 << 16; + pr->scaling_factor() = 16; + Fix64N16 fl0(sl[0].get(), sl[1].get()); + Fix64N16 fl1(sl[1].get(), sl[2].get()); + Fix64N16 fl2(sl[2].get(), sl[0].get()); + Fix64N16 fout0(sout[0].get(), sout[1].get()); + Fix64N16 fout1(sout[2].get(), sout[3].get()); + Fix64N16 fout2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + Fix64N16::truncate(&fl0, &fout0, 16); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + Fix64N16::truncate(&fl1, &fout1, 16); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + _t[2] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + Fix64N16::truncate(&fl2, &fout2, 16); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + // failed: result is not close to 6 + EXPECT_GT(std::abs((p->data()[0] >> 16) - 6), 1000); +} + +#else +TEST_F(FixedTensorTest, truncate3_msb_not_failed) { + std::vector shape = { 1 }; + std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; + std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), + gen(shape), gen(shape), gen(shape)}; + // lhs = 6 = 1 + 2 + 3, share before truncate + // zero share 0 = (1 << 62) + (1 << 62) - (1 << 63) + sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63); + sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62); + sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62); + + auto pr = gen(shape); + + // rhs = 15 + pr->data()[0] = 6 << 16; + pr->scaling_factor() = 16; + Fix64N16 fl0(sl[0].get(), sl[1].get()); + Fix64N16 fl1(sl[1].get(), sl[2].get()); + Fix64N16 fl2(sl[2].get(), sl[0].get()); + Fix64N16 fout0(sout[0].get(), sout[1].get()); + Fix64N16 fout1(sout[2].get(), sout[3].get()); + Fix64N16 fout2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + Fix64N16::truncate(&fl0, &fout0, 16); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + Fix64N16::truncate(&fl1, &fout1, 16); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + _t[2] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + Fix64N16::truncate(&fl2, &fout2, 16); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ((p->data()[0] >> 16), 6); +} +#endif + } // namespace aby3