From e11231035ef23f610350911c7ed6159e70145c1e Mon Sep 17 00:00:00 2001 From: yangqingyou Date: Wed, 26 Aug 2020 12:51:05 +0000 Subject: [PATCH] fixed abstract_context bug --- .../mpc_protocol/abstract_context.h | 37 ++++++++++--------- 1 file changed, 20 insertions(+), 17 deletions(-) diff --git a/core/paddlefl_mpc/mpc_protocol/abstract_context.h b/core/paddlefl_mpc/mpc_protocol/abstract_context.h index fc5728a..c945aa5 100644 --- a/core/paddlefl_mpc/mpc_protocol/abstract_context.h +++ b/core/paddlefl_mpc/mpc_protocol/abstract_context.h @@ -65,9 +65,17 @@ public: size_t party() const { return _party; } - size_t pre_party() const { return (_party + _num_party - 1) % _num_party; } + size_t pre_party() const { + PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, + "number of party is not set."); + return (_party + _num_party - 1) % _num_party; + } - size_t next_party() const { return (_party + 1) % _num_party; } + size_t next_party() const { + PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, + "number of party is not set."); + return (_party + 1) % _num_party; + } template T gen_random(bool next) { return _prng[next].get(); } @@ -120,13 +128,11 @@ public: [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); } - template class Tensor> + template class Tensor> void ot(size_t sender, size_t receiver, size_t helper, - const Tensor *choice, const Tensor *m[2], Tensor *buffer[2], - Tensor *ret) { + const Tensor* choice, const Tensor* m[2], + Tensor* buffer[2], Tensor* ret) { // TODO: check tensor shape equals - PADDLE_ENFORCE_EQ(_num_party, 3, - "`ot` API is for 3 party protocol."); const size_t numel = buffer[0]->numel(); if (party() == sender) { bool common = helper == next_party(); @@ -146,9 +152,8 @@ public: this->template gen_random(*buffer[1], common); for (size_t i = 0; i < numel; ++i) { - // TODO: check if choice is one bit - buffer[0]->data()[i] = - choice->data()[i] ? buffer[1]->data()[i] : buffer[0]->data()[i]; + buffer[0]->data()[i] = choice->data()[i] & 1 ? + buffer[1]->data()[i] : buffer[0]->data()[i]; } network()->template send(receiver, *buffer[0]); } else if (party() == receiver) { @@ -156,13 +161,11 @@ public: network()->template recv(sender, *buffer[1]); network()->template recv(helper, *ret); size_t i = 0; - std::for_each(ret->data(), ret->data() + numel, - [&buffer, &i, choice, ret](T &in) { - // TODO: check if choice is one bit - bool c = choice->data()[i]; - in ^= buffer[c]->data()[i]; - ++i; - }); + std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { + bool c = choice->data()[i] & 1; + in ^= buffer[c]->data()[i]; + ++i;} + ); } } -- GitLab