提交 4a7f0698 编写于 作者: M Michal Gallus

Add consts to new MKLDNN integration

Also replace memory types from int64_t to size_t
上级 6588d0e0
...@@ -32,7 +32,7 @@ size_t Tensor::memory_size() const { ...@@ -32,7 +32,7 @@ size_t Tensor::memory_size() const {
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type, void* Tensor::mutable_data(platform::Place place, std::type_index type,
int64_t requested_size) { size_t requested_size) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -40,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, ...@@ -40,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
"When calling this method, the Tensor's numel must be " "When calling this method, the Tensor's numel must be "
"equal or larger than zero. " "equal or larger than zero. "
"Please check Tensor::Resize has been called first."); "Please check Tensor::Resize has been called first.");
int64_t size = requested_size ? requested_size : numel() * SizeOfType(type); size_t size = requested_size ? requested_size : numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
...@@ -69,7 +69,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type, ...@@ -69,7 +69,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
offset_); offset_);
} }
void* Tensor::mutable_data(platform::Place place, int64_t requested_size) { void* Tensor::mutable_data(platform::Place place, size_t requested_size) {
PADDLE_ENFORCE(this->holder_ != nullptr, PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing."); "Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type(), requested_size); return mutable_data(place, holder_->type(), requested_size);
......
...@@ -89,12 +89,12 @@ class Tensor { ...@@ -89,12 +89,12 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(platform::Place place, int64_t requested_size = 0); T* mutable_data(platform::Place place, size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type, void* mutable_data(platform::Place place, std::type_index type,
int64_t requested_size = 0); size_t requested_size = 0);
void* mutable_data(platform::Place place, int64_t requested_size = 0); void* mutable_data(platform::Place place, size_t requested_size = 0);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
...@@ -106,7 +106,7 @@ class Tensor { ...@@ -106,7 +106,7 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(DDim dims, platform::Place place, int64_t requested_size = 0); T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0);
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
const DDim& dims() const; const DDim& dims() const;
......
...@@ -47,14 +47,14 @@ inline T* Tensor::data() { ...@@ -47,14 +47,14 @@ inline T* Tensor::data() {
template <typename T> template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place, inline T* Tensor::mutable_data(DDim dims, platform::Place place,
int64_t requested_size) { size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims); Resize(dims);
return mutable_data<T>(place, requested_size); return mutable_data<T>(place, requested_size);
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place, int64_t requested_size) { inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size)); return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size));
} }
......
...@@ -53,15 +53,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -53,15 +53,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
key_ += "-BWD"; key_ += "-BWD";
} }
size_t GetDstMemorySize() { size_t GetDstMemorySize() const {
return conv_pd_->dst_primitive_desc().get_size(); return conv_pd_->dst_primitive_desc().get_size();
} }
size_t GetDiffWeightsMemorySize() { size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
} }
size_t GetDiffSourceMemorySize() { size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size(); return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
} }
...@@ -491,7 +491,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -491,7 +491,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromWeightsPrimitive( handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline); user_diff_dst_memory_p, pipeline);
size_t size = handler.GetDiffWeightsMemorySize(); const size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size); filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_weights_memory_p = auto diff_weights_memory_p =
...@@ -516,7 +516,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -516,7 +516,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline); pipeline);
size_t size = handler.GetDiffSourceMemorySize(); const size_t size = handler.GetDiffSourceMemorySize();
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size); input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册