提交 72652845 编写于 作者: T tensor-tang

add MKLDNNDeviceContext

上级 cbe25b33
...@@ -176,5 +176,69 @@ cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; } ...@@ -176,5 +176,69 @@ cudnnHandle_t CUDNNDeviceContext::cudnn_handle() const { return cudnn_handle_; }
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: CPUDeviceContext(place), ready_(false) {
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
engine_.reset(new mkldnn::engine(mkldnn::engine::cpu, 0));
}
template <typename T>
void MKLDNNDeviceContext::AddElement(const std::string& op_key,
const T& value) {
if (GetElement<T>(op_key)) {
return;
}
GetElementPool<T>().emplace(op_key, value);
}
template <typename T>
const T MKLDNNDeviceContext::GetElement(const std::string& op_key) const {
auto it = GetElementPool<T>().find(op_key);
return it == GetElementPool<T>().end() ? nullptr : it->second;
}
template <>
const std::unordered_map<const std::string, const MKLDNNMemoryPtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNMemoryPtr>() const {
return memory_pool_;
}
template <>
const std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitivePtr>() const {
return primitive_pool_;
}
template <>
const std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
std::hash<std::string>>&
MKLDNNDeviceContext::GetElementPool<MKLDNNPrimitiveDescPtr>() const {
return primitive_desc_pool_;
}
void MKLDNNDeviceContext::Execute(bool block) {
if (pipeline_.empty()) {
return;
}
ResetStream();
stream_->submit(pipeline_).wait(block);
ready_ = false;
pipeline_.clear();
}
void MKLDNNDeviceContext::ResetStream() {
if (ready_) {
return;
}
// TODO(TJ): change me when mkldnn have specific method to reset this state
stream_.reset(new mkldnn::stream(mkldnn::stream::kind::eager));
ready_ = true;
}
#endif
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -21,6 +21,10 @@ limitations under the License. */ ...@@ -21,6 +21,10 @@ limitations under the License. */
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
#include "mkldnn.hpp"
#endif
#include "paddle/platform/enforce.h" #include "paddle/platform/enforce.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
...@@ -117,6 +121,65 @@ class CUDNNDeviceContext : public CUDADeviceContext { ...@@ -117,6 +121,65 @@ class CUDNNDeviceContext : public CUDADeviceContext {
#endif #endif
#ifdef PADDLE_WITH_MKLDNN
using MKLDNNStream = mkldnn::stream;
using MKLDNNEngine = mkldnn::engine;
using MKLDNNMemory = mkldnn::memory;
using MKLDNNPrimitive = mkldnn::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
typedef std::shared_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::shared_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::shared_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::shared_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
class MKLDNNDeviceContext : public CPUDeviceContext {
public:
explicit MKLDNNDeviceContext(CPUPlace place);
virtual ~MKLDNNDeviceContext();
/* \brief Add new element: memory, primitive or primitive desc */
template <typename T>
void AddElement(const std::string& op_key, const T& value);
/* \brief Get existed element: memory, primitive or primitive desc */
template <typename T>
const T GetElement(const std::string& op_key) const;
/* \brief Get element pool: memory, primitive or primitive desc pool */
template <typename T>
const std::unordered_map<const std::string, const T, std::hash<std::string>>&
GetElementPool() const;
/* \brief Get the active engine */
const MKLDNNEnginePtr GetEngine() const { return engine_; }
/* \brief Submit primitive to pipeline */
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
/*! \brief Execute all submitted primitives in pipeline */
void Execute(bool block = true);
protected:
/*! \brief Reset the stream to prepare next exectue */
void ResetStream();
private:
std::unordered_map<const std::string, const MKLDNNMemoryPtr,
std::hash<std::string>>
memory_pool_;
std::unordered_map<const std::string, const MKLDNNPrimitivePtr,
std::hash<std::string>>
primitive_pool_;
std::unordered_map<const std::string, const MKLDNNPrimitiveDescPtr,
std::hash<std::string>>
primitive_desc_pool_;
std::vector<MKLDNNPrimitive> pipeline_;
std::unique_ptr<MKLDNNStream> stream_;
MKLDNNEnginePtr engine_;
bool ready_;
};
#endif
/*! \brief device context pool singleton */ /*! \brief device context pool singleton */
class DeviceContextPool { class DeviceContextPool {
public: public:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册