提交 9f408dfb 编写于 作者: F fengjiayi

fix some compile error

上级 696ba1d2
......@@ -14,32 +14,39 @@ limitations under the License. */
#pragma once
#include <memory>
#include <type_traits>
#include <typeinfo>
#include "paddle/framework/ddim.h"
#include "paddle/framework/enforce.h"
#include "paddle/memory/memory.h"
#include "paddle/platform/assert.h"
#include "paddle/platform/place.h"
namespace paddle {
namespace framework {
class Tensor {
using paddle::platform::Place;
public:
template <typename T>
const T* data() const {
PADDLE_ENFORCE(holder_ != nullptr,
"Tensor::data must be called after Tensor::mutable_data");
return static_cast<const T*>(holder->Ptr());
return static_cast<const T*>(holder_->Ptr());
}
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
T* mutable_data(DDim dims, Place place) {
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims, paddle::platform::Place place) {
if (holder_ == nullptr || holder_->Place() != place ||
holder_->Size() < dims.product() * sizeof(T)) {
holder_.reset(new PlaceholderImpl(place, dims.product() * sizeof(T)));
holder_->Size() < product(dims) * sizeof(T)) {
holder_.reset(new PlaceholderImpl<T>(place, product(dims) * sizeof(T)));
}
return static_cast<T*>(holder_->Ptr());
}
template <typename T, // must be POD types
typename = std::enable_if<std::is_pod<T>::value>::type>
typename std::enable_if<std::is_pod<T>::value>::type* = nullptr>
T* mutable_data(DDim dims) {
return mutable_data<T>(dims, paddle::platform::get_place());
}
......@@ -50,24 +57,24 @@ class Tensor {
struct Placeholder {
virtual ~Placeholder() {}
virtual void* Ptr() const = 0;
virtual Place Place() const = 0;
virtual paddle::platform::Place Place() const = 0;
virtual size_t Size() const = 0;
};
template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(Place pl, size_t size)
PlaceholderImpl(paddle::platform::Place pl, size_t size)
: ptr_(paddle::memory::Alloc(pl, size), paddle::memory::Deleter(pl)),
place_(pl),
size_(size) {}
virtual void* Ptr() const { return static_cast<void*>(ptr_.get()); }
virtual size_t Size() const { return size_; }
virtual Place Place() const { return place_; }
virtual paddle::platform::Place Place() const { return place_; }
std::unique_ptr<T, memory::Deleter> ptr_;
Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
paddle::platform::Place place_; // record the place of ptr_.
size_t size_; // size of the memory block.
};
std::unique_ptr<Placeholder> holder_; // holds the memory block if allocated.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册