From 780c8d969e0d2d220df19a672c141ff7c44f53d2 Mon Sep 17 00:00:00 2001 From: tensor-tang Date: Wed, 23 Aug 2017 17:03:16 +0800 Subject: [PATCH] make downSpatial work, and remove hasSpatial_ --- paddle/gserver/layers/MKLDNNFcLayer.cpp | 4 ---- paddle/gserver/layers/MKLDNNFcLayer.h | 5 +---- paddle/math/MKLDNNMatrix.cpp | 7 ++++++- 3 files changed, 7 insertions(+), 9 deletions(-) diff --git a/paddle/gserver/layers/MKLDNNFcLayer.cpp b/paddle/gserver/layers/MKLDNNFcLayer.cpp index a3291e6a8fb..a5555c4618a 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.cpp +++ b/paddle/gserver/layers/MKLDNNFcLayer.cpp @@ -111,10 +111,6 @@ void MKLDNNFcLayer::reshape() { if (iw_ == 0) { iw_ = 1; } - hasSpatial_ = true; - if (ih_ == 1 && iw_ == 1) { - hasSpatial_ = false; - } CHECK_EQ(iLayerSize_, inputLayers_[0]->getSize()); ic_ = iLayerSize_ / (ih_ * iw_); CHECK_EQ(size_t(ic_ * ih_ * iw_), iLayerSize_) << "not divisible"; diff --git a/paddle/gserver/layers/MKLDNNFcLayer.h b/paddle/gserver/layers/MKLDNNFcLayer.h index 7954852a23f..e2657a8d5e9 100644 --- a/paddle/gserver/layers/MKLDNNFcLayer.h +++ b/paddle/gserver/layers/MKLDNNFcLayer.h @@ -32,16 +32,13 @@ protected: // if has already init the weight bool hasInitedWgt_; - // if input layer has image size info (ih>1 && iw>1) - bool hasSpatial_; - // fc weight and bias std::unique_ptr weight_; std::unique_ptr biases_; public: explicit MKLDNNFcLayer(const LayerConfig& config) - : MKLDNNLayer(config), hasInitedWgt_(false), hasSpatial_(true) {} + : MKLDNNLayer(config), hasInitedWgt_(false) {} ~MKLDNNFcLayer() {} diff --git a/paddle/math/MKLDNNMatrix.cpp b/paddle/math/MKLDNNMatrix.cpp index 24d54ec0f73..94df9c15508 100644 --- a/paddle/math/MKLDNNMatrix.cpp +++ b/paddle/math/MKLDNNMatrix.cpp @@ -85,7 +85,12 @@ void MKLDNNMatrix::downSpatial() { memory::desc md = memory::desc(dstDims, getDtype(), dstFmt); memory::primitive_desc pd = memory::primitive_desc(md, getEngine()); void* data = getData(); - memory(pd, data); + mkldnn_primitive_t result; + mkldnn::error::wrap_c_api( + mkldnn_primitive_create(&result, pd.get(), nullptr, nullptr), + "could not create a memory primitive"); + reset(result); + set_data_handle(data); } } // namespace paddle -- GitLab