tensor_adapter.h 3.0 KB
Newer Older
J
jingqinghe 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#pragma once

#include <algorithm>
#include <vector>

namespace aby3 {

template <typename T> class TensorAdapter {
public:
  TensorAdapter() = default;

  virtual ~TensorAdapter() = default;

  virtual T *data() = 0;

  virtual const T *data() const = 0;

  virtual std::vector<size_t> shape() const = 0;

  virtual void reshape(const std::vector<size_t> &shape) = 0;

  virtual size_t numel() const = 0;

  virtual void copy(TensorAdapter *ret) const {
    // TODO: check shape equals
    std::copy(data(), data() + numel(), ret->data());
  }

  // element wise op, need operands' dim are same
  virtual void add(const TensorAdapter *rhs, TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void sub(const TensorAdapter *rhs, TensorAdapter *ret) const = 0;

  virtual void negative(TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void mul(const TensorAdapter *rhs, TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void div(const TensorAdapter *rhs, TensorAdapter *ret) const = 0;

  // 2d matrix muliply,  need operands' rank are 2
  virtual void mat_mul(const TensorAdapter *rhs, TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void bitwise_xor(const TensorAdapter *rhs,
                           TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void bitwise_and(const TensorAdapter *rhs,
                           TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void bitwise_or(const TensorAdapter *rhs,
                          TensorAdapter *ret) const = 0;

  // element wise op, need operands' dim are same
  virtual void bitwise_not(TensorAdapter *ret) const = 0;

  virtual void lshift(size_t rhs, TensorAdapter *ret) const = 0;

  virtual void rshift(size_t rhs, TensorAdapter *ret) const = 0;

  virtual void logical_rshift(size_t rhs, TensorAdapter *ret) const = 0;

  // when using an integer type T as fixed-point number
  // value of T val is interpreted as val / 2 ^ scaling_factor()
  virtual size_t scaling_factor() const = 0;

  virtual size_t &scaling_factor() = 0;

  // slice by shape[0]
  // e.g. x.shape = [ 2, 3, 4]
  //      x.slice(1, 2, y)
  //      y.shape = [ 1, 3, 4]
  virtual void slice(size_t begin_idx, size_t end_idx,
                     TensorAdapter *out) const = 0;
};
} // namespace aby3