提交 df5bc787 编写于 作者: D dzhwinter

"fix tensor shared_ptr"

上级 a46f3fce
......@@ -5,7 +5,6 @@
#include <string.h>
#include <memory>
#include "paddle/math/MemoryHandle.h"
#include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h"
......@@ -15,17 +14,16 @@ namespace optimizer {
template <class T>
class TensorT {
public:
TensorT(size_t size)
: TensorT(std::make_shared<CpuMemoryHandle>(size * sizeof(float)), size) {
TensorT(size_t size) : height_(1), width_(size) {
data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
data_ = data_ptr_.get();
}
TensorT(CpuMemHandlePtr handle, size_t size)
: height_(1),
width_(size),
data_(reinterpret_cast<T*>(handle->getBuf())) {}
TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {}
TensorT(T* data, size_t size)
: height_(1), width_(size), data_ptr_(nullptr), data_(data) {}
TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data) {}
TensorT(T* data, size_t h, size_t w)
: height_(h), width_(w), data_ptr_(nullptr), data_(data) {}
virtual ~TensorT() {}
......@@ -45,6 +43,7 @@ public:
protected:
size_t height_;
size_t width_;
std::shared_ptr<T> data_ptr_;
T* data_;
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册