fixedpoint_tensor.h 8.8 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "boolean_tensor.h"
20
#include "aby3_context.h"
J
update  
jingqinghe 已提交
21
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
J
jingqinghe 已提交
22
#include "paddle_tensor.h"
J
jhjiangcs 已提交
23 24
#include "boolean_tensor.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
J
jingqinghe 已提交
25 26 27

namespace aby3 {

J
jhjiangcs 已提交
28 29
template<typename T, size_t N>
class FixedPointTensor {
J
jingqinghe 已提交
30 31

public:
J
jhjiangcs 已提交
32 33 34 35 36 37
    explicit FixedPointTensor(TensorAdapter<T>* share_tensor[2]);

    explicit FixedPointTensor(TensorAdapter<T>* share_tensor_0,
                              TensorAdapter<T>* share_tensor_1);

    ~FixedPointTensor() {};
J
jingqinghe 已提交
38

J
jhjiangcs 已提交
39 40
    //get mutable shape of tensor
    TensorAdapter<T>* mutable_share(size_t idx);
J
jingqinghe 已提交
41

J
jhjiangcs 已提交
42
    const TensorAdapter<T>* share(size_t idx) const;
J
jingqinghe 已提交
43

J
jhjiangcs 已提交
44 45 46 47 48 49
    size_t numel() const {
        return _share[0]->numel();
    }

