提交 9c7cea81 编写于 作者: T tensor-tang

follow comments, use unique_ptr and remove unused file

上级 880b2e80
...@@ -189,11 +189,11 @@ void MKLDNNDeviceContext::AddElement(const std::string& op_key, ...@@ -189,11 +189,11 @@ void MKLDNNDeviceContext::AddElement(const std::string& op_key,
if (GetElement<T>(op_key)) { if (GetElement<T>(op_key)) {
return; return;
} }
GetElementPool<T>().emplace(op_key, value); GetElementPool<T>().emplace(op_key, std::move(value));
} }
template <typename T> template <typename T>
const T MKLDNNDeviceContext::GetElement(const std::string& op_key) const { const T& MKLDNNDeviceContext::GetElement(const std::string& op_key) const {
auto it = GetElementPool<T>().find(op_key); auto it = GetElementPool<T>().find(op_key);
return it == GetElementPool<T>().end() ? nullptr : it->second; return it == GetElementPool<T>().end() ? nullptr : it->second;
} }
......
...@@ -132,7 +132,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -132,7 +132,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
/* \brief Get existed element: memory, primitive or primitive desc */ /* \brief Get existed element: memory, primitive or primitive desc */
template <typename T> template <typename T>
const T GetElement(const std::string& op_key) const; const T& GetElement(const std::string& op_key) const;
/* \brief Get element pool: memory, primitive or primitive desc pool */ /* \brief Get element pool: memory, primitive or primitive desc pool */
template <typename T> template <typename T>
...@@ -140,7 +140,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -140,7 +140,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
GetElementPool() const; GetElementPool() const;
/* \brief Get the active engine */ /* \brief Get the active engine */
const MKLDNNEnginePtr GetEngine() const { return engine_; } const MKLDNNEngine& engine() const { return *engine_; }
/* \brief Submit primitive to pipeline */ /* \brief Submit primitive to pipeline */
void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); } void Submit(const MKLDNNPrimitivePtr& p) { pipeline_.push_back(*p); }
...@@ -163,7 +163,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -163,7 +163,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::hash<std::string>> std::hash<std::string>>
primitive_desc_pool_; primitive_desc_pool_;
std::vector<MKLDNNPrimitive> pipeline_; std::vector<MKLDNNPrimitive> pipeline_;
std::unique_ptr<MKLDNNStream> stream_; MKLDNNStreamPtr stream_;
MKLDNNEnginePtr engine_; MKLDNNEnginePtr engine_;
bool ready_; bool ready_;
}; };
......
...@@ -25,10 +25,11 @@ using MKLDNNMemory = mkldnn::memory; ...@@ -25,10 +25,11 @@ using MKLDNNMemory = mkldnn::memory;
using MKLDNNPrimitive = mkldnn::primitive; using MKLDNNPrimitive = mkldnn::primitive;
using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>; using MKLDNNPrimitiveDesc = mkldnn::handle<mkldnn_primitive_desc_t>;
typedef std::shared_ptr<MKLDNNEngine> MKLDNNEnginePtr; typedef std::unique_ptr<MKLDNNStream> MKLDNNStreamPtr;
typedef std::shared_ptr<MKLDNNMemory> MKLDNNMemoryPtr; typedef std::unique_ptr<MKLDNNEngine> MKLDNNEnginePtr;
typedef std::shared_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr; typedef std::unique_ptr<MKLDNNMemory> MKLDNNMemoryPtr;
typedef std::shared_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr; typedef std::unique_ptr<MKLDNNPrimitive> MKLDNNPrimitivePtr;
typedef std::unique_ptr<MKLDNNPrimitiveDesc> MKLDNNPrimitiveDescPtr;
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册