diff --git a/CMakeLists.txt b/CMakeLists.txt index 539390338f0345be990a76e6b1d161799a045825..d14c3faa9d10b19f060cd468ca55720b8e8894b2 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -61,6 +61,8 @@ option(USE_AES_NI "Compile with AES NI" ON) option(USE_OPENMP "Compile with OpenMP" ON) +option(USE_ABY3_TRUNC1 "Compile with ABY3 truncate 1 algorithm" OFF) + ########################### the project build part ############################### message(STATUS "Using paddlepaddle installation of ${paddle_version}") message(STATUS "paddlepaddle include directory: ${PADDLE_INCLUDE}") @@ -84,6 +86,10 @@ if (USE_OPENMP) find_package(OpenMP REQUIRED) endif(USE_OPENMP) +if (USE_ABY3_TRUNC1) + add_compile_definitions(USE_ABY3_TRUNC1) +endif(USE_ABY3_TRUNC1) + add_subdirectory(core/privc3) add_subdirectory(core/paddlefl_mpc/mpc_protocol) add_subdirectory(core/paddlefl_mpc/operators) diff --git a/core/privc3/fixedpoint_tensor.h b/core/privc3/fixedpoint_tensor.h index 35e21b6e3e48550fe97eeb64bc6f0d9ed9f6a747..3fb2883d76a479af30fee1da67cf65ed980c60da 100644 --- a/core/privc3/fixedpoint_tensor.h +++ b/core/privc3/fixedpoint_tensor.h @@ -191,6 +191,9 @@ public: void max_pooling(FixedPointTensor* ret, BooleanTensor* pos = nullptr) const; + static void truncate(const FixedPointTensor* op, FixedPointTensor* ret, + size_t scaling_factor); + private: static inline std::shared_ptr aby3_ctx() { @@ -201,9 +204,6 @@ private: return paddle::mpc::ContextHolder::tensor_factory(); } - static void truncate(const FixedPointTensor* op, FixedPointTensor* ret, - size_t scaling_factor); - template static void mul_trunc(const FixedPointTensor* lhs, const FixedPointTensor* rhs, diff --git a/core/privc3/fixedpoint_tensor_imp.h b/core/privc3/fixedpoint_tensor_imp.h index cb508288b01c5d5a6625262d24b7d5f1887c5a65..47fc0bf2be18dac71de3f231efff2e2e912729b5 100644 --- a/core/privc3/fixedpoint_tensor_imp.h +++ b/core/privc3/fixedpoint_tensor_imp.h @@ -21,7 +21,6 @@ #include "prng.h" namespace aby3 { - template FixedPointTensor::FixedPointTensor(TensorAdapter* share_tensor[2]) { // TODO: check tensors' shapes @@ -166,6 +165,7 @@ void FixedPointTensor::mul(const FixedPointTensor* rhs, mul_trunc(this, rhs, ret, &TensorAdapter::mul); } +#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1 template void FixedPointTensor::truncate(const FixedPointTensor* op, FixedPointTensor* ret, @@ -208,7 +208,20 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, return; } -// Protocol. `truncate3` +#else // use truncate3 + +// Protocol. `truncate3` (illustrated for data type T = int64_t) +// motivation: +// truncates in aby3 may cause msb error with small probability +// the reason is that before rishft op, its masked value e.g., x' - r' may overflow in int64_t +// so that, in `truncate3`, we limit r' in (-2^62, 2^62) to avoid the problem. + +// notice: +// when r' is contrainted in (-2^62, 2^62), +// the SD (statistical distance) of x' - r' between this +// and r' in Z_{2^64} is equal to |X| / (2^63 + |X|) + +// detail protocol: // P2 randomly generates r' \in (-2^62, 2^62), randomly generates r'_0, r_0, r_1 in Z_{2^64}, // P2 compute r'_1 = r' - r'_0, r_2 = r'/2^N - r_0 - r_1, let x2 = r_2 // P2 send r_0, r'_0 to P0, send r_1, r'_1 to P1 @@ -217,7 +230,7 @@ void FixedPointTensor::truncate(const FixedPointTensor* op, // P0 set x0 = r_0 // P0, P1, P2 invoke reshare() with inputs x0, x1, x2 respectively. template -void FixedPointTensor::truncate3(const FixedPointTensor* op, +void FixedPointTensor::truncate(const FixedPointTensor* op, FixedPointTensor* ret, size_t scaling_factor) { if (scaling_factor == 0) { @@ -231,23 +244,9 @@ void FixedPointTensor::truncate3(const FixedPointTensor* op, temp.emplace_back( tensor_factory()->template create(op->shape())); } - // r', contraint in (-2^62, 2^62) - // notice : when r' is contrainted in (-2^62, 2^62), - // the SD (statistical distance) of x - r' between this - // and r' in Z_{2^64} is equal to |X| / (2^63 + |X|) - // according to http://yuyu.hk/files/ho2.pdf + // r' aby3_ctx()->template gen_random_private(*temp[0]); - int64_t contraint_upper = ~((uint64_t) 1 << 62); - int64_t contraint_low = (uint64_t) 1 << 62; - std::for_each(temp[0]->data(), temp[0]->data() + temp[0]->numel(), - [&contraint_upper, &contraint_low] (T& a) { - // contraint -2^62 < a < 2^62 - if (a >= 0) { - a &= contraint_upper; - } else { - a |= contraint_low; - } - }); + temp[0]->rshift(1, temp[0].get()); //r'_0, r'_1 aby3_ctx()->template gen_random_private(*temp[1]); @@ -307,6 +306,7 @@ void FixedPointTensor::truncate3(const FixedPointTensor* op, tensor_carry_in->scaling_factor() = N; ret->add(tensor_carry_in.get(), ret); } +#endif //USE_ABY3_TRUNC1 template template @@ -345,7 +345,7 @@ void FixedPointTensor::mul_trunc(const FixedPointTensor* lhs, temp->copy(ret_no_trunc->_share[0]); reshare(temp.get(), ret_no_trunc->_share[1]); - truncate3(ret_no_trunc.get(), ret, N); + truncate(ret_no_trunc.get(), ret, N); } template @@ -360,7 +360,7 @@ void FixedPointTensor::mul(const TensorAdapter* rhs, _share[0]->mul(rhs, temp->_share[0]); _share[1]->mul(rhs, temp->_share[1]); - truncate3(temp.get(), ret, rhs->scaling_factor()); + truncate(temp.get(), ret, rhs->scaling_factor()); } template @@ -404,7 +404,7 @@ void FixedPointTensor::mat_mul(const TensorAdapter* rhs, FixedPointTensor* ret) const { _share[0]->mat_mul(rhs, ret->_share[0]); _share[1]->mat_mul(rhs, ret->_share[1]); - truncate3(ret, ret, rhs->scaling_factor()); + truncate(ret, ret, rhs->scaling_factor()); } template< typename T, size_t N> @@ -831,7 +831,7 @@ void FixedPointTensor::long_div(const FixedPointTensor* rhs, } for (size_t i = 1; i <= N; ++i) { - truncate3(&abs_rhs, &sub_rhs, i); + truncate(&abs_rhs, &sub_rhs, i); abs_lhs.gt(&sub_rhs, &cmp_res); cmp_res.mul(&sub_rhs, &sub_rhs); cmp_res.lshift(N - i, &cmp_res); @@ -1184,7 +1184,7 @@ void FixedPointTensor::inverse_square_root(const FixedPointTensor* op, std::shared_ptr> x2 = std::make_shared>(temp[2].get(), temp[3].get()); // x2 = 0.5 * op - truncate3(op, x2.get(), 1); + truncate(op, x2.get(), 1); assign_to_tensor(y->mutable_share(0), (T)(x0 * pow(2, N))); assign_to_tensor(y->mutable_share(1), (T)(x0 * pow(2, N))); diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index 2cf6978294f5542f2a842ad18ea6821609df83e1..c2f83189fc46535932fad8048d720aa408013a95 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -1267,6 +1267,7 @@ TEST_F(FixedTensorTest, mulfixed) { EXPECT_TRUE(test_fixedt_check_tensor_eq(out0.get(), &result)); } +#ifndef USE_ABY3_TRUNC1 //use aby3 trunc1 TEST_F(FixedTensorTest, mulfixed_multi_times) { std::vector shape = {100000, 1}; @@ -1327,6 +1328,7 @@ TEST_F(FixedTensorTest, mulfixed_multi_times) { EXPECT_TRUE(test_fixedt_check_tensor_eq(out1.get(), out2.get())); EXPECT_TRUE(test_fixedt_check_tensor_eq(out0.get(), &result)); } +#endif TEST_F(FixedTensorTest, mulfixed_overflow) { @@ -3435,4 +3437,124 @@ TEST_F(FixedTensorTest, inv_sqrt_test) { } +#ifdef USE_ABY3_TRUNC1 //use aby3 trunc1 +TEST_F(FixedTensorTest, truncate1_msb_incorrect) { + std::vector shape = { 1 }; + std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; + std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), + gen(shape), gen(shape), gen(shape)}; + // lhs = 6 = 1 + 2 + 3, share before truncate + // zero share 0 = (1 << 62) + (1 << 62) - (1 << 63) + sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63); + sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62); + sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62); + + auto pr = gen(shape); + + // rhs = 15 + pr->data()[0] = 6 << 16; + pr->scaling_factor() = 16; + Fix64N16 fl0(sl[0].get(), sl[1].get()); + Fix64N16 fl1(sl[1].get(), sl[2].get()); + Fix64N16 fl2(sl[2].get(), sl[0].get()); + Fix64N16 fout0(sout[0].get(), sout[1].get()); + Fix64N16 fout1(sout[2].get(), sout[3].get()); + Fix64N16 fout2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + Fix64N16::truncate(&fl0, &fout0, 16); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + Fix64N16::truncate(&fl1, &fout1, 16); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + _t[2] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + Fix64N16::truncate(&fl2, &fout2, 16); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + // failed: result is not close to 6 + EXPECT_GT(std::abs((p->data()[0] >> 16) - 6), 1000); +} + +#else +TEST_F(FixedTensorTest, truncate3_msb_correct) { + std::vector shape = { 1 }; + std::shared_ptr> sl[3] = { gen(shape), gen(shape), gen(shape) }; + std::shared_ptr> sout[6] = { gen(shape), gen(shape), gen(shape), + gen(shape), gen(shape), gen(shape)}; + // lhs = 6 = 1 + 2 + 3, share before truncate + // zero share 0 = (1 << 62) + (1 << 62) - (1 << 63) + sl[0]->data()[0] = ((int64_t) 3 << 32) - ((uint64_t) 1 << 63); + sl[1]->data()[0] = ((int64_t) 2 << 32) + ((int64_t) 1 << 62); + sl[2]->data()[0] = ((int64_t) 1 << 32) + ((int64_t) 1 << 62); + + auto pr = gen(shape); + + // rhs = 15 + pr->data()[0] = 6 << 16; + pr->scaling_factor() = 16; + Fix64N16 fl0(sl[0].get(), sl[1].get()); + Fix64N16 fl1(sl[1].get(), sl[2].get()); + Fix64N16 fl2(sl[2].get(), sl[0].get()); + Fix64N16 fout0(sout[0].get(), sout[1].get()); + Fix64N16 fout1(sout[2].get(), sout[3].get()); + Fix64N16 fout2(sout[4].get(), sout[5].get()); + + auto p = gen(shape); + + _t[0] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[0], [&](){ + Fix64N16::truncate(&fl0, &fout0, 16); + fout0.reveal_to_one(0, p.get()); + }); + } + ); + _t[1] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[1], [&](){ + Fix64N16::truncate(&fl1, &fout1, 16); + fout1.reveal_to_one(0, nullptr); + }); + } + ); + _t[2] = std::thread( + [&] () { + g_ctx_holder::template run_with_context( + _exec_ctx.get(), _mpc_ctx[2], [&](){ + Fix64N16::truncate(&fl2, &fout2, 16); + fout2.reveal_to_one(0, nullptr); + }); + } + ); + for (auto &t: _t) { + t.join(); + } + EXPECT_EQ((p->data()[0] >> 16), 6); +} +#endif + } // namespace aby3