// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include "boolean_tensor.h" #include "circuit_context.h" #include "core/paddlefl_mpc/mpc_protocol/context_holder.h" #include "paddle_tensor.h" namespace aby3 { template class FixedPointTensor { public: explicit FixedPointTensor(TensorAdapter *share_tensor[2]); explicit FixedPointTensor(TensorAdapter *share_tensor_0, TensorAdapter *share_tensor_1); ~FixedPointTensor(){}; // get mutable shape of tensor TensorAdapter *mutable_share(size_t idx); const TensorAdapter *share(size_t idx) const; size_t numel() const { return _share[0]->numel(); } // reveal fixedpointtensor to one party void reveal_to_one(size_t party, TensorAdapter *ret) const; // reveal fixedpointtensor to all parties void reveal(TensorAdapter *ret) const; const std::vector shape() const; // convert TensorAdapter to shares static void share(const TensorAdapter *input, TensorAdapter *output_shares[3], block seed = g_zero_block); // element-wise add with FixedPointTensor void add(const FixedPointTensor *rhs, FixedPointTensor *ret) const; // element-wise add with TensorAdapter void add(const TensorAdapter *rhs, FixedPointTensor *ret) const; // element-wise sub with FixedPointTensor void sub(const FixedPointTensor *rhs, FixedPointTensor *ret) const; // element-wise sub with TensorAdapter void sub(const TensorAdapter *rhs, FixedPointTensor *ret) const; // negative void negative(FixedPointTensor *ret) const; // element-wise mul with FixedPointTensor using truncate1 void mul(const FixedPointTensor *rhs, FixedPointTensor *ret) const; // element-wise mul with TensorAdapter void mul(const TensorAdapter *rhs, FixedPointTensor *ret) const; // div by TensorAdapter void div(const TensorAdapter *rhs, FixedPointTensor *ret) const; // element-wise mul, use trunc2 void mul2(const FixedPointTensor *rhs, FixedPointTensor *ret) const; // dot_mul template