提交 d829309d 编写于 作者: Y yangqingyou

amend according comment

上级 5e248249
...@@ -29,24 +29,23 @@ using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator; ...@@ -29,24 +29,23 @@ using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class AbstractContext { class AbstractContext {
public: public:
AbstractContext() = default; AbstractContext(size_t party, std::shared_ptr<AbstractNetwork> network) {
init(party, network);
};
AbstractContext(const AbstractContext &other) = delete; AbstractContext(const AbstractContext &other) = delete;
AbstractContext &operator=(const AbstractContext &other) = delete; AbstractContext &operator=(const AbstractContext &other) = delete;
virtual void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed, void init(size_t party, std::shared_ptr<AbstractNetwork> network) {
block seed2) = 0; set_party(party);
set_network(network);
}
void set_party(size_t party) { void set_party(size_t party) {
PADDLE_ENFORCE_LT(party, _num_party,
"party idx should less than %d.",
_num_party);
_party = party; _party = party;
} }
void set_num_party(size_t num_party) { void set_num_party(size_t num_party) {
PADDLE_ENFORCE_EQ(num_party == 2 || num_party == 3, true,
"2 or 3 party protocol is supported.");
_num_party = num_party; _num_party = num_party;
} }
...@@ -60,35 +59,33 @@ public: ...@@ -60,35 +59,33 @@ public:
PADDLE_ENFORCE_LE(idx, _num_party, PADDLE_ENFORCE_LE(idx, _num_party,
"prng idx should be less and equal to %d.", "prng idx should be less and equal to %d.",
_num_party); _num_party);
_prng[idx].set_seed(seed); get_prng(idx).set_seed(seed);
} }
size_t party() const { return _party; } size_t party() const { return _party; }
size_t pre_party() const { size_t pre_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + _num_party - 1) % _num_party; return (_party + _num_party - 1) % _num_party;
} }
size_t next_party() const { size_t next_party() const {
PADDLE_ENFORCE_EQ(_num_party == 2 || _num_party == 3, true,
"number of party is not set.");
return (_party + 1) % _num_party; return (_party + 1) % _num_party;
} }
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); } // generate random from prng[0] or prng[1]
// @param next: use bool type for idx 0 or 1
template <typename T> T gen_random(bool next) {
return get_prng(next).get<T>();
}
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_random(Tensor<T> &tensor, bool next) { void gen_random(Tensor<T> &tensor, bool next) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_random` API is for 3 party protocol.");
std::for_each( std::for_each(
tensor.data(), tensor.data() + tensor.numel(), tensor.data(), tensor.data() + tensor.numel(),
[this, next](T &val) { val = this->template gen_random<T>(next); }); [this, next](T &val) { val = this->template gen_random<T>(next); });
} }
template <typename T> T gen_random_private() { return _prng[2].get<T>(); } template <typename T> T gen_random_private() { return get_prng(2).get<T>(); }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_random_private(Tensor<T> &tensor) { void gen_random_private(Tensor<T> &tensor) {
...@@ -98,15 +95,11 @@ public: ...@@ -98,15 +95,11 @@ public:
} }
template <typename T> T gen_zero_sharing_arithmetic() { template <typename T> T gen_zero_sharing_arithmetic() {
PADDLE_ENFORCE_EQ(_num_party, 3, return get_prng(0).get<T>() - get_prng(1).get<T>();
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
return _prng[0].get<T>() - _prng[1].get<T>();
} }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T> &tensor) { void gen_zero_sharing_arithmetic(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
std::for_each(tensor.data(), tensor.data() + tensor.numel(), std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { [this](T &val) {
val = this->template gen_zero_sharing_arithmetic<T>(); val = this->template gen_zero_sharing_arithmetic<T>();
...@@ -114,60 +107,18 @@ public: ...@@ -114,60 +107,18 @@ public:
} }
template <typename T> T gen_zero_sharing_boolean() { template <typename T> T gen_zero_sharing_boolean() {
PADDLE_ENFORCE_EQ(_num_party, 3, return get_prng(0).get<T>() ^ get_prng(1).get<T>();
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
return _prng[0].get<T>() ^ _prng[1].get<T>();
} }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T> &tensor) { void gen_zero_sharing_boolean(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
std::for_each( std::for_each(
tensor.data(), tensor.data() + tensor.numel(), tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); }); [this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
} }
template<typename T, template <typename> class Tensor> protected:
void ot(size_t sender, size_t receiver, size_t helper, virtual PseudorandomNumberGenerator& get_prng(size_t idx) = 0;
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
const size_t numel = buffer[0]->numel();
if (party() == sender) {
bool common = helper == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
network()->template send(receiver, *buffer[0]);
network()->template send(receiver, *buffer[1]);
} else if (party() == helper) {
bool common = sender == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
network()->template send(receiver, *buffer[0]);
} else if (party() == receiver) {
network()->template recv(sender, *buffer[0]);
network()->template recv(sender, *buffer[1]);
network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
private: private:
size_t _num_party; size_t _num_party;
......
...@@ -29,26 +29,26 @@ using AbstractContext = paddle::mpc::AbstractContext; ...@@ -29,26 +29,26 @@ using AbstractContext = paddle::mpc::AbstractContext;
class PrivCContext : public AbstractContext { class PrivCContext : public AbstractContext {
public: public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network, PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block) { block seed = g_zero_block):
init(party, network, g_zero_block, seed); AbstractContext::AbstractContext(party, network) {
set_num_party(2);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
} }
PrivCContext(const PrivCContext &other) = delete; PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete; PrivCContext &operator=(const PrivCContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed, protected:
block seed2) override { PseudorandomNumberGenerator& get_prng(size_t idx) override {
set_num_party(2); return _prng;
set_party(party);
set_network(network);
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
// seed2 is private
set_random_seed(seed2, 2);
} }
private:
PseudorandomNumberGenerator _prng;
}; };
} // namespace aby3 } // namespace aby3
...@@ -29,20 +29,10 @@ using AbstractContext = paddle::mpc::AbstractContext; ...@@ -29,20 +29,10 @@ using AbstractContext = paddle::mpc::AbstractContext;
class ABY3Context : public AbstractContext { class ABY3Context : public AbstractContext {
public: public:
ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network, ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block, block seed = g_zero_block,
const block &seed2 = g_zero_block) { block seed2 = g_zero_block) :
init(party, network, seed, seed2); AbstractContext::AbstractContext(party, network) {
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_num_party(3); set_num_party(3);
set_party(party);
set_network(network);
if (psi::equals(seed, psi::g_zero_block)) { if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom(); seed = psi::block_from_dev_urandom();
...@@ -70,6 +60,16 @@ public: ...@@ -70,6 +60,16 @@ public:
set_random_seed(seed, 1); set_random_seed(seed, 1);
} }
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng[idx];
}
private:
PseudorandomNumberGenerator _prng[3];
}; };
} // namespace aby3 } // namespace aby3
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#pragma once #pragma once
#include <algorithm> #include <algorithm>
#include "core/privc3/ot.h"
namespace aby3 { namespace aby3 {
...@@ -432,7 +433,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs, ...@@ -432,7 +433,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
m[0]->add(tmp[0], m[0]); m[0]->add(tmp[0], m[0]);
m[1]->add(tmp[0], m[1]); m[1]->add(tmp[0], m[1]);
aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0], ObliviousTransfer::ot(idx0, idx1, idx2, null_arg[0],
const_cast<const aby3::TensorAdapter<T>**>(m), const_cast<const aby3::TensorAdapter<T>**>(m),
tmp, null_arg[0]); tmp, null_arg[0]);
...@@ -445,7 +446,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs, ...@@ -445,7 +446,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = s1 // ret0 = s1
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0))); aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0)));
// ret1 = a * b + s0 // ret1 = a * b + s0
aby3_ctx()->template ot(idx0, idx1, idx2, share(1), ObliviousTransfer::ot(idx0, idx1, idx2, share(1),
const_cast<const aby3::TensorAdapter<T>**>(null_arg), const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, ret->mutable_share(1)); tmp, ret->mutable_share(1));
aby3_ctx()->network()->template send(idx0, *(ret->share(0))); aby3_ctx()->network()->template send(idx0, *(ret->share(0)));
...@@ -454,7 +455,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs, ...@@ -454,7 +455,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = a * b + s0 // ret0 = a * b + s0
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1))); aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1)));
// ret1 = s2 // ret1 = s2
aby3_ctx()->template ot(idx0, idx1, idx2, share(0), ObliviousTransfer::ot(idx0, idx1, idx2, share(0),
const_cast<const aby3::TensorAdapter<T>**>(null_arg), const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, null_arg[0]); tmp, null_arg[0]);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
namespace aby3 {
class ObliviousTransfer {
public:
template<typename T, template <typename> class Tensor>
static inline void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
auto aby3_ctx = paddle::mpc::ContextHolder::mpc_ctx();
const size_t numel = buffer[0]->numel();
if (aby3_ctx->party() == sender) {
bool common = helper == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
aby3_ctx->network()->template send(receiver, *buffer[1]);
} else if (aby3_ctx->party() == helper) {
bool common = sender == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
} else if (aby3_ctx->party() == receiver) {
aby3_ctx->network()->template recv(sender, *buffer[0]);
aby3_ctx->network()->template recv(sender, *buffer[1]);
aby3_ctx->network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
};
} // namespace aby3
\ No newline at end of file
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册