未验证 提交 e52df3b1 编写于 作者: A arlesniak 提交者: GitHub

Added DNNL cache management for DyGraph (#25624)

* Added DNNL cache management for DyGraph

* move FLAGS_use_mkldnn to more general CMakeLists, getu use of the flag in ClearGradients

* missing file

* Fixes after review

* Bringing back original idea of place for 'use_mkldnn' flag to be accessible from platform nad imperative.

* Removed duplicate and added docs

* Fixes for CI
上级 650d7223
...@@ -37,9 +37,12 @@ limitations under the License. */ ...@@ -37,9 +37,12 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(benchmark); DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -83,14 +86,7 @@ Executor::~Executor() { ...@@ -83,14 +86,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) { ClearMKLDNNCache(place_);
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif #endif
} }
......
...@@ -19,6 +19,9 @@ limitations under the License. */ ...@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -57,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const { ...@@ -57,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded // Passes can change params, tensors, so caching need to be discarded
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); ClearMKLDNNCache(paddle::platform::CPUPlace());
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace());
dev_ctx->ResetBlobMap();
#endif #endif
return graph; return graph;
} }
......
...@@ -25,6 +25,9 @@ ...@@ -25,6 +25,9 @@
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/string/pretty_log.h" #include "paddle/fluid/string/pretty_log.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -122,14 +125,7 @@ NaiveExecutor::~NaiveExecutor() { ...@@ -122,14 +125,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache, // Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working // this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) { ClearMKLDNNCache(place_);
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext *dev_ctx =
(platform::MKLDNNDeviceContext *)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
#endif #endif
} }
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(imperative_flag SRCS flags.cc DEPS gflags)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform)
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
......
...@@ -28,6 +28,11 @@ ...@@ -28,6 +28,11 @@
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(use_mkldnn);
namespace paddle { namespace paddle {
namespace imperative { namespace imperative {
...@@ -192,6 +197,9 @@ void VarBase::ClearGradient() { ...@@ -192,6 +197,9 @@ void VarBase::ClearGradient() {
auto* grad_t = auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>(); grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) { if (grad_t->mutable_value()->IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
grad_t->mutable_rows()->clear(); grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear(); grad_t->mutable_value()->clear();
} }
...@@ -202,6 +210,9 @@ void VarBase::ClearGradient() { ...@@ -202,6 +210,9 @@ void VarBase::ClearGradient() {
auto* dev_ctx = auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place()); platform::DeviceContextPool::Instance().Get(grad_t->place());
operators::math::set_constant(*dev_ctx, grad_t, 0.0); operators::math::set_constant(*dev_ctx, grad_t, 0.0);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
} }
} }
} }
......
...@@ -25,7 +25,7 @@ if (WITH_PYTHON) ...@@ -25,7 +25,7 @@ if (WITH_PYTHON)
endif(NOT WIN32) endif(NOT WIN32)
endif() endif()
cc_library(flags SRCS flags.cc DEPS gflags) cc_library(flags SRCS flags.cc DEPS gflags)
cc_library(errors SRCS errors.cc DEPS error_codes_proto) cc_library(errors SRCS errors.cc DEPS error_codes_proto)
cc_test(errors_test SRCS errors_test.cc DEPS errors enforce) cc_test(errors_test SRCS errors_test.cc DEPS errors enforce)
......
...@@ -473,3 +473,13 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes ...@@ -473,3 +473,13 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes
"each CUDAPlace. If you don't need to limit the memory, " "each CUDAPlace. If you don't need to limit the memory, "
"you should set FLAGS_local_exe_sub_scope_limit=-1. " "you should set FLAGS_local_exe_sub_scope_limit=-1. "
"The default value is 256 MBytes."); "The default value is 256 MBytes.");
/**
* MKLDNN related FLAG
* Name: use_mkldnn
* Since Version:
* Value Range: bool, default=false
* Example:
* Note:
*/
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
...@@ -117,6 +117,18 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { ...@@ -117,6 +117,18 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
} }
inline void ClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
}
template <typename Type> template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() { mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef; return mkldnn::memory::data_type::undef;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册