fixedpoint_tensor.h 9.4 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

19
#include "aby3_context.h"
J
update  
jingqinghe 已提交
20
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
J
jingqinghe 已提交
21
#include "paddle_tensor.h"
J
jhjiangcs 已提交
22
#include "boolean_tensor.h"
J
jingqinghe 已提交
23 24 25

namespace aby3 {

J
jhjiangcs 已提交
26 27
template<typename T, size_t N>
class FixedPointTensor {
J
jingqinghe 已提交
28 29

public:
J
jhjiangcs 已提交
30 31 32 33 34 35
    explicit FixedPointTensor(TensorAdapter<T>* share_tensor[2]);

    explicit FixedPointTensor(TensorAdapter<T>* share_tensor_0,
                              TensorAdapter<T>* share_tensor_1);

    ~FixedPointTensor() {};
J
jingqinghe 已提交
36

J
jhjiangcs 已提交
37 38
    //get mutable shape of tensor
    TensorAdapter<T>* mutable_share(size_t idx);
J
jingqinghe 已提交
39

J
jhjiangcs 已提交
40
    const TensorAdapter<T>* share(size_t idx) const;
J
jingqinghe 已提交
41

J
jhjiangcs 已提交
42 43 44 45 46 47
    size_t numel() const {
        return _share[0]->numel();
    }

