提交 1228e002 编写于 作者: Y yangqingyou

refactor circuit context and add privc context

上级 ab32cf3d
......@@ -18,63 +18,46 @@
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
#include "core/privc3/prng_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace aby3 {
namespace paddle {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
namespace mpc {
class CircuitContext {
using block = psi::block;
using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class AbstractContext {
public:
CircuitContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block,
const block &seed2 = g_zero_block) {
/*
AbstractContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = psi::g_zero_block,
const block &seed2 = psi::g_zero_block) {
init(party, network, seed, seed2);
}
*/
AbstractContext() = default;
AbstractContext(const AbstractContext &other) = delete;
CircuitContext(const CircuitContext &other) = delete;
CircuitContext &operator=(const CircuitContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) {
set_party(party);
set_network(network);
if (equals(seed, g_zero_block)) {
seed = block_from_dev_urandom();
}
if (equals(seed2, g_zero_block)) {
seed2 = block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = (this->party() - 1 + 3) % 3;
size_t party_next = (this->party() + 1) % 3;
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
AbstractContext &operator=(const AbstractContext &other) = delete;
set_random_seed(seed, 1);
}
virtual void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) = 0;
void set_party(size_t party) {
if (party >= 3) {
// exception handling
}
PADDLE_ENFORCE_LT(party, _num_party,
"party idx should less than %d.",
_num_party);
_party = party;
}
void set_num_party(size_t num_party) {
PADDLE_ENFORCE_TRUE(num_party == 2 || num_party == 3,
"2 or 3 party protocol is supported.");
_num_party = num_party;
}
void set_network(std::shared_ptr<AbstractNetwork> network) {
_network = network;
}
......@@ -82,22 +65,24 @@ public:
AbstractNetwork *network() { return _network.get(); }
void set_random_seed(const block &seed, size_t idx) {
if (idx >= 3) {
// exception handling
}
PADDLE_ENFORCE_LE(idx, _num_party,
"prng idx should be less and equal to %d.",
_num_party);
_prng[idx].set_seed(seed);
}
size_t party() const { return _party; }
size_t pre_party() const { return (_party + 3 - 1) % 3; }
size_t pre_party() const { return (_party + _num_party - 1) % _num_party; }
size_t next_party() const { return (_party + 1) % 3; }
size_t next_party() const { return (_party + 1) % _num_party; }
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); }
template <typename T, template <typename> class Tensor>
void gen_random(Tensor<T> &tensor, bool next) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_random` API is for 3 party protocol.");
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this, next](T &val) { val = this->template gen_random<T>(next); });
......@@ -113,11 +98,15 @@ public:
}
template <typename T> T gen_zero_sharing_arithmetic() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
return _prng[0].get<T>() - _prng[1].get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) {
val = this->template gen_zero_sharing_arithmetic<T>();
......@@ -125,11 +114,15 @@ public:
}
template <typename T> T gen_zero_sharing_boolean() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
return _prng[0].get<T>() ^ _prng[1].get<T>();
}
template <typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
std::for_each(
tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
......@@ -140,6 +133,8 @@ public:
const Tensor<T> *choice, const Tensor<T> *m[2], Tensor<T> *buffer[2],
Tensor<T> *ret) {
// TODO: check tensor shape equals
PADDLE_ENFORCE_EQ(_num_party, 3,
"`ot` API is for 3 party protocol.");
const size_t numel = buffer[0]->numel();
if (party() == sender) {
bool common = helper == next_party();
......@@ -180,6 +175,7 @@ public:
}
private:
size_t _num_party;
size_t _party;
std::shared_ptr<AbstractNetwork> _network;
......@@ -187,4 +183,6 @@ private:
PseudorandomNumberGenerator _prng[3];
};
} // namespace aby3
} // namespace mpc
} //namespace paddle
......@@ -22,7 +22,7 @@
#include "mpc_operators.h"
#include "paddle/fluid/framework/tensor.h"
#include "core/privc3/boolean_tensor.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/aby3_context.h"
#include "core/privc3/fixedpoint_tensor.h"
#include "core/privc3/paddle_tensor.h"
......@@ -30,7 +30,7 @@ namespace paddle {
namespace mpc {
using paddle::framework::Tensor;
using aby3::CircuitContext;
using aby3::ABY3Context;
// TODO: decide scaling factor
const size_t ABY3_SCALING_FACTOR = 16;
using FixedTensor = aby3::FixedPointTensor<int64_t, ABY3_SCALING_FACTOR>;
......
......@@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store(
mesh_net->init();
_network = std::move(mesh_net);
_circuit_ctx = std::make_shared<CircuitContext>(role, _network);
_circuit_ctx = std::make_shared<ABY3Context>(role, _network);
_operators = std::make_shared<Aby3OperatorsImpl>();
_is_initialized = true;
}
......@@ -63,7 +63,7 @@ std::shared_ptr<AbstractNetwork> Aby3Protocol::network() {
return _network;
}
std::shared_ptr<CircuitContext> Aby3Protocol::mpc_context() {
std::shared_ptr<AbstractContext> Aby3Protocol::mpc_context() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _circuit_ctx;
}
......
......@@ -24,12 +24,13 @@
#include "mesh_network.h"
#include "mpc_operators.h"
#include "mpc_protocol.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/privc3/aby3_context.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
using ABY3Context = aby3::ABY3Context;
class Aby3Protocol : public MpcProtocol {
public:
......@@ -46,14 +47,14 @@ public:
std::shared_ptr<AbstractNetwork> network() override;
std::shared_ptr<CircuitContext> mpc_context() override;
std::shared_ptr<AbstractContext> mpc_context() override;
private:
bool _is_initialized = false;
const std::string PROT_INIT_ERR = "The protocol is not yet initialized.";
std::shared_ptr<MpcOperators> _operators;
std::shared_ptr<AbstractNetwork> _network;
std::shared_ptr<CircuitContext> _circuit_ctx;
std::shared_ptr<AbstractContext> _circuit_ctx;
};
} // mpc
......
......@@ -24,7 +24,7 @@
namespace paddle {
namespace mpc {
thread_local std::shared_ptr<CircuitContext> ContextHolder::current_mpc_ctx;
thread_local std::shared_ptr<AbstractContext> ContextHolder::current_mpc_ctx;
thread_local const ExecutionContext *ContextHolder::current_exec_ctx;
......
......@@ -22,20 +22,20 @@
#pragma once
#include "paddle/fluid/framework/operator.h"
#include "core/privc3/circuit_context.h"
#include "core/privc3/aby3_context.h"
#include "core/privc3/paddle_tensor.h"
namespace paddle {
namespace mpc {
using CircuitContext = aby3::CircuitContext;
using ABY3Context = aby3::ABY3Context;
using ExecutionContext = paddle::framework::ExecutionContext;
class ContextHolder {
public:
template <typename Operation>
static void run_with_context(const ExecutionContext *exec_ctx,
std::shared_ptr<CircuitContext> mpc_ctx,
std::shared_ptr<AbstractContext> mpc_ctx,
Operation op) {
// set new ctxs
......@@ -60,7 +60,7 @@ public:
_s_current_tensor_factory = old_factory;
}
static std::shared_ptr<CircuitContext> mpc_ctx() { return current_mpc_ctx; }
static std::shared_ptr<AbstractContext> mpc_ctx() { return current_mpc_ctx; }
static const ExecutionContext *exec_ctx() { return current_exec_ctx; }
......@@ -77,7 +77,7 @@ public:
}
private:
thread_local static std::shared_ptr<CircuitContext> current_mpc_ctx;
thread_local static std::shared_ptr<AbstractContext> current_mpc_ctx;
thread_local static const ExecutionContext *current_exec_ctx;
......
......@@ -19,7 +19,6 @@
#include "aby3_protocol.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
......
......@@ -21,7 +21,7 @@
#include "gloo/rendezvous/hash_store.h"
#include "mpc_config.h"
#include "mpc_operators.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle {
namespace mpc {
......@@ -44,7 +44,7 @@ public:
virtual std::shared_ptr<AbstractNetwork> network() = 0;
virtual std::shared_ptr<aby3::CircuitContext> mpc_context() = 0;
virtual std::shared_ptr<AbstractContext> mpc_context() = 0;
private:
const std::string _name;
......
......@@ -17,7 +17,6 @@
#include "aby3_protocol.h"
#include "mpc_config.h"
#include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h"
namespace paddle {
......
......@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle {
namespace operators {
......@@ -32,7 +32,7 @@ public:
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
std::shared_ptr<mpc::AbstractContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); });
}
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class PrivCContext : public AbstractContext {
public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block) {
init(party, network, g_zero_block, seed);
}
PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_party(party);
set_network(network);
set_num_party(2);
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
// seed2 is private
set_random_seed(seed2, 2);
}
};
} // namespace aby3
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class ABY3Context : public AbstractContext {
public:
ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block,
const block &seed2 = g_zero_block) {
init(party, network, seed, seed2);
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_party(party);
set_network(network);
set_num_party(3);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = pre_party();
size_t party_next = next_party();
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
set_random_seed(seed, 1);
}
};
} // namespace aby3
......@@ -116,7 +116,7 @@ public:
void bit_extract(size_t i, BooleanTensor *ret) const;
private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() {
static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
......
......@@ -258,7 +258,7 @@ void BooleanTensor<T>::ppa(const BooleanTensor *rhs, BooleanTensor *ret,
}
template <typename T, size_t N>
void a2b(CircuitContext *aby3_ctx, TensorAdapterFactory *tensor_factory,
void a2b(AbstractContext *aby3_ctx, TensorAdapterFactory *tensor_factory,
const FixedPointTensor<T, N> *a, BooleanTensor<T> *b, size_t n_bits) {
std::shared_ptr<TensorAdapter<T>> tmp[4];
......
......@@ -27,19 +27,20 @@
#include "boolean_tensor.h"
#include "fixedpoint_tensor.h"
#include "paddle_tensor.h"
#include "circuit_context.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
namespace aby3 {
using paddle::framework::Tensor;
using AbstractContext = paddle::mpc::AbstractContext;
class BooleanTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3];
std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
......@@ -83,7 +84,7 @@ public:
void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net);
_mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
}
std::shared_ptr<TensorAdapter<int64_t>> gen1() {
......
......@@ -17,7 +17,7 @@
#include <vector>
#include "boolean_tensor.h"
#include "circuit_context.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "paddle_tensor.h"
......@@ -139,7 +139,7 @@ public:
void neq(const CTensor<T, N1...> *rhs, BooleanTensor<T> *ret) const;
private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() {
static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx();
}
......
......@@ -16,21 +16,23 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h"
namespace aby3 {
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N16 = FixedPointTensor<int64_t, 16>;
using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N16 = FixedPointTensor<int64_t, 16>;
using AbstractContext = paddle::mpc::AbstractContext;
class FixedTensorTest : public ::testing::Test {
public:
paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3];
std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store;
std::thread _t[3];
std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
......@@ -67,7 +69,7 @@ public:
void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx);
net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net);
_mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
}
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册