未验证 提交 e52df3b1 编写于 作者: A arlesniak 提交者: GitHub

Added DNNL cache management for DyGraph (#25624)

* Added DNNL cache management for DyGraph

* move FLAGS_use_mkldnn to more general CMakeLists, getu use of the flag in ClearGradients

* missing file

* Fixes after review

* Bringing back original idea of place for 'use_mkldnn' flag to be accessible from platform nad imperative.

* Removed duplicate and added docs

* Fixes for CI
上级 650d7223
......@@ -37,9 +37,12 @@ limitations under the License. */
#include "paddle/fluid/operators/distributed/distributed.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(benchmark);
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace framework {
......@@ -83,14 +86,7 @@ Executor::~Executor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
ClearMKLDNNCache(place_);
#endif
}
......
......@@ -19,6 +19,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/platform/device_context.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace framework {
......@@ -57,10 +60,7 @@ Graph* Pass::Apply(Graph* graph) const {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// Passes can change params, tensors, so caching need to be discarded
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(paddle::platform::CPUPlace());
dev_ctx->ResetBlobMap();
ClearMKLDNNCache(paddle::platform::CPUPlace());
#endif
return graph;
}
......
......@@ -25,6 +25,9 @@
#include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/framework/variable_helper.h"
#include "paddle/fluid/string/pretty_log.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace framework {
......@@ -122,14 +125,7 @@ NaiveExecutor::~NaiveExecutor() {
#ifdef PADDLE_WITH_MKLDNN
// Clear mkl-dnn cache,
// this is needed to have mkl-dnn unit tests working
if (platform::is_cpu_place(place_)) {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext *dev_ctx =
(platform::MKLDNNDeviceContext *)pool.Get(place_);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
ClearMKLDNNCache(place_);
#endif
}
......
cc_library(imperative_flag SRCS flags.cc DEPS gflags)
cc_library(imperative_flag SRCS flags.cc DEPS gflags)
cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform)
cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry)
......
......@@ -28,6 +28,11 @@
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/profiler.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
DECLARE_bool(use_mkldnn);
namespace paddle {
namespace imperative {
......@@ -192,6 +197,9 @@ void VarBase::ClearGradient() {
auto* grad_t =
grad_var_->MutableVar()->GetMutable<framework::SelectedRows>();
if (grad_t->mutable_value()->IsInitialized()) {
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
grad_t->mutable_rows()->clear();
grad_t->mutable_value()->clear();
}
......@@ -202,6 +210,9 @@ void VarBase::ClearGradient() {
auto* dev_ctx =
platform::DeviceContextPool::Instance().Get(grad_t->place());
operators::math::set_constant(*dev_ctx, grad_t, 0.0);
#ifdef PADDLE_WITH_MKLDNN
if (FLAGS_use_mkldnn) ClearMKLDNNCache(grad_t->place());
#endif
}
}
}
......
......@@ -25,7 +25,7 @@ if (WITH_PYTHON)
endif(NOT WIN32)
endif()
cc_library(flags SRCS flags.cc DEPS gflags)
cc_library(flags SRCS flags.cc DEPS gflags)
cc_library(errors SRCS errors.cc DEPS error_codes_proto)
cc_test(errors_test SRCS errors_test.cc DEPS errors enforce)
......
......@@ -473,3 +473,13 @@ DEFINE_double(local_exe_sub_scope_limit, 256.0, // MBytes
"each CUDAPlace. If you don't need to limit the memory, "
"you should set FLAGS_local_exe_sub_scope_limit=-1. "
"The default value is 256 MBytes.");
/**
* MKLDNN related FLAG
* Name: use_mkldnn
* Since Version:
* Value Range: bool, default=false
* Example:
* Note:
*/
DEFINE_bool(use_mkldnn, false, "Use MKLDNN to run");
......@@ -117,6 +117,18 @@ inline bool CanMKLDNNBeUsed(const framework::ExecutionContext& ctx) {
return use_mkldnn && platform::is_cpu_place(ctx.GetPlace());
}
inline void ClearMKLDNNCache(const platform::Place& place) {
// Clear mkl-dnn cache,
if (platform::is_cpu_place(place)) {
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap();
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
}
template <typename Type>
mkldnn::memory::data_type MKLDNNGetDataType() {
return mkldnn::memory::data_type::undef;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册