    // reveal fixedpointtensor to one party
    void reveal_to_one(size_t party, TensorAdapter<T>* ret) const;
J
jingqinghe 已提交
48

J
jhjiangcs 已提交
49 50
    // reveal fixedpointtensor to all parties
    void reveal(TensorAdapter<T>* ret) const;
J
jingqinghe 已提交
51

J
jhjiangcs 已提交
52
    const std::vector<size_t> shape() const;
J
jingqinghe 已提交
53

J
jhjiangcs 已提交
54 55 56 57
    //convert TensorAdapter to shares
    static void share(const TensorAdapter<T>* input,
                      TensorAdapter<T>* output_shares[3],
                      block seed = g_zero_block);
J
jingqinghe 已提交
58

J
jhjiangcs 已提交
59 60
    // element-wise add with FixedPointTensor
    void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
61

J
jhjiangcs 已提交
62
    // element-wise add with TensorAdapter
J
jingqinghe 已提交
63

J
jhjiangcs 已提交
64
    void add(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
65

J
jhjiangcs 已提交
66 67
    // element-wise sub with FixedPointTensor
    void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
68

J
jhjiangcs 已提交
69 70
    // element-wise sub with TensorAdapter
    void sub(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
71

J
jhjiangcs 已提交
72 73
    // negative
    void negative(FixedPointTensor* ret) const;
J
jingqinghe 已提交
74

J
jhjiangcs 已提交
75 76
    // element-wise mul with FixedPointTensor using truncate1
    void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
77

J
jhjiangcs 已提交
78 79
    // element-wise mul with TensorAdapter
    void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
80

J
jhjiangcs 已提交
81 82
    // div by TensorAdapter
    void div(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
83

J
jhjiangcs 已提交
84 85 86 87
    // div by FixedPointedTensor
    // TODO@yqy : not surport operator rhs <= 0 now
    void div(const FixedPointTensor* rhs, FixedPointTensor* ret,
             size_t iter = 16, double x0 = pow(2, -15)) const;
J
jingqinghe 已提交
88

J
jhjiangcs 已提交
89 90 91 92
    // long div by boolean circuit
    // res_int_len: estimated bit len of the integer part of result
    void long_div(const FixedPointTensor* rhs,
                  FixedPointTensor* ret, size_t res_int_len = 20) const;
J
jingqinghe 已提交
93

J
jhjiangcs 已提交
94 95
    void inverse_square_root(FixedPointTensor* ret,
                             size_t iter = 16, double x0 = 0x1p-10) const;
J
jingqinghe 已提交
96

J
jhjiangcs 已提交
97 98 99 100
    // dot_mul
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void dot_mul(const CTensor<T, N1...>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
101

J
jhjiangcs 已提交
102 103
    //sum all element
    void sum(FixedPointTensor* ret) const;
J
jingqinghe 已提交
104

J
jhjiangcs 已提交
105 106
    // mat_mul with FixedPointTensor
    void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
107

J
jhjiangcs 已提交
108 109
    // mat_mul with TensorAdapter
    void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
110

J
jhjiangcs 已提交
111 112 113
    // exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n
    // where n = 2^ite
    void exp(FixedPointTensor* ret, size_t iter = 8) const;
J
jingqinghe 已提交
114

J
jhjiangcs 已提交
115 116
    // element-wise relu
    void relu(FixedPointTensor* ret) const;
J
jingqinghe 已提交
117

J
jhjiangcs 已提交
118 119
    // element-wise relu with relu'
    void relu_with_derivative(FixedPointTensor* ret, BooleanTensor<T>* derivative) const;
J
jingqinghe 已提交
120

J
jhjiangcs 已提交
121 122
    // element-wise sigmoid using 3 piecewise polynomials
    void sigmoid(FixedPointTensor* ret) const;
J
jingqinghe 已提交
123

J
jhjiangcs 已提交
124 125 126 127
    // element-wise sigmoid using 5 pieces polynomial
    // see paper [Privacy-preserving collaborative machine learning
    //            on genomic data using TensorFlow]
    void sigmoid_enhanced(FixedPointTensor* ret) const;
J
jingqinghe 已提交
128

J
jhjiangcs 已提交
129 130 131
    // element-wise sigmoid using Chebyshev polynomial approximation
    // implemented with ref to tfe[https://github.com/tf-encrypted/tf-encrypted]
    void sigmoid_chebyshev(FixedPointTensor* ret) const;
J
jingqinghe 已提交
132

J
jhjiangcs 已提交
133 134 135 136
    // softmax axis = -1
    void softmax(FixedPointTensor* ret,
                 bool use_relu = false,
                 bool use_long_div = true) const;
J
jingqinghe 已提交
137

J
jhjiangcs 已提交
138 139 140
    // element-wise polynomial
    void polynomial(const TensorAdapter<T>* coeff,
                    FixedPointTensor* ret) const;
J
jingqinghe 已提交
141

J
jhjiangcs 已提交
142 143 144 145 146
    // element-wise piecewise polynomial
    void polynomial_piecewise(
                const TensorAdapter<T>* coeff,
                const TensorAdapter<T>* break_point,
                FixedPointTensor* ret) const;
J
jingqinghe 已提交
147

J
jhjiangcs 已提交
148 149 150 151 152
    // element-wise compare
    // <
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void lt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
153

J
jhjiangcs 已提交
154 155 156 157
    // <=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void leq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
158

J
jhjiangcs 已提交
159 160 161 162
    // >
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void gt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
163

J
jhjiangcs 已提交
164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    // >=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void geq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // ==
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void eq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // !=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void neq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // element-wise max
    // if not null, cmp stores true if rhs is bigger
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void max(const CTensor<T, N1...>* rhs,
             FixedPointTensor* ret,
             BooleanTensor<T>* cmp = nullptr) const;

    // for tensor with shape like [k, n, m, ...]
    // ret shape is [1, n, m, ...], in which every element is largest of k elements
    // pos shape is [k, n, m, ...], each col of pos is an one-hot tensor
    // which indicating the max element's position
    void max_pooling(FixedPointTensor* ret,
                     BooleanTensor<T>* pos = nullptr) const;
J
jingqinghe 已提交
193

H
He, Kai 已提交
194 195 196 197 198 199 200 201 202 203 204 205 206 207
    // only support pred for 1 in binary classification for now
    static void preds_to_indices(const FixedPointTensor* preds,
                                 FixedPointTensor* indices,
                                 float threshold = 0.5);

    static void calc_tp_fp_fn(const FixedPointTensor* indices,
                              const FixedPointTensor* labels,
                              FixedPointTensor* tp_fp_fn);

    // clac precision_recall f1_score
    // result is a plaintext fixed-point tensor, shape is [3]
    static void calc_precision_recall(const FixedPointTensor* tp_fp_fn,
                                      TensorAdapter<T>* ret);

208 209 210
    static void truncate(const FixedPointTensor* op, FixedPointTensor* ret,
                        size_t scaling_factor);

J
jingqinghe 已提交
211
private:
212 213
    static inline std::shared_ptr<AbstractContext> aby3_ctx() {
      return paddle::mpc::ContextHolder::mpc_ctx();
J
jhjiangcs 已提交
214 215 216 217 218 219 220 221 222 223 224
    }

    static inline std::shared_ptr<TensorAdapterFactory> tensor_factory() {
        return paddle::mpc::ContextHolder::tensor_factory();
    }

    template<typename MulFunc>
    static void mul_trunc(const FixedPointTensor<T, N>* lhs,
                          const FixedPointTensor<T, N>* rhs,
                          FixedPointTensor<T, N>* ret,
                          MulFunc mul_func);
J
jingqinghe 已提交
225

J
jhjiangcs 已提交
226 227 228 229
    // truncate3 protocol can avoid losing msb error when truncate
    // with acceptable security compromise
    static void truncate3(const FixedPointTensor* op, FixedPointTensor* ret,
                          size_t scaling_factor);
J
jingqinghe 已提交
230

J
jhjiangcs 已提交
231
    // reduce last dim
H
He, Kai 已提交
232
    static void reduce(const FixedPointTensor<T, N>* input,
J
jhjiangcs 已提交
233
                       FixedPointTensor<T, N>* ret);
J
jingqinghe 已提交
234

J
jhjiangcs 已提交
235 236 237
    static size_t party() {
        return aby3_ctx()->party();
    }
J
jingqinghe 已提交
238

J
jhjiangcs 已提交
239 240 241
    static size_t pre_party() {
        return aby3_ctx()->pre_party();
    }
J
jingqinghe 已提交
242

J
jhjiangcs 已提交
243 244 245
    static size_t next_party() {
        return aby3_ctx()->next_party();
    }
J
jingqinghe 已提交
246

J
jhjiangcs 已提交
247 248 249 250 251 252 253 254 255
    static void reshare(const TensorAdapter<T>* send_val,
                 TensorAdapter<T>* recv_val) {
        if (party() == 0) {
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
            aby3_ctx()->network()->template send(pre_party(), *send_val);
        } else {
            aby3_ctx()->network()->template send(pre_party(), *send_val);
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
        }
J
jingqinghe 已提交
256 257
    }

J
jhjiangcs 已提交
258 259 260 261 262 263 264 265 266
    static void reciprocal(const FixedPointTensor* op, FixedPointTensor* ret,
                                                      size_t iter, double x0);

    static void inverse_square_root(const FixedPointTensor* op,
                                    FixedPointTensor* ret,
                                    size_t iter, double x0);

    TensorAdapter<T>* _share[2];

J
jingqinghe 已提交
267 268
};

J
jhjiangcs 已提交
269
} //namespace aby3
J
jingqinghe 已提交
270 271

#include "fixedpoint_tensor_imp.h"