提交 9f408dfb 编写于 作者: F fengjiayi

fix some compile error

上级 696ba1d2
...@@ -14,32 +14,39 @@ limitations under the License. */ ...@@ -14,32 +14,39 @@ limitations under the License. */
#pragma once #pragma once
#include <memory>
#include <type_traits>
#include <typeinfo>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/place.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class Tensor { class Tensor {
using paddle::platform::Place;
public: public:
template <typename T> template <typename T>
const T* data() const { const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr, PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data"); "Tensor::data must be called after Tensor::mutable_data");
return static_cast<const T*>(holder->Ptr()); return static_cast<const T*>(holder_->Ptr());
} }
template <typename T, // must be POD types template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type> typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, Place place) { T* mutable_data(DDim dims, paddle::platform::Place place) {
if (holder_ == nullptr || holder_->Place() != place || if (holder_ == nullptr || holder_->Place() != place ||
holder_->Size() < dims.product() * sizeof(T)) { holder_->Size() < product(dims) * sizeof(T)) {
holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T))); holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
} }
return static_cast<T*>(holder_->Ptr()); return static_cast<T*>(holder_->Ptr());
} }
template <typename T, // must be POD types template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type> typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims) { T* mutable_data(DDim dims) {
return mutable_data<T>(dims, paddle::platform::get_place()); return mutable_data<T>(dims, paddle::platform::get_place());
} }
...@@ -50,24 +57,24 @@ class Tensor { ...@@ -50,24 +57,24 @@ class Tensor {
struct Placeholder { struct Placeholder {
virtual ~Placeholder() {} virtual ~Placeholder() {}
virtual void* Ptr() const = 0; virtual void* Ptr() const = 0;
virtual Place Place() const = 0; virtual paddle::platform::Place Place() const = 0;
virtual size_t Size() const = 0; virtual size_t Size() const = 0;
}; };
template <typename T> template <typename T>
struct PlaceholderImpl : public Placeholder { struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place pl, size_t size) PlaceholderImpl(paddle::platform::Place pl, size_t size)
: ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)), : ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
place_(pl), place_(pl),
size_(size) {} size_(size) {}
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); } virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; } virtual size_t Size() const { return size_; }
virtual Place Place() const { return place_; } virtual paddle::platform::Place Place() const { return place_; }
std::unique_ptr<T, memory::Deleter> ptr_; std::unique_ptr<T, memory::Deleter> ptr_;
Place place_; // record the place of ptr_. paddle::platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block. size_t size_; // size of the memory block.
}; };
std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated. std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册