From e52df3b125317f99b5876b5a9d173d62d12d6201 Mon Sep 17 00:00:00 2001 From: arlesniak Date: Tue, 28 Jul 2020 04:28:21 +0200 Subject: [PATCH] Added DNNL cache management for DyGraph (#25624) * Added DNNL cache management for DyGraph * move FLAGS_use_mkldnn to more general CMakeLists, getu use of the flag in ClearGradients * missing file * Fixes after review * Bringing back original idea of place for 'use_mkldnn' flag to be accessible from platform nad imperative. * Removed duplicate and added docs * Fixes for CI --- paddle/fluid/framework/executor.cc | 14 +++++--------- paddle/fluid/framework/ir/pass.cc | 8 ++++---- paddle/fluid/framework/naive_executor.cc | 12 ++++-------- paddle/fluid/imperative/CMakeLists.txt | 2 +- paddle/fluid/imperative/layer.cc | 11 +++++++++++ paddle/fluid/platform/CMakeLists.txt | 2 +- paddle/fluid/platform/flags.cc | 10 ++++++++++ paddle/fluid/platform/mkldnn_helper.h | 12 ++++++++++++ 8 files changed, 48 insertions(+), 23 deletions(-) diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 68eca6e328d..8e2e1d38a66 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -37,9 +37,12 @@ limitations under the License. */ #include "paddle/fluid/operators/distributed/distributed.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif DECLARE_bool(benchmark); -DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); +DECLARE_bool(use_mkldnn); namespace paddle { namespace framework { @@ -83,14 +86,7 @@ Executor::~Executor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - if (platform::is_cpu_place(place_)) { - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(place_); - dev_ctx->ResetBlobMap(); - platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( - paddle::framework::DataLayout::kNCHW); - } + ClearMKLDNNCache(place_); #endif } diff --git a/paddle/fluid/framework/ir/pass.cc b/paddle/fluid/framework/ir/pass.cc index fb95504d9a5..a5ca13f1ce2 100644 --- a/paddle/fluid/framework/ir/pass.cc +++ b/paddle/fluid/framework/ir/pass.cc @@ -19,6 +19,9 @@ limitations under the License. */ #include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/platform/device_context.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace framework { @@ -57,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // Passes can change params, tensors, so caching need to be discarded - platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext* dev_ctx = - (platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace()); - dev_ctx->ResetBlobMap(); + ClearMKLDNNCache(paddle::platform::CPUPlace()); #endif return graph; } diff --git a/paddle/fluid/framework/naive_executor.cc b/paddle/fluid/framework/naive_executor.cc index a5de53e9d07..6544c1aa7b0 100644 --- a/paddle/fluid/framework/naive_executor.cc +++ b/paddle/fluid/framework/naive_executor.cc @@ -25,6 +25,9 @@ #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/variable_helper.h" #include "paddle/fluid/string/pretty_log.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif namespace paddle { namespace framework { @@ -122,14 +125,7 @@ NaiveExecutor::~NaiveExecutor() { #ifdef PADDLE_WITH_MKLDNN // Clear mkl-dnn cache, // this is needed to have mkl-dnn unit tests working - if (platform::is_cpu_place(place_)) { - platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); - platform::MKLDNNDeviceContext *dev_ctx = - (platform::MKLDNNDeviceContext *)pool.Get(place_); - dev_ctx->ResetBlobMap(); - platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( - paddle::framework::DataLayout::kNCHW); - } + ClearMKLDNNCache(place_); #endif } diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index e0c2934ab32..4d602d5c021 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,4 +1,4 @@ -cc_library(imperative_flag SRCS flags.cc DEPS gflags) +cc_library(imperative_flag SRCS flags.cc DEPS gflags) cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) diff --git a/paddle/fluid/imperative/layer.cc b/paddle/fluid/imperative/layer.cc index 3e682863795..ec76f58d77e 100644 --- a/paddle/fluid/imperative/layer.cc +++ b/paddle/fluid/imperative/layer.cc @@ -28,6 +28,11 @@ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/profiler.h" +#ifdef PADDLE_WITH_MKLDNN +#include "paddle/fluid/platform/mkldnn_helper.h" +#endif + +DECLARE_bool(use_mkldnn); namespace paddle { namespace imperative { @@ -192,6 +197,9 @@ void VarBase::ClearGradient() { auto* grad_t = grad_var_->MutableVar()->GetMutable(); if (grad_t->mutable_value()->IsInitialized()) { +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place()); +#endif grad_t->mutable_rows()->clear(); grad_t->mutable_value()->clear(); } @@ -202,6 +210,9 @@ void VarBase::ClearGradient() { auto* dev_ctx = platform::DeviceContextPool::Instance().Get(grad_t->place()); operators::math::set_constant(*dev_ctx, grad_t, 0.0); +#ifdef PADDLE_WITH_MKLDNN + if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place()); +#endif } } } diff --git a/paddle/fluid/platform/CMakeLists.txt b/paddle/fluid/platform/CMakeLists.txt index dcc3a51e72b..5a100c5746e 100644 --- a/paddle/fluid/platform/CMakeLists.txt +++ b/paddle/fluid/platform/CMakeLists.txt @@ -25,7 +25,7 @@ if (WITH_PYTHON) endif(NOT WIN32) endif() -cc_library(flags SRCS flags.cc DEPS gflags) +cc_library(flags SRCS flags.cc DEPS gflags) cc_library(errors SRCS errors.cc DEPS error_codes_proto) cc_test(errors_test SRCS errors_test.cc DEPS errors enforce) diff --git a/paddle/fluid/platform/flags.cc b/paddle/fluid/platform/flags.cc index c2af3d0e982..98bdf1f8c67 100644 --- a/paddle/fluid/platform/flags.cc +++ b/paddle/fluid/platform/flags.cc @@ -473,3 +473,13 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes "each CUDAPlace. If you don't need to limit the memory, " "you should set FLAGS_local_exe_sub_scope_limit=-1. " "The default value is 256 MBytes."); + +/** + * MKLDNN related FLAG + * Name: use_mkldnn + * Since Version: + * Value Range: bool, default=false + * Example: + * Note: + */ +DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run"); diff --git a/paddle/fluid/platform/mkldnn_helper.h b/paddle/fluid/platform/mkldnn_helper.h index 0fcb2367916..c147bdccbe9 100644 --- a/paddle/fluid/platform/mkldnn_helper.h +++ b/paddle/fluid/platform/mkldnn_helper.h @@ -117,6 +117,18 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) { return use_mkldnn && platform::is_cpu_place(ctx.GetPlace()); } +inline void ClearMKLDNNCache(const platform::Place& place) { + // Clear mkl-dnn cache, + if (platform::is_cpu_place(place)) { + platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); + platform::MKLDNNDeviceContext* dev_ctx = + (platform::MKLDNNDeviceContext*)pool.Get(place); + dev_ctx->ResetBlobMap(); + platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout( + paddle::framework::DataLayout::kNCHW); + } +} + template mkldnn::memory::data_type MKLDNNGetDataType() { return mkldnn::memory::data_type::undef; -- GitLab