提交 d829309d 编写于 作者: Y yangqingyou

amend according comment

上级 5e248249
......@@ -29,24 +29,23 @@ using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class AbstractContext {
public:
AbstractContext() = default;
AbstractContext(size_t party, std::shared_ptr<AbstractNetwork> network) {
init(party, network);
};
AbstractContext(const AbstractContext &other) = delete;
AbstractContext &operator=(const AbstractContext &other) = delete;
virtual void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) = 0;
void init(size_t party, std::shared_ptr<AbstractNetwork> network) {
set_party(party);
set_network(network);
}
void set_party(size_t party) {
PADDLE_ENFORCE_LT(party, _num_party,
"party idx should less than %d.",
_num_party);
_party = party;
}
void set_num_party(size_t num_party) {
PADDLE_ENFORCE_EQ(num_party == 2 || num_party == 3, true,
"2 or 3 party protocol is supported.");
_num_party = num_party;
}
......@@ -60,35 +59,33 @@ public:
PADDLE_ENFORCE_LE(idx, _num_party,
"prng idx should be less and equal to %d.",
_num_party);
_prng[idx].set_seed(seed);
get_prng(idx).set_seed(seed);
}
size_t party() const { return _party; }
size_t pre_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + _num_party - 1) % _num_party;
}
size_t next_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + 1) % _num_party;
}
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); }
// generate random from prng[0] or prng[1]
// @param next: use bool type for idx 0 or 1
template <typename T> T gen_random(bool next) {
return get_prng(next).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_random(Tensor<T> &tensor, bool next) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_random` API is for 3 party protocol.");
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this, next](T &val) { val = this->template gen_random<T>(next); });
}
template <typename T> T gen_random_private() { return _prng[2].get<T>(); }
template <typename T> T gen_random_private() { return get_prng(2).get<T>(); }
template <typename T, template <typename> class Tensor>
void gen_random_private(Tensor<T> &tensor) {
......@@ -98,15 +95,11 @@ public:
}
template <typename T> T gen_zero_sharing_arithmetic() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
return _prng[0].get<T>() - _prng[1].get<T>();
return get_prng(0).get<T>() - get_prng(1).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) {
val = this->template gen_zero_sharing_arithmetic<T>();
......@@ -114,60 +107,18 @@ public:
}
template <typename T> T gen_zero_sharing_boolean() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
return _prng[0].get<T>() ^ _prng[1].get<T>();
return get_prng(0).get<T>() ^ get_prng(1).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
}
template<typename T, template <typename> class Tensor>
void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
const size_t numel = buffer[0]->numel();
if (party() == sender) {
bool common = helper == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
network()->template send(receiver, *buffer[0]);
network()->template send(receiver, *buffer[1]);
} else if (party() == helper) {
bool common = sender == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
network()->template send(receiver, *buffer[0]);
} else if (party() == receiver) {
network()->template recv(sender, *buffer[0]);
network()->template recv(sender, *buffer[1]);
network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
protected:
virtual PseudorandomNumberGenerator& get_prng(size_t idx) = 0;
private:
size_t _num_party;
......
......@@ -29,26 +29,26 @@ using AbstractContext = paddle::mpc::AbstractContext;
class PrivCContext : public AbstractContext {
public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block) {
init(party, network, g_zero_block, seed);
block seed = g_zero_block):
AbstractContext::AbstractContext(party, network) {
set_num_party(2);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
}
PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_num_party(2);
set_party(party);
set_network(network);
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
// seed2 is private
set_random_seed(seed2, 2);
protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng;
}
private:
PseudorandomNumberGenerator _prng;
};
} // namespace aby3
......@@ -29,20 +29,10 @@ using AbstractContext = paddle::mpc::AbstractContext;
class ABY3Context : public AbstractContext {
public:
ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block,
const block &seed2 = g_zero_block) {
init(party, network, seed, seed2);
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
block seed = g_zero_block,
block seed2 = g_zero_block) :
AbstractContext::AbstractContext(party, network) {
set_num_party(3);
set_party(party);
set_network(network);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
......@@ -70,6 +60,16 @@ public:
set_random_seed(seed, 1);
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng[idx];
}
private:
PseudorandomNumberGenerator _prng[3];
};
} // namespace aby3
......@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include "core/privc3/ot.h"
namespace aby3 {
......@@ -432,9 +433,9 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
m[0]->add(tmp[0], m[0]);
m[1]->add(tmp[0], m[1]);
aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0],
const_cast<const aby3::TensorAdapter<T>**>(m),
tmp, null_arg[0]);
ObliviousTransfer::ot(idx0, idx1, idx2, null_arg[0],
const_cast<const aby3::TensorAdapter<T>**>(m),
tmp, null_arg[0]);
// ret0 = s2
// ret1 = s1
......@@ -445,18 +446,18 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = s1
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0)));
// ret1 = a * b + s0
aby3_ctx()->template ot(idx0, idx1, idx2, share(1),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, ret->mutable_share(1));
ObliviousTransfer::ot(idx0, idx1, idx2, share(1),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, ret->mutable_share(1));
aby3_ctx()->network()->template send(idx0, *(ret->share(0)));
aby3_ctx()->network()->template send(idx2, *(ret->share(1)));
} else if (party() == idx2) {
// ret0 = a * b + s0
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1)));
// ret1 = s2
aby3_ctx()->template ot(idx0, idx1, idx2, share(0),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, null_arg[0]);
ObliviousTransfer::ot(idx0, idx1, idx2, share(0),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, null_arg[0]);
aby3_ctx()->network()->template send(idx0, *(ret->share(1)));
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
namespace aby3 {
class ObliviousTransfer {
public:
template<typename T, template <typename> class Tensor>
static inline void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
auto aby3_ctx = paddle::mpc::ContextHolder::mpc_ctx();
const size_t numel = buffer[0]->numel();
if (aby3_ctx->party() == sender) {
bool common = helper == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
aby3_ctx->network()->template send(receiver, *buffer[1]);
} else if (aby3_ctx->party() == helper) {
bool common = sender == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
} else if (aby3_ctx->party() == receiver) {
aby3_ctx->network()->template recv(sender, *buffer[0]);
aby3_ctx->network()->template recv(sender, *buffer[1]);
aby3_ctx->network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
};
} // namespace aby3
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册