提交 6588d0e0 编写于 作者: M Michal Gallus

Update MKLDNN to 0.15, fix conv integration

上级 1c3e66bb
...@@ -54,7 +54,7 @@ ExternalProject_Add( ...@@ -54,7 +54,7 @@ ExternalProject_Add(
${EXTERNAL_PROJECT_LOG_ARGS} ${EXTERNAL_PROJECT_LOG_ARGS}
DEPENDS ${MKLDNN_DEPENDS} DEPENDS ${MKLDNN_DEPENDS}
GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git" GIT_REPOSITORY "https://github.com/01org/mkl-dnn.git"
GIT_TAG "a29d8487a63afca3d5b8c5bbdbb473cf8ccc6e51" GIT_TAG "64e03a1939e0d526aa8e9f2e3f7dc0ad8d372944"
PREFIX ${MKLDNN_SOURCES_DIR} PREFIX ${MKLDNN_SOURCES_DIR}
UPDATE_COMMAND "" UPDATE_COMMAND ""
CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER} CMAKE_ARGS -DCMAKE_CXX_COMPILER=${CMAKE_CXX_COMPILER}
......
...@@ -31,7 +31,8 @@ size_t Tensor::memory_size() const { ...@@ -31,7 +31,8 @@ size_t Tensor::memory_size() const {
return holder_ == nullptr ? 0UL : holder_->size() - offset_; return holder_ == nullptr ? 0UL : holder_->size() - offset_;
} }
void* Tensor::mutable_data(platform::Place place, std::type_index type) { void* Tensor::mutable_data(platform::Place place, std::type_index type,
int64_t requested_size) {
if (holder_ != nullptr) { if (holder_ != nullptr) {
holder_->set_type(type); holder_->set_type(type);
} }
...@@ -39,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { ...@@ -39,7 +40,7 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) {
"When calling this method, the Tensor's numel must be " "When calling this method, the Tensor's numel must be "
"equal or larger than zero. " "equal or larger than zero. "
"Please check Tensor::Resize has been called first."); "Please check Tensor::Resize has been called first.");
int64_t size = numel() * SizeOfType(type); int64_t size = requested_size ? requested_size : numel() * SizeOfType(type);
/* some versions of boost::variant don't have operator!= */ /* some versions of boost::variant don't have operator!= */
if (holder_ == nullptr || !(holder_->place() == place) || if (holder_ == nullptr || !(holder_->place() == place) ||
holder_->size() < size + offset_) { holder_->size() < size + offset_) {
...@@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) { ...@@ -68,10 +69,10 @@ void* Tensor::mutable_data(platform::Place place, std::type_index type) {
offset_); offset_);
} }
void* Tensor::mutable_data(platform::Place place) { void* Tensor::mutable_data(platform::Place place, int64_t requested_size) {
PADDLE_ENFORCE(this->holder_ != nullptr, PADDLE_ENFORCE(this->holder_ != nullptr,
"Cannot invoke mutable data if current hold nothing."); "Cannot invoke mutable data if current hold nothing.");
return mutable_data(place, holder_->type()); return mutable_data(place, holder_->type(), requested_size);
} }
Tensor& Tensor::ShareDataWith(const Tensor& src) { Tensor& Tensor::ShareDataWith(const Tensor& src) {
......
...@@ -89,22 +89,24 @@ class Tensor { ...@@ -89,22 +89,24 @@ class Tensor {
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(platform::Place place); T* mutable_data(platform::Place place, int64_t requested_size = 0);
void* mutable_data(platform::Place place, std::type_index type); void* mutable_data(platform::Place place, std::type_index type,
int64_t requested_size = 0);
void* mutable_data(platform::Place place); void* mutable_data(platform::Place place, int64_t requested_size = 0);
/** /**
* @brief Return a pointer to mutable memory block. * @brief Return a pointer to mutable memory block.
* *
* @param[in] dims The dimensions of the memory block. * @param[in] dims The dimensions of the memory block.
* @param[in] place The place of the memory block. * @param[in] place The place of the memory block.
* @param[in] requested_size The size of the block in bytes.
* *
* @note If not exist, then allocation. * @note If not exist, then allocation.
*/ */
template <typename T> template <typename T>
T* mutable_data(DDim dims, platform::Place place); T* mutable_data(DDim dims, platform::Place place, int64_t requested_size = 0);
/*! Return the dimensions of the memory block. */ /*! Return the dimensions of the memory block. */
const DDim& dims() const; const DDim& dims() const;
......
...@@ -46,16 +46,17 @@ inline T* Tensor::data() { ...@@ -46,16 +46,17 @@ inline T* Tensor::data() {
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(DDim dims, platform::Place place) { inline T* Tensor::mutable_data(DDim dims, platform::Place place,
int64_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
Resize(dims); Resize(dims);
return mutable_data<T>(place); return mutable_data<T>(place, requested_size);
} }
template <typename T> template <typename T>
inline T* Tensor::mutable_data(platform::Place place) { inline T* Tensor::mutable_data(platform::Place place, int64_t requested_size) {
static_assert(std::is_pod<T>::value, "T must be POD"); static_assert(std::is_pod<T>::value, "T must be POD");
return reinterpret_cast<T*>(mutable_data(place, typeid(T))); return reinterpret_cast<T*>(mutable_data(place, typeid(T), requested_size));
} }
inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) { inline Tensor ReshapeToMatrix(const Tensor& src, int num_col_dims) {
......
...@@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -53,6 +53,18 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
key_ += "-BWD"; key_ += "-BWD";
} }
size_t GetDstMemorySize() {
return conv_pd_->dst_primitive_desc().get_size();
}
size_t GetDiffWeightsMemorySize() {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
}
size_t GetDiffSourceMemorySize() {
return conv_bwd_data_pd_->diff_src_primitive_desc().get_size();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p, const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
...@@ -251,7 +263,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -251,7 +263,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
PADDLE_ENFORCE(input->dims().size() == 4, PADDLE_ENFORCE(input->dims().size() == 4,
"Input must be with 4 dimensions, i.e. NCHW"); "Input must be with 4 dimensions, i.e. NCHW");
...@@ -306,6 +317,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -306,6 +317,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto user_weights_memory_p = handler.AcquireWeightsMemory( auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
...@@ -393,13 +406,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -393,13 +406,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T* input_grad_data = nullptr; T* input_grad_data = nullptr;
T* filter_grad_data = nullptr; T* filter_grad_data = nullptr;
if (input_grad) {
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace());
}
if (filter_grad) {
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace());
}
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
...@@ -485,6 +491,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -485,6 +491,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromWeightsPrimitive( handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline); user_diff_dst_memory_p, pipeline);
size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_weights_memory_p = auto diff_weights_memory_p =
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive( handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data)); reinterpret_cast<void*>(filter_grad_data));
...@@ -507,6 +516,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -507,6 +516,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p, handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline); pipeline);
size_t size = handler.GetDiffSourceMemorySize();
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive( auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
reinterpret_cast<void*>(input_grad_data)); reinterpret_cast<void*>(input_grad_data));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册