From d829309d1b08a5a087d42110a4ba7b3a813b980b Mon Sep 17 00:00:00 2001 From: yangqingyou Date: Mon, 31 Aug 2020 10:15:20 +0000 Subject: [PATCH] amend according comment --- .../mpc_protocol/abstract_context.h | 85 ++++--------------- core/privc/privc_context.h | 26 +++--- core/privc3/aby3_context.h | 26 +++--- core/privc3/boolean_tensor_impl.h | 19 +++-- core/privc3/ot.h | 67 +++++++++++++++ 5 files changed, 121 insertions(+), 102 deletions(-) create mode 100644 core/privc3/ot.h diff --git a/core/paddlefl_mpc/mpc_protocol/abstract_context.h b/core/paddlefl_mpc/mpc_protocol/abstract_context.h index c945aa5..958b224 100644 --- a/core/paddlefl_mpc/mpc_protocol/abstract_context.h +++ b/core/paddlefl_mpc/mpc_protocol/abstract_context.h @@ -29,24 +29,23 @@ using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator; class AbstractContext { public: - AbstractContext() = default; + AbstractContext(size_t party, std::shared_ptr network) { + init(party, network); + }; AbstractContext(const AbstractContext &other) = delete; AbstractContext &operator=(const AbstractContext &other) = delete; - virtual void init(size_t party, std::shared_ptr network, block seed, - block seed2) = 0; + void init(size_t party, std::shared_ptr network) { + set_party(party); + set_network(network); + } void set_party(size_t party) { - PADDLE_ENFORCE_LT(party, _num_party, - "party idx should less than %d.", - _num_party); _party = party; } void set_num_party(size_t num_party) { - PADDLE_ENFORCE_EQ(num_party == 2 || num_party == 3, true, - "2 or 3 party protocol is supported."); _num_party = num_party; } @@ -60,35 +59,33 @@ public: PADDLE_ENFORCE_LE(idx, _num_party, "prng idx should be less and equal to %d.", _num_party); - _prng[idx].set_seed(seed); + get_prng(idx).set_seed(seed); } size_t party() const { return _party; } size_t pre_party() const { - PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, - "number of party is not set."); return (_party + _num_party - 1) % _num_party; } size_t next_party() const { - PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, - "number of party is not set."); return (_party + 1) % _num_party; } - template T gen_random(bool next) { return _prng[next].get(); } + // generate random from prng[0] or prng[1] + // @param next: use bool type for idx 0 or 1 + template T gen_random(bool next) { + return get_prng(next).get(); + } template class Tensor> void gen_random(Tensor &tensor, bool next) { - PADDLE_ENFORCE_EQ(_num_party, 3, - "`gen_random` API is for 3 party protocol."); std::for_each( tensor.data(), tensor.data() + tensor.numel(), [this, next](T &val) { val = this->template gen_random(next); }); } - template T gen_random_private() { return _prng[2].get(); } + template T gen_random_private() { return get_prng(2).get(); } template class Tensor> void gen_random_private(Tensor &tensor) { @@ -98,15 +95,11 @@ public: } template T gen_zero_sharing_arithmetic() { - PADDLE_ENFORCE_EQ(_num_party, 3, - "`gen_zero_sharing_arithmetic` API is for 3 party protocol."); - return _prng[0].get() - _prng[1].get(); + return get_prng(0).get() - get_prng(1).get(); } template class Tensor> void gen_zero_sharing_arithmetic(Tensor &tensor) { - PADDLE_ENFORCE_EQ(_num_party, 3, - "`gen_zero_sharing_arithmetic` API is for 3 party protocol."); std::for_each(tensor.data(), tensor.data() + tensor.numel(), [this](T &val) { val = this->template gen_zero_sharing_arithmetic(); @@ -114,60 +107,18 @@ public: } template T gen_zero_sharing_boolean() { - PADDLE_ENFORCE_EQ(_num_party, 3, - "`gen_zero_sharing_boolean` API is for 3 party protocol."); - return _prng[0].get() ^ _prng[1].get(); + return get_prng(0).get() ^ get_prng(1).get(); } template class Tensor> void gen_zero_sharing_boolean(Tensor &tensor) { - PADDLE_ENFORCE_EQ(_num_party, 3, - "`gen_zero_sharing_boolean` API is for 3 party protocol."); std::for_each( tensor.data(), tensor.data() + tensor.numel(), [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); } - template class Tensor> - void ot(size_t sender, size_t receiver, size_t helper, - const Tensor* choice, const Tensor* m[2], - Tensor* buffer[2], Tensor* ret) { - // TODO: check tensor shape equals - const size_t numel = buffer[0]->numel(); - if (party() == sender) { - bool common = helper == next_party(); - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - for (size_t i = 0; i < numel; ++i) { - buffer[0]->data()[i] ^= m[0]->data()[i]; - buffer[1]->data()[i] ^= m[1]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - network()->template send(receiver, *buffer[1]); - - } else if (party() == helper) { - bool common = sender == next_party(); - - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - - for (size_t i = 0; i < numel; ++i) { - buffer[0]->data()[i] = choice->data()[i] & 1 ? - buffer[1]->data()[i] : buffer[0]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - } else if (party() == receiver) { - network()->template recv(sender, *buffer[0]); - network()->template recv(sender, *buffer[1]); - network()->template recv(helper, *ret); - size_t i = 0; - std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { - bool c = choice->data()[i] & 1; - in ^= buffer[c]->data()[i]; - ++i;} - ); - } - } +protected: + virtual PseudorandomNumberGenerator& get_prng(size_t idx) = 0; private: size_t _num_party; diff --git a/core/privc/privc_context.h b/core/privc/privc_context.h index 1108eb8..35b67a7 100644 --- a/core/privc/privc_context.h +++ b/core/privc/privc_context.h @@ -29,26 +29,26 @@ using AbstractContext = paddle::mpc::AbstractContext; class PrivCContext : public AbstractContext { public: PrivCContext(size_t party, std::shared_ptr network, - const block &seed = g_zero_block) { - init(party, network, g_zero_block, seed); + block seed = g_zero_block): + AbstractContext::AbstractContext(party, network) { + set_num_party(2); + + if (psi::equals(seed, psi::g_zero_block)) { + seed = psi::block_from_dev_urandom(); + } + set_random_seed(seed, 0); } PrivCContext(const PrivCContext &other) = delete; PrivCContext &operator=(const PrivCContext &other) = delete; - void init(size_t party, std::shared_ptr network, block seed, - block seed2) override { - set_num_party(2); - set_party(party); - set_network(network); - - if (psi::equals(seed2, psi::g_zero_block)) { - seed2 = psi::block_from_dev_urandom(); - } - // seed2 is private - set_random_seed(seed2, 2); +protected: + PseudorandomNumberGenerator& get_prng(size_t idx) override { + return _prng; } +private: + PseudorandomNumberGenerator _prng; }; } // namespace aby3 diff --git a/core/privc3/aby3_context.h b/core/privc3/aby3_context.h index 08b53ac..8094d9d 100644 --- a/core/privc3/aby3_context.h +++ b/core/privc3/aby3_context.h @@ -29,20 +29,10 @@ using AbstractContext = paddle::mpc::AbstractContext; class ABY3Context : public AbstractContext { public: ABY3Context(size_t party, std::shared_ptr network, - const block &seed = g_zero_block, - const block &seed2 = g_zero_block) { - init(party, network, seed, seed2); - } - - ABY3Context(const ABY3Context &other) = delete; - - ABY3Context &operator=(const ABY3Context &other) = delete; - - void init(size_t party, std::shared_ptr network, block seed, - block seed2) override { + block seed = g_zero_block, + block seed2 = g_zero_block) : + AbstractContext::AbstractContext(party, network) { set_num_party(3); - set_party(party); - set_network(network); if (psi::equals(seed, psi::g_zero_block)) { seed = psi::block_from_dev_urandom(); @@ -70,6 +60,16 @@ public: set_random_seed(seed, 1); } + + ABY3Context(const ABY3Context &other) = delete; + + ABY3Context &operator=(const ABY3Context &other) = delete; +protected: + PseudorandomNumberGenerator& get_prng(size_t idx) override { + return _prng[idx]; + } +private: + PseudorandomNumberGenerator _prng[3]; }; } // namespace aby3 diff --git a/core/privc3/boolean_tensor_impl.h b/core/privc3/boolean_tensor_impl.h index 8bfd4da..74bf270 100644 --- a/core/privc3/boolean_tensor_impl.h +++ b/core/privc3/boolean_tensor_impl.h @@ -15,6 +15,7 @@ #pragma once #include +#include "core/privc3/ot.h" namespace aby3 { @@ -432,9 +433,9 @@ void BooleanTensor::mul(const TensorAdapter* rhs, m[0]->add(tmp[0], m[0]); m[1]->add(tmp[0], m[1]); - aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], - const_cast**>(m), - tmp, null_arg[0]); + ObliviousTransfer::ot(idx0, idx1, idx2, null_arg[0], + const_cast**>(m), + tmp, null_arg[0]); // ret0 = s2 // ret1 = s1 @@ -445,18 +446,18 @@ void BooleanTensor::mul(const TensorAdapter* rhs, // ret0 = s1 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); // ret1 = a * b + s0 - aby3_ctx()->template ot(idx0, idx1, idx2, share(1), - const_cast**>(null_arg), - tmp, ret->mutable_share(1)); + ObliviousTransfer::ot(idx0, idx1, idx2, share(1), + const_cast**>(null_arg), + tmp, ret->mutable_share(1)); aby3_ctx()->network()->template send(idx0, *(ret->share(0))); aby3_ctx()->network()->template send(idx2, *(ret->share(1))); } else if (party() == idx2) { // ret0 = a * b + s0 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); // ret1 = s2 - aby3_ctx()->template ot(idx0, idx1, idx2, share(0), - const_cast**>(null_arg), - tmp, null_arg[0]); + ObliviousTransfer::ot(idx0, idx1, idx2, share(0), + const_cast**>(null_arg), + tmp, null_arg[0]); aby3_ctx()->network()->template send(idx0, *(ret->share(1))); diff --git a/core/privc3/ot.h b/core/privc3/ot.h new file mode 100644 index 0000000..1a5c17f --- /dev/null +++ b/core/privc3/ot.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" + +namespace aby3 { + +class ObliviousTransfer { + public: + template class Tensor> + static inline void ot(size_t sender, size_t receiver, size_t helper, + const Tensor* choice, const Tensor* m[2], + Tensor* buffer[2], Tensor* ret) { + // TODO: check tensor shape equals + auto aby3_ctx = paddle::mpc::ContextHolder::mpc_ctx(); + const size_t numel = buffer[0]->numel(); + if (aby3_ctx->party() == sender) { + bool common = helper == aby3_ctx->next_party(); + aby3_ctx->template gen_random(*buffer[0], common); + aby3_ctx->template gen_random(*buffer[1], common); + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] ^= m[0]->data()[i]; + buffer[1]->data()[i] ^= m[1]->data()[i]; + } + aby3_ctx->network()->template send(receiver, *buffer[0]); + aby3_ctx->network()->template send(receiver, *buffer[1]); + + } else if (aby3_ctx->party() == helper) { + bool common = sender == aby3_ctx->next_party(); + + aby3_ctx->template gen_random(*buffer[0], common); + aby3_ctx->template gen_random(*buffer[1], common); + + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] = choice->data()[i] & 1 ? + buffer[1]->data()[i] : buffer[0]->data()[i]; + } + aby3_ctx->network()->template send(receiver, *buffer[0]); + } else if (aby3_ctx->party() == receiver) { + aby3_ctx->network()->template recv(sender, *buffer[0]); + aby3_ctx->network()->template recv(sender, *buffer[1]); + aby3_ctx->network()->template recv(helper, *ret); + size_t i = 0; + std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { + bool c = choice->data()[i] & 1; + in ^= buffer[c]->data()[i]; + ++i;} + ); + } + } +}; + +} // namespace aby3 \ No newline at end of file -- GitLab