提交 df5bc787 编写于 作者: D dzhwinter

"fix tensor shared_ptr"

上级 a46f3fce
...@@ -5,7 +5,6 @@ ...@@ -5,7 +5,6 @@
#include <string.h> #include <string.h>
#include <memory> #include <memory>
#include "paddle/math/MemoryHandle.h"
#include "paddle/utils/Common.h" #include "paddle/utils/Common.h"
#include "paddle/utils/Logging.h" #include "paddle/utils/Logging.h"
...@@ -15,17 +14,16 @@ namespace optimizer { ...@@ -15,17 +14,16 @@ namespace optimizer {
template <class T> template <class T>
class TensorT { class TensorT {
public: public:
TensorT(size_t size) TensorT(size_t size) : height_(1), width_(size) {
: TensorT(std::make_shared<CpuMemoryHandle>(size * sizeof(float)), size) { data_ptr_ = std::shared_ptr<T>(new T[size], std::default_delete<T[]>());
data_ = data_ptr_.get();
} }
TensorT(CpuMemHandlePtr handle, size_t size)
: height_(1),
width_(size),
data_(reinterpret_cast<T*>(handle->getBuf())) {}
TensorT(T* data, size_t size) : height_(1), width_(size), data_(data) {} TensorT(T* data, size_t size)
: height_(1), width_(size), data_ptr_(nullptr), data_(data) {}
TensorT(T* data, size_t h, size_t w) : height_(h), width_(w), data_(data) {} TensorT(T* data, size_t h, size_t w)
: height_(h), width_(w), data_ptr_(nullptr), data_(data) {}
virtual ~TensorT() {} virtual ~TensorT() {}
...@@ -45,6 +43,7 @@ public: ...@@ -45,6 +43,7 @@ public:
protected: protected:
size_t height_; size_t height_;
size_t width_; size_t width_;
std::shared_ptr<T> data_ptr_;
T* data_; T* data_;
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册