// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include namespace aby3 { template size_t BooleanTensor::pre_party() const { return aby3_ctx()->pre_party(); } template size_t BooleanTensor::next_party() const { return aby3_ctx()->next_party(); } template size_t BooleanTensor::party() const { return aby3_ctx()->party(); } template BooleanTensor::BooleanTensor(TensorAdapter *tensor[2]) { // TODO: check if tensor shape equal _share[0] = tensor[0]; _share[1] = tensor[1]; } template BooleanTensor::BooleanTensor(TensorAdapter *tensor0, TensorAdapter *tensor1) { // TODO: check if tensor shape equal _share[0] = tensor0; _share[1] = tensor1; } template BooleanTensor::BooleanTensor() {} template TensorAdapter *BooleanTensor::share(size_t idx) { // TODO: check if idx < 2 return _share[idx]; } template const TensorAdapter *BooleanTensor::share(size_t idx) const { // TODO: check if idx < 2 return _share[idx]; } template void BooleanTensor::reveal_to_one(size_t party_num, TensorAdapter *ret) const { if (party_num == party()) { // TODO: check if tensor shape equal // incase of this and ret shares tensor ptr auto buffer = tensor_factory()->template create(ret->shape()); aby3_ctx()->network()->template recv(pre_party(), *buffer); share(0)->bitwise_xor(buffer.get(), ret); share(1)->bitwise_xor(ret, ret); } else if (party_num == next_party()) { aby3_ctx()->network()->template send(party_num, *share(0)); } } template void BooleanTensor::reveal(TensorAdapter *ret) const { for (size_t idx = 0; idx < 3; ++idx) { reveal_to_one(idx, ret); } } template const std::vector BooleanTensor::shape() const { if (share(0)) { return share(0)->shape(); } else { return std::vector(); } } template size_t BooleanTensor::numel() const { if (share(0)) { return share(0)->numel(); } else { 0; } } template void BooleanTensor::bitwise_xor(const BooleanTensor *rhs, BooleanTensor *ret) const { share(0)->bitwise_xor(rhs->share(0), ret->share(0)); share(1)->bitwise_xor(rhs->share(1), ret->share(1)); } template void BooleanTensor::bitwise_xor(const TensorAdapter *rhs, BooleanTensor *ret) const { share(0)->bitwise_xor(rhs, ret->share(0)); share(1)->bitwise_xor(rhs, ret->share(1)); } template void BooleanTensor::bitwise_and(const BooleanTensor *rhs, BooleanTensor *ret) const { auto tmp_zero = tensor_factory()->template create(ret->shape()); auto tmp0 = tensor_factory()->template create(ret->shape()); auto tmp1 = tensor_factory()->template create(ret->shape()); auto tmp2 = tensor_factory()->template create(ret->shape()); aby3_ctx()->template gen_zero_sharing_boolean(*tmp_zero.get()); share(0)->bitwise_and(rhs->share(0), tmp0.get()); share(0)->bitwise_and(rhs->share(1), tmp1.get()); share(1)->bitwise_and(rhs->share(0), tmp2.get()); tmp0->bitwise_xor(tmp1.get(), tmp0.get()); tmp0->bitwise_xor(tmp2.get(), tmp0.get()); tmp0->bitwise_xor(tmp_zero.get(), ret->share(0)); // 3-party msg send recv sequence // p0 p1 p2 // t0: 0->2 2<-0 // t1: 1<-2 2->1 // t2: 0<-1 1->2 if (party() > 0) { aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); } else { aby3_ctx()->network()->template send(pre_party(), *(ret->share(0))); aby3_ctx()->network()->template recv(next_party(), *(ret->share(1))); } } template void BooleanTensor::bitwise_and(const TensorAdapter *rhs, BooleanTensor *ret) const { share(0)->bitwise_and(rhs, ret->share(0)); share(1)->bitwise_and(rhs, ret->share(1)); } template void BooleanTensor::bitwise_or(const BooleanTensor *rhs, BooleanTensor *ret) const { // ret = x & y bitwise_and(rhs, ret); // ret = x & y ^ x bitwise_xor(ret, ret); // ret = x & y ^ x ^ y rhs->bitwise_xor(ret, ret); } template void BooleanTensor::bitwise_or(const TensorAdapter *rhs, BooleanTensor *ret) const { // ret = x & y bitwise_and(rhs, ret); // ret = x & y ^ x bitwise_xor(ret, ret); // ret = x & y ^ x ^ y ret->bitwise_xor(rhs, ret); } template void BooleanTensor::bitwise_not(BooleanTensor *ret) const { if (party() == 0) { share(0)->bitwise_not(ret->share(0)); share(1)->copy(ret->share(1)); } else if (party() == 1) { share(0)->copy(ret->share(0)); share(1)->copy(ret->share(1)); } else { share(0)->copy(ret->share(0)); share(1)->bitwise_not(ret->share(1)); } } template void BooleanTensor::lshift(size_t rhs, BooleanTensor *ret) const { share(0)->lshift(rhs, ret->share(0)); share(1)->lshift(rhs, ret->share(1)); } template void BooleanTensor::rshift(size_t rhs, BooleanTensor *ret) const { share(0)->rshift(rhs, ret->share(0)); share(1)->rshift(rhs, ret->share(1)); } template void BooleanTensor::logical_rshift(size_t rhs, BooleanTensor *ret) const { share(0)->logical_rshift(rhs, ret->share(0)); share(1)->logical_rshift(rhs, ret->share(1)); } template void BooleanTensor::ppa(const BooleanTensor *rhs, BooleanTensor *ret, size_t n_bits) const { // kogge stone adder from tfe // https://github.com/tf-encrypted // TODO: check T is int64_t other native type not support yet const size_t k = std::ceil(std::log2(n_bits)); std::vector keep_masks(k); for (size_t i = 0; i < k; ++i) { keep_masks[i] = (T(1) << (T)std::exp2(i)) - 1; } std::shared_ptr> tmp[11]; for (auto &ti : tmp) { ti = tensor_factory()->template create(ret->shape()); } BooleanTensor g(tmp[0].get(), tmp[1].get()); BooleanTensor p(tmp[2].get(), tmp[3].get()); BooleanTensor g1(tmp[4].get(), tmp[5].get()); BooleanTensor p1(tmp[6].get(), tmp[7].get()); BooleanTensor c(tmp[8].get(), tmp[9].get()); auto k_mask = tmp[10].get(); bitwise_and(rhs, &g); bitwise_xor(rhs, &p); for (size_t i = 0; i < k; ++i) { std::transform(k_mask->data(), k_mask->data() + k_mask->numel(), k_mask->data(), [&keep_masks, i](T) -> T { return keep_masks[i]; }); g.lshift(std::exp2(i), &g1); p.lshift(std::exp2(i), &p1); p1.bitwise_xor(k_mask, &p1); g1.bitwise_and(&p, &c); g.bitwise_xor(&c, &g); p.bitwise_and(&p1, &p); } g.lshift(1, &c); bitwise_xor(rhs, &p); c.bitwise_xor(&p, ret); } template void a2b(CircuitContext *aby3_ctx, TensorAdapterFactory *tensor_factory, const FixedPointTensor *a, BooleanTensor *b, size_t n_bits) { std::shared_ptr> tmp[4]; for (auto &ti : tmp) { ti = tensor_factory->template create(a->shape()); // set 0 std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), [](T) -> T { return 0; }); } std::shared_ptr> lhs = std::make_shared>(tmp[0].get(), tmp[1].get()); std::shared_ptr> rhs = std::make_shared>(tmp[2].get(), tmp[3].get()); if (aby3_ctx->party() == 0) { a->share(0)->add(a->share(1), lhs->share(0)); // reshare x0 + x1 aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(1)); lhs->share(0)->bitwise_xor(lhs->share(1), lhs->share(0)); aby3_ctx->network()->template send(2, *(lhs->share(0))); aby3_ctx->network()->template recv(1, *(lhs->share(1))); } else if (aby3_ctx->party() == 1) { aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); aby3_ctx->network()->template send(0, *(lhs->share(0))); aby3_ctx->network()->template recv(2, *(lhs->share(1))); a->share(1)->copy(rhs->share(1)); } else { // party == 2 aby3_ctx->template gen_zero_sharing_boolean(*lhs->share(0)); aby3_ctx->network()->template recv(0, *(lhs->share(1))); aby3_ctx->network()->template send(1, *(lhs->share(0))); a->share(0)->copy(rhs->share(0)); } lhs->ppa(rhs.get(), b, n_bits); } template template BooleanTensor &BooleanTensor:: operator=(const FixedPointTensor *other) { a2b(aby3_ctx().get(), tensor_factory().get(), other, this, sizeof(T) * 8); return *this; } template void tensor_rshift_transform(const TensorAdapter *lhs, size_t rhs, TensorAdapter *ret) { const T *begin = lhs->data(); std::transform(begin, begin + lhs->numel(), ret->data(), [rhs](T in) { return (in >> rhs) & 1; }); }; template template void BooleanTensor::bit_extract(size_t i, const FixedPointTensor *in) { a2b(aby3_ctx().get(), tensor_factory().get(), in, this, i + 1); tensor_rshift_transform(share(0), i, share(0)); tensor_rshift_transform(share(1), i, share(1)); } template void BooleanTensor::bit_extract(size_t i, BooleanTensor *ret) const { tensor_rshift_transform(share(0), i, ret->share(0)); tensor_rshift_transform(share(1), i, ret->share(1)); } template template void BooleanTensor::b2a(FixedPointTensor *ret) const { std::shared_ptr> tmp[2]; for (auto &ti : tmp) { ti = tensor_factory()->template create(shape()); // set 0 std::transform(ti->data(), ti->data() + ti->numel(), ti->data(), [](T) -> T { return 0; }); } BooleanTensor bt(tmp[0].get(), tmp[1].get()); if (party() == 1) { aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); ret->share(0)->add(ret->share(1), tmp[0].get()); tmp[0]->negative(tmp[0].get()); aby3_ctx()->network()->template send(0, *(tmp[0].get())); } else if (party() == 0) { aby3_ctx()->network()->template recv(1, *(tmp[1].get())); // dummy gen random, for prng sync aby3_ctx()->template gen_random(*ret->mutable_share(1), 1); } else { // party == 2 aby3_ctx()->template gen_random(*ret->mutable_share(0), 0); } bt.ppa(this, &bt, sizeof(T) * 8); TensorAdapter *dest = nullptr; if (party() == 0) { dest = ret->mutable_share(0); } bt.reveal_to_one(0, dest); if (party() == 0) { aby3_ctx()->network()->template recv(1, *(ret->mutable_share(1))); aby3_ctx()->network()->template send(2, *(ret->mutable_share(0))); } else if (party() == 1) { aby3_ctx()->network()->template send(0, *(ret->mutable_share(0))); } else { // party == 2 aby3_ctx()->network()->template recv(0, *(ret->mutable_share(1))); } } template template void BooleanTensor::mul(const TensorAdapter *rhs, FixedPointTensor *ret, size_t rhs_party) const { // ot sender size_t idx0 = rhs_party; size_t idx1 = (rhs_party + 1) % 3; size_t idx2 = (rhs_party + 2) % 3; auto tmp0 = tensor_factory()->template create(ret->shape()); auto tmp1 = tensor_factory()->template create(ret->shape()); TensorAdapter *tmp[2] = {tmp0.get(), tmp1.get()}; TensorAdapter *null_arg[2] = {nullptr, nullptr}; if (party() == idx0) { // use ret as buffer TensorAdapter *m[2] = {ret->mutable_share(0), ret->mutable_share(1)}; aby3_ctx()->template gen_zero_sharing_arithmetic(*tmp[0]); // m0 = a * (b0 ^ b1) + s0 // m1 = a * (1 ^ b0 ^ b1) + s0 share(0)->bitwise_xor(share(1), m[0]); std::transform(m[0]->data(), m[0]->data() + m[0]->numel(), m[1]->data(), [](T in) { return 1 ^ in; }); m[0]->mul(rhs, m[0]); m[1]->mul(rhs, m[1]); m[0]->add(tmp[0], m[0]); m[1]->add(tmp[0], m[1]); aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], const_cast **>(m), tmp, null_arg[0]); // ret0 = s2 // ret1 = s1 aby3_ctx()->network()->template recv(idx2, *(ret->mutable_share(0))); aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(1))); } else if (party() == idx1) { // ret0 = s1 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); // ret1 = a * b + s0 aby3_ctx()->template ot( idx0, idx1, idx2, share(1), const_cast **>(null_arg), tmp, ret->mutable_share(1)); aby3_ctx()->network()->template send(idx0, *(ret->share(0))); aby3_ctx()->network()->template send(idx2, *(ret->share(1))); } else if (party() == idx2) { // ret0 = a * b + s0 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); // ret1 = s2 aby3_ctx()->template ot( idx0, idx1, idx2, share(0), const_cast **>(null_arg), tmp, null_arg[0]); aby3_ctx()->network()->template send(idx0, *(ret->share(1))); aby3_ctx()->network()->template recv(idx1, *(ret->mutable_share(0))); } } template template void BooleanTensor::mul(const FixedPointTensor *rhs, FixedPointTensor *ret) const { auto tmp0 = tensor_factory()->template create(ret->shape()); auto tmp1 = tensor_factory()->template create(ret->shape()); auto tmp2 = tensor_factory()->template create(ret->shape()); FixedPointTensor tmp(tmp0.get(), tmp1.get()); if (party() == 0) { mul(nullptr, ret, 1); mul(rhs->share(0), &tmp, 0); ret->add(&tmp, ret); } else if (party() == 1) { rhs->share(0)->add(rhs->share(1), tmp2.get()); mul(tmp2.get(), ret, 1); mul(nullptr, &tmp, 0); ret->add(&tmp, ret); } else { // party() == 2 mul(nullptr, ret, 1); mul(nullptr, &tmp, 0); ret->add(&tmp, ret); } } } // namespace aby3