diff --git a/core/paddlefl_mpc/mpc_protocol/abstract_context.h b/core/paddlefl_mpc/mpc_protocol/abstract_context.h new file mode 100644 index 0000000000000000000000000000000000000000..1b6deab31154bf647eef605e6156d4d2d3c0970b --- /dev/null +++ b/core/paddlefl_mpc/mpc_protocol/abstract_context.h @@ -0,0 +1,127 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include + +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "core/privc3/prng_utils.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { + +namespace mpc { + +using block = psi::block; +using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator; + +class AbstractContext { +public: + AbstractContext(size_t party, std::shared_ptr network) { + set_party(party); + set_network(network); + }; + AbstractContext(const AbstractContext &other) = delete; + + AbstractContext &operator=(const AbstractContext &other) = delete; + + void set_party(size_t party) { + _party = party; + } + + void set_num_party(size_t num_party) { + _num_party = num_party; + } + + void set_network(std::shared_ptr network) { + _network = network; + } + + AbstractNetwork *network() { return _network.get(); } + + void set_random_seed(const block &seed, size_t idx) { + PADDLE_ENFORCE_LE(idx, _num_party, + "prng idx should be less and equal to %d.", + _num_party); + get_prng(idx).set_seed(seed); + } + + size_t party() const { return _party; } + + size_t pre_party() const { + return (_party + _num_party - 1) % _num_party; + } + + size_t next_party() const { + return (_party + 1) % _num_party; + } + + // generate random from prng[0] or prng[1] + // @param next: use bool type for idx 0 or 1 + template T gen_random(bool next) { + return get_prng(next).get(); + } + + template class Tensor> + void gen_random(Tensor &tensor, bool next) { + std::for_each( + tensor.data(), tensor.data() + tensor.numel(), + [this, next](T &val) { val = this->template gen_random(next); }); + } + + template T gen_random_private() { return get_prng(2).get(); } + + template class Tensor> + void gen_random_private(Tensor &tensor) { + std::for_each( + tensor.data(), tensor.data() + tensor.numel(), + [this](T &val) { val = this->template gen_random_private(); }); + } + + template T gen_zero_sharing_arithmetic() { + return get_prng(0).get() - get_prng(1).get(); + } + + template class Tensor> + void gen_zero_sharing_arithmetic(Tensor &tensor) { + std::for_each(tensor.data(), tensor.data() + tensor.numel(), + [this](T &val) { + val = this->template gen_zero_sharing_arithmetic(); + }); + } + + template T gen_zero_sharing_boolean() { + return get_prng(0).get() ^ get_prng(1).get(); + } + + template class Tensor> + void gen_zero_sharing_boolean(Tensor &tensor) { + std::for_each( + tensor.data(), tensor.data() + tensor.numel(), + [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); + } + +protected: + virtual PseudorandomNumberGenerator& get_prng(size_t idx) = 0; + +private: + size_t _num_party; + size_t _party; + std::shared_ptr _network; +}; + +} // namespace mpc + +} //namespace paddle diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index 369b02c03a0829900d9aea722e93a7bdf12a744b..337d1fb581c72839503c195c1b131e0a2cf7d189 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -21,7 +21,8 @@ limitations under the License. */ #include "context_holder.h" #include "mpc_operators.h" #include "paddle/fluid/framework/tensor.h" -#include "core/privc3/circuit_context.h" +#include "core/privc3/boolean_tensor.h" +#include "core/privc3/aby3_context.h" #include "core/privc3/fixedpoint_tensor.h" #include "core/privc3/boolean_tensor.h" #include "core/privc3/paddle_tensor.h" @@ -30,7 +31,7 @@ namespace paddle { namespace mpc { using paddle::framework::Tensor; -using aby3::CircuitContext; +using aby3::ABY3Context; // TODO: decide scaling factor const size_t ABY3_SCALING_FACTOR = FIXED_POINTER_SCALING_FACTOR; using FixedTensor = aby3::FixedPointTensor; diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc index a4cc939ab00b6d06558f72d8d4f6778ffd221cee..8ec5a1d239a7476beaaf515da1dbd1028a230615 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc +++ b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc @@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store( mesh_net->init(); _network = std::move(mesh_net); - _circuit_ctx = std::make_shared(role, _network); + _circuit_ctx = std::make_shared(role, _network); _operators = std::make_shared(); _is_initialized = true; } @@ -63,7 +63,7 @@ std::shared_ptr Aby3Protocol::network() { return _network; } -std::shared_ptr Aby3Protocol::mpc_context() { +std::shared_ptr Aby3Protocol::mpc_context() { PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR); return _circuit_ctx; } diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h index 6940c558d4bcdf842cae8775ad8fa1f93670b3f2..5d2d7d73cbcbada189c0449ec5bd06b8c735b6e5 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h @@ -24,12 +24,13 @@ #include "mesh_network.h" #include "mpc_operators.h" #include "mpc_protocol.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/privc3/aby3_context.h" namespace paddle { namespace mpc { -using CircuitContext = aby3::CircuitContext; +using ABY3Context = aby3::ABY3Context; class Aby3Protocol : public MpcProtocol { public: @@ -46,14 +47,14 @@ public: std::shared_ptr network() override; - std::shared_ptr mpc_context() override; + std::shared_ptr mpc_context() override; private: bool _is_initialized = false; const std::string PROT_INIT_ERR = "The protocol is not yet initialized."; std::shared_ptr _operators; std::shared_ptr _network; - std::shared_ptr _circuit_ctx; + std::shared_ptr _circuit_ctx; }; } // mpc diff --git a/core/paddlefl_mpc/mpc_protocol/context_holder.cc b/core/paddlefl_mpc/mpc_protocol/context_holder.cc index 9acf0fb72667c7da8b53b842b3189f26c337f923..3cbd8b8ce4c0979dbd7551581f930c897b0a0d74 100644 --- a/core/paddlefl_mpc/mpc_protocol/context_holder.cc +++ b/core/paddlefl_mpc/mpc_protocol/context_holder.cc @@ -24,7 +24,7 @@ namespace paddle { namespace mpc { -thread_local std::shared_ptr ContextHolder::current_mpc_ctx; +thread_local std::shared_ptr ContextHolder::current_mpc_ctx; thread_local const ExecutionContext *ContextHolder::current_exec_ctx; diff --git a/core/paddlefl_mpc/mpc_protocol/context_holder.h b/core/paddlefl_mpc/mpc_protocol/context_holder.h index a8c2d5f15764cc49463790761310d181abd04edf..710de07f189c2c63e926051fe0a38cb0a65212a4 100644 --- a/core/paddlefl_mpc/mpc_protocol/context_holder.h +++ b/core/paddlefl_mpc/mpc_protocol/context_holder.h @@ -22,20 +22,20 @@ #pragma once #include "paddle/fluid/framework/operator.h" -#include "core/privc3/circuit_context.h" +#include "core/privc3/aby3_context.h" #include "core/privc3/paddle_tensor.h" namespace paddle { namespace mpc { -using CircuitContext = aby3::CircuitContext; +using ABY3Context = aby3::ABY3Context; using ExecutionContext = paddle::framework::ExecutionContext; class ContextHolder { public: template static void run_with_context(const ExecutionContext *exec_ctx, - std::shared_ptr mpc_ctx, + std::shared_ptr mpc_ctx, Operation op) { // set new ctxs @@ -60,7 +60,7 @@ public: _s_current_tensor_factory = old_factory; } - static std::shared_ptr mpc_ctx() { return current_mpc_ctx; } + static std::shared_ptr mpc_ctx() { return current_mpc_ctx; } static const ExecutionContext *exec_ctx() { return current_exec_ctx; } @@ -77,7 +77,7 @@ public: } private: - thread_local static std::shared_ptr current_mpc_ctx; + thread_local static std::shared_ptr current_mpc_ctx; thread_local static const ExecutionContext *current_exec_ctx; diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc b/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc index dfb8f4a452dfcbec6ecfc707847cacb4db721558..68daf78b328c1c4624eed6b3aee9272f3bd8126a 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc +++ b/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc @@ -19,7 +19,6 @@ #include "aby3_protocol.h" #include "mpc_protocol_factory.h" -#include "core/privc3/circuit_context.h" #include "gtest/gtest.h" namespace paddle { diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h b/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h index da1aa49de70bbd92dbbdfa60785602ce4397b47f..a5eeae6c26c5dcff53bf2de1bb0aedd505e6b9c0 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h +++ b/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h @@ -21,7 +21,7 @@ #include "gloo/rendezvous/hash_store.h" #include "mpc_config.h" #include "mpc_operators.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" namespace paddle { namespace mpc { @@ -44,7 +44,7 @@ public: virtual std::shared_ptr network() = 0; - virtual std::shared_ptr mpc_context() = 0; + virtual std::shared_ptr mpc_context() = 0; private: const std::string _name; diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc b/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc index 332536bceb08096b45a63e9f600d81732bfa7b67..199e046ebbc3d48d078dff60001ecc2cc2f8bc83 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc +++ b/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc @@ -17,7 +17,6 @@ #include "aby3_protocol.h" #include "mpc_config.h" #include "mpc_protocol_factory.h" -#include "core/privc3/circuit_context.h" #include "gtest/gtest.h" namespace paddle { diff --git a/core/paddlefl_mpc/operators/mpc_op.h b/core/paddlefl_mpc/operators/mpc_op.h index 6cff543b0dbf03f13a9e8baeb83335986f5fe249..f3a17941463c0beb45b847dbbd3faa680bc3ca53 100644 --- a/core/paddlefl_mpc/operators/mpc_op.h +++ b/core/paddlefl_mpc/operators/mpc_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" namespace paddle { namespace operators { @@ -32,7 +32,7 @@ public: PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), "Mpc protocol is not yet initialized in executor"); - std::shared_ptr mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); + std::shared_ptr mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, [&] { ComputeImpl(ctx); }); } diff --git a/core/privc/privc_context.h b/core/privc/privc_context.h new file mode 100644 index 0000000000000000000000000000000000000000..35b67a7e9c948099883e4343c3d5dfe26f8eaff2 --- /dev/null +++ b/core/privc/privc_context.h @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "prng_utils.h" + +namespace aby3 { + +using AbstractNetwork = paddle::mpc::AbstractNetwork; +using AbstractContext = paddle::mpc::AbstractContext; + +class PrivCContext : public AbstractContext { +public: + PrivCContext(size_t party, std::shared_ptr network, + block seed = g_zero_block): + AbstractContext::AbstractContext(party, network) { + set_num_party(2); + + if (psi::equals(seed, psi::g_zero_block)) { + seed = psi::block_from_dev_urandom(); + } + set_random_seed(seed, 0); + } + + PrivCContext(const PrivCContext &other) = delete; + + PrivCContext &operator=(const PrivCContext &other) = delete; + +protected: + PseudorandomNumberGenerator& get_prng(size_t idx) override { + return _prng; + } +private: + PseudorandomNumberGenerator _prng; +}; + +} // namespace aby3 diff --git a/core/privc3/aby3_context.h b/core/privc3/aby3_context.h new file mode 100644 index 0000000000000000000000000000000000000000..8094d9d67640fb8098f6ce97c3e5050e287d6e63 --- /dev/null +++ b/core/privc3/aby3_context.h @@ -0,0 +1,75 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "prng_utils.h" + +namespace aby3 { + +using AbstractNetwork = paddle::mpc::AbstractNetwork; +using AbstractContext = paddle::mpc::AbstractContext; + +class ABY3Context : public AbstractContext { +public: + ABY3Context(size_t party, std::shared_ptr network, + block seed = g_zero_block, + block seed2 = g_zero_block) : + AbstractContext::AbstractContext(party, network) { + set_num_party(3); + + if (psi::equals(seed, psi::g_zero_block)) { + seed = psi::block_from_dev_urandom(); + } + + if (psi::equals(seed2, psi::g_zero_block)) { + seed2 = psi::block_from_dev_urandom(); + } + set_random_seed(seed, 0); + // seed2 is private + set_random_seed(seed2, 2); + + // 3 for 3-party computation + size_t party_pre = pre_party(); + size_t party_next = next_party(); + + if (party == 1) { + block recv_seed = this->network()->template recv(party_next); + this->network()->template send(party_pre, seed); + seed = recv_seed; + } else { + this->network()->template send(party_pre, seed); + seed = this->network()->template recv(party_next); + } + + set_random_seed(seed, 1); + } + + ABY3Context(const ABY3Context &other) = delete; + + ABY3Context &operator=(const ABY3Context &other) = delete; +protected: + PseudorandomNumberGenerator& get_prng(size_t idx) override { + return _prng[idx]; + } +private: + PseudorandomNumberGenerator _prng[3]; +}; + +} // namespace aby3 diff --git a/core/privc3/boolean_tensor.h b/core/privc3/boolean_tensor.h index 36418a34dd1d56968aff9fcc7c4e55e03f9fb301..d8c66b22822be4812ff1bd240827f725accfc9c6 100644 --- a/core/privc3/boolean_tensor.h +++ b/core/privc3/boolean_tensor.h @@ -122,9 +122,9 @@ public: void onehot_from_cmp(); private: - static inline std::shared_ptr aby3_ctx() { - return paddle::mpc::ContextHolder::mpc_ctx(); - } + static inline std::shared_ptr aby3_ctx() { + return paddle::mpc::ContextHolder::mpc_ctx(); + } static inline std::shared_ptr tensor_factory() { return paddle::mpc::ContextHolder::tensor_factory(); diff --git a/core/privc3/boolean_tensor_impl.h b/core/privc3/boolean_tensor_impl.h index 012a158f4d8677787f7695035f06cab451e2dea3..74bf27090f17965d973630c9777d410e5087ec33 100644 --- a/core/privc3/boolean_tensor_impl.h +++ b/core/privc3/boolean_tensor_impl.h @@ -15,6 +15,7 @@ #pragma once #include +#include "core/privc3/ot.h" namespace aby3 { @@ -268,7 +269,7 @@ void BooleanTensor::ppa(const BooleanTensor* rhs, } template -void a2b(CircuitContext* aby3_ctx, +void a2b(AbstractContext* aby3_ctx, TensorAdapterFactory* tensor_factory, const FixedPointTensor* a, BooleanTensor* b, @@ -432,9 +433,9 @@ void BooleanTensor::mul(const TensorAdapter* rhs, m[0]->add(tmp[0], m[0]); m[1]->add(tmp[0], m[1]); - aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], - const_cast**>(m), - tmp, null_arg[0]); + ObliviousTransfer::ot(idx0, idx1, idx2, null_arg[0], + const_cast**>(m), + tmp, null_arg[0]); // ret0 = s2 // ret1 = s1 @@ -445,18 +446,18 @@ void BooleanTensor::mul(const TensorAdapter* rhs, // ret0 = s1 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); // ret1 = a * b + s0 - aby3_ctx()->template ot(idx0, idx1, idx2, share(1), - const_cast**>(null_arg), - tmp, ret->mutable_share(1)); + ObliviousTransfer::ot(idx0, idx1, idx2, share(1), + const_cast**>(null_arg), + tmp, ret->mutable_share(1)); aby3_ctx()->network()->template send(idx0, *(ret->share(0))); aby3_ctx()->network()->template send(idx2, *(ret->share(1))); } else if (party() == idx2) { // ret0 = a * b + s0 aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); // ret1 = s2 - aby3_ctx()->template ot(idx0, idx1, idx2, share(0), - const_cast**>(null_arg), - tmp, null_arg[0]); + ObliviousTransfer::ot(idx0, idx1, idx2, share(0), + const_cast**>(null_arg), + tmp, null_arg[0]); aby3_ctx()->network()->template send(idx0, *(ret->share(1))); diff --git a/core/privc3/boolean_tensor_test.cc b/core/privc3/boolean_tensor_test.cc index 984fbb5a81e45d6dd1eb8113fecbbd9425319f9b..21ccd31b8a8b828b29790cc6fb181f9cefb9efb2 100644 --- a/core/privc3/boolean_tensor_test.cc +++ b/core/privc3/boolean_tensor_test.cc @@ -27,19 +27,20 @@ #include "boolean_tensor.h" #include "fixedpoint_tensor.h" #include "paddle_tensor.h" -#include "circuit_context.h" +#include "aby3_context.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" namespace aby3 { using paddle::framework::Tensor; +using AbstractContext = paddle::mpc::AbstractContext; class BooleanTensorTest : public ::testing::Test { public: paddle::platform::CPUDeviceContext _cpu_ctx; std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _mpc_ctx[3]; std::shared_ptr _store; @@ -83,7 +84,7 @@ public: void gen_mpc_ctx(size_t idx) { auto net = gen_network(idx); net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); + _mpc_ctx[idx] = std::make_shared(idx, net); } std::shared_ptr> gen1() { diff --git a/core/privc3/circuit_context.h b/core/privc3/circuit_context.h deleted file mode 100644 index ed75e31d8f1770537d7ca7ddde401e31de12246e..0000000000000000000000000000000000000000 --- a/core/privc3/circuit_context.h +++ /dev/null @@ -1,209 +0,0 @@ -// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. -#pragma once - -#include -#include - -#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" -#include "prng_utils.h" - -namespace aby3 { - -using AbstractNetwork = paddle::mpc::AbstractNetwork; - -class CircuitContext { -public: - CircuitContext(size_t party, - std::shared_ptr network, - const block& seed = g_zero_block, - const block& seed2 = g_zero_block) { - init(party, network, seed, seed2); - } - - CircuitContext(const CircuitContext& other) = delete; - - CircuitContext& operator=(const CircuitContext& other) = delete; - - void init(size_t party, - std::shared_ptr network, - block seed, - block seed2) { - set_party(party); - set_network(network); - - if (equals(seed, g_zero_block)) { - seed = block_from_dev_urandom(); - } - - if (equals(seed2, g_zero_block)) { - seed2 = block_from_dev_urandom(); - } - set_random_seed(seed, 0); - // seed2 is private - set_random_seed(seed2, 2); - - // 3 for 3-party computation - size_t party_pre = (this->party() - 1 + 3) % 3; - size_t party_next = (this->party() + 1) % 3; - - if (party == 1) { - block recv_seed = this->network()->template recv(party_next); - this->network()->template send(party_pre, seed); - seed = recv_seed; - } else { - this->network()->template send(party_pre, seed); - seed = this->network()->template recv(party_next); - } - - set_random_seed(seed, 1); - } - - void set_party(size_t party) { - if (party >= 3) { - // exception handling - } - _party = party; - } - - void set_network(std::shared_ptr network) { - _network = network; - } - - AbstractNetwork* network() { - return _network.get(); - } - - void set_random_seed(const block& seed, size_t idx) { - if (idx >= 3) { - // exception handling - } - _prng[idx].set_seed(seed); - } - - size_t party() const { - return _party; - } - - size_t pre_party() const { - return (_party + 3 - 1) % 3; - } - - size_t next_party() const { - return (_party + 1) % 3; - } - - template - T gen_random(bool next) { - return _prng[next].get(); - } - - template class Tensor> - void gen_random(Tensor& tensor, bool next) { - std::for_each(tensor.data(), tensor.data() + tensor.numel(), - [this, next](T& val) { - val = this->template gen_random(next); - }); - } - - template - T gen_random_private() { - return _prng[2].get(); - } - - template class Tensor> - void gen_random_private(Tensor& tensor) { - std::for_each(tensor.data(), tensor.data() + tensor.numel(), - [this](T& val) { - val = this->template gen_random_private(); - }); - } - - template - T gen_zero_sharing_arithmetic() { - return _prng[0].get() - _prng[1].get(); - } - - template class Tensor> - void gen_zero_sharing_arithmetic(Tensor& tensor) { - std::for_each(tensor.data(), tensor.data() + tensor.numel(), - [this](T& val) { - val = this->template gen_zero_sharing_arithmetic(); - }); - } - - template - T gen_zero_sharing_boolean() { - return _prng[0].get() ^ _prng[1].get(); - } - - template class Tensor> - void gen_zero_sharing_boolean(Tensor& tensor) { - std::for_each(tensor.data(), tensor.data() + tensor.numel(), - [this](T& val) { - val = this->template gen_zero_sharing_boolean(); - }); - } - - template class Tensor> - void ot(size_t sender, size_t receiver, size_t helper, - const Tensor* choice, const Tensor* m[2], - Tensor* buffer[2], Tensor* ret) { - // TODO: check tensor shape equals - const size_t numel = buffer[0]->numel(); - if (party() == sender) { - bool common = helper == next_party(); - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - for (size_t i = 0; i < numel; ++i) { - buffer[0]->data()[i] ^= m[0]->data()[i]; - buffer[1]->data()[i] ^= m[1]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - network()->template send(receiver, *buffer[1]); - - } else if (party() == helper) { - bool common = sender == next_party(); - - this->template gen_random(*buffer[0], common); - this->template gen_random(*buffer[1], common); - - for (size_t i = 0; i < numel; ++i) { - buffer[0]->data()[i] = choice->data()[i] & 1 ? - buffer[1]->data()[i] : buffer[0]->data()[i]; - } - network()->template send(receiver, *buffer[0]); - } else if (party() == receiver) { - network()->template recv(sender, *buffer[0]); - network()->template recv(sender, *buffer[1]); - network()->template recv(helper, *ret); - size_t i = 0; - std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { - bool c = choice->data()[i] & 1; - in ^= buffer[c]->data()[i]; - ++i;} - ); - } - } - -private: - size_t _party; - - std::shared_ptr _network; - - PseudorandomNumberGenerator _prng[3]; - -}; - -} // namespace aby3 diff --git a/core/privc3/fixedpoint_tensor.h b/core/privc3/fixedpoint_tensor.h index 3fb2883d76a479af30fee1da67cf65ed980c60da..2346ad1f128823b6a2d4259e57fa249672ac5b53 100644 --- a/core/privc3/fixedpoint_tensor.h +++ b/core/privc3/fixedpoint_tensor.h @@ -16,7 +16,9 @@ #include -#include "circuit_context.h" +#include "boolean_tensor.h" +#include "aby3_context.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "paddle_tensor.h" #include "boolean_tensor.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" @@ -195,9 +197,8 @@ public: size_t scaling_factor); private: - - static inline std::shared_ptr aby3_ctx() { - return paddle::mpc::ContextHolder::mpc_ctx(); + static inline std::shared_ptr aby3_ctx() { + return paddle::mpc::ContextHolder::mpc_ctx(); } static inline std::shared_ptr tensor_factory() { diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index c2f83189fc46535932fad8048d720aa408013a95..c525205b9a3a6adfcede17c5a6fdbd17944ab9b2 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -20,21 +20,23 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" +#include "aby3_context.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "fixedpoint_tensor.h" namespace aby3 { - using g_ctx_holder = paddle::mpc::ContextHolder; - using Fix64N16 = FixedPointTensor; +using g_ctx_holder = paddle::mpc::ContextHolder; +using Fix64N16 = FixedPointTensor; +using AbstractContext = paddle::mpc::AbstractContext; class FixedTensorTest : public ::testing::Test { public: paddle::platform::CPUDeviceContext _cpu_ctx; std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _mpc_ctx[3]; std::shared_ptr _store; std::thread _t[3]; std::shared_ptr _s_tensor_factory; @@ -71,7 +73,7 @@ public: void gen_mpc_ctx(size_t idx) { auto net = gen_network(idx); net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); + _mpc_ctx[idx] = std::make_shared(idx, net); } std::shared_ptr> gen(std::vector shape) { diff --git a/core/privc3/ot.h b/core/privc3/ot.h new file mode 100644 index 0000000000000000000000000000000000000000..7bf33c7e46f0339c915dbbcf182f9bf96e680aa1 --- /dev/null +++ b/core/privc3/ot.h @@ -0,0 +1,67 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" + +namespace aby3 { + +class ObliviousTransfer { + public: + template class Tensor> + static inline void ot(size_t sender, size_t receiver, size_t helper, + const Tensor* choice, const Tensor* m[2], + Tensor* buffer[2], Tensor* ret) { + // TODO: check tensor shape equals + auto aby3_ctx = paddle::mpc::ContextHolder::mpc_ctx(); + const size_t numel = buffer[0]->numel(); + if (aby3_ctx->party() == sender) { + bool common = helper == aby3_ctx->next_party(); + aby3_ctx->template gen_random(*buffer[0], common); + aby3_ctx->template gen_random(*buffer[1], common); + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] ^= m[0]->data()[i]; + buffer[1]->data()[i] ^= m[1]->data()[i]; + } + aby3_ctx->network()->template send(receiver, *buffer[0]); + aby3_ctx->network()->template send(receiver, *buffer[1]); + + } else if (aby3_ctx->party() == helper) { + bool common = sender == aby3_ctx->next_party(); + + aby3_ctx->template gen_random(*buffer[0], common); + aby3_ctx->template gen_random(*buffer[1], common); + + for (size_t i = 0; i < numel; ++i) { + buffer[0]->data()[i] = choice->data()[i] & 1 ? + buffer[1]->data()[i] : buffer[0]->data()[i]; + } + aby3_ctx->network()->template send(receiver, *buffer[0]); + } else if (aby3_ctx->party() == receiver) { + aby3_ctx->network()->template recv(sender, *buffer[0]); + aby3_ctx->network()->template recv(sender, *buffer[1]); + aby3_ctx->network()->template recv(helper, *ret); + size_t i = 0; + std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { + bool c = choice->data()[i] & 1; + in ^= buffer[c]->data()[i]; + ++i;} + ); + } + } +}; + +} // namespace aby3