提交 e1123103 编写于 作者: Y yangqingyou

fixed abstract_context bug

上级 b8d8ee2b
...@@ -65,9 +65,17 @@ public: ...@@ -65,9 +65,17 @@ public:
size_t party() const { return _party; } size_t party() const { return _party; }
size_t pre_party() const { return (_party + _num_party - 1) % _num_party; } size_t pre_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + _num_party - 1) % _num_party;
}
size_t next_party() const { return (_party + 1) % _num_party; } size_t next_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + 1) % _num_party;
}
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); } template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); }
...@@ -120,13 +128,11 @@ public: ...@@ -120,13 +128,11 @@ public:
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); }); [this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
} }
template <typename T, template <typename> class Tensor> template<typename T, template <typename> class Tensor>
void ot(size_t sender, size_t receiver, size_t helper, void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T> *choice, const Tensor<T> *m[2], Tensor<T> *buffer[2], const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T> *ret) { Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals // TODO: check tensor shape equals
PADDLE_ENFORCE_EQ(_num_party, 3,
"`ot` API is for 3 party protocol.");
const size_t numel = buffer[0]->numel(); const size_t numel = buffer[0]->numel();
if (party() == sender) { if (party() == sender) {
bool common = helper == next_party(); bool common = helper == next_party();
...@@ -146,9 +152,8 @@ public: ...@@ -146,9 +152,8 @@ public:
this->template gen_random(*buffer[1], common); this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) { for (size_t i = 0; i < numel; ++i) {
// TODO: check if choice is one bit buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[0]->data()[i] = buffer[1]->data()[i] : buffer[0]->data()[i];
choice->data()[i] ? buffer[1]->data()[i] : buffer[0]->data()[i];
} }
network()->template send(receiver, *buffer[0]); network()->template send(receiver, *buffer[0]);
} else if (party() == receiver) { } else if (party() == receiver) {
...@@ -156,13 +161,11 @@ public: ...@@ -156,13 +161,11 @@ public:
network()->template recv(sender, *buffer[1]); network()->template recv(sender, *buffer[1]);
network()->template recv(helper, *ret); network()->template recv(helper, *ret);
size_t i = 0; size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
[&buffer, &i, choice, ret](T &in) { bool c = choice->data()[i] & 1;
// TODO: check if choice is one bit in ^= buffer[c]->data()[i];
bool c = choice->data()[i]; ++i;}
in ^= buffer[c]->data()[i]; );
++i;
});
} }
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册