未验证 提交 e9b4d0be 编写于 作者: Y YuanRisheng 提交者: GitHub

[Phi]Improve the mechanism for mkldnn kernel in PHI (#43941)

* adapt mkldnn kernel in PHI

* fix ci compile bugs

* fix compile bugs

* fix compile bugs

* fix compile bugs

* fix compile bugs

* delete comment

* fix compile bugs in windows-inference

* delete code for converage

* modify code by review

* modify code by review

* add todo

* fix compile bugs

* fix compile bugs

* fix compile bugs

* fix unittest bugsx
上级 1bc47c84
......@@ -103,6 +103,9 @@ function(kernel_declare TARGET_LIST)
elseif(${kernel_path} MATCHES "./kps\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, KPS, ALL_LAYOUT);\n")
elseif(${kernel_path} MATCHES "./onednn\/")
file(APPEND ${kernel_declare_file}
"PD_DECLARE_KERNEL(${kernel_name}, OneDNN, ALL_LAYOUT);\n")
else()
# deal with device independent kernel, now we use CPU temporaary
file(APPEND ${kernel_declare_file}
......
......@@ -1276,24 +1276,32 @@ bool OperatorWithKernel::SupportNPU() const {
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
VLOG(6) << "Warning: " << type_
<< " don't find its MKLDNN Kernel in Fluid "
"Registered Kernels. And We don't "
"search its kernels in phi lib, "
"SupportsMKLDNN() return false.";
return false;
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
auto has_phi_kernel =
std::any_of(phi_kernels.begin(),
phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::ONEDNN;
});
if (has_phi_kernel) {
return true;
} else {
auto op_kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (op_kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
return false;
} else {
auto& op_kernels = op_kernel_iter->second;
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ == LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
});
}
}
auto& op_kernels = op_kernel_iter->second;
return std::any_of(op_kernels.begin(),
op_kernels.end(),
[data_type](OpKernelMap::const_reference kern_pair) {
return platform::is_cpu_place(kern_pair.first.place_) &&
kern_pair.first.library_type_ ==
LibraryType::kMKLDNN &&
kern_pair.first.data_type_ == data_type;
});
}
bool OperatorWithKernel::SupportsKernelType(
......
......@@ -66,7 +66,7 @@ OpKernelType TransPhiKernelKeyToOpKernelType(const phi::KernelKey& kernel_key) {
platform::Place place = phi::TransToPhiPlace(kernel_key.backend(), false);
DataLayout data_layout = kernel_key.layout();
LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == phi::Backend::MKLDNN) {
if (kernel_key.backend() == phi::Backend::ONEDNN) {
library_type = LibraryType::kMKLDNN;
} else if (kernel_key.backend() == phi::Backend::GPUDNN) {
library_type = LibraryType::kCUDNN;
......@@ -87,7 +87,7 @@ phi::KernelKey TransOpKernelTypeToPhiKernelKey(
backend = phi::Backend::GPUDNN;
break;
case LibraryType::kMKLDNN:
backend = phi::Backend::MKLDNN;
backend = phi::Backend::ONEDNN;
break;
case LibraryType::kKP:
backend = phi::Backend::KPS;
......
......@@ -32,7 +32,7 @@ TEST(PhiUtils, TransPhiKernelKeyToOpKernelType) {
#ifdef PADDLE_WITH_MKLDNN
phi::KernelKey kernel_key_mkldnn(
phi::Backend::MKLDNN, phi::DataLayout::NCHW, phi::DataType::FLOAT32);
phi::Backend::ONEDNN, phi::DataLayout::NCHW, phi::DataType::FLOAT32);
op_kernel_type =
paddle::framework::TransPhiKernelKeyToOpKernelType(kernel_key_mkldnn);
ASSERT_EQ(op_kernel_type.data_type_, paddle::framework::proto::VarType::FP32);
......@@ -76,7 +76,7 @@ TEST(PhiUtils, TransOpKernelTypeToPhiKernelKey) {
paddle::framework::TransOpKernelTypeToPhiKernelKey(op_kernel_type_mkldnn);
ASSERT_EQ(kernel_key_mkldnn.dtype(), phi::DataType::FLOAT32);
ASSERT_EQ(kernel_key_mkldnn.layout(), phi::DataLayout::MKLDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::MKLDNN);
ASSERT_EQ(kernel_key_mkldnn.backend(), phi::Backend::ONEDNN);
#endif
#ifdef PADDLE_WITH_CUDA
......
......@@ -20,13 +20,6 @@ namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
......@@ -19,13 +19,6 @@ namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
......@@ -21,13 +21,6 @@ namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename T>
class LogSoftmaxMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> {
public:
LogSoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine,
platform::Place cpu_place,
const Tensor* x,
const int axis)
: platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
mkldnn_engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, x->mem_desc(), axis);
}
};
template <typename T>
class LogSoftmaxMKLDNNKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* x = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
axis = axis >= 0 ? axis : x->dims().size() + axis;
LogSoftmaxMKLDNNHandler<T> handler(mkldnn_engine, ctx.GetPlace(), x, axis);
auto src_memory_p = handler.AcquireSrcMemory(x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto logsoftmax_p = handler.AcquireForwardPrimitive();
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
logsoftmax_p->execute(
astream,
{{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_KERNEL(log_softmax,
MKLDNN,
::paddle::platform::CPUPlace,
ops::LogSoftmaxMKLDNNKernel<float>,
ops::LogSoftmaxMKLDNNKernel<paddle::platform::bfloat16>);
......@@ -21,13 +21,6 @@ namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
......@@ -31,13 +31,6 @@ namespace phi {
class DenseTensor;
} // namespace phi
namespace paddle {
namespace framework {} // namespace framework
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
......
......@@ -195,6 +195,7 @@ cc_library(
# memcpy depends on device_context, here add deps individually for
# avoiding cycle dependencies
cc_library(
device_context
SRCS device_context.cc
......@@ -219,12 +220,17 @@ cc_library(
${XPU_CTX_DEPS}
${MLU_CTX_DEPS}
eigen3
cpu_context
generator)
if(WITH_XPU)
target_link_libraries(device_context xpu_context xpu_resource_pool)
endif()
if(WITH_MKLDNN)
target_link_libraries(device_context onednn_context)
endif()
target_link_libraries(device_context cpu_context)
cc_library(
collective_helper
SRCS collective_helper.cc gen_comm_id_helper.cc
......
......@@ -753,275 +753,6 @@ Eigen::DefaultDevice* CUDAPinnedDeviceContext::eigen_device() const {
const Place& CUDAPinnedDeviceContext::GetPlace() const { return place_; }
#endif
#ifdef PADDLE_WITH_MKLDNN
MKLDNNDeviceContext::MKLDNNDeviceContext(CPUPlace place)
: phi::CPUContext(place), p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_exec_items_.reset(new ExecShape());
p_mutex_.reset(new std::mutex());
}
MKLDNNDeviceContextThreadLocals::Body::Body()
: cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
cur_input_shape_str = "";
cur_input_shape_cache_capacity = 1;
cur_paddle_data_layout = paddle::framework::DataLayout::kNCHW;
}
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
MKLDNNDeviceContextThreadLocals::Body::~Body() {
auto cpu_place = paddle::platform::CPUPlace();
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(cpu_place);
dev_ctx->ResetBlobMap(exec_ptr_);
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_mkldnn_session_id(
size_t sid) {
cur_mkldnn_session_id = sid;
}
size_t MKLDNNDeviceContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
return cur_mkldnn_session_id;
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_str(
std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void MKLDNNDeviceContextThreadLocals::Body::set_cur_paddle_data_layout(
framework::DataLayout dl) {
cur_paddle_data_layout = dl;
}
framework::DataLayout
MKLDNNDeviceContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
void MKLDNNDeviceContextThreadLocals::Body::log_lib_version(void) {
if (!said_once) {
said_once = true;
auto dv = dnnl::version();
LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
<< dv->patch;
}
}
const dnnl::engine& MKLDNNDeviceContextThreadLocals::Body::get_engine(void) {
return cur_engine;
}
dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
return cur_stream;
}
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
VLOG(4) << tls().get_curr_exec() << " " << ptr;
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (block_next_cache_clearing_ == 0) {
VLOG(3) << "Clearing DNNL cache.";
// If no specific executor pointer then clear
// everything. For executor pointer then clear only
// objects allocated when using given executor
if (ptr == nullptr) {
p_blobmap_->clear();
} else {
// Iterate through all shapes and release
// for each shape and active executor all entries
// of this executor
for (auto& s : *p_exec_items_) {
for (auto& v : (*s.second)[ptr]) {
(v.first)->erase(v.second);
}
s.second->erase(ptr);
}
}
// Reset paddle layout to NCHW
VLOG(3) << "Resetting Paddle data layout to NCHW.";
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
} else {
--block_next_cache_clearing_;
VLOG(3) << "Prevented Clearing DNNL cache. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
PADDLE_ENFORCE_GE(block_next_cache_clearing_,
0,
platform::errors::InvalidArgument(
"Cache clearing mark should be non-negative "
". But received %d.",
block_next_cache_clearing_));
}
}
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const {
p_exec_items_->erase(p_exec_items_->begin());
}
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
KeyBlob::iterator it) const {
// Take current input shape from TLS
// Take current executor addess from TLS
// and for this executor's items add the one defined with arguments
auto key_it = p_exec_items_
->insert(std::make_pair(tls().cur_input_shape_str,
std::make_shared<ExecMap>()))
.first;
(*key_it->second)[tls().get_curr_exec()].push_back(std::make_pair(pblob, it));
VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
<< " curr exec size: "
<< (*key_it->second)[tls().get_curr_exec()].size() << "\n";
}
void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
++block_next_cache_clearing_;
VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
}
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(tls().cur_mkldnn_session_id);
if (map_it == pMap->end()) {
PADDLE_THROW(platform::errors::NotFound(
"MKLDNNDeviceContext don't find cur_mkldnn_session_id: %d.",
tls().cur_mkldnn_session_id));
}
return map_it->second->size();
}
void MKLDNNDeviceContext::SetBlob(const std::string& name,
BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
sBlob = std::make_shared<ShapeBlob>();
(*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else {
sBlob = map_it->second;
}
// Find KeyBlob for current input shape
auto key_it = sBlob->find(tls().cur_input_shape_str);
if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((static_cast<size_t>(sid) ==
MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
sBlob->size() &&
(sBlob->size() >=
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
RemoveShapeEntriesWithExecutor();
}
pBlob = std::make_shared<KeyBlob>();
(*sBlob)[tls().cur_input_shape_str] = pBlob;
} else {
pBlob = key_it->second;
}
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope
return;
}
unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
unsigned int num_entries = 0;
for (auto const& l3 : *p_blobmap_) {
for (auto const& l2 : *(l3.second)) {
num_entries += (l2.second)->size();
}
}
return num_entries;
}
MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id();
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid);
// (jczaja): After first iteration of model's execution we
// should have all elements cached (mostly) so failures are unlikely (less
// likely for dynamic shapes)
if (unlikely(map_it == pMap->end())) {
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr;
}
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(tls().cur_input_shape_str);
if (unlikely(sBlob_it == sBlob->end())) {
VLOG(2) << "GetBlob: sid=" << tls().cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
pBlob = sBlob_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (unlikely(key_it == pBlob->end())) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr;
}
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
return key_it->second;
}
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
CustomDeviceContext::CustomDeviceContext(CustomPlace place)
: phi::CustomContext(place) {
......
......@@ -59,6 +59,7 @@ limitations under the License. */
#ifdef PADDLE_WITH_MKLDNN
#include "dnnl.hpp" // NOLINT
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#endif
#include <map>
......@@ -716,132 +717,8 @@ struct DefaultDeviceContextType<platform::CUDAPinnedPlace> {
#endif
#ifdef PADDLE_WITH_MKLDNN
class MKLDNNDeviceContextThreadLocals {
// default mkldnn session id
typedef MKLDNNDeviceContextThreadLocals self;
struct Body {
bool said_once = false;
size_t cur_mkldnn_session_id;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
std::string cur_input_shape_str;
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
int cur_input_shape_cache_capacity;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
paddle::framework::DataLayout cur_paddle_data_layout;
// MKL-DNN stream used for execution of primitives (per-thread)
dnnl::engine cur_engine;
dnnl::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
void* exec_ptr_ = nullptr;
Body();
~Body();
void set_cur_mkldnn_session_id(size_t sid);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(framework::DataLayout dl);
framework::DataLayout get_cur_paddle_data_layout(void);
void log_lib_version(void);
const dnnl::engine& get_engine(void);
dnnl::stream& get_stream(void);
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
void* get_curr_exec(void) const { return exec_ptr_; }
};
MKLDNNDeviceContextThreadLocals() = default;
MKLDNNDeviceContextThreadLocals(const MKLDNNDeviceContextThreadLocals& c) =
delete;
public:
// default mkldnn session id
static constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
static Body& fetch() {
thread_local Body b;
return b;
}
};
class MKLDNNDeviceContext : public phi::CPUContext {
public:
template <class T>
using BlobPtr_t = std::shared_ptr<T>;
template <class P1, class P2>
using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
template <class T>
using umap_key_string_t = umap_value_smart_t<std::string, T>;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
// Auxillary two-level structure (shape, executor) to easier control
// clearing cache objects related to specific executor
using ExecKey = void*;
using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
using ExecMap =
std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
explicit MKLDNNDeviceContext(CPUPlace place);
/* \brief Get the active engine */
const dnnl::engine& GetEngine() const { return tls().get_engine(); }
// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const;
void RemoveShapeEntriesWithExecutor(void) const;
// Remove all entries from the blob map
void ResetBlobMap(void* ptr);
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;
// Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
// Calculate number of oneDNN objects cached
unsigned int GetCachedObjectsNumber(void) const;
// Find a saved blob. Return nullptr if not found
std::shared_ptr<void> GetBlob(const std::string& name) const;
static auto tls() -> decltype(MKLDNNDeviceContextThreadLocals::fetch()) {
return MKLDNNDeviceContextThreadLocals::fetch();
}
private:
std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
std::shared_ptr<ExecShape> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_;
// 0 - clearing is allowed. x > 0 do not clear.
unsigned int block_next_cache_clearing_ = 0;
};
using MKLDNNDeviceContextThreadLocals = phi::OneDNNContextThreadLocals;
using MKLDNNDeviceContext = phi::OneDNNContext;
#endif
#ifdef PADDLE_WITH_CUSTOM_DEVICE
......
......@@ -57,7 +57,7 @@ BackendSet GetTensorBackendSet(const phi::TensorBase& t) {
BackendSet backend_set(phi::TransToPhiBackend(t.place()));
switch (t.layout()) {
case DataLayout::MKLDNN:
backend_set = backend_set | BackendSet(Backend::MKLDNN);
backend_set = backend_set | BackendSet(Backend::ONEDNN);
break;
default:
// do nothing
......
......@@ -12,6 +12,10 @@ if(WITH_XPU)
add_subdirectory(xpu)
endif()
if(WITH_MKLDNN)
add_subdirectory(onednn)
endif()
cc_library(
phi_context
SRCS all_context.cc
......
......@@ -14,8 +14,8 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
// NOTE: The paddle framework should add WITH_EIGEN option to support compile
// without eigen.
......@@ -41,7 +41,10 @@ struct CPUContext::Impl {
}
Eigen::DefaultDevice* GetEigenDevice() const {
PD_CHECK(eigen_device_ != nullptr, "the cpu eigen_device is nullptr.");
PADDLE_ENFORCE_NE(
eigen_device_,
nullptr,
phi::errors::Unavailable("the cpu eigen_device is nullptr."));
return eigen_device_;
}
......
if(WITH_MKLDNN)
cc_library(
onednn_context
SRCS onednn_context.cc
DEPS cpu_context mkldnn)
endif()
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/fluid/framework/expect.h"
#include "paddle/fluid/platform/device_context.h"
namespace phi {
OneDNNContextThreadLocals::Body::Body()
: cur_engine(dnnl::engine::kind::cpu, 0), cur_stream(cur_engine) {
cur_mkldnn_session_id = kMKLDNNSessionID_Default;
cur_input_shape_str = "";
cur_input_shape_cache_capacity = 1;
cur_paddle_data_layout = DataLayout::kNCHW;
}
// When Thread finish we clear oneDNN cache
// This is needed when we have one executor used by many threads
// e.g. test_analyzer_detect. Thread ID is not part of caching key
// (for naive executor) so we need to clear cache when one thread finish
// and other is to start inference
// TODO(jczaja): Ideally it would be good to clear only part of cache
// related to thread that is to be terminated
OneDNNContextThreadLocals::Body::~Body() {
auto cpu_place = phi::CPUPlace();
// TODO(YuanRisheng): we need remove the dependency on fluid device context
// here
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
OneDNNContext* dev_ctx = static_cast<OneDNNContext*>(pool.Get(cpu_place));
dev_ctx->ResetBlobMap(exec_ptr_);
}
void OneDNNContextThreadLocals::Body::set_cur_mkldnn_session_id(size_t sid) {
cur_mkldnn_session_id = sid;
}
size_t OneDNNContextThreadLocals::Body::get_cur_mkldnn_session_id(void) {
return cur_mkldnn_session_id;
}
void OneDNNContextThreadLocals::Body::set_cur_input_shape_str(
std::string input_shape_str) {
cur_input_shape_str = input_shape_str;
}
void OneDNNContextThreadLocals::Body::set_cur_input_shape_cache_capacity(
int input_shape_cache_capacity) {
cur_input_shape_cache_capacity = input_shape_cache_capacity;
}
void OneDNNContextThreadLocals::Body::set_cur_paddle_data_layout(
DataLayout dl) {
cur_paddle_data_layout = dl;
}
DataLayout OneDNNContextThreadLocals::Body::get_cur_paddle_data_layout(void) {
return cur_paddle_data_layout;
}
void OneDNNContextThreadLocals::Body::log_lib_version(void) {
if (!said_once) {
said_once = true;
auto dv = dnnl::version();
LOG(INFO) << "oneDNN v" << dv->major << "." << dv->minor << "."
<< dv->patch;
}
}
struct OneDNNContext::Impl {
Impl() : p_blobmap_() {
p_blobmap_.reset(new BlobMap());
p_exec_items_.reset(new ExecShape());
p_mutex_.reset(new std::mutex());
}
~Impl() {}
void ResetBlobMap(void* ptr) {
VLOG(4) << OneDNNContext::tls().get_curr_exec() << " " << ptr;
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (block_next_cache_clearing_ == 0) {
VLOG(3) << "Clearing DNNL cache.";
// If no specific executor pointer then clear
// everything. For executor pointer then clear only
// objects allocated when using given executor
if (ptr == nullptr) {
p_blobmap_->clear();
} else {
// Iterate through all shapes and release
// for each shape and active executor all entries
// of this executor
for (auto& s : *p_exec_items_) {
for (auto& v : (*s.second)[ptr]) {
(v.first)->erase(v.second);
}
s.second->erase(ptr);
}
}
// Reset paddle layout to NCHW
VLOG(3) << "Resetting Paddle data layout to NCHW.";
OneDNNContext::tls().set_cur_paddle_data_layout(DataLayout::kNCHW);
} else {
--block_next_cache_clearing_;
VLOG(3) << "Prevented Clearing DNNL cache. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
PADDLE_ENFORCE_GE(block_next_cache_clearing_,
0,
phi::errors::InvalidArgument(
"Cache clearing mark should be non-negative "
". But received %d.",
block_next_cache_clearing_));
}
}
// Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
KeyBlob::iterator it) const {
// Take current input shape from TLS
// Take current executor addess from TLS
// and for this executor's items add the one defined with arguments
auto key_it =
p_exec_items_
->insert(std::make_pair(OneDNNContext::tls().cur_input_shape_str,
std::make_shared<ExecMap>()))
.first;
(*key_it->second)[OneDNNContext::tls().get_curr_exec()].push_back(
std::make_pair(pblob, it));
VLOG(3) << "LinkEntryWithExecutor, shapes: " << p_exec_items_->size()
<< " curr exec size: "
<< (*key_it->second)[OneDNNContext::tls().get_curr_exec()].size()
<< "\n";
}
void RemoveShapeEntriesWithExecutor() const {
p_exec_items_->erase(p_exec_items_->begin());
}
void BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
++block_next_cache_clearing_;
VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
}
size_t GetShapeBlobSize() const {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
BlobMap* pMap = p_blobmap_.get();
auto map_it = pMap->find(OneDNNContext::tls().cur_mkldnn_session_id);
if (map_it == pMap->end()) {
PADDLE_THROW(phi::errors::NotFound(
"OneDNNContext don't find cur_mkldnn_session_id: %d.",
OneDNNContext::tls().cur_mkldnn_session_id));
}
return map_it->second->size();
}
void SetBlob(const std::string& name, BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = OneDNNContext::tls().get_cur_mkldnn_session_id();
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id.
auto map_it = pMap->find(sid);
if (map_it == pMap->end()) {
// 1st time to set blob in current thread
sBlob = std::make_shared<ShapeBlob>();
(*pMap)[sid] = sBlob;
VLOG(2) << "SetBlob: sid=" << sid << ", add new sid\n";
} else {
sBlob = map_it->second;
}
// Find KeyBlob for current input shape
auto key_it = sBlob->find(OneDNNContext::tls().cur_input_shape_str);
if (key_it == sBlob->end()) {
// In cache clearing mode, cur_input_shape_cache_capacity defines
// max pblob capacity
if ((static_cast<size_t>(sid) ==
OneDNNContextThreadLocals::kMKLDNNSessionID_CacheClearing) &&
sBlob->size() &&
(sBlob->size() >=
static_cast<size_t>(
OneDNNContext::tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid
<< ", remove all blobs of shape: " << sBlob->begin()->first;
sBlob->erase(sBlob->begin()->first);
RemoveShapeEntriesWithExecutor();
}
pBlob = std::make_shared<KeyBlob>();
(*sBlob)[OneDNNContext::tls().cur_input_shape_str] = pBlob;
} else {
pBlob = key_it->second;
}
// Find Blob via name
auto blob_it = pBlob->find(name);
if (blob_it == pBlob->end()) {
auto el =
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map
// to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first);
} else {
blob_it->second = data; // set data to existing blob
}
VLOG(2) << "SetBlob: sid=" << sid << ", add blob=" << name << "\n";
// lock will be automatically released when out of scope
return;
}
unsigned int GetCachedObjectsNumber(void) const {
unsigned int num_entries = 0;
for (auto const& l3 : *p_blobmap_) {
for (auto const& l2 : *(l3.second)) {
num_entries += (l2.second)->size();
}
}
return num_entries;
}
OneDNNContext::BlobPtr_t<void> GetBlob(const std::string& name) const {
BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr;
int sid = OneDNNContext::tls().get_cur_mkldnn_session_id();
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
// Find ShapeBlob for current mkldnn session id firstly
auto map_it = pMap->find(sid);
// (jczaja): After first iteration of model's execution we
// should have all elements cached (mostly) so failures are unlikely (less
// likely for dynamic shapes)
if (unlikely(map_it == pMap->end())) {
VLOG(2) << "GetBlob: sid=" << sid << ", miss sid\n";
return nullptr;
}
sBlob = map_it->second;
// Find KeyBlob for current input shape secondly
auto sBlob_it = sBlob->find(OneDNNContext::tls().cur_input_shape_str);
if (unlikely(sBlob_it == sBlob->end())) {
VLOG(2) << "GetBlob: sid=" << OneDNNContext::tls().cur_input_shape_str
<< ", miss input_shape_str\n";
return nullptr;
}
pBlob = sBlob_it->second;
// Find Blob via name
auto key_it = pBlob->find(name);
if (unlikely(key_it == pBlob->end())) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr;
}
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope
return key_it->second;
}
std::shared_ptr<BlobMap> p_blobmap_;
// Map key is pointer of executor and value is a data(iterator in map) needed
// to erase
std::shared_ptr<ExecShape> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_;
// 0 - clearing is allowed. x > 0 do not clear.
unsigned int block_next_cache_clearing_ = 0;
};
OneDNNContext::OneDNNContext(const Place& place)
: CPUContext(place), impl_(std::make_unique<Impl>()) {}
OneDNNContext::~OneDNNContext() = default;
void OneDNNContext::ResetBlobMap(void* ptr) { impl_->ResetBlobMap(ptr); }
void OneDNNContext::BlockNextCacheClearing() {
impl_->BlockNextCacheClearing();
}
size_t OneDNNContext::GetShapeBlobSize() const {
return impl_->GetShapeBlobSize();
}
void OneDNNContext::SetBlob(const std::string& name,
BlobPtr_t<void> data) const {
impl_->SetBlob(name, data);
}
unsigned int OneDNNContext::GetCachedObjectsNumber(void) const {
return impl_->GetCachedObjectsNumber();
}
OneDNNContext::BlobPtr_t<void> OneDNNContext::GetBlob(
const std::string& name) const {
return impl_->GetBlob(name);
}
} // namespace phi
#endif
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_MKLDNN
#include <memory>
#include <mutex> // NOLINT
#include "dnnl.hpp" // NOLINT
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/common/place.h"
namespace phi {
class OneDNNContextThreadLocals {
// default mkldnn session id
typedef OneDNNContextThreadLocals self;
struct Body {
bool said_once = false;
size_t cur_mkldnn_session_id;
// Current data input shape string.
// - For fixed-shape, it's a null string in default.
// - For dynamic-shape, it's user specific.
std::string cur_input_shape_str;
// the cache capacity of different input shapes for MKLDNN.
// Default 1 means fixed input shape, not dynamic shape.
int cur_input_shape_cache_capacity;
// Recently registered data_format. This is needed to
// know for converting MKL-DNN Tensor to non MKL-DNN
DataLayout cur_paddle_data_layout;
// MKL-DNN stream used for execution of primitives (per-thread)
dnnl::engine cur_engine;
dnnl::stream cur_stream;
std::string key_suffix; // Key identifying current Executor
bool key_attach_thread_id = true;
void* exec_ptr_ = nullptr;
Body();
~Body();
void set_cur_mkldnn_session_id(size_t sid);
size_t get_cur_mkldnn_session_id(void);
void set_cur_input_shape_str(std::string input_shape_str);
void set_cur_input_shape_cache_capacity(int input_shape_cache_capacity);
void set_cur_paddle_data_layout(DataLayout dl);
DataLayout get_cur_paddle_data_layout(void);
void log_lib_version(void);
const dnnl::engine& get_engine(void) { return cur_engine; }
dnnl::stream& get_stream(void) { return cur_stream; }
void set_key_suffix(const std::string& suffix) { key_suffix = suffix; }
const std::string& get_key_suffix(void) const { return key_suffix; }
void disable_tid_in_key(void) { key_attach_thread_id = false; }
bool is_tid_used_in_key(void) const { return key_attach_thread_id; }
void set_curr_exec(void* exec_ptr) { exec_ptr_ = exec_ptr; }
void* get_curr_exec(void) const { return exec_ptr_; }
};
OneDNNContextThreadLocals() = default;
OneDNNContextThreadLocals(const OneDNNContextThreadLocals& c) = delete;
public:
// default mkldnn session id
static constexpr size_t kMKLDNNSessionID_Default = 0;
// mkldnn session id for cache clearing mode
static constexpr size_t kMKLDNNSessionID_CacheClearing = -1;
static Body& fetch() {
thread_local Body b;
return b;
}
};
class OneDNNContext : public CPUContext {
public:
template <class T>
using BlobPtr_t = std::shared_ptr<T>;
template <class P1, class P2>
using umap_value_smart_t = std::unordered_map<P1, BlobPtr_t<P2>>;
template <class T>
using umap_key_string_t = umap_value_smart_t<std::string, T>;
// Following three maps are used to cache MKLDNN primitives.
// There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob>
// - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>;
// Auxillary two-level structure (shape, executor) to easier control
// clearing cache objects related to specific executor
using ExecKey = void*;
using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>;
using ExecMap =
std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
explicit OneDNNContext(const Place& place);
~OneDNNContext();
/* \brief Get the active engine */
const dnnl::engine& GetEngine() const { return tls().get_engine(); }
// Remove all entries from the blob map
void ResetBlobMap(void* ptr);
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
// Get the ShapeBlob size in cur_mkldnn_session_id.
size_t GetShapeBlobSize() const;
// Set data to blob (i.e. name/data pair). Create blob if not existing
void SetBlob(const std::string& name, std::shared_ptr<void> data) const;
// Calculate number of oneDNN objects cached
unsigned int GetCachedObjectsNumber(void) const;
// Find a saved blob. Return nullptr if not found
std::shared_ptr<void> GetBlob(const std::string& name) const;
static auto tls() -> decltype(OneDNNContextThreadLocals::fetch()) {
return OneDNNContextThreadLocals::fetch();
}
private:
struct Impl;
std::unique_ptr<Impl> impl_;
};
} // namespace phi
#endif
......@@ -50,7 +50,7 @@ enum class Backend : uint8_t {
MLU, // MLU currently does not exist at the same time as CUDA
// the third library backend
MKLDNN,
ONEDNN,
GPUDNN, // cuDNN and hipDNN
// paddle kernel primitives backend
......@@ -118,8 +118,8 @@ inline std::ostream& operator<<(std::ostream& os, Backend backend) {
case Backend::MLU:
os << "MLU";
break;
case Backend::MKLDNN:
os << "MKLDNN";
case Backend::ONEDNN:
os << "ONEDNN";
break;
case Backend::GPUDNN:
os << "GPUDNN";
......@@ -160,8 +160,8 @@ inline Backend StringToBackend(const char* backend_cstr) {
return Backend::NPU;
} else if (s == std::string("MLU")) {
return Backend::MLU;
} else if (s == std::string("MKLDNN")) {
return Backend::MKLDNN;
} else if (s == std::string("OneDNN")) {
return Backend::ONEDNN;
} else if (s == std::string("GPUDNN")) {
return Backend::GPUDNN;
} else if (s == std::string("KPS")) {
......
......@@ -66,7 +66,7 @@ phi::Place TransToPhiPlace(const Backend& backend, bool set_device_id) {
set_device_id ? phi::backends::gpu::GetCurrentDeviceId() : 0);
#endif
#ifdef PADDLE_WITH_MKLDNN
case phi::Backend::MKLDNN:
case phi::Backend::ONEDNN:
return phi::CPUPlace();
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......
......@@ -46,9 +46,17 @@ const Kernel& KernelFactory::SelectKernel(const std::string& kernel_name,
return empty_kernel;
}
auto kernel_iter = iter->second.find(kernel_key);
if (kernel_iter == iter->second.end() &&
kernel_key.layout() != phi::DataLayout::ALL_LAYOUT) {
phi::KernelKey any_layout_kernel_key(
kernel_key.backend(), phi::DataLayout::ALL_LAYOUT, kernel_key.dtype());
kernel_iter = iter->second.find(any_layout_kernel_key);
}
if (kernel_iter == iter->second.end()) {
return empty_kernel;
}
return kernel_iter->second;
}
......
......@@ -56,6 +56,9 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
auto args_type = ParseArgType(Indices{});
for (auto arg_type : args_type) {
if (arg_type == std::type_index(typeid(const CPUContext&))
#if defined(PADDLE_WITH_MKLDNN)
|| arg_type == std::type_index(typeid(const OneDNNContext&))
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
|| arg_type == std::type_index(typeid(const GPUContext&))) {
#elif defined(PADDLE_WITH_XPU)
......@@ -63,6 +66,7 @@ struct KernelArgsParseFunctor<Return_ (*)(Args_...)> {
#elif defined(PADDLE_WITH_CUSTOM_DEVICE)
|| arg_type == std::type_index(typeid(const CustomContext&))) {
#else
) {
#endif
// do nothing, skip context arg now
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/custom/custom_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/common/int_array.h"
#include "paddle/phi/common/scalar.h"
......@@ -257,7 +258,9 @@ struct KernelImpl<Return (*)(DevCtx, Args...), kernel_fn> {
#ifdef PADDLE_WITH_CUSTOM_DEVICE
PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(CustomContext);
#endif
#ifdef PADDLE_WITH_MKLDNN
PD_SPECIALIZE_KernelCallHelper_FOR_DEVICE_CONTEXT(OneDNNContext);
#endif
/* Input Helpers */
PD_SPECIALIZE_KernelCallHelper_FOR_INPUT(DenseTensor);
......
......@@ -113,11 +113,13 @@ file(
# file(GLOB kernel_cudnn "gpudnn/*.cu")
# file(GLOB kernel_kps "kps/*.cu")
file(GLOB kernel_xpu "xpu/*.cc")
file(GLOB kernel_onednn "onednn/*.cc")
add_library(phi_cpu ${kernel_cc})
kernel_declare("${kernel_cc}")
target_link_libraries(phi_cpu ${COMMON_KERNEL_DEPS})
set_property(GLOBAL PROPERTY PHI_KERNELS phi_cpu)
set(ADD_PHI_KERNELS phi_cpu)
if(WITH_GPU OR WITH_ROCM)
if(WITH_GPU)
......@@ -127,7 +129,7 @@ if(WITH_GPU OR WITH_ROCM)
endif()
kernel_declare("${kernel_cu}")
target_link_libraries(phi_gpu ${COMMON_KERNEL_DEPS})
set_property(GLOBAL PROPERTY PHI_KERNELS phi_cpu phi_gpu)
set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_gpu)
endif()
if(WITH_XPU)
......@@ -148,5 +150,15 @@ if(WITH_XPU)
kernel_declare("${kernel_xpu}")
kernel_declare("${kernel_xpu_kps}")
target_link_libraries(phi_xpu ${COMMON_KERNEL_DEPS})
set_property(GLOBAL PROPERTY PHI_KERNELS phi_cpu phi_xpu)
set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_xpu)
endif()
if(WITH_MKLDNN)
add_library(phi_onednn ${kernel_onednn})
kernel_declare(${kernel_onednn})
set(COMMON_KERNEL_DEPS ${COMMON_KERNEL_DEPS} onednn_context)
target_link_libraries(phi_onednn ${COMMON_KERNEL_DEPS})
set(ADD_PHI_KERNELS ${ADD_PHI_KERNELS} phi_onednn)
endif()
set_property(GLOBAL PROPERTY PHI_KERNELS ${ADD_PHI_KERNELS})
......@@ -116,5 +116,7 @@ void LogSoftmaxKernel(const Context& dev_ctx,
} // namespace phi
// TODO(YuanRisheng): The layout of mkldnn kernel should be MKLDNN, we should
// support specifying the exact layout when the kernel is registered
PD_REGISTER_KERNEL(
log_softmax, CPU, ALL_LAYOUT, phi::LogSoftmaxKernel, float, double) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/log_softmax_kernel.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/phi/backends/onednn/onednn_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T>
class LogSoftmaxMKLDNNHandler
: public paddle::platform::
MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward> {
public:
LogSoftmaxMKLDNNHandler(const dnnl::engine mkldnn_engine,
Place cpu_place,
const DenseTensor& x,
const int axis)
: paddle::platform::MKLDNNHandlerNoCachingT<T, dnnl::logsoftmax_forward>(
mkldnn_engine, cpu_place) {
this->AcquireForwardPrimitiveDescriptor(
dnnl::prop_kind::forward_inference, x.mem_desc(), axis);
}
};
template <typename T, typename Context>
void LogSoftmaxKernel(const Context& dev_ctx,
const DenseTensor& x,
int axis,
DenseTensor* out) {
const auto& mkldnn_engine = dev_ctx.GetEngine();
axis = axis >= 0 ? axis : x.dims().size() + axis;
LogSoftmaxMKLDNNHandler<T> handler(
mkldnn_engine, dev_ctx.GetPlace(), x, axis);
auto src_memory_p = handler.AcquireSrcMemory(&x);
auto dst_memory_p = handler.AcquireDstMemory(out);
auto logsoftmax_p = handler.AcquireForwardPrimitive();
auto& astream = OneDNNContext::tls().get_stream();
logsoftmax_p->execute(
astream, {{DNNL_ARG_SRC, *src_memory_p}, {DNNL_ARG_DST, *dst_memory_p}});
astream.wait();
out->set_mem_desc(dst_memory_p->get_desc());
}
} // namespace phi
PD_REGISTER_KERNEL(log_softmax,
OneDNN,
ALL_LAYOUT,
phi::LogSoftmaxKernel,
float,
phi::dtype::bfloat16) {}
......@@ -39,8 +39,8 @@ TEST(Backend, OStream) {
oss << phi::Backend::NPU;
EXPECT_EQ(oss.str(), "NPU");
oss.str("");
oss << phi::Backend::MKLDNN;
EXPECT_EQ(oss.str(), "MKLDNN");
oss << phi::Backend::ONEDNN;
EXPECT_EQ(oss.str(), "ONEDNN");
oss.str("");
oss << phi::Backend::GPUDNN;
EXPECT_EQ(oss.str(), "GPUDNN");
......@@ -63,7 +63,7 @@ TEST(Backend, StringToBackend) {
EXPECT_EQ(phi::Backend::GPU, pexp::StringToBackend("GPU"));
EXPECT_EQ(phi::Backend::XPU, pexp::StringToBackend("XPU"));
EXPECT_EQ(phi::Backend::NPU, pexp::StringToBackend("NPU"));
EXPECT_EQ(phi::Backend::MKLDNN, pexp::StringToBackend("MKLDNN"));
EXPECT_EQ(phi::Backend::ONEDNN, pexp::StringToBackend("OneDNN"));
EXPECT_EQ(phi::Backend::GPUDNN, pexp::StringToBackend("GPUDNN"));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
EXPECT_EQ(phi::Backend::GPU, pexp::StringToBackend("KPS"));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册