    // reveal fixedpointtensor to one party
    void reveal_to_one(size_t party, TensorAdapter<T>* ret) const;
J
jingqinghe 已提交
50

J
jhjiangcs 已提交
51 52
    // reveal fixedpointtensor to all parties
    void reveal(TensorAdapter<T>* ret) const;
J
jingqinghe 已提交
53

J
jhjiangcs 已提交
54
    const std::vector<size_t> shape() const;
J
jingqinghe 已提交
55

J
jhjiangcs 已提交
56 57 58 59
    //convert TensorAdapter to shares
    static void share(const TensorAdapter<T>* input,
                      TensorAdapter<T>* output_shares[3],
                      block seed = g_zero_block);
J
jingqinghe 已提交
60

J
jhjiangcs 已提交
61 62
    // element-wise add with FixedPointTensor
    void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
63

J
jhjiangcs 已提交
64
    // element-wise add with TensorAdapter
J
jingqinghe 已提交
65

J
jhjiangcs 已提交
66
    void add(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
67

J
jhjiangcs 已提交
68 69
    // element-wise sub with FixedPointTensor
    void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
70

J
jhjiangcs 已提交
71 72
    // element-wise sub with TensorAdapter
    void sub(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
73

J
jhjiangcs 已提交
74 75
    // negative
    void negative(FixedPointTensor* ret) const;
J
jingqinghe 已提交
76

J
jhjiangcs 已提交
77 78
    // element-wise mul with FixedPointTensor using truncate1
    void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
79

J
jhjiangcs 已提交
80 81
    // element-wise mul with TensorAdapter
    void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
82

J
jhjiangcs 已提交
83 84
    // div by TensorAdapter
    void div(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
85

J
jhjiangcs 已提交
86 87 88 89
    // div by FixedPointedTensor
    // TODO@yqy : not surport operator rhs <= 0 now
    void div(const FixedPointTensor* rhs, FixedPointTensor* ret,
             size_t iter = 16, double x0 = pow(2, -15)) const;
J
jingqinghe 已提交
90

J
jhjiangcs 已提交
91 92 93 94
    // long div by boolean circuit
    // res_int_len: estimated bit len of the integer part of result
    void long_div(const FixedPointTensor* rhs,
                  FixedPointTensor* ret, size_t res_int_len = 20) const;
J
jingqinghe 已提交
95

J
jhjiangcs 已提交
96 97
    void inverse_square_root(FixedPointTensor* ret,
                             size_t iter = 16, double x0 = 0x1p-10) const;
J
jingqinghe 已提交
98

J
jhjiangcs 已提交
99 100 101 102
    // dot_mul
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void dot_mul(const CTensor<T, N1...>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
103

J
jhjiangcs 已提交
104 105
    //sum all element
    void sum(FixedPointTensor* ret) const;
J
jingqinghe 已提交
106

J
jhjiangcs 已提交
107 108
    // mat_mul with FixedPointTensor
    void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
109

J
jhjiangcs 已提交
110 111
    // mat_mul with TensorAdapter
    void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
112

J
jhjiangcs 已提交
113 114 115
    // exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n
    // where n = 2^ite
    void exp(FixedPointTensor* ret, size_t iter = 8) const;
J
jingqinghe 已提交
116

J
jhjiangcs 已提交
117 118
    // element-wise relu
    void relu(FixedPointTensor* ret) const;
J
jingqinghe 已提交
119

J
jhjiangcs 已提交
120 121
    // element-wise relu with relu'
    void relu_with_derivative(FixedPointTensor* ret, BooleanTensor<T>* derivative) const;
J
jingqinghe 已提交
122

J
jhjiangcs 已提交
123 124
    // element-wise sigmoid using 3 piecewise polynomials
    void sigmoid(FixedPointTensor* ret) const;
J
jingqinghe 已提交
125

J
jhjiangcs 已提交
126 127 128 129
    // element-wise sigmoid using 5 pieces polynomial
    // see paper [Privacy-preserving collaborative machine learning
    //            on genomic data using TensorFlow]
    void sigmoid_enhanced(FixedPointTensor* ret) const;
J
jingqinghe 已提交
130

J
jhjiangcs 已提交
131 132 133
    // element-wise sigmoid using Chebyshev polynomial approximation
    // implemented with ref to tfe[https://github.com/tf-encrypted/tf-encrypted]
    void sigmoid_chebyshev(FixedPointTensor* ret) const;
J
jingqinghe 已提交
134

J
jhjiangcs 已提交
135 136 137 138
    // softmax axis = -1
    void softmax(FixedPointTensor* ret,
                 bool use_relu = false,
                 bool use_long_div = true) const;
J
jingqinghe 已提交
139

J
jhjiangcs 已提交
140 141 142
    // element-wise polynomial
    void polynomial(const TensorAdapter<T>* coeff,
                    FixedPointTensor* ret) const;
J
jingqinghe 已提交
143

J
jhjiangcs 已提交
144 145 146 147 148
    // element-wise piecewise polynomial
    void polynomial_piecewise(
                const TensorAdapter<T>* coeff,
                const TensorAdapter<T>* break_point,
                FixedPointTensor* ret) const;
J
jingqinghe 已提交
149

J
jhjiangcs 已提交
150 151 152 153 154
    // element-wise compare
    // <
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void lt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
155

J
jhjiangcs 已提交
156 157 158 159
    // <=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void leq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
160

J
jhjiangcs 已提交
161 162 163 164
    // >
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void gt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;
J
jingqinghe 已提交
165

J
jhjiangcs 已提交
166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194
    // >=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void geq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // ==
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void eq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // !=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void neq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // element-wise max
    // if not null, cmp stores true if rhs is bigger
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void max(const CTensor<T, N1...>* rhs,
             FixedPointTensor* ret,
             BooleanTensor<T>* cmp = nullptr) const;

    // for tensor with shape like [k, n, m, ...]
    // ret shape is [1, n, m, ...], in which every element is largest of k elements
    // pos shape is [k, n, m, ...], each col of pos is an one-hot tensor
    // which indicating the max element's position
    void max_pooling(FixedPointTensor* ret,
                     BooleanTensor<T>* pos = nullptr) const;
J
jingqinghe 已提交
195

196 197 198
    static void truncate(const FixedPointTensor* op, FixedPointTensor* ret,
                        size_t scaling_factor);

J
jingqinghe 已提交
199
private:
200 201
    static inline std::shared_ptr<AbstractContext> aby3_ctx() {
      return paddle::mpc::ContextHolder::mpc_ctx();
J
jhjiangcs 已提交
202 203 204 205 206 207 208 209 210 211 212
    }

    static inline std::shared_ptr<TensorAdapterFactory> tensor_factory() {
        return paddle::mpc::ContextHolder::tensor_factory();
    }

    template<typename MulFunc>
    static void mul_trunc(const FixedPointTensor<T, N>* lhs,
                          const FixedPointTensor<T, N>* rhs,
                          FixedPointTensor<T, N>* ret,
                          MulFunc mul_func);
J
jingqinghe 已提交
213

J
jhjiangcs 已提交
214 215 216 217
    // truncate3 protocol can avoid losing msb error when truncate
    // with acceptable security compromise
    static void truncate3(const FixedPointTensor* op, FixedPointTensor* ret,
                          size_t scaling_factor);
J
jingqinghe 已提交
218

J
jhjiangcs 已提交
219 220 221
    // reduce last dim
    static void reduce(FixedPointTensor<T, N>* input,
                       FixedPointTensor<T, N>* ret);
J
jingqinghe 已提交
222

J
jhjiangcs 已提交
223 224 225
    static size_t party() {
        return aby3_ctx()->party();
    }
J
jingqinghe 已提交
226

J
jhjiangcs 已提交
227 228 229
    static size_t pre_party() {
        return aby3_ctx()->pre_party();
    }
J
jingqinghe 已提交
230

J
jhjiangcs 已提交
231 232 233
    static size_t next_party() {
        return aby3_ctx()->next_party();
    }
J
jingqinghe 已提交
234

J
jhjiangcs 已提交
235 236 237 238 239 240 241 242 243
    static void reshare(const TensorAdapter<T>* send_val,
                 TensorAdapter<T>* recv_val) {
        if (party() == 0) {
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
            aby3_ctx()->network()->template send(pre_party(), *send_val);
        } else {
            aby3_ctx()->network()->template send(pre_party(), *send_val);
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
        }
J
jingqinghe 已提交
244 245
    }

J
jhjiangcs 已提交
246 247 248 249 250 251 252 253 254
    static void reciprocal(const FixedPointTensor* op, FixedPointTensor* ret,
                                                      size_t iter, double x0);

    static void inverse_square_root(const FixedPointTensor* op,
                                    FixedPointTensor* ret,
                                    size_t iter, double x0);

    TensorAdapter<T>* _share[2];

J
jingqinghe 已提交
255 256
};

J
jhjiangcs 已提交
257
} //namespace aby3
J
jingqinghe 已提交
258 259

#include "fixedpoint_tensor_imp.h"