diff --git a/core/paddlefl_mpc/mpc_protocol/abstract_context.h b/core/paddlefl_mpc/mpc_protocol/abstract_context.h index fc5728a80f838668114b148226b94e422af64a55..c945aa572e6b26d8b6f5dfb18ee56280cdca0834 100644 --- a/core/paddlefl_mpc/mpc_protocol/abstract_context.h +++ b/core/paddlefl_mpc/mpc_protocol/abstract_context.h @@ -65,9 +65,17 @@ public: size_t party() const { return _party; } - size_t pre_party() const { return (_party + _num_party - 1) % _num_party; } + size_t pre_party() const { + PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, + "number of party is not set."); + return (_party + _num_party - 1) % _num_party; + } - size_t next_party() const { return (_party + 1) % _num_party; } + size_t next_party() const { + PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true, + "number of party is not set."); + return (_party + 1) % _num_party; + } template T gen_random(bool next) { return _prng[next].get(); } @@ -120,13 +128,11 @@ public: [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); } - template class Tensor> + template class Tensor> void ot(size_t sender, size_t receiver, size_t helper, - const Tensor *choice, const Tensor *m[2], Tensor *buffer[2], - Tensor *ret) { + const Tensor* choice, const Tensor* m[2], + Tensor* buffer[2], Tensor* ret) { // TODO: check tensor shape equals - PADDLE_ENFORCE_EQ(_num_party, 3, - "`ot` API is for 3 party protocol."); const size_t numel = buffer[0]->numel(); if (party() == sender) { bool common = helper == next_party(); @@ -146,9 +152,8 @@ public: this->template gen_random(*buffer[1], common); for (size_t i = 0; i < numel; ++i) { - // TODO: check if choice is one bit - buffer[0]->data()[i] = - choice->data()[i] ? buffer[1]->data()[i] : buffer[0]->data()[i]; + buffer[0]->data()[i] = choice->data()[i] & 1 ? + buffer[1]->data()[i] : buffer[0]->data()[i]; } network()->template send(receiver, *buffer[0]); } else if (party() == receiver) { @@ -156,13 +161,11 @@ public: network()->template recv(sender, *buffer[1]); network()->template recv(helper, *ret); size_t i = 0; - std::for_each(ret->data(), ret->data() + numel, - [&buffer, &i, choice, ret](T &in) { - // TODO: check if choice is one bit - bool c = choice->data()[i]; - in ^= buffer[c]->data()[i]; - ++i; - }); + std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) { + bool c = choice->data()[i] & 1; + in ^= buffer[c]->data()[i]; + ++i;} + ); } }