fixedpoint_tensor.h 8.7 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <vector>

#include "circuit_context.h"
#include "paddle_tensor.h"
J
jhjiangcs 已提交
21 22
#include "boolean_tensor.h"
#include "core/paddlefl_mpc/mpc_protocol/context_holder.h"
J
jingqinghe 已提交
23 24 25

namespace aby3 {

J
jhjiangcs 已提交
26 27
template<typename T, size_t N>
class FixedPointTensor {
J
jingqinghe 已提交
28 29

public:
J
jhjiangcs 已提交
30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67
    explicit FixedPointTensor(TensorAdapter<T>* share_tensor[2]);

    explicit FixedPointTensor(TensorAdapter<T>* share_tensor_0,
                              TensorAdapter<T>* share_tensor_1);

    ~FixedPointTensor() {};

    //get mutable shape of tensor
    TensorAdapter<T>* mutable_share(size_t idx);

    const TensorAdapter<T>* share(size_t idx) const;

    size_t numel() const {
        return _share[0]->numel();
    }

    // reveal fixedpointtensor to one party
    void reveal_to_one(size_t party, TensorAdapter<T>* ret) const;

    // reveal fixedpointtensor to all parties
    void reveal(TensorAdapter<T>* ret) const;

    const std::vector<size_t> shape() const;

    //convert TensorAdapter to shares
    static void share(const TensorAdapter<T>* input,
                      TensorAdapter<T>* output_shares[3],
                      block seed = g_zero_block);

    // element-wise add with FixedPointTensor
    void add(const FixedPointTensor* rhs, FixedPointTensor* ret) const;

    // element-wise add with TensorAdapter

