tensor_adapter.h 3.3 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <vector>

namespace aby3 {

J
jhjiangcs 已提交
22 23
template <typename T>
class TensorAdapter {
J
jingqinghe 已提交
24 25
public:

J
jhjiangcs 已提交
26
    TensorAdapter() = default;
J
jingqinghe 已提交
27

J
jhjiangcs 已提交
28
    virtual ~TensorAdapter() = default;
J
jingqinghe 已提交
29

J
jhjiangcs 已提交
30
    virtual T* data() = 0;
J
jingqinghe 已提交
31

J
jhjiangcs 已提交
32
    virtual const T* data() const = 0;
J
jingqinghe 已提交
33

J
jhjiangcs 已提交
34
    virtual std::vector<size_t> shape() const = 0;
J
jingqinghe 已提交
35

J
jhjiangcs 已提交
36
    virtual void reshape(const std::vector<size_t>& shape) = 0;
J
jingqinghe 已提交
37

J
jhjiangcs 已提交
38
    virtual size_t numel() const = 0;
J
jingqinghe 已提交
39

J
jhjiangcs 已提交
40 41 42 43
    virtual void copy(TensorAdapter* ret) const {
        // TODO: check shape equals
        std::copy(data(), data() + numel(), ret->data());
    }
J
jingqinghe 已提交
44

J
jhjiangcs 已提交
45 46
    // element wise op, need operands' dim are same
    virtual void add(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
47

J
jhjiangcs 已提交
48 49
    // element wise op, need operands' dim are same
    virtual void sub(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
50

J
jhjiangcs 已提交
51
    virtual void negative(TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
52

J
jhjiangcs 已提交
53 54
    // element wise op, need operands' dim are same
    virtual void mul(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
55

J
jhjiangcs 已提交
56 57
    // element wise op, need operands' dim are same
    virtual void div(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
58

J
jhjiangcs 已提交
59 60
    // 2d matrix muliply,  need operands' rank are 2
    virtual void mat_mul(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
61

J
jhjiangcs 已提交
62 63
    // element wise op, need operands' dim are same
    virtual void bitwise_xor(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
64

J
jhjiangcs 已提交
65 66
    // element wise op, need operands' dim are same
    virtual void bitwise_and(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
67

J
jhjiangcs 已提交
68 69
    // element wise op, need operands' dim are same
    virtual void bitwise_or(const TensorAdapter* rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
70

J
jhjiangcs 已提交
71 72
    // element wise op, need operands' dim are same
    virtual void bitwise_not(TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
73

J
jhjiangcs 已提交
74
    virtual void lshift(size_t rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
75

J
jhjiangcs 已提交
76
    virtual void rshift(size_t rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
77

J
jhjiangcs 已提交
78
    virtual void logical_rshift(size_t rhs, TensorAdapter* ret) const = 0;
J
jingqinghe 已提交
79

J
jhjiangcs 已提交
80 81 82
    // when using an integer type T as fixed-point number
    // value of T val is interpreted as val / 2 ^ scaling_factor()
    virtual size_t scaling_factor() const = 0;
J
jingqinghe 已提交
83

J
jhjiangcs 已提交
84 85 86 87 88 89 90
    virtual size_t& scaling_factor() = 0;

    // slice by shape[0]
    // e.g. x.shape = [ 2, 3, 4]
    //      x.slice(1, 2, y)
    //      y.shape = [ 1, 3, 4]
    virtual void slice(size_t begin_idx, size_t end_idx, TensorAdapter* out) const = 0;
J
jingqinghe 已提交
91
};
J
jhjiangcs 已提交
92 93 94 95 96 97 98

template<typename T>
inline void assign_to_tensor(TensorAdapter<T>* input, T assign_num) {
    std::transform(input->data(), input->data() + input->numel(),
                   input->data(), [assign_num](T) { return assign_num; });
}

J
jingqinghe 已提交
99
} // namespace aby3