提交 1cb42505 编写于 作者: Z Zhuo Peng 提交者: TensorFlower Gardener

Made the Tensor constructor that takes a TensorBuffer public.

PiperOrigin-RevId: 262940594
上级 5daa70bf
......@@ -54,6 +54,44 @@ Status MaybeMoveSliceToElement(Tensor* parent, Tensor* element, int64 index);
} // namespace batch_util
/// @ingroup core
/// Interface to access the raw ref-counted data buffer.
class TensorBuffer : public core::RefCounted {
public:
explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
~TensorBuffer() override {}
/// \brief data() points to a memory region of size() bytes.
///
/// NOTE(mrry): The `data()` method is not virtual for performance reasons.
/// It can be called multiple times when the contents of a `Tensor` are
/// accessed, and so making it non-virtual allows the body to be inlined.
void* data() const { return data_; }
/// \brief Size (in bytes) of the buffer.
virtual size_t size() const = 0;
/// \brief If this TensorBuffer is sub-buffer of another TensorBuffer,
/// returns that TensorBuffer. Otherwise, returns this.
virtual TensorBuffer* root_buffer() = 0;
/// \brief Fills metadata about the allocation into the proto.
virtual void FillAllocationDescription(
AllocationDescription* proto) const = 0;
/// \brief Helper method to reinterpret the buffer as an array of `T`.
template <typename T>
T* base() const {
return reinterpret_cast<T*>(data());
}
/// \brief Whether this TensorBuffer owns the underlying memory.
virtual bool OwnsMemory() const { return true; }
private:
void* const data_;
};
/// Represents an n-dimensional array of values.
class Tensor {
public:
......@@ -108,6 +146,11 @@ class Tensor {
Tensor(Allocator* a, DataType type, const TensorShape& shape,
const AllocationAttributes& allocation_attr);
/// \brief Creates a tensor with the input datatype, shape and buf.
///
/// Acquires a ref on buf that belongs to this Tensor.
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
/// \brief Creates an empty Tensor of the given data type.
///
/// Like Tensor(), returns a 1-dimensional, 0-element Tensor with
......@@ -606,20 +649,16 @@ class Tensor {
TensorShape shape_;
TensorBuffer* buf_;
friend class DMAHelper;
friend class TensorCApi;
friend class TensorCord; // For access to buf_
friend class TensorReference; // For access to buf_
friend class VariableOp; // For access to set_shape
friend class AutoReloadVariableOp; // For access to set_shape
friend class TensorTestHelper; // For access to set_shape
friend class CastOpBase; // For access to set_dtype;
friend class DMAHelper; // For access to buf_.
friend class TensorCApi; // For access to buf_.
friend class TensorReference; // For access to buf_.
friend class VariableOp; // For access to set_shape.
friend class AutoReloadVariableOp; // For access to set_shape.
friend class TensorTestHelper; // For access to set_shape.
friend class CastOpBase; // For access to set_dtype.
friend class OpKernelContext; // For access to RefCountIsOne().
friend class ScopedAllocator; // For access to buf_.
friend class XlaTensor; // For access to RefCountIsOne().
friend class XlaTensorBuffer; // For access to the private constructor taking
// the buffer
friend class Var;
template <typename Device, typename T>
friend class AssignVariableOp; // For access to RefCountIsOne().
template <typename Device, typename T>
......@@ -636,11 +675,6 @@ class Tensor {
Tensor* parent, Tensor* element,
int64 index); // For access to RefCountIsOne().
// Creates a tensor with the input datatype, shape and buf.
//
// Acquires a ref on buf that belongs to this Tensor.
Tensor(DataType type, const TensorShape& shape, TensorBuffer* buf);
bool CanUseDMA() const;
// Only needed by variable op to set the shape of an uninitialized
......@@ -673,40 +707,6 @@ class Tensor {
// START_SKIP_DOXYGEN
// Interface to access the raw ref-counted data buffer.
class TensorBuffer : public core::RefCounted {
public:
explicit TensorBuffer(void* data_ptr) : data_(data_ptr) {}
~TensorBuffer() override {}
// data() points to a memory region of size() bytes.
//
// NOTE(mrry): The `data()` method is not virtual for performance reasons.
// It can be called multiple times when the contents of a `Tensor` are
// accessed, and so making it non-virtual allows the body to be inlined.
void* data() const { return data_; }
virtual size_t size() const = 0;
// If this TensorBuffer is sub-buffer of another TensorBuffer,
// returns that TensorBuffer. Otherwise, returns this.
virtual TensorBuffer* root_buffer() = 0;
// Fill metadata about the allocation into the proto.
virtual void FillAllocationDescription(
AllocationDescription* proto) const = 0;
template <typename T>
T* base() const {
return reinterpret_cast<T*>(data());
}
// Whether this TensorBuffer owns the underlying memory.
virtual bool OwnsMemory() const { return true; }
private:
void* const data_;
};
template <typename T>
T* Tensor::base() const {
return buf_ == nullptr ? nullptr : buf_->base<T>();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册