提交 e1123103 编写于 作者: Y yangqingyou

fixed abstract_context bug

上级 b8d8ee2b
......@@ -65,9 +65,17 @@ public:
size_t party() const { return _party; }
size_t pre_party() const { return (_party + _num_party - 1) % _num_party; }
size_t pre_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + _num_party - 1) % _num_party;
}
size_t next_party() const { return (_party + 1) % _num_party; }
size_t next_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + 1) % _num_party;
}
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); }
......@@ -120,13 +128,11 @@ public:
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
}
template <typename T, template <typename> class Tensor>
template<typename T, template <typename> class Tensor>
void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T> *choice, const Tensor<T> *m[2], Tensor<T> *buffer[2],
Tensor<T> *ret) {
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
PADDLE_ENFORCE_EQ(_num_party, 3,
"`ot` API is for 3 party protocol.");
const size_t numel = buffer[0]->numel();
if (party() == sender) {
bool common = helper == next_party();
......@@ -146,9 +152,8 @@ public:
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
// TODO: check if choice is one bit
buffer[0]->data()[i] =
choice->data()[i] ? buffer[1]->data()[i] : buffer[0]->data()[i];
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
network()->template send(receiver, *buffer[0]);
} else if (party() == receiver) {
......@@ -156,13 +161,11 @@ public:
network()->template recv(sender, *buffer[1]);
network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel,
[&buffer, &i, choice, ret](T &in) {
// TODO: check if choice is one bit
bool c = choice->data()[i];
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;
});
++i;}
);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册