diff --git a/paddle/platform/device_context.cc b/paddle/platform/device_context.cc index ea07f2e002cb76d09a11f7a5305c2d45b780e7bd..9d72569f51ac4b1ae722fcced8e0fada849974f0 100644 --- a/paddle/platform/device_context.cc +++ b/paddle/platform/device_context.cc @@ -176,5 +176,69 @@ cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; } #endif +#ifdef PADDLE_WITH_MKLDNN +MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place) + : CPUDeviceContext(place), ready_(false) { + stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0)); +} + +template +void MKLDNNDeviceContext::AddElement(const std::string& op_key, + const T& value) { + if (GetElement(op_key)) { + return; + } + GetElementPool().emplace(op_key, value); +} + +template +const T MKLDNNDeviceContext::GetElement(const std::string& op_key) const { + auto it = GetElementPool().find(op_key); + return it == GetElementPool().end() ? nullptr : it->second; +} + +template <> +const std::unordered_map>& +MKLDNNDeviceContext::GetElementPool() const { + return memory_pool_; +} + +template <> +const std::unordered_map>& +MKLDNNDeviceContext::GetElementPool() const { + return primitive_pool_; +} + +template <> +const std::unordered_map>& +MKLDNNDeviceContext::GetElementPool() const { + return primitive_desc_pool_; +} + +void MKLDNNDeviceContext::Execute(bool block) { + if (pipeline_.empty()) { + return; + } + ResetStream(); + stream_->submit(pipeline_).wait(block); + ready_ = false; + pipeline_.clear(); +} + +void MKLDNNDeviceContext::ResetStream() { + if (ready_) { + return; + } + // TODO(TJ): change me when mkldnn have specific method to reset this state + stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager)); + ready_ = true; +} + +#endif + } // namespace platform } // namespace paddle diff --git a/paddle/platform/device_context.h b/paddle/platform/device_context.h index 2b366e6383d23e2d31a194edd04412892a8311eb..faabb8575e057ad3f6fb9b1223e649be25b7ec6a 100644 --- a/paddle/platform/device_context.h +++ b/paddle/platform/device_context.h @@ -21,6 +21,10 @@ limitations under the License. */ #define EIGEN_USE_GPU #endif +#ifdef PADDLE_WITH_MKLDNN +#include "mkldnn.hpp" +#endif + #include "paddle/platform/enforce.h" #include "paddle/platform/place.h" #include "unsupported/Eigen/CXX11/Tensor" @@ -117,6 +121,65 @@ class CUDNNDeviceContext : public CUDADeviceContext { #endif +#ifdef PADDLE_WITH_MKLDNN +using MKLDNNStream = mkldnn::stream; +using MKLDNNEngine = mkldnn::engine; +using MKLDNNMemory = mkldnn::memory; +using MKLDNNPrimitive = mkldnn::primitive; +using MKLDNNPrimitiveDesc = mkldnn::handle; + +typedef std::shared_ptr MKLDNNEnginePtr; +typedef std::shared_ptr MKLDNNMemoryPtr; +typedef std::shared_ptr MKLDNNPrimitivePtr; +typedef std::shared_ptr MKLDNNPrimitiveDescPtr; +class MKLDNNDeviceContext : public CPUDeviceContext { + public: + explicit MKLDNNDeviceContext(CPUPlace place); + virtual ~MKLDNNDeviceContext(); + + /* \brief Add new element: memory, primitive or primitive desc */ + template + void AddElement(const std::string& op_key, const T& value); + + /* \brief Get existed element: memory, primitive or primitive desc */ + template + const T GetElement(const std::string& op_key) const; + + /* \brief Get element pool: memory, primitive or primitive desc pool */ + template + const std::unordered_map>& + GetElementPool() const; + + /* \brief Get the active engine */ + const MKLDNNEnginePtr GetEngine() const { return engine_; } + + /* \brief Submit primitive to pipeline */ + void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); } + + /*! \brief Execute all submitted primitives in pipeline */ + void Execute(bool block = true); + + protected: + /*! \brief Reset the stream to prepare next exectue */ + void ResetStream(); + + private: + std::unordered_map> + memory_pool_; + std::unordered_map> + primitive_pool_; + std::unordered_map> + primitive_desc_pool_; + std::vector pipeline_; + std::unique_ptr stream_; + MKLDNNEnginePtr engine_; + bool ready_; +}; +#endif + /*! \brief device context pool singleton */ class DeviceContextPool { public: