// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "tensor_adapter.h" namespace aby3 { template class FixedPointTensor; template class BooleanTensor { public: BooleanTensor(TensorAdapter* share_tensor[2]); BooleanTensor(TensorAdapter* tensor0, TensorAdapter* tensor1); BooleanTensor(); // ABY3 a2b template BooleanTensor& operator=(const FixedPointTensor* other); ~BooleanTensor() {} //get share TensorAdapter* share(size_t idx); const TensorAdapter* share(size_t idx) const; // reveal boolean tensor to one party void reveal_to_one(size_t party_num, TensorAdapter* ret) const; // reveal boolean tensor to all parties void reveal(TensorAdapter* ret) const; const std::vector shape() const; size_t numel() const; // //convert TensorAdapter to shares // static void share(const TensorAdapter* input, // TensorAdapter* output_shares[3], // const std::string& rnd_seed = ""); // element-wise xor with BooleanTensor void bitwise_xor(const BooleanTensor* rhs, BooleanTensor* ret) const; // element-wise xor with TensorAdapter void bitwise_xor(const TensorAdapter* rhs, BooleanTensor* ret) const; // element-wise and with BooleanTensor void bitwise_and(const BooleanTensor* rhs, BooleanTensor* ret) const; // element-wise and with TensorAdapter void bitwise_and(const TensorAdapter* rhs, BooleanTensor* ret) const; // element-wise or // for both tensor adapter and boolean tensor template class CTensor> void bitwise_or(const CTensor* rhs, BooleanTensor* ret) const; // element-wise not void bitwise_not(BooleanTensor* ret) const; // element-wise lshift void lshift(size_t rhs, BooleanTensor* ret) const; // element-wise rshift void rshift(size_t rhs, BooleanTensor* ret) const; // element-wise logical_rshift void logical_rshift(size_t rhs, BooleanTensor* ret) const; // element-wise ppa with BooleanTensor void ppa(const BooleanTensor* rhs, BooleanTensor*ret , size_t nbits) const; // ABY3 b2a template void b2a(FixedPointTensor* ret) const; // ABY3 ab mul // this is an one-bit boolean share template void mul(const TensorAdapter* rhs, FixedPointTensor* ret, size_t rhs_party) const; // ABY3 ab mul // this is an one-bit boolean share template void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const; // extract to this template void bit_extract(size_t i, const FixedPointTensor* in); // extract from this to ret void bit_extract(size_t i, BooleanTensor* ret) const; // turn all 1s to 0s except the last 1 in a col // given cmp result from max pooling, generate one hot tensor // indicating which element is max // inplace transform void onehot_from_cmp(); private: static inline std::shared_ptr aby3_ctx() { return paddle::mpc::ContextHolder::mpc_ctx(); } static inline std::shared_ptr tensor_factory() { return paddle::mpc::ContextHolder::tensor_factory(); } size_t pre_party() const; size_t next_party() const; size_t party() const; private: TensorAdapter* _share[2]; }; } //namespace aby3 #include "boolean_tensor_impl.h"