提交 4a7f0698 编写于 作者: M Michal Gallus

Add consts to new MKLDNN integration

Also replace memory types from int64_t to size_t
上级 6588d0e0
......@@ -32,7 +32,7 @@ size_t Tensor::memory_size() const {
}
void* Tensor::mutable_data(platform::Place place, std::type_index type,
int64_t requested_size) {
size_t requested_size) {
if (holder_ != nullptr) {
holder_->set_type(type);
}
......@@ -40,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
"When calling this method, the Tensor's numel must be "
"equal or larger than zero. "
"Please check Tensor::Resize has been called first.");
int64_t size = requested_size ? requested_size : numel() * SizeOfType(type);
size_t size = requested_size ? requested_size : numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) {
......@@ -69,7 +69,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type,
offset_);
}
void* Tensor::mutable_data(platform::Place place, int64_t requested_size) {
void* Tensor::mutable_data(platform::Place place, size_t requested_size) {
PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type(), requested_size);
......
......@@ -89,12 +89,12 @@ class Tensor {
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(platform::Place place, int64_t requested_size = 0);
T* mutable_data(platform::Place place, size_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type,
int64_t requested_size = 0);
size_t requested_size = 0);
void* mutable_data(platform::Place place, int64_t requested_size = 0);
void* mutable_data(platform::Place place, size_t requested_size = 0);
/**
* @brief Return a pointer to mutable memory block.
......@@ -106,7 +106,7 @@ class Tensor {
* @note If not exist, then allocation.
*/
template <typename T>
T* mutable_data(DDim dims, platform::Place place, int64_t requested_size = 0);
T* mutable_data(DDim dims, platform::Place place, size_t requested_size = 0);
/*! Return the dimensions of the memory block. */
const DDim& dims() const;
......
......@@ -47,14 +47,14 @@ inline T* Tensor::data() {
template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place,
int64_t requested_size) {
size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims);
return mutable_data<T>(place, requested_size);
}
template <typename T>
inline T* Tensor::mutable_data(platform::Place place, int64_t requested_size) {
inline T* Tensor::mutable_data(platform::Place place, size_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size));
}
......
......@@ -53,15 +53,15 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
key_ += "-BWD";
}
size_t GetDstMemorySize() {
size_t GetDstMemorySize() const {
return conv_pd_->dst_primitive_desc().get_size();
}
size_t GetDiffWeightsMemorySize() {
size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
}
size_t GetDiffSourceMemorySize() {
size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
}
......@@ -491,7 +491,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline);
size_t size = handler.GetDiffWeightsMemorySize();
const size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_weights_memory_p =
......@@ -516,7 +516,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline);
size_t size = handler.GetDiffSourceMemorySize();
const size_t size = handler.GetDiffSourceMemorySize();
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册