diff --git a/core/privc3/circuit_context.h b/core/paddlefl_mpc/mpc_protocol/abstract_context.h similarity index 68% rename from core/privc3/circuit_context.h rename to core/paddlefl_mpc/mpc_protocol/abstract_context.h index 26b7a7a8fd4ad54e625110ce63b265a962dd34ac..e144e205a244b0757b4c429356a7506eb43f7537 100644 --- a/core/privc3/circuit_context.h +++ b/core/paddlefl_mpc/mpc_protocol/abstract_context.h @@ -18,63 +18,46 @@ #include #include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" -#include "prng_utils.h" +#include "core/privc3/prng_utils.h" +#include "paddle/fluid/platform/enforce.h" -namespace aby3 { +namespace paddle { -using AbstractNetwork = paddle::mpc::AbstractNetwork; +namespace mpc { -class CircuitContext { +using block = psi::block; +using PseudorandomNumberGenerator = psi::PseudorandomNumberGenerator; + +class AbstractContext { public: - CircuitContext(size_t party, std::shared_ptr network, - const block &seed = g_zero_block, - const block &seed2 = g_zero_block) { +/* + AbstractContext(size_t party, std::shared_ptr network, + const block &seed = psi::g_zero_block, + const block &seed2 = psi::g_zero_block) { init(party, network, seed, seed2); } +*/ + AbstractContext() = default; + AbstractContext(const AbstractContext &other) = delete; - CircuitContext(const CircuitContext &other) = delete; - - CircuitContext &operator=(const CircuitContext &other) = delete; - - void init(size_t party, std::shared_ptr network, block seed, - block seed2) { - set_party(party); - set_network(network); - - if (equals(seed, g_zero_block)) { - seed = block_from_dev_urandom(); - } - - if (equals(seed2, g_zero_block)) { - seed2 = block_from_dev_urandom(); - } - set_random_seed(seed, 0); - // seed2 is private - set_random_seed(seed2, 2); - - // 3 for 3-party computation - size_t party_pre = (this->party() - 1 + 3) % 3; - size_t party_next = (this->party() + 1) % 3; - - if (party == 1) { - block recv_seed = this->network()->template recv(party_next); - this->network()->template send(party_pre, seed); - seed = recv_seed; - } else { - this->network()->template send(party_pre, seed); - seed = this->network()->template recv(party_next); - } + AbstractContext &operator=(const AbstractContext &other) = delete; - set_random_seed(seed, 1); - } + virtual void init(size_t party, std::shared_ptr network, block seed, + block seed2) = 0; void set_party(size_t party) { - if (party >= 3) { - // exception handling - } + PADDLE_ENFORCE_LT(party, _num_party, + "party idx should less than %d.", + _num_party); _party = party; } + void set_num_party(size_t num_party) { + PADDLE_ENFORCE_TRUE(num_party == 2 || num_party == 3, + "2 or 3 party protocol is supported."); + _num_party = num_party; + } + void set_network(std::shared_ptr network) { _network = network; } @@ -82,22 +65,24 @@ public: AbstractNetwork *network() { return _network.get(); } void set_random_seed(const block &seed, size_t idx) { - if (idx >= 3) { - // exception handling - } + PADDLE_ENFORCE_LE(idx, _num_party, + "prng idx should be less and equal to %d.", + _num_party); _prng[idx].set_seed(seed); } size_t party() const { return _party; } - size_t pre_party() const { return (_party + 3 - 1) % 3; } + size_t pre_party() const { return (_party + _num_party - 1) % _num_party; } - size_t next_party() const { return (_party + 1) % 3; } + size_t next_party() const { return (_party + 1) % _num_party; } template T gen_random(bool next) { return _prng[next].get(); } template class Tensor> void gen_random(Tensor &tensor, bool next) { + PADDLE_ENFORCE_EQ(_num_party, 3, + "`gen_random` API is for 3 party protocol."); std::for_each( tensor.data(), tensor.data() + tensor.numel(), [this, next](T &val) { val = this->template gen_random(next); }); @@ -113,11 +98,15 @@ public: } template T gen_zero_sharing_arithmetic() { + PADDLE_ENFORCE_EQ(_num_party, 3, + "`gen_zero_sharing_arithmetic` API is for 3 party protocol."); return _prng[0].get() - _prng[1].get(); } template class Tensor> void gen_zero_sharing_arithmetic(Tensor &tensor) { + PADDLE_ENFORCE_EQ(_num_party, 3, + "`gen_zero_sharing_arithmetic` API is for 3 party protocol."); std::for_each(tensor.data(), tensor.data() + tensor.numel(), [this](T &val) { val = this->template gen_zero_sharing_arithmetic(); @@ -125,11 +114,15 @@ public: } template T gen_zero_sharing_boolean() { + PADDLE_ENFORCE_EQ(_num_party, 3, + "`gen_zero_sharing_boolean` API is for 3 party protocol."); return _prng[0].get() ^ _prng[1].get(); } template class Tensor> void gen_zero_sharing_boolean(Tensor &tensor) { + PADDLE_ENFORCE_EQ(_num_party, 3, + "`gen_zero_sharing_boolean` API is for 3 party protocol."); std::for_each( tensor.data(), tensor.data() + tensor.numel(), [this](T &val) { val = this->template gen_zero_sharing_boolean(); }); @@ -140,6 +133,8 @@ public: const Tensor *choice, const Tensor *m[2], Tensor *buffer[2], Tensor *ret) { // TODO: check tensor shape equals + PADDLE_ENFORCE_EQ(_num_party, 3, + "`ot` API is for 3 party protocol."); const size_t numel = buffer[0]->numel(); if (party() == sender) { bool common = helper == next_party(); @@ -180,6 +175,7 @@ public: } private: + size_t _num_party; size_t _party; std::shared_ptr _network; @@ -187,4 +183,6 @@ private: PseudorandomNumberGenerator _prng[3]; }; -} // namespace aby3 +} // namespace mpc + +} //namespace paddle diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h index 9981b2a0087d6e7c2915f88d6cb9ad32b9e234a0..b47fec6aad82429c29f8222a9d4935f5e3c7abd2 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_operators.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_operators.h @@ -22,7 +22,7 @@ #include "mpc_operators.h" #include "paddle/fluid/framework/tensor.h" #include "core/privc3/boolean_tensor.h" -#include "core/privc3/circuit_context.h" +#include "core/privc3/aby3_context.h" #include "core/privc3/fixedpoint_tensor.h" #include "core/privc3/paddle_tensor.h" @@ -30,7 +30,7 @@ namespace paddle { namespace mpc { using paddle::framework::Tensor; -using aby3::CircuitContext; +using aby3::ABY3Context; // TODO: decide scaling factor const size_t ABY3_SCALING_FACTOR = 16; using FixedTensor = aby3::FixedPointTensor; diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc index a4cc939ab00b6d06558f72d8d4f6778ffd221cee..8ec5a1d239a7476beaaf515da1dbd1028a230615 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc +++ b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.cc @@ -48,7 +48,7 @@ void Aby3Protocol::init_with_store( mesh_net->init(); _network = std::move(mesh_net); - _circuit_ctx = std::make_shared(role, _network); + _circuit_ctx = std::make_shared(role, _network); _operators = std::make_shared(); _is_initialized = true; } @@ -63,7 +63,7 @@ std::shared_ptr Aby3Protocol::network() { return _network; } -std::shared_ptr Aby3Protocol::mpc_context() { +std::shared_ptr Aby3Protocol::mpc_context() { PADDLE_ENFORCE(_is_initialized, PROT_INIT_ERR); return _circuit_ctx; } diff --git a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h index 6940c558d4bcdf842cae8775ad8fa1f93670b3f2..5d2d7d73cbcbada189c0449ec5bd06b8c735b6e5 100644 --- a/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h +++ b/core/paddlefl_mpc/mpc_protocol/aby3_protocol.h @@ -24,12 +24,13 @@ #include "mesh_network.h" #include "mpc_operators.h" #include "mpc_protocol.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/privc3/aby3_context.h" namespace paddle { namespace mpc { -using CircuitContext = aby3::CircuitContext; +using ABY3Context = aby3::ABY3Context; class Aby3Protocol : public MpcProtocol { public: @@ -46,14 +47,14 @@ public: std::shared_ptr network() override; - std::shared_ptr mpc_context() override; + std::shared_ptr mpc_context() override; private: bool _is_initialized = false; const std::string PROT_INIT_ERR = "The protocol is not yet initialized."; std::shared_ptr _operators; std::shared_ptr _network; - std::shared_ptr _circuit_ctx; + std::shared_ptr _circuit_ctx; }; } // mpc diff --git a/core/paddlefl_mpc/mpc_protocol/context_holder.cc b/core/paddlefl_mpc/mpc_protocol/context_holder.cc index 9acf0fb72667c7da8b53b842b3189f26c337f923..3cbd8b8ce4c0979dbd7551581f930c897b0a0d74 100644 --- a/core/paddlefl_mpc/mpc_protocol/context_holder.cc +++ b/core/paddlefl_mpc/mpc_protocol/context_holder.cc @@ -24,7 +24,7 @@ namespace paddle { namespace mpc { -thread_local std::shared_ptr ContextHolder::current_mpc_ctx; +thread_local std::shared_ptr ContextHolder::current_mpc_ctx; thread_local const ExecutionContext *ContextHolder::current_exec_ctx; diff --git a/core/paddlefl_mpc/mpc_protocol/context_holder.h b/core/paddlefl_mpc/mpc_protocol/context_holder.h index a8c2d5f15764cc49463790761310d181abd04edf..710de07f189c2c63e926051fe0a38cb0a65212a4 100644 --- a/core/paddlefl_mpc/mpc_protocol/context_holder.h +++ b/core/paddlefl_mpc/mpc_protocol/context_holder.h @@ -22,20 +22,20 @@ #pragma once #include "paddle/fluid/framework/operator.h" -#include "core/privc3/circuit_context.h" +#include "core/privc3/aby3_context.h" #include "core/privc3/paddle_tensor.h" namespace paddle { namespace mpc { -using CircuitContext = aby3::CircuitContext; +using ABY3Context = aby3::ABY3Context; using ExecutionContext = paddle::framework::ExecutionContext; class ContextHolder { public: template static void run_with_context(const ExecutionContext *exec_ctx, - std::shared_ptr mpc_ctx, + std::shared_ptr mpc_ctx, Operation op) { // set new ctxs @@ -60,7 +60,7 @@ public: _s_current_tensor_factory = old_factory; } - static std::shared_ptr mpc_ctx() { return current_mpc_ctx; } + static std::shared_ptr mpc_ctx() { return current_mpc_ctx; } static const ExecutionContext *exec_ctx() { return current_exec_ctx; } @@ -77,7 +77,7 @@ public: } private: - thread_local static std::shared_ptr current_mpc_ctx; + thread_local static std::shared_ptr current_mpc_ctx; thread_local static const ExecutionContext *current_exec_ctx; diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc b/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc index dfb8f4a452dfcbec6ecfc707847cacb4db721558..68daf78b328c1c4624eed6b3aee9272f3bd8126a 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc +++ b/core/paddlefl_mpc/mpc_protocol/mpc_instance_test.cc @@ -19,7 +19,6 @@ #include "aby3_protocol.h" #include "mpc_protocol_factory.h" -#include "core/privc3/circuit_context.h" #include "gtest/gtest.h" namespace paddle { diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h b/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h index da1aa49de70bbd92dbbdfa60785602ce4397b47f..a5eeae6c26c5dcff53bf2de1bb0aedd505e6b9c0 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h +++ b/core/paddlefl_mpc/mpc_protocol/mpc_protocol.h @@ -21,7 +21,7 @@ #include "gloo/rendezvous/hash_store.h" #include "mpc_config.h" #include "mpc_operators.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" namespace paddle { namespace mpc { @@ -44,7 +44,7 @@ public: virtual std::shared_ptr network() = 0; - virtual std::shared_ptr mpc_context() = 0; + virtual std::shared_ptr mpc_context() = 0; private: const std::string _name; diff --git a/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc b/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc index 332536bceb08096b45a63e9f600d81732bfa7b67..199e046ebbc3d48d078dff60001ecc2cc2f8bc83 100644 --- a/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc +++ b/core/paddlefl_mpc/mpc_protocol/mpc_protocol_test.cc @@ -17,7 +17,6 @@ #include "aby3_protocol.h" #include "mpc_config.h" #include "mpc_protocol_factory.h" -#include "core/privc3/circuit_context.h" #include "gtest/gtest.h" namespace paddle { diff --git a/core/paddlefl_mpc/operators/mpc_op.h b/core/paddlefl_mpc/operators/mpc_op.h index 6cff543b0dbf03f13a9e8baeb83335986f5fe249..f3a17941463c0beb45b847dbbd3faa680bc3ca53 100644 --- a/core/paddlefl_mpc/operators/mpc_op.h +++ b/core/paddlefl_mpc/operators/mpc_op.h @@ -19,7 +19,7 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "core/paddlefl_mpc/mpc_protocol/mpc_instance.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" -#include "core/privc3/circuit_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" namespace paddle { namespace operators { @@ -32,7 +32,7 @@ public: PADDLE_ENFORCE_NOT_NULL(mpc::MpcInstance::mpc_instance()->mpc_protocol(), "Mpc protocol is not yet initialized in executor"); - std::shared_ptr mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); + std::shared_ptr mpc_ctx(mpc::MpcInstance::mpc_instance()->mpc_protocol()->mpc_context()); mpc::ContextHolder::template run_with_context<>(&ctx, mpc_ctx, [&] { ComputeImpl(ctx); }); } diff --git a/core/privc/privc_context.h b/core/privc/privc_context.h new file mode 100644 index 0000000000000000000000000000000000000000..5803b0fd8a204a6cdef9815bea38b7e7b1876566 --- /dev/null +++ b/core/privc/privc_context.h @@ -0,0 +1,54 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "prng_utils.h" + +namespace aby3 { + +using AbstractNetwork = paddle::mpc::AbstractNetwork; +using AbstractContext = paddle::mpc::AbstractContext; + +class PrivCContext : public AbstractContext { +public: + PrivCContext(size_t party, std::shared_ptr network, + const block &seed = g_zero_block) { + init(party, network, g_zero_block, seed); + } + + PrivCContext(const PrivCContext &other) = delete; + + PrivCContext &operator=(const PrivCContext &other) = delete; + + void init(size_t party, std::shared_ptr network, block seed, + block seed2) override { + set_party(party); + set_network(network); + set_num_party(2); + + if (psi::equals(seed2, psi::g_zero_block)) { + seed2 = psi::block_from_dev_urandom(); + } + // seed2 is private + set_random_seed(seed2, 2); + } +}; + +} // namespace aby3 diff --git a/core/privc3/aby3_context.h b/core/privc3/aby3_context.h new file mode 100644 index 0000000000000000000000000000000000000000..539535df8035e35a0cf74d73daaf8dfe35a419bc --- /dev/null +++ b/core/privc3/aby3_context.h @@ -0,0 +1,75 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +#pragma once + +#include +#include +#include + +#include "core/paddlefl_mpc/mpc_protocol/abstract_context.h" +#include "core/paddlefl_mpc/mpc_protocol/abstract_network.h" +#include "prng_utils.h" + +namespace aby3 { + +using AbstractNetwork = paddle::mpc::AbstractNetwork; +using AbstractContext = paddle::mpc::AbstractContext; + +class ABY3Context : public AbstractContext { +public: + ABY3Context(size_t party, std::shared_ptr network, + const block &seed = g_zero_block, + const block &seed2 = g_zero_block) { + init(party, network, seed, seed2); + } + + ABY3Context(const ABY3Context &other) = delete; + + ABY3Context &operator=(const ABY3Context &other) = delete; + + void init(size_t party, std::shared_ptr network, block seed, + block seed2) override { + set_party(party); + set_network(network); + set_num_party(3); + + if (psi::equals(seed, psi::g_zero_block)) { + seed = psi::block_from_dev_urandom(); + } + + if (psi::equals(seed2, psi::g_zero_block)) { + seed2 = psi::block_from_dev_urandom(); + } + set_random_seed(seed, 0); + // seed2 is private + set_random_seed(seed2, 2); + + // 3 for 3-party computation + size_t party_pre = pre_party(); + size_t party_next = next_party(); + + if (party == 1) { + block recv_seed = this->network()->template recv(party_next); + this->network()->template send(party_pre, seed); + seed = recv_seed; + } else { + this->network()->template send(party_pre, seed); + seed = this->network()->template recv(party_next); + } + + set_random_seed(seed, 1); + } +}; + +} // namespace aby3 diff --git a/core/privc3/boolean_tensor.h b/core/privc3/boolean_tensor.h index 8a3de5595223d1fd043364cbfebaeb07d6d94523..c370531b871801cf3da5d3592ede245843b29e5e 100644 --- a/core/privc3/boolean_tensor.h +++ b/core/privc3/boolean_tensor.h @@ -116,7 +116,7 @@ public: void bit_extract(size_t i, BooleanTensor *ret) const; private: - static inline std::shared_ptr aby3_ctx() { + static inline std::shared_ptr aby3_ctx() { return paddle::mpc::ContextHolder::mpc_ctx(); } diff --git a/core/privc3/boolean_tensor_impl.h b/core/privc3/boolean_tensor_impl.h index 99c6c1f61de284f078c729f17ef87e7b84619f68..fc81826e9c76ecee1c0f9ba359102ad3246627ca 100644 --- a/core/privc3/boolean_tensor_impl.h +++ b/core/privc3/boolean_tensor_impl.h @@ -258,7 +258,7 @@ void BooleanTensor::ppa(const BooleanTensor *rhs, BooleanTensor *ret, } template -void a2b(CircuitContext *aby3_ctx, TensorAdapterFactory *tensor_factory, +void a2b(AbstractContext *aby3_ctx, TensorAdapterFactory *tensor_factory, const FixedPointTensor *a, BooleanTensor *b, size_t n_bits) { std::shared_ptr> tmp[4]; diff --git a/core/privc3/boolean_tensor_test.cc b/core/privc3/boolean_tensor_test.cc index c44d6cd0e2d92a44df4e407427aba9f63b09a304..4414cb4dc17ccc79f297239ab7098a02bd6775ec 100644 --- a/core/privc3/boolean_tensor_test.cc +++ b/core/privc3/boolean_tensor_test.cc @@ -27,19 +27,20 @@ #include "boolean_tensor.h" #include "fixedpoint_tensor.h" #include "paddle_tensor.h" -#include "circuit_context.h" +#include "aby3_context.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" namespace aby3 { using paddle::framework::Tensor; +using AbstractContext = paddle::mpc::AbstractContext; class BooleanTensorTest : public ::testing::Test { public: paddle::platform::CPUDeviceContext _cpu_ctx; std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _mpc_ctx[3]; std::shared_ptr _store; @@ -83,7 +84,7 @@ public: void gen_mpc_ctx(size_t idx) { auto net = gen_network(idx); net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); + _mpc_ctx[idx] = std::make_shared(idx, net); } std::shared_ptr> gen1() { diff --git a/core/privc3/fixedpoint_tensor.h b/core/privc3/fixedpoint_tensor.h index b3ca55f93a6115ebbe3266e343af8297c74b9a7b..73274ae258207d184907af8494f67890a283ce0f 100644 --- a/core/privc3/fixedpoint_tensor.h +++ b/core/privc3/fixedpoint_tensor.h @@ -17,7 +17,7 @@ #include #include "boolean_tensor.h" -#include "circuit_context.h" +#include "aby3_context.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "paddle_tensor.h" @@ -139,7 +139,7 @@ public: void neq(const CTensor *rhs, BooleanTensor *ret) const; private: - static inline std::shared_ptr aby3_ctx() { + static inline std::shared_ptr aby3_ctx() { return paddle::mpc::ContextHolder::mpc_ctx(); } diff --git a/core/privc3/fixedpoint_tensor_test.cc b/core/privc3/fixedpoint_tensor_test.cc index 0828594dc18805a2a5bff568a820d2e151b7f8d5..eea75bee74ffbb60fbadc703fa62f08f02217a31 100644 --- a/core/privc3/fixedpoint_tensor_test.cc +++ b/core/privc3/fixedpoint_tensor_test.cc @@ -16,21 +16,23 @@ limitations under the License. */ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" +#include "aby3_context.h" #include "core/paddlefl_mpc/mpc_protocol/mesh_network.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "fixedpoint_tensor.h" namespace aby3 { - using g_ctx_holder = paddle::mpc::ContextHolder; - using Fix64N16 = FixedPointTensor; +using g_ctx_holder = paddle::mpc::ContextHolder; +using Fix64N16 = FixedPointTensor; +using AbstractContext = paddle::mpc::AbstractContext; class FixedTensorTest : public ::testing::Test { public: paddle::platform::CPUDeviceContext _cpu_ctx; std::shared_ptr _exec_ctx; - std::shared_ptr _mpc_ctx[3]; + std::shared_ptr _mpc_ctx[3]; std::shared_ptr _store; std::thread _t[3]; std::shared_ptr _s_tensor_factory; @@ -67,7 +69,7 @@ public: void gen_mpc_ctx(size_t idx) { auto net = gen_network(idx); net->init(); - _mpc_ctx[idx] = std::make_shared(idx, net); + _mpc_ctx[idx] = std::make_shared(idx, net); } std::shared_ptr> gen(std::vector shape) {