    void add(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;

    // element-wise sub with FixedPointTensor
    void sub(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
68

J
jhjiangcs 已提交
69 70
    // element-wise sub with TensorAdapter
    void sub(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
71

J
jhjiangcs 已提交
72 73
    // negative
    void negative(FixedPointTensor* ret) const;
J
jingqinghe 已提交
74

J
jhjiangcs 已提交
75 76
    // element-wise mul with FixedPointTensor using truncate1
    void mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
77

J
jhjiangcs 已提交
78 79
    // element-wise mul with TensorAdapter
    void mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
80

J
jhjiangcs 已提交
81 82
    // div by TensorAdapter
    void div(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
83

J
jhjiangcs 已提交
84 85 86 87
    // div by FixedPointedTensor
    // TODO@yqy : not surport operator rhs <= 0 now
    void div(const FixedPointTensor* rhs, FixedPointTensor* ret,
             size_t iter = 16, double x0 = pow(2, -15)) const;
J
jingqinghe 已提交
88

J
jhjiangcs 已提交
89 90 91 92
    // long div by boolean circuit
    // res_int_len: estimated bit len of the integer part of result
    void long_div(const FixedPointTensor* rhs,
                  FixedPointTensor* ret, size_t res_int_len = 20) const;
J
jingqinghe 已提交
93

J
jhjiangcs 已提交
94 95
    void inverse_square_root(FixedPointTensor* ret,
                             size_t iter = 16, double x0 = 0x1p-10) const;
J
jingqinghe 已提交
96

J
jhjiangcs 已提交
97 98 99 100
    // dot_mul
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void dot_mul(const CTensor<T, N1...>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
101

J
jhjiangcs 已提交
102 103
    //sum all element
    void sum(FixedPointTensor* ret) const;
J
jingqinghe 已提交
104

J
jhjiangcs 已提交
105 106
    // mat_mul with FixedPointTensor
    void mat_mul(const FixedPointTensor* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
107

J
jhjiangcs 已提交
108 109
    // mat_mul with TensorAdapter
    void mat_mul(const TensorAdapter<T>* rhs, FixedPointTensor* ret) const;
J
jingqinghe 已提交
110

J
jhjiangcs 已提交
111 112 113
    // exp approximate: exp(x) = \lim_{n->inf} (1+x/n)^n
    // where n = 2^ite
    void exp(FixedPointTensor* ret, size_t iter = 8) const;
J
jingqinghe 已提交
114

J
jhjiangcs 已提交
115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192
    // element-wise relu
    void relu(FixedPointTensor* ret) const;

    // element-wise relu with relu'
    void relu_with_derivative(FixedPointTensor* ret, BooleanTensor<T>* derivative) const;

    // element-wise sigmoid using 3 piecewise polynomials
    void sigmoid(FixedPointTensor* ret) const;

    // element-wise sigmoid using 5 pieces polynomial
    // see paper [Privacy-preserving collaborative machine learning
    //            on genomic data using TensorFlow]
    void sigmoid_enhanced(FixedPointTensor* ret) const;

    // element-wise sigmoid using Chebyshev polynomial approximation
    // implemented with ref to tfe[https://github.com/tf-encrypted/tf-encrypted]
    void sigmoid_chebyshev(FixedPointTensor* ret) const;

    // softmax axis = -1
    void softmax(FixedPointTensor* ret,
                 bool use_relu = false,
                 bool use_long_div = true) const;

    // element-wise polynomial
    void polynomial(const TensorAdapter<T>* coeff,
                    FixedPointTensor* ret) const;

    // element-wise piecewise polynomial
    void polynomial_piecewise(
                const TensorAdapter<T>* coeff,
                const TensorAdapter<T>* break_point,
                FixedPointTensor* ret) const;

    // element-wise compare
    // <
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void lt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // <=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void leq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // >
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void gt(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // >=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void geq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // ==
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void eq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // !=
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void neq(const CTensor<T, N1...>* rhs, BooleanTensor<T>* ret) const;

    // element-wise max
    // if not null, cmp stores true if rhs is bigger
    template<template<typename U, size_t...> class CTensor,
            size_t... N1>
    void max(const CTensor<T, N1...>* rhs,
             FixedPointTensor* ret,
             BooleanTensor<T>* cmp = nullptr) const;

    // for tensor with shape like [k, n, m, ...]
    // ret shape is [1, n, m, ...], in which every element is largest of k elements
    // pos shape is [k, n, m, ...], each col of pos is an one-hot tensor
    // which indicating the max element's position
    void max_pooling(FixedPointTensor* ret,
                     BooleanTensor<T>* pos = nullptr) const;
J
jingqinghe 已提交
193

194 195 196
    static void truncate(const FixedPointTensor* op, FixedPointTensor* ret,
                        size_t scaling_factor);

J
jingqinghe 已提交
197 198
private:

J
jhjiangcs 已提交
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216
    static inline std::shared_ptr<CircuitContext> aby3_ctx() {
        return paddle::mpc::ContextHolder::mpc_ctx();
    }

    static inline std::shared_ptr<TensorAdapterFactory> tensor_factory() {
        return paddle::mpc::ContextHolder::tensor_factory();
    }

    template<typename MulFunc>
    static void mul_trunc(const FixedPointTensor<T, N>* lhs,
                          const FixedPointTensor<T, N>* rhs,
                          FixedPointTensor<T, N>* ret,
                          MulFunc mul_func);

    // truncate3 protocol can avoid losing msb error when truncate
    // with acceptable security compromise
    static void truncate3(const FixedPointTensor* op, FixedPointTensor* ret,
                          size_t scaling_factor);
J
jingqinghe 已提交
217

J
jhjiangcs 已提交
218 219 220
    // reduce last dim
    static void reduce(FixedPointTensor<T, N>* input,
                       FixedPointTensor<T, N>* ret);
J
jingqinghe 已提交
221

J
jhjiangcs 已提交
222 223 224
    static size_t party() {
        return aby3_ctx()->party();
    }
J
jingqinghe 已提交
225

J
jhjiangcs 已提交
226 227 228
    static size_t pre_party() {
        return aby3_ctx()->pre_party();
    }
J
jingqinghe 已提交
229

J
jhjiangcs 已提交
230 231 232
    static size_t next_party() {
        return aby3_ctx()->next_party();
    }
J
jingqinghe 已提交
233

J
jhjiangcs 已提交
234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253
    static void reshare(const TensorAdapter<T>* send_val,
                 TensorAdapter<T>* recv_val) {
        if (party() == 0) {
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
            aby3_ctx()->network()->template send(pre_party(), *send_val);
        } else {
            aby3_ctx()->network()->template send(pre_party(), *send_val);
            aby3_ctx()->network()->template recv(next_party(), *recv_val);
        }
    }

    static void reciprocal(const FixedPointTensor* op, FixedPointTensor* ret,
                                                      size_t iter, double x0);

    static void inverse_square_root(const FixedPointTensor* op,
                                    FixedPointTensor* ret,
                                    size_t iter, double x0);

    TensorAdapter<T>* _share[2];

J
jingqinghe 已提交
254 255
};

J
jhjiangcs 已提交
256
} //namespace aby3
J
jingqinghe 已提交
257 258

#include "fixedpoint_tensor_imp.h"