diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 68eca6e328da9510552f77760aea915c24292a49..8e2e1d38a66d1039519bab312f77bef6604d8ec1 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,9 +37,12 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif DECLARE_bool(benchmark); -DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); +DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { @@ -83,14 +86,7 @@ Executor::~Executor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - if (platform::is_cpu_place(place_)) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(place_); - dev_ctx->ResetBlobMap(); - platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( - paddle::framework::DataLayout::kNCHW); - } + ClearMKLDNNCache(place_); #endif } diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index fb95504d9a53a13aea69b0f203e18ddab79c6e66..a5ca13f1ce252d2368e2fc765e49d397356660a7 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace framework { @@ -57,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // Passes can change params, tensors, so caching need to be discarded - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace()); - dev_ctx->ResetBlobMap(); + ClearMKLDNNCache(paddle::platform::CPUPlace()); #endif return graph; } diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index a5de53e9d07d562c32885b1495981757f45cb5f9..6544c1aa7b0591efc96638ab34ab559c8b563b09 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -25,6 +25,9 @@ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/string/pretty_log.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace framework { @@ -122,14 +125,7 @@ NaiveExecutor::~NaiveExecutor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - if (platform::is_cpu_place(place_)) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext *dev_ctx = - (platform::MKLDNNDeviceContext *)pool.Get(place_); - dev_ctx->ResetBlobMap(); - platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( - paddle::framework::DataLayout::kNCHW); - } + ClearMKLDNNCache(place_); #endif } diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index e0c2934ab32bb8135fcecf4577bae0f48bedf0ba..4d602d5c0211e221a99e0e87a3344c5a9c2a0142 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(imperative_flag SRCS flags.cc DEPS gflags) +cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 3e682863795724bcd3d521976c8b061b5602c8eb..ec76f58d77ed5dece46c53795b3cccfe8bfbd902 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -28,6 +28,11 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +DECLARE_bool(use_mkldnn); namespace paddle { namespace imperative { @@ -192,6 +197,9 @@ void VarBase::ClearGradient() { auto* grad_t = grad_var_->MutableVar()->GetMutable(); if (grad_t->mutable_value()->IsInitialized()) { +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place()); +#endif grad_t->mutable_rows()->clear(); grad_t->mutable_value()->clear(); } @@ -202,6 +210,9 @@ void VarBase::ClearGradient() { auto* dev_ctx = platform::DeviceContextPool::Instance().Get(grad_t->place()); operators::math::set_constant(*dev_ctx, grad_t, 0.0); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place()); +#endif } } } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index dcc3a51e72b3ef5ffc29f7db566840e32b5d43e9..5a100c5746e616e860811dd47da27036ea7355d5 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -25,7 +25,7 @@ if (WITH_PYTHON) endif(NOT WIN32) endif() -cc_library(flags SRCS flags.cc DEPS gflags) +cc_library(flags SRCS flags.cc DEPS gflags) cc_library(errors SRCS errors.cc DEPS error_codes_proto) cc_test(errors_test SRCS errors_test.cc DEPS errors enforce) diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c2af3d0e982992fc6bec54aa4f4751378d8e0336..98bdf1f8c675da4e3a272945d605563e35016f8d 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -473,3 +473,13 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes "each CUDAPlace. If you don't need to limit the memory, " "you should set FLAGS_local_exe_sub_scope_limit=-1. " "The default value is 256 MBytes."); + +/** + * MKLDNN related FLAG + * Name: use_mkldnn + * Since Version: + * Value Range: bool, default=false + * Example: + * Note: + */ +DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 0fcb23679164079865947b0b0b539ae344732b58..c147bdccbe99e505a8fd8f1ec75c487b00c02067 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -117,6 +117,18 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); } +inline void ClearMKLDNNCache(const platform::Place& place) { + // Clear mkl-dnn cache, + if (platform::is_cpu_place(place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(place); + dev_ctx->ResetBlobMap(); + platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( + paddle::framework::DataLayout::kNCHW); + } +} + template mkldnn::memory::data_type MKLDNNGetDataType() { return mkldnn::memory::data_type::undef;