未验证 提交 acfda7c1 编写于 作者: J jed 提交者: GitHub

Merge pull request #106 from Yanghello/refactor_context

Refactor context
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "core/privc3/prng_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace mpc {
using block = psi::block;
using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class AbstractContext {
public:
AbstractContext(size_t party, std::shared_ptr<AbstractNetwork> network) {
set_party(party);
set_network(network);
};
AbstractContext(const AbstractContext &other) = delete;
AbstractContext &operator=(const AbstractContext &other) = delete;
void set_party(size_t party) {
_party = party;
}
void set_num_party(size_t num_party) {
_num_party = num_party;
}
void set_network(std::shared_ptr<AbstractNetwork> network) {
_network = network;
}
AbstractNetwork *network() { return _network.get(); }
void set_random_seed(const block &seed, size_t idx) {
PADDLE_ENFORCE_LE(idx, _num_party,
"prng idx should be less and equal to %d.",
_num_party);
get_prng(idx).set_seed(seed);
}
size_t party() const { return _party; }
size_t pre_party() const {
return (_party + _num_party - 1) % _num_party;
}
size_t next_party() const {
return (_party + 1) % _num_party;
}
// generate random from prng[0] or prng[1]
// @param next: use bool type for idx 0 or 1
template <typename T> T gen_random(bool next) {
return get_prng(next).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_random(Tensor<T> &tensor, bool next) {
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this, next](T &val) { val = this->template gen_random<T>(next); });
}
template <typename T> T gen_random_private() { return get_prng(2).get<T>(); }
template <typename T, template <typename> class Tensor>
void gen_random_private(Tensor<T> &tensor) {
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_random_private<T>(); });
}
template <typename T> T gen_zero_sharing_arithmetic() {
return get_prng(0).get<T>() - get_prng(1).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T> &tensor) {
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) {
val = this->template gen_zero_sharing_arithmetic<T>();
});
}
template <typename T> T gen_zero_sharing_boolean() {
return get_prng(0).get<T>() ^ get_prng(1).get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T> &tensor) {
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
}
protected:
virtual PseudorandomNumberGenerator& get_prng(size_t idx) = 0;
private:
size_t _num_party;
size_t _party;
std::shared_ptr<AbstractNetwork> _network;
};
} // namespace mpc
} //namespace paddle
......@@ -21,7 +21,8 @@ limitations under the License. */
#include "context_holder.h"
#include "mpc_operators.h"
#include "paddle/fluid/framework/tensor.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/boolean_tensor.h"
#include "core/privc3/aby3_context.h"
#include "core/privc3/fixedpoint_tensor.h"
#include "core/privc3/boolean_tensor.h"
#include "core/privc3/paddle_tensor.h"
......@@ -30,7 +31,7 @@ namespace paddle {
namespace mpc {
using paddle::framework::Tensor;
using aby3::CircuitContext;
using aby3::ABY3Context;
// TODO: decide scaling factor
const size_t ABY3_SCALING_FACTOR = FIXED_POINTER_SCALING_FACTOR;
using FixedTensor = aby3::FixedPointTensor<int64_t, ABY3_SCALING_FACTOR>;
......
......@@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store(
mesh_net->init();
_network = std::move(mesh_net);
_circuit_ctx = std::make_shared<CircuitContext>(role, _network);
_circuit_ctx = std::make_shared<ABY3Context>(role, _network);
_operators = std::make_shared<Aby3OperatorsImpl>();
_is_initialized = true;
}
......@@ -63,7 +63,7 @@ std::shared_ptr<AbstractNetwork> Aby3Protocol::network() {
return _network;
}
std::shared_ptr<CircuitContext> Aby3Protocol::mpc_context() {
std::shared_ptr<AbstractContext> Aby3Protocol::mpc_context() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _circuit_ctx;
}
......
......@@ -24,12 +24,13 @@
#include "mesh_network.h"
#include "mpc_operators.h"
#include "mpc_protocol.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/privc3/aby3_context.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
using ABY3Context = aby3::ABY3Context;
class Aby3Protocol : public MpcProtocol {
public:
......@@ -46,14 +47,14 @@ public:
std::shared_ptr<AbstractNetwork> network() override;
std::shared_ptr<CircuitContext> mpc_context() override;
std::shared_ptr<AbstractContext> mpc_context() override;
private:
bool _is_initialized = false;
const std::string PROT_INIT_ERR = "The protocol is not yet initialized.";
std::shared_ptr<MpcOperators> _operators;
std::shared_ptr<AbstractNetwork> _network;
std::shared_ptr<CircuitContext> _circuit_ctx;
std::shared_ptr<AbstractContext> _circuit_ctx;
};
} // mpc
......
......@@ -24,7 +24,7 @@
namespace paddle {
namespace mpc {
thread_local std::shared_ptr<CircuitContext> ContextHolder::current_mpc_ctx;
thread_local std::shared_ptr<AbstractContext> ContextHolder::current_mpc_ctx;
thread_local const ExecutionContext *ContextHolder::current_exec_ctx;
......
......@@ -22,20 +22,20 @@
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/aby3_context.h"
#include "core/privc3/paddle_tensor.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
using ABY3Context = aby3::ABY3Context;
using ExecutionContext = paddle::framework::ExecutionContext;
class ContextHolder {
public:
template <typename Operation>
static void run_with_context(const ExecutionContext *exec_ctx,
std::shared_ptr<CircuitContext> mpc_ctx,
std::shared_ptr<AbstractContext> mpc_ctx,
Operation op) {
// set new ctxs
......@@ -60,7 +60,7 @@ public:
_s_current_tensor_factory = old_factory;
}
static std::shared_ptr<CircuitContext> mpc_ctx() { return current_mpc_ctx; }
static std::shared_ptr<AbstractContext> mpc_ctx() { return current_mpc_ctx; }
static const ExecutionContext *exec_ctx() { return current_exec_ctx; }
......@@ -77,7 +77,7 @@ public:
}
private:
thread_local static std::shared_ptr<CircuitContext> current_mpc_ctx;
thread_local static std::shared_ptr<AbstractContext> current_mpc_ctx;
thread_local static const ExecutionContext *current_exec_ctx;
......
......@@ -19,7 +19,6 @@
#include "aby3_protocol.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
......
......@@ -21,7 +21,7 @@
#include "gloo/rendezvous/hash_store.h"
#include "mpc_config.h"
#include "mpc_operators.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle {
namespace mpc {
......@@ -44,7 +44,7 @@ public:
virtual std::shared_ptr<AbstractNetwork> network() = 0;
virtual std::shared_ptr<aby3::CircuitContext> mpc_context() = 0;
virtual std::shared_ptr<AbstractContext> mpc_context() = 0;
private:
const std::string _name;
......
......@@ -17,7 +17,6 @@
#include "aby3_protocol.h"
#include "mpc_config.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle {
namespace operators {
......@@ -32,7 +32,7 @@ public:
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
std::shared_ptr<mpc::AbstractContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class PrivCContext : public AbstractContext {
public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
block seed = g_zero_block):
AbstractContext::AbstractContext(party, network) {
set_num_party(2);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
}
PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete;
protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng;
}
private:
PseudorandomNumberGenerator _prng;
};
} // namespace aby3
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class ABY3Context : public AbstractContext {
public:
ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network,
block seed = g_zero_block,
block seed2 = g_zero_block) :
AbstractContext::AbstractContext(party, network) {
set_num_party(3);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = pre_party();
size_t party_next = next_party();
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
set_random_seed(seed, 1);
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
protected:
PseudorandomNumberGenerator& get_prng(size_t idx) override {
return _prng[idx];
}
private:
PseudorandomNumberGenerator _prng[3];
};
} // namespace aby3
......@@ -122,7 +122,7 @@ public:
void onehot_from_cmp();
private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() {
static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
......
......@@ -15,6 +15,7 @@
#pragma once
#include <algorithm>
#include "core/privc3/ot.h"
namespace aby3 {
......@@ -268,7 +269,7 @@ void BooleanTensor<T>::ppa(const BooleanTensor* rhs,
}
template<typename T, size_t N>
void a2b(CircuitContext* aby3_ctx,
void a2b(AbstractContext* aby3_ctx,
TensorAdapterFactory* tensor_factory,
const FixedPointTensor<T, N>* a,
BooleanTensor<T>* b,
......@@ -432,7 +433,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
m[0]->add(tmp[0], m[0]);
m[1]->add(tmp[0], m[1]);
aby3_ctx()->template ot(idx0, idx1, idx2, null_arg[0],
ObliviousTransfer::ot(idx0, idx1, idx2, null_arg[0],
const_cast<const aby3::TensorAdapter<T>**>(m),
tmp, null_arg[0]);
......@@ -445,7 +446,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = s1
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(0)));
// ret1 = a * b + s0
aby3_ctx()->template ot(idx0, idx1, idx2, share(1),
ObliviousTransfer::ot(idx0, idx1, idx2, share(1),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, ret->mutable_share(1));
aby3_ctx()->network()->template send(idx0, *(ret->share(0)));
......@@ -454,7 +455,7 @@ void BooleanTensor<T>::mul(const TensorAdapter<T>* rhs,
// ret0 = a * b + s0
aby3_ctx()->template gen_zero_sharing_arithmetic(*(ret->mutable_share(1)));
// ret1 = s2
aby3_ctx()->template ot(idx0, idx1, idx2, share(0),
ObliviousTransfer::ot(idx0, idx1, idx2, share(0),
const_cast<const aby3::TensorAdapter<T>**>(null_arg),
tmp, null_arg[0]);
......
......@@ -27,19 +27,20 @@
#include "boolean_tensor.h"
#include "fixedpoint_tensor.h"
#include "paddle_tensor.h"
#include "circuit_context.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
namespace aby3 {
using paddle::framework::Tensor;
using AbstractContext = paddle::mpc::AbstractContext;
class BooleanTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3];
std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
......@@ -83,7 +84,7 @@ public:
void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net);
_mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
}
std::shared_ptr<TensorAdapter<int64_t>> gen1() {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
class CircuitContext {
public:
CircuitContext(size_t party,
std::shared_ptr<AbstractNetwork> network,
const block& seed = g_zero_block,
const block& seed2 = g_zero_block) {
init(party, network, seed, seed2);
}
CircuitContext(const CircuitContext& other) = delete;
CircuitContext& operator=(const CircuitContext& other) = delete;
void init(size_t party,
std::shared_ptr<AbstractNetwork> network,
block seed,
block seed2) {
set_party(party);
set_network(network);
if (equals(seed, g_zero_block)) {
seed = block_from_dev_urandom();
}
if (equals(seed2, g_zero_block)) {
seed2 = block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = (this->party() - 1 + 3) % 3;
size_t party_next = (this->party() + 1) % 3;
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
set_random_seed(seed, 1);
}
void set_party(size_t party) {
if (party >= 3) {
// exception handling
}
_party = party;
}
void set_network(std::shared_ptr<AbstractNetwork> network) {
_network = network;
}
AbstractNetwork* network() {
return _network.get();
}
void set_random_seed(const block& seed, size_t idx) {
if (idx >= 3) {
// exception handling
}
_prng[idx].set_seed(seed);
}
size_t party() const {
return _party;
}
size_t pre_party() const {
return (_party + 3 - 1) % 3;
}
size_t next_party() const {
return (_party + 1) % 3;
}
template <typename T>
T gen_random(bool next) {
return _prng[next].get<T>();
}
template<typename T, template <typename> class Tensor>
void gen_random(Tensor<T>& tensor, bool next) {
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this, next](T& val) {
val = this->template gen_random<T>(next);
});
}
template <typename T>
T gen_random_private() {
return _prng[2].get<T>();
}
template<typename T, template <typename> class Tensor>
void gen_random_private(Tensor<T>& tensor) {
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T& val) {
val = this->template gen_random_private<T>();
});
}
template <typename T>
T gen_zero_sharing_arithmetic() {
return _prng[0].get<T>() - _prng[1].get<T>();
}
template<typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T>& tensor) {
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T& val) {
val = this->template gen_zero_sharing_arithmetic<T>();
});
}
template <typename T>
T gen_zero_sharing_boolean() {
return _prng[0].get<T>() ^ _prng[1].get<T>();
}
template<typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T>& tensor) {
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T& val) {
val = this->template gen_zero_sharing_boolean<T>();
});
}
template<typename T, template <typename> class Tensor>
void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
const size_t numel = buffer[0]->numel();
if (party() == sender) {
bool common = helper == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
network()->template send(receiver, *buffer[0]);
network()->template send(receiver, *buffer[1]);
} else if (party() == helper) {
bool common = sender == next_party();
this->template gen_random(*buffer[0], common);
this->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
network()->template send(receiver, *buffer[0]);
} else if (party() == receiver) {
network()->template recv(sender, *buffer[0]);
network()->template recv(sender, *buffer[1]);
network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
private:
size_t _party;
std::shared_ptr<AbstractNetwork> _network;
PseudorandomNumberGenerator _prng[3];
};
} // namespace aby3
......@@ -16,7 +16,9 @@
#include <vector>
#include "circuit_context.h"
#include "boolean_tensor.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "paddle_tensor.h"
#include "boolean_tensor.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
......@@ -195,8 +197,7 @@ public:
size_t scaling_factor);
private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() {
static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
......
......@@ -20,21 +20,23 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h"
namespace aby3 {
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N16 = FixedPointTensor<int64_t, 16>;
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N16 = FixedPointTensor<int64_t, 16>;
using AbstractContext = paddle::mpc::AbstractContext;
class FixedTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3];
std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
std::thread _t[3];
std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
......@@ -71,7 +73,7 @@ public:
void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net);
_mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
}
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
namespace aby3 {
class ObliviousTransfer {
public:
template<typename T, template <typename> class Tensor>
static inline void ot(size_t sender, size_t receiver, size_t helper,
const Tensor<T>* choice, const Tensor<T>* m[2],
Tensor<T>* buffer[2], Tensor<T>* ret) {
// TODO: check tensor shape equals
auto aby3_ctx = paddle::mpc::ContextHolder::mpc_ctx();
const size_t numel = buffer[0]->numel();
if (aby3_ctx->party() == sender) {
bool common = helper == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] ^= m[0]->data()[i];
buffer[1]->data()[i] ^= m[1]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
aby3_ctx->network()->template send(receiver, *buffer[1]);
} else if (aby3_ctx->party() == helper) {
bool common = sender == aby3_ctx->next_party();
aby3_ctx->template gen_random(*buffer[0], common);
aby3_ctx->template gen_random(*buffer[1], common);
for (size_t i = 0; i < numel; ++i) {
buffer[0]->data()[i] = choice->data()[i] & 1 ?
buffer[1]->data()[i] : buffer[0]->data()[i];
}
aby3_ctx->network()->template send(receiver, *buffer[0]);
} else if (aby3_ctx->party() == receiver) {
aby3_ctx->network()->template recv(sender, *buffer[0]);
aby3_ctx->network()->template recv(sender, *buffer[1]);
aby3_ctx->network()->template recv(helper, *ret);
size_t i = 0;
std::for_each(ret->data(), ret->data() + numel, [&buffer, &i, choice, ret](T& in) {
bool c = choice->data()[i] & 1;
in ^= buffer[c]->data()[i];
++i;}
);
}
}
};
} // namespace aby3
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册