ext_tensor.h 4.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */

#pragma once

#include <memory>
#include <vector>
19

20 21
#ifdef PADDLE_WITH_CUDA
#include <cuda_runtime.h>
22 23 24 25 26 27
using gpuStream_t = cudaStream_t;
#endif

#ifdef PADDLE_WITH_HIP
#include <hip/hip_runtime.h>
using gpuStream_t = hipStream_t;
28
#endif
29 30 31 32 33

#include "ext_dll_decl.h"  // NOLINT
#include "ext_dtype.h"     // NOLINT
#include "ext_place.h"     // NOLINT

34 35 36 37
namespace paddle {
namespace framework {
class CustomTensorUtils;
}  // namespace framework
38

39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56
class StreamWrapper {
 public:
  StreamWrapper() : stream_(nullptr), is_stream_set_(false) {}
  void SetStream(void* stream) {
    stream_ = stream;
    is_stream_set_ = true;
  }

  void* GetStream() const { return stream_; }

  bool IsStreamSet() const { return is_stream_set_; }

 private:
  //  cudaStream_t stream_;
  void* stream_;
  bool is_stream_set_;
};

57
class PD_DLL_DECL Tensor {
58
 public:
59
  /// \brief Construct a Tensor on target Place for CustomOp.
60 61
  /// Generally it's only used for user to create Tensor.
  explicit Tensor(const PlaceType& place);
62 63 64
  /// \brief Construct a Tensor on target Place with shape for CustomOp.
  /// Generally it's only used for user to create Tensor.
  Tensor(const PlaceType& place, const std::vector<int64_t>& shape);
65 66 67
  /// \brief Reset the shape of the tensor.
  /// Generally it's only used for the input tensor.
  /// Reshape must be called before calling
68
  /// mutable_data() or copy_to(const PlaceType& place)
69
  /// \param shape The shape to set.
C
Chen Weihang 已提交
70
  void reshape(const std::vector<int64_t>& shape);
71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95

  /// \brief Get the memory pointer in CPU or GPU with
  /// specific data type.
  /// Please Reshape the tensor first before call this.
  /// It's usually used to get input data pointer.
  /// \param place The place of the tensor this will
  /// override the original place of current tensor.
  template <typename T>
  T* mutable_data(const PlaceType& place);

  /// \brief Get the memory pointer in CPU or GPU with
  /// specific data type. Please Reshape the tensor
  /// first before call this.It's usually used to get
  /// input data pointer.
  template <typename T>
  T* mutable_data();

  /// \brief Get the memory pointer directly.
  /// It's usually used to get the output data pointer.
  /// \return The tensor data buffer pointer.
  template <typename T>
  T* data() const;

  /// \brief Copy the host memory to tensor data.
  /// It's usually used to set the input tensor data.
96 97
  /// \param PlaceType of target place, of which
  /// the tensor will copy to.
98
  template <typename T>
99
  Tensor copy_to(const PlaceType& place) const;
100

H
Hao Lin 已提交
101 102 103 104 105 106 107 108 109 110 111
  /// \brief Return a sub-tensor of the given tensor.
  /// It is usually used to extract a sub-tensor (which supports
  /// modifying the data of the original tensor) to perform further
  /// operations.
  /// \param begin_idx The index of the start row (inclusive) to slice.
  ///                  The index number begins from 0.
  /// \param end_idx  The index of the end row (exclusive) to slice.
  ///                 The index number begins from begin_idx + 1.
  /// \return The sliced tensor.
  Tensor slice(const int64_t begin_idx, const int64_t end_idx) const;

112
  /// \brief Return the shape of the Tensor.
C
Chen Weihang 已提交
113
  std::vector<int64_t> shape() const;
114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130

  /// \brief Return the data type of the tensor.
  /// It's usually used to get the output tensor data type.
  /// \return The data type of the tensor.
  DataType type() const;

  /// \brief Get the size of current tensor.
  /// Use this method to get the size of tensor
  /// \return int64_t.
  int64_t size() const;

  /// \brief Get the place of current tensor.
  /// Use this method to get the place of tensor
  /// \return Place.
  const PlaceType& place() const;

  /// \brief Cast datatype from one to another
131
  Tensor cast(const DataType& target_type) const;
132

133 134 135
  /// \brief Check Tensor is initialized
  bool is_initialized() const;

136
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
137
  /// \bref Get current stream of Tensor
138
  gpuStream_t stream() const;
139 140
#endif

141 142 143 144
 private:
  friend class framework::CustomTensorUtils;
  mutable std::shared_ptr<void> tensor_;
  mutable PlaceType place_;
145
  StreamWrapper stream_;
146 147 148
};

}  // namespace paddle