提交 1228e002 编写于 作者: Y yangqingyou

refactor circuit context and add privc context

上级 ab32cf3d
...@@ -18,63 +18,46 @@ ...@@ -18,63 +18,46 @@
#include <memory> #include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h" #include "core/privc3/prng_utils.h"
#include "paddle/fluid/platform/enforce.h"
namespace aby3 { namespace paddle {
using AbstractNetwork = paddle::mpc::AbstractNetwork; namespace mpc {
class CircuitContext { using block = psi::block;
using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator;
class AbstractContext {
public: public:
CircuitContext(size_t party, std::shared_ptr<AbstractNetwork> network, /*
const block &seed = g_zero_block, AbstractContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed2 = g_zero_block) { const block &seed = psi::g_zero_block,
const block &seed2 = psi::g_zero_block) {
init(party, network, seed, seed2); init(party, network, seed, seed2);
} }
*/
AbstractContext() = default;
AbstractContext(const AbstractContext &other) = delete;
CircuitContext(const CircuitContext &other) = delete; AbstractContext &operator=(const AbstractContext &other) = delete;
CircuitContext &operator=(const CircuitContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) {
set_party(party);
set_network(network);
if (equals(seed, g_zero_block)) {
seed = block_from_dev_urandom();
}
if (equals(seed2, g_zero_block)) {
seed2 = block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = (this->party() - 1 + 3) % 3;
size_t party_next = (this->party() + 1) % 3;
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
set_random_seed(seed, 1); virtual void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
} block seed2) = 0;
void set_party(size_t party) { void set_party(size_t party) {
if (party >= 3) { PADDLE_ENFORCE_LT(party, _num_party,
// exception handling "party idx should less than %d.",
} _num_party);
_party = party; _party = party;
} }
void set_num_party(size_t num_party) {
PADDLE_ENFORCE_TRUE(num_party == 2 || num_party == 3,
"2 or 3 party protocol is supported.");
_num_party = num_party;
}
void set_network(std::shared_ptr<AbstractNetwork> network) { void set_network(std::shared_ptr<AbstractNetwork> network) {
_network = network; _network = network;
} }
...@@ -82,22 +65,24 @@ public: ...@@ -82,22 +65,24 @@ public:
AbstractNetwork *network() { return _network.get(); } AbstractNetwork *network() { return _network.get(); }
void set_random_seed(const block &seed, size_t idx) { void set_random_seed(const block &seed, size_t idx) {
if (idx >= 3) { PADDLE_ENFORCE_LE(idx, _num_party,
// exception handling "prng idx should be less and equal to %d.",
} _num_party);
_prng[idx].set_seed(seed); _prng[idx].set_seed(seed);
} }
size_t party() const { return _party; } size_t party() const { return _party; }
size_t pre_party() const { return (_party + 3 - 1) % 3; } size_t pre_party() const { return (_party + _num_party - 1) % _num_party; }
size_t next_party() const { return (_party + 1) % 3; } size_t next_party() const { return (_party + 1) % _num_party; }
template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); } template <typename T> T gen_random(bool next) { return _prng[next].get<T>(); }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_random(Tensor<T> &tensor, bool next) { void gen_random(Tensor<T> &tensor, bool next) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_random` API is for 3 party protocol.");
std::for_each( std::for_each(
tensor.data(), tensor.data() + tensor.numel(), tensor.data(), tensor.data() + tensor.numel(),
[this, next](T &val) { val = this->template gen_random<T>(next); }); [this, next](T &val) { val = this->template gen_random<T>(next); });
...@@ -113,11 +98,15 @@ public: ...@@ -113,11 +98,15 @@ public:
} }
template <typename T> T gen_zero_sharing_arithmetic() { template <typename T> T gen_zero_sharing_arithmetic() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
return _prng[0].get<T>() - _prng[1].get<T>(); return _prng[0].get<T>() - _prng[1].get<T>();
} }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_zero_sharing_arithmetic(Tensor<T> &tensor) { void gen_zero_sharing_arithmetic(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_arithmetic` API is for 3 party protocol.");
std::for_each(tensor.data(), tensor.data() + tensor.numel(), std::for_each(tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { [this](T &val) {
val = this->template gen_zero_sharing_arithmetic<T>(); val = this->template gen_zero_sharing_arithmetic<T>();
...@@ -125,11 +114,15 @@ public: ...@@ -125,11 +114,15 @@ public:
} }
template <typename T> T gen_zero_sharing_boolean() { template <typename T> T gen_zero_sharing_boolean() {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
return _prng[0].get<T>() ^ _prng[1].get<T>(); return _prng[0].get<T>() ^ _prng[1].get<T>();
} }
template <typename T, template <typename> class Tensor> template <typename T, template <typename> class Tensor>
void gen_zero_sharing_boolean(Tensor<T> &tensor) { void gen_zero_sharing_boolean(Tensor<T> &tensor) {
PADDLE_ENFORCE_EQ(_num_party, 3,
"`gen_zero_sharing_boolean` API is for 3 party protocol.");
std::for_each( std::for_each(
tensor.data(), tensor.data() + tensor.numel(), tensor.data(), tensor.data() + tensor.numel(),
[this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); }); [this](T &val) { val = this->template gen_zero_sharing_boolean<T>(); });
...@@ -140,6 +133,8 @@ public: ...@@ -140,6 +133,8 @@ public:
const Tensor<T> *choice, const Tensor<T> *m[2], Tensor<T> *buffer[2], const Tensor<T> *choice, const Tensor<T> *m[2], Tensor<T> *buffer[2],
Tensor<T> *ret) { Tensor<T> *ret) {
// TODO: check tensor shape equals // TODO: check tensor shape equals
PADDLE_ENFORCE_EQ(_num_party, 3,
"`ot` API is for 3 party protocol.");
const size_t numel = buffer[0]->numel(); const size_t numel = buffer[0]->numel();
if (party() == sender) { if (party() == sender) {
bool common = helper == next_party(); bool common = helper == next_party();
...@@ -180,6 +175,7 @@ public: ...@@ -180,6 +175,7 @@ public:
} }
private: private:
size_t _num_party;
size_t _party; size_t _party;
std::shared_ptr<AbstractNetwork> _network; std::shared_ptr<AbstractNetwork> _network;
...@@ -187,4 +183,6 @@ private: ...@@ -187,4 +183,6 @@ private:
PseudorandomNumberGenerator _prng[3]; PseudorandomNumberGenerator _prng[3];
}; };
} // namespace aby3 } // namespace mpc
} //namespace paddle
...@@ -22,7 +22,7 @@ ...@@ -22,7 +22,7 @@
#include "mpc_operators.h" #include "mpc_operators.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "core/privc3/boolean_tensor.h" #include "core/privc3/boolean_tensor.h"
#include "core/privc3/circuit_context.h" #include "core/privc3/aby3_context.h"
#include "core/privc3/fixedpoint_tensor.h" #include "core/privc3/fixedpoint_tensor.h"
#include "core/privc3/paddle_tensor.h" #include "core/privc3/paddle_tensor.h"
...@@ -30,7 +30,7 @@ namespace paddle { ...@@ -30,7 +30,7 @@ namespace paddle {
namespace mpc { namespace mpc {
using paddle::framework::Tensor; using paddle::framework::Tensor;
using aby3::CircuitContext; using aby3::ABY3Context;
// TODO: decide scaling factor // TODO: decide scaling factor
const size_t ABY3_SCALING_FACTOR = 16; const size_t ABY3_SCALING_FACTOR = 16;
using FixedTensor = aby3::FixedPointTensor<int64_t, ABY3_SCALING_FACTOR>; using FixedTensor = aby3::FixedPointTensor<int64_t, ABY3_SCALING_FACTOR>;
......
...@@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store( ...@@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store(
mesh_net->init(); mesh_net->init();
_network = std::move(mesh_net); _network = std::move(mesh_net);
_circuit_ctx = std::make_shared<CircuitContext>(role, _network); _circuit_ctx = std::make_shared<ABY3Context>(role, _network);
_operators = std::make_shared<Aby3OperatorsImpl>(); _operators = std::make_shared<Aby3OperatorsImpl>();
_is_initialized = true; _is_initialized = true;
} }
...@@ -63,7 +63,7 @@ std::shared_ptr<AbstractNetwork> Aby3Protocol::network() { ...@@ -63,7 +63,7 @@ std::shared_ptr<AbstractNetwork> Aby3Protocol::network() {
return _network; return _network;
} }
std::shared_ptr<CircuitContext> Aby3Protocol::mpc_context() { std::shared_ptr<AbstractContext> Aby3Protocol::mpc_context() {
PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR); PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR);
return _circuit_ctx; return _circuit_ctx;
} }
......
...@@ -24,12 +24,13 @@ ...@@ -24,12 +24,13 @@
#include "mesh_network.h" #include "mesh_network.h"
#include "mpc_operators.h" #include "mpc_operators.h"
#include "mpc_protocol.h" #include "mpc_protocol.h"
#include "core/privc3/circuit_context.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/privc3/aby3_context.h"
namespace paddle { namespace paddle {
namespace mpc { namespace mpc {
using CircuitContext = aby3::CircuitContext; using ABY3Context = aby3::ABY3Context;
class Aby3Protocol : public MpcProtocol { class Aby3Protocol : public MpcProtocol {
public: public:
...@@ -46,14 +47,14 @@ public: ...@@ -46,14 +47,14 @@ public:
std::shared_ptr<AbstractNetwork> network() override; std::shared_ptr<AbstractNetwork> network() override;
std::shared_ptr<CircuitContext> mpc_context() override; std::shared_ptr<AbstractContext> mpc_context() override;
private: private:
bool _is_initialized = false; bool _is_initialized = false;
const std::string PROT_INIT_ERR = "The protocol is not yet initialized."; const std::string PROT_INIT_ERR = "The protocol is not yet initialized.";
std::shared_ptr<MpcOperators> _operators; std::shared_ptr<MpcOperators> _operators;
std::shared_ptr<AbstractNetwork> _network; std::shared_ptr<AbstractNetwork> _network;
std::shared_ptr<CircuitContext> _circuit_ctx; std::shared_ptr<AbstractContext> _circuit_ctx;
}; };
} // mpc } // mpc
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
namespace paddle { namespace paddle {
namespace mpc { namespace mpc {
thread_local std::shared_ptr<CircuitContext> ContextHolder::current_mpc_ctx; thread_local std::shared_ptr<AbstractContext> ContextHolder::current_mpc_ctx;
thread_local const ExecutionContext *ContextHolder::current_exec_ctx; thread_local const ExecutionContext *ContextHolder::current_exec_ctx;
......
...@@ -22,20 +22,20 @@ ...@@ -22,20 +22,20 @@
#pragma once #pragma once
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "core/privc3/circuit_context.h" #include "core/privc3/aby3_context.h"
#include "core/privc3/paddle_tensor.h" #include "core/privc3/paddle_tensor.h"
namespace paddle { namespace paddle {
namespace mpc { namespace mpc {
using CircuitContext = aby3::CircuitContext; using ABY3Context = aby3::ABY3Context;
using ExecutionContext = paddle::framework::ExecutionContext; using ExecutionContext = paddle::framework::ExecutionContext;
class ContextHolder { class ContextHolder {
public: public:
template <typename Operation> template <typename Operation>
static void run_with_context(const ExecutionContext *exec_ctx, static void run_with_context(const ExecutionContext *exec_ctx,
std::shared_ptr<CircuitContext> mpc_ctx, std::shared_ptr<AbstractContext> mpc_ctx,
Operation op) { Operation op) {
// set new ctxs // set new ctxs
...@@ -60,7 +60,7 @@ public: ...@@ -60,7 +60,7 @@ public:
_s_current_tensor_factory = old_factory; _s_current_tensor_factory = old_factory;
} }
static std::shared_ptr<CircuitContext> mpc_ctx() { return current_mpc_ctx; } static std::shared_ptr<AbstractContext> mpc_ctx() { return current_mpc_ctx; }
static const ExecutionContext *exec_ctx() { return current_exec_ctx; } static const ExecutionContext *exec_ctx() { return current_exec_ctx; }
...@@ -77,7 +77,7 @@ public: ...@@ -77,7 +77,7 @@ public:
} }
private: private:
thread_local static std::shared_ptr<CircuitContext> current_mpc_ctx; thread_local static std::shared_ptr<AbstractContext> current_mpc_ctx;
thread_local static const ExecutionContext *current_exec_ctx; thread_local static const ExecutionContext *current_exec_ctx;
......
...@@ -19,7 +19,6 @@ ...@@ -19,7 +19,6 @@
#include "aby3_protocol.h" #include "aby3_protocol.h"
#include "mpc_protocol_factory.h" #include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace paddle { namespace paddle {
......
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "gloo/rendezvous/hash_store.h" #include "gloo/rendezvous/hash_store.h"
#include "mpc_config.h" #include "mpc_config.h"
#include "mpc_operators.h" #include "mpc_operators.h"
#include "core/privc3/circuit_context.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle { namespace paddle {
namespace mpc { namespace mpc {
...@@ -44,7 +44,7 @@ public: ...@@ -44,7 +44,7 @@ public:
virtual std::shared_ptr<AbstractNetwork> network() = 0; virtual std::shared_ptr<AbstractNetwork> network() = 0;
virtual std::shared_ptr<aby3::CircuitContext> mpc_context() = 0; virtual std::shared_ptr<AbstractContext> mpc_context() = 0;
private: private:
const std::string _name; const std::string _name;
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
#include "aby3_protocol.h" #include "aby3_protocol.h"
#include "mpc_config.h" #include "mpc_config.h"
#include "mpc_protocol_factory.h" #include "mpc_protocol_factory.h"
#include "core/privc3/circuit_context.h"
#include "gtest/gtest.h" #include "gtest/gtest.h"
namespace paddle { namespace paddle {
......
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "core/privc3/circuit_context.h" #include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -32,7 +32,7 @@ public: ...@@ -32,7 +32,7 @@ public:
PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(),
"Mpc protocol is not yet initialized in executor"); "Mpc protocol is not yet initialized in executor");
std::shared_ptr<aby3::CircuitContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); std::shared_ptr<mpc::AbstractContext> mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context());
mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx,
[&] { ComputeImpl(ctx); }); [&] { ComputeImpl(ctx); });
} }
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class PrivCContext : public AbstractContext {
public:
PrivCContext(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block) {
init(party, network, g_zero_block, seed);
}
PrivCContext(const PrivCContext &other) = delete;
PrivCContext &operator=(const PrivCContext &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_party(party);
set_network(network);
set_num_party(2);
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
// seed2 is private
set_random_seed(seed2, 2);
}
};
} // namespace aby3
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <algorithm>
#include <memory>
#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h"
#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h"
#include "prng_utils.h"
namespace aby3 {
using AbstractNetwork = paddle::mpc::AbstractNetwork;
using AbstractContext = paddle::mpc::AbstractContext;
class ABY3Context : public AbstractContext {
public:
ABY3Context(size_t party, std::shared_ptr<AbstractNetwork> network,
const block &seed = g_zero_block,
const block &seed2 = g_zero_block) {
init(party, network, seed, seed2);
}
ABY3Context(const ABY3Context &other) = delete;
ABY3Context &operator=(const ABY3Context &other) = delete;
void init(size_t party, std::shared_ptr<AbstractNetwork> network, block seed,
block seed2) override {
set_party(party);
set_network(network);
set_num_party(3);
if (psi::equals(seed, psi::g_zero_block)) {
seed = psi::block_from_dev_urandom();
}
if (psi::equals(seed2, psi::g_zero_block)) {
seed2 = psi::block_from_dev_urandom();
}
set_random_seed(seed, 0);
// seed2 is private
set_random_seed(seed2, 2);
// 3 for 3-party computation
size_t party_pre = pre_party();
size_t party_next = next_party();
if (party == 1) {
block recv_seed = this->network()->template recv<block>(party_next);
this->network()->template send(party_pre, seed);
seed = recv_seed;
} else {
this->network()->template send(party_pre, seed);
seed = this->network()->template recv<block>(party_next);
}
set_random_seed(seed, 1);
}
};
} // namespace aby3
...@@ -116,7 +116,7 @@ public: ...@@ -116,7 +116,7 @@ public:
void bit_extract(size_t i, BooleanTensor *ret) const; void bit_extract(size_t i, BooleanTensor *ret) const;
private: private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() { static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx(); return paddle::mpc::ContextHolder::mpc_ctx();
} }
......
...@@ -258,7 +258,7 @@ void BooleanTensor<T>::ppa(const BooleanTensor *rhs, BooleanTensor *ret, ...@@ -258,7 +258,7 @@ void BooleanTensor<T>::ppa(const BooleanTensor *rhs, BooleanTensor *ret,
} }
template <typename T, size_t N> template <typename T, size_t N>
void a2b(CircuitContext *aby3_ctx, TensorAdapterFactory *tensor_factory, void a2b(AbstractContext *aby3_ctx, TensorAdapterFactory *tensor_factory,
const FixedPointTensor<T, N> *a, BooleanTensor<T> *b, size_t n_bits) { const FixedPointTensor<T, N> *a, BooleanTensor<T> *b, size_t n_bits) {
std::shared_ptr<TensorAdapter<T>> tmp[4]; std::shared_ptr<TensorAdapter<T>> tmp[4];
......
...@@ -27,19 +27,20 @@ ...@@ -27,19 +27,20 @@
#include "boolean_tensor.h" #include "boolean_tensor.h"
#include "fixedpoint_tensor.h" #include "fixedpoint_tensor.h"
#include "paddle_tensor.h" #include "paddle_tensor.h"
#include "circuit_context.h" #include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
namespace aby3 { namespace aby3 {
using paddle::framework::Tensor; using paddle::framework::Tensor;
using AbstractContext = paddle::mpc::AbstractContext;
class BooleanTensorTest : public ::testing::Test { class BooleanTensorTest : public ::testing::Test {
public: public:
paddle::platform::CPUDeviceContext _cpu_ctx; paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx; std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3]; std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store; std::shared_ptr<gloo::rendezvous::HashStore> _store;
...@@ -83,7 +84,7 @@ public: ...@@ -83,7 +84,7 @@ public:
void gen_mpc_ctx(size_t idx) { void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx); auto net = gen_network(idx);
net->init(); net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net); _mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
} }
std::shared_ptr<TensorAdapter<int64_t>> gen1() { std::shared_ptr<TensorAdapter<int64_t>> gen1() {
......
...@@ -17,7 +17,7 @@ ...@@ -17,7 +17,7 @@
#include <vector> #include <vector>
#include "boolean_tensor.h" #include "boolean_tensor.h"
#include "circuit_context.h" #include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "paddle_tensor.h" #include "paddle_tensor.h"
...@@ -139,7 +139,7 @@ public: ...@@ -139,7 +139,7 @@ public:
void neq(const CTensor<T, N1...> *rhs, BooleanTensor<T> *ret) const; void neq(const CTensor<T, N1...> *rhs, BooleanTensor<T> *ret) const;
private: private:
static inline std::shared_ptr<CircuitContext> aby3_ctx() { static inline std::shared_ptr<AbstractContext> aby3_ctx() {
return paddle::mpc::ContextHolder::mpc_ctx(); return paddle::mpc::ContextHolder::mpc_ctx();
} }
......
...@@ -16,21 +16,23 @@ limitations under the License. */ ...@@ -16,21 +16,23 @@ limitations under the License. */
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "aby3_context.h"
#include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
#include "fixedpoint_tensor.h" #include "fixedpoint_tensor.h"
namespace aby3 { namespace aby3 {
using g_ctx_holder = paddle::mpc::ContextHolder; using g_ctx_holder = paddle::mpc::ContextHolder;
using Fix64N16 = FixedPointTensor<int64_t, 16>; using Fix64N16 = FixedPointTensor<int64_t, 16>;
using AbstractContext = paddle::mpc::AbstractContext;
class FixedTensorTest : public ::testing::Test { class FixedTensorTest : public ::testing::Test {
public: public:
paddle::platform::CPUDeviceContext _cpu_ctx; paddle::platform::CPUDeviceContext _cpu_ctx;
std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx; std::shared_ptr<paddle::framework::ExecutionContext> _exec_ctx;
std::shared_ptr<CircuitContext> _mpc_ctx[3]; std::shared_ptr<AbstractContext> _mpc_ctx[3];
std::shared_ptr<gloo::rendezvous::HashStore> _store; std::shared_ptr<gloo::rendezvous::HashStore> _store;
std::thread _t[3]; std::thread _t[3];
std::shared_ptr<TensorAdapterFactory> _s_tensor_factory; std::shared_ptr<TensorAdapterFactory> _s_tensor_factory;
...@@ -67,7 +69,7 @@ public: ...@@ -67,7 +69,7 @@ public:
void gen_mpc_ctx(size_t idx) { void gen_mpc_ctx(size_t idx) {
auto net = gen_network(idx); auto net = gen_network(idx);
net->init(); net->init();
_mpc_ctx[idx] = std::make_shared<CircuitContext>(idx, net); _mpc_ctx[idx] = std::make_shared<ABY3Context>(idx, net);
} }
std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) { std::shared_ptr<TensorAdapter<int64_t>> gen(std::vector<size_t> shape